fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:20 +00:00
parent eabd05b6c4
commit fa7365ca37

View File

@@ -1,26 +1,25 @@
"""Base provider interface."""
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict, Optional
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Dict, List, Optional
@dataclass
class ProviderResponse:
def __init__(
self,
content: str,
model: str,
provider: str,
usage: Optional[Dict[str, Any]] = None,
latency_ms: float = 0.0,
metadata: Optional[Dict[str, Any]] = None,
):
self.content = content
self.model = model
self.provider = provider
self.usage = usage or {}
self.latency_ms = latency_ms
self.metadata = metadata or {}
"""Response from an LLM provider."""
content: str
model: str
provider: str
usage: Dict[str, int] = field(default_factory=dict)
latency_ms: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
class ProviderBase(ABC):
"""Abstract base class for LLM providers."""
def __init__(
self,
api_key: Optional[str] = None,
@@ -28,14 +27,23 @@ class ProviderBase(ABC):
temperature: float = 0.7,
**kwargs,
):
"""Initialize provider.
Args:
api_key: API key for authentication.
model: Model identifier to use.
temperature: Sampling temperature (0.0-1.0).
**kwargs: Additional provider-specific options.
"""
self.api_key = api_key
self.model = model
self.temperature = temperature
self.extra_kwargs = kwargs
self.kwargs = kwargs
@property
@abstractmethod
def name(self) -> str:
"""Provider name identifier."""
pass
@abstractmethod
@@ -46,6 +54,17 @@ class ProviderBase(ABC):
max_tokens: Optional[int] = None,
**kwargs,
) -> ProviderResponse:
"""Send a completion request.
Args:
prompt: User prompt to send.
system_prompt: Optional system instructions.
max_tokens: Maximum tokens in response.
**kwargs: Additional provider-specific parameters.
Returns:
ProviderResponse with the generated content.
"""
pass
@abstractmethod
@@ -56,7 +75,32 @@ class ProviderBase(ABC):
max_tokens: Optional[int] = None,
**kwargs,
) -> AsyncIterator[str]:
"""Stream completions incrementally.
Args:
prompt: User prompt to send.
system_prompt: Optional system instructions.
max_tokens: Maximum tokens in response.
**kwargs: Additional provider-specific parameters.
Yields:
Chunks of generated content.
"""
pass
@abstractmethod
def validate_api_key(self) -> bool:
return True
"""Validate that the API key is configured correctly."""
pass
@abstractmethod
def list_models(self) -> List[str]:
"""List available models for this provider."""
pass
def _get_system_prompt(self, prompt: str) -> Optional[str]:
"""Extract system prompt from prompt if using special syntax."""
if "---" in prompt:
parts = prompt.split("---", 1)
return parts[0].strip()
return None