fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,26 +1,25 @@
|
|||||||
|
"""Base provider interface."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, AsyncIterator, Dict, Optional
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class ProviderResponse:
|
class ProviderResponse:
|
||||||
def __init__(
|
"""Response from an LLM provider."""
|
||||||
self,
|
|
||||||
content: str,
|
content: str
|
||||||
model: str,
|
model: str
|
||||||
provider: str,
|
provider: str
|
||||||
usage: Optional[Dict[str, Any]] = None,
|
usage: Dict[str, int] = field(default_factory=dict)
|
||||||
latency_ms: float = 0.0,
|
latency_ms: float = 0.0
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
):
|
|
||||||
self.content = content
|
|
||||||
self.model = model
|
|
||||||
self.provider = provider
|
|
||||||
self.usage = usage or {}
|
|
||||||
self.latency_ms = latency_ms
|
|
||||||
self.metadata = metadata or {}
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderBase(ABC):
|
class ProviderBase(ABC):
|
||||||
|
"""Abstract base class for LLM providers."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
@@ -28,14 +27,23 @@ class ProviderBase(ABC):
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Initialize provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication.
|
||||||
|
model: Model identifier to use.
|
||||||
|
temperature: Sampling temperature (0.0-1.0).
|
||||||
|
**kwargs: Additional provider-specific options.
|
||||||
|
"""
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.extra_kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
"""Provider name identifier."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -46,6 +54,17 @@ class ProviderBase(ABC):
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ProviderResponse:
|
) -> ProviderResponse:
|
||||||
|
"""Send a completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProviderResponse with the generated content.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -56,7 +75,32 @@ class ProviderBase(ABC):
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream completions incrementally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send.
|
||||||
|
system_prompt: Optional system instructions.
|
||||||
|
max_tokens: Maximum tokens in response.
|
||||||
|
**kwargs: Additional provider-specific parameters.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Chunks of generated content.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def validate_api_key(self) -> bool:
|
def validate_api_key(self) -> bool:
|
||||||
return True
|
"""Validate that the API key is configured correctly."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self) -> List[str]:
|
||||||
|
"""List available models for this provider."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_system_prompt(self, prompt: str) -> Optional[str]:
|
||||||
|
"""Extract system prompt from prompt if using special syntax."""
|
||||||
|
if "---" in prompt:
|
||||||
|
parts = prompt.split("---", 1)
|
||||||
|
return parts[0].strip()
|
||||||
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user