222 lines
5.6 KiB
Python
222 lines
5.6 KiB
Python
"""Base API client with common functionality."""
|
|
|
|
import os
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Iterator
|
|
from typing import Any
|
|
|
|
|
|
class APIClientError(Exception):
|
|
"""Base exception for API client errors."""
|
|
|
|
pass
|
|
|
|
|
|
class RateLimitError(APIClientError):
|
|
"""Exception raised when API rate limit is exceeded."""
|
|
|
|
pass
|
|
|
|
|
|
class AuthenticationError(APIClientError):
|
|
"""Exception raised for authentication failures."""
|
|
|
|
pass
|
|
|
|
|
|
class BaseAPIClient(ABC):
|
|
"""Abstract base class for API clients."""
|
|
|
|
def __init__(self, repo: str):
|
|
"""Initialize the API client.
|
|
|
|
Args:
|
|
repo: Repository identifier (owner/repo format).
|
|
"""
|
|
self.repo = repo
|
|
self.owner, self.repo_name = self._parse_repo(repo)
|
|
self.max_retries = 3
|
|
self.retry_delay = 1.0
|
|
|
|
def _parse_repo(self, repo: str) -> tuple[str, str]:
|
|
"""Parse repository identifier into owner and name.
|
|
|
|
Args:
|
|
repo: Repository identifier (owner/repo or just repo name).
|
|
|
|
Returns:
|
|
Tuple of (owner, repo_name).
|
|
"""
|
|
if "/" in repo:
|
|
parts = repo.split("/")
|
|
return parts[0], parts[-1]
|
|
return "", repo
|
|
|
|
def _get_token(self, token_env: str) -> str | None:
|
|
"""Get API token from environment variable.
|
|
|
|
Args:
|
|
token_env: Environment variable name for the token.
|
|
|
|
Returns:
|
|
Token string or None if not found.
|
|
"""
|
|
return os.getenv(token_env)
|
|
|
|
@abstractmethod
|
|
def get_current_user(self) -> str:
|
|
"""Get the authenticated username.
|
|
|
|
Returns:
|
|
Username string.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_repository(self) -> dict:
|
|
"""Get repository information.
|
|
|
|
Returns:
|
|
Repository data dictionary.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_pull_requests(self, state: str = "open") -> list[dict]:
|
|
"""Get pull requests.
|
|
|
|
Args:
|
|
state: PR state (open, closed, all).
|
|
|
|
Returns:
|
|
List of PR data dictionaries.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_issues(self, state: str = "open") -> list[dict]:
|
|
"""Get issues.
|
|
|
|
Args:
|
|
state: Issue state (open, closed, all).
|
|
|
|
Returns:
|
|
List of issue data dictionaries.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_workflows(self) -> list[dict]:
|
|
"""Get workflows.
|
|
|
|
Returns:
|
|
List of workflow data dictionaries.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_workflow_runs(self, limit: int = 10) -> list[dict]:
|
|
"""Get workflow runs.
|
|
|
|
Args:
|
|
limit: Maximum number of runs to return.
|
|
|
|
Returns:
|
|
List of workflow run data dictionaries.
|
|
"""
|
|
pass
|
|
|
|
def _check_response_status(self, response: Any) -> None:
|
|
"""Check response status and raise appropriate errors.
|
|
|
|
Args:
|
|
response: Response object to check.
|
|
|
|
Raises:
|
|
AuthenticationError: On 401 status.
|
|
RateLimitError: On 403 with rate limit message.
|
|
APIClientError: On 404 or other client errors.
|
|
"""
|
|
if response.status_code == 401:
|
|
raise AuthenticationError("Authentication failed. Check your API token.")
|
|
|
|
if response.status_code == 403:
|
|
if "rate limit" in response.text.lower():
|
|
raise RateLimitError("API rate limit exceeded.")
|
|
|
|
if response.status_code == 404:
|
|
raise APIClientError("Resource not found")
|
|
|
|
response.raise_for_status()
|
|
|
|
def _request_with_retry(
|
|
self,
|
|
method: str,
|
|
url: str,
|
|
**kwargs,
|
|
) -> dict:
|
|
"""Make an API request with retry logic.
|
|
|
|
Args:
|
|
method: HTTP method.
|
|
url: Request URL.
|
|
**kwargs: Additional request arguments.
|
|
|
|
Returns:
|
|
Response JSON data.
|
|
"""
|
|
last_error = None
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
response = self._make_request(method, url, **kwargs)
|
|
self._check_response_status(response)
|
|
return response.json()
|
|
|
|
except RateLimitError:
|
|
wait_time = self.retry_delay * (2 ** attempt)
|
|
time.sleep(wait_time)
|
|
except AuthenticationError:
|
|
raise
|
|
except APIClientError:
|
|
if attempt < self.max_retries - 1:
|
|
wait_time = self.retry_delay * (2 ** attempt)
|
|
time.sleep(wait_time)
|
|
else:
|
|
raise
|
|
except Exception as e:
|
|
last_error = e
|
|
if attempt < self.max_retries - 1:
|
|
wait_time = self.retry_delay * (2 ** attempt)
|
|
time.sleep(wait_time)
|
|
|
|
raise APIClientError(f"Request failed after {self.max_retries} attempts: {last_error}")
|
|
|
|
@abstractmethod
|
|
def _make_request(self, method: str, url: str, **kwargs) -> Any:
|
|
"""Make an HTTP request.
|
|
|
|
Args:
|
|
method: HTTP method.
|
|
url: Request URL.
|
|
**kwargs: Additional request arguments.
|
|
|
|
Returns:
|
|
Response object.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _paginate(self, url: str, **kwargs) -> Iterator[dict]:
|
|
"""Iterate over paginated API results.
|
|
|
|
Args:
|
|
url: API endpoint URL.
|
|
**kwargs: Additional request arguments.
|
|
|
|
Yields:
|
|
Resource dictionaries from each page.
|
|
"""
|
|
pass
|