fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:23 +00:00
parent 627c0ec550
commit 03ed9d92b2

View File

@@ -1,6 +1,7 @@
import asyncio """Anthropic provider implementation."""
import time import time
from typing import Any, AsyncIterator, Dict, Optional from typing import AsyncIterator, Optional
from anthropic import Anthropic, APIError, RateLimitError from anthropic import Anthropic, APIError, RateLimitError
@@ -9,6 +10,8 @@ from ..core.exceptions import ProviderError
class AnthropicProvider(ProviderBase): class AnthropicProvider(ProviderBase):
"""Anthropic Claude models provider."""
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@@ -16,6 +19,7 @@ class AnthropicProvider(ProviderBase):
temperature: float = 0.7, temperature: float = 0.7,
**kwargs, **kwargs,
): ):
"""Initialize Anthropic provider."""
super().__init__(api_key, model, temperature, **kwargs) super().__init__(api_key, model, temperature, **kwargs)
self._client: Optional[Anthropic] = None self._client: Optional[Anthropic] = None
@@ -24,6 +28,7 @@ class AnthropicProvider(ProviderBase):
return "anthropic" return "anthropic"
def _get_client(self) -> Anthropic: def _get_client(self) -> Anthropic:
"""Get or create Anthropic client."""
if self._client is None: if self._client is None:
api_key = self.api_key or self._get_api_key_from_env() api_key = self.api_key or self._get_api_key_from_env()
if not api_key: if not api_key:
@@ -45,25 +50,37 @@ class AnthropicProvider(ProviderBase):
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ProviderResponse: ) -> ProviderResponse:
"""Send completion request to Anthropic."""
start_time = time.time() start_time = time.time()
try: try:
client = self._get_client() client = self._get_client()
messages = [{"role": "user", "content": prompt}] if system_prompt:
system_message = system_prompt
user_message = prompt
else:
system_message = None
user_message = prompt
response = client.messages.create( response = client.messages.create( # type: ignore[arg-type]
model=self.model, model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=max_tokens or 4096, max_tokens=max_tokens or 4096,
temperature=self.temperature,
system=system_message, # type: ignore[arg-type]
messages=[{"role": "user", "content": user_message}],
**kwargs, **kwargs,
) )
latency_ms = (time.time() - start_time) * 1000 latency_ms = (time.time() - start_time) * 1000
content = ""
for block in response.content:
if block.type == "text":
content += block.text
return ProviderResponse( return ProviderResponse(
content=response.content[0].text, content=content,
model=self.model, model=self.model,
provider=self.name, provider=self.name,
usage={ usage={
@@ -71,29 +88,39 @@ class AnthropicProvider(ProviderBase):
"output_tokens": response.usage.output_tokens, "output_tokens": response.usage.output_tokens,
}, },
latency_ms=latency_ms, latency_ms=latency_ms,
metadata={
"stop_reason": response.stop_reason,
},
) )
except APIError as e: except APIError as e:
raise ProviderError(f"Anthropic API error: {e}") raise ProviderError(f"Anthropic API error: {e}")
except RateLimitError as e: except RateLimitError as e:
raise ProviderError(f"Anthropic rate limit exceeded: {e}") raise ProviderError(f"Anthropic rate limit exceeded: {e}")
async def stream_complete( async def stream_complete( # type: ignore[override]
self, self,
prompt: str, prompt: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Stream completion from Anthropic."""
try: try:
client = self._get_client() client = self._get_client()
messages = [{"role": "user", "content": prompt}] if system_prompt:
system_message = system_prompt
user_message = prompt
else:
system_message = None
user_message = prompt
with client.messages.stream( with client.messages.stream( # type: ignore[arg-type]
model=self.model, model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=max_tokens or 4096, max_tokens=max_tokens or 4096,
temperature=self.temperature,
system=system_message, # type: ignore[arg-type]
messages=[{"role": "user", "content": user_message}],
**kwargs, **kwargs,
) as stream: ) as stream:
for text in stream.text_stream: for text in stream.text_stream:
@@ -102,12 +129,24 @@ class AnthropicProvider(ProviderBase):
raise ProviderError(f"Anthropic API error: {e}") raise ProviderError(f"Anthropic API error: {e}")
def validate_api_key(self) -> bool: def validate_api_key(self) -> bool:
"""Validate Anthropic API key."""
try: try:
import os import os
api_key = self.api_key or os.environ.get("ANTHROPIC_API_KEY") api_key = self.api_key or os.environ.get("ANTHROPIC_API_KEY")
if not api_key: if not api_key:
return False return False
client = Anthropic(api_key=api_key) _ = Anthropic(api_key=api_key)
return True return True
except Exception: except Exception:
return False return False
def list_models(self) -> list[str]:
"""List available Anthropic models."""
return [
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"claude-2.1",
"claude-2.0",
"claude-instant-1.2",
]