fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:20 +00:00
parent eabd05b6c4
commit fa7365ca37

View File

@@ -1,26 +1,25 @@
"""Base provider interface."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict, Optional from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional
@dataclass
class ProviderResponse: class ProviderResponse:
def __init__( """Response from an LLM provider."""
self,
content: str, content: str
model: str, model: str
provider: str, provider: str
usage: Optional[Dict[str, Any]] = None, usage: Dict[str, int] = field(default_factory=dict)
latency_ms: float = 0.0, latency_ms: float = 0.0
metadata: Optional[Dict[str, Any]] = None, metadata: Dict[str, Any] = field(default_factory=dict)
):
self.content = content
self.model = model
self.provider = provider
self.usage = usage or {}
self.latency_ms = latency_ms
self.metadata = metadata or {}
class ProviderBase(ABC): class ProviderBase(ABC):
"""Abstract base class for LLM providers."""
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@@ -28,14 +27,23 @@ class ProviderBase(ABC):
temperature: float = 0.7, temperature: float = 0.7,
**kwargs, **kwargs,
): ):
"""Initialize provider.
Args:
api_key: API key for authentication.
model: Model identifier to use.
temperature: Sampling temperature (0.0-1.0).
**kwargs: Additional provider-specific options.
"""
self.api_key = api_key self.api_key = api_key
self.model = model self.model = model
self.temperature = temperature self.temperature = temperature
self.extra_kwargs = kwargs self.kwargs = kwargs
@property @property
@abstractmethod @abstractmethod
def name(self) -> str: def name(self) -> str:
"""Provider name identifier."""
pass pass
@abstractmethod @abstractmethod
@@ -46,6 +54,17 @@ class ProviderBase(ABC):
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ProviderResponse: ) -> ProviderResponse:
"""Send a completion request.
Args:
prompt: User prompt to send.
system_prompt: Optional system instructions.
max_tokens: Maximum tokens in response.
**kwargs: Additional provider-specific parameters.
Returns:
ProviderResponse with the generated content.
"""
pass pass
@abstractmethod @abstractmethod
@@ -56,7 +75,32 @@ class ProviderBase(ABC):
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Stream completions incrementally.
Args:
prompt: User prompt to send.
system_prompt: Optional system instructions.
max_tokens: Maximum tokens in response.
**kwargs: Additional provider-specific parameters.
Yields:
Chunks of generated content.
"""
pass pass
@abstractmethod
def validate_api_key(self) -> bool: def validate_api_key(self) -> bool:
return True """Validate that the API key is configured correctly."""
pass
@abstractmethod
def list_models(self) -> List[str]:
"""List available models for this provider."""
pass
def _get_system_prompt(self, prompt: str) -> Optional[str]:
"""Extract system prompt from prompt if using special syntax."""
if "---" in prompt:
parts = prompt.split("---", 1)
return parts[0].strip()
return None