diff --git a/src/promptforge/providers/anthropic.py b/src/promptforge/providers/anthropic.py new file mode 100644 index 0000000..0979381 --- /dev/null +++ b/src/promptforge/providers/anthropic.py @@ -0,0 +1,113 @@ +import asyncio +import time +from typing import Any, AsyncIterator, Dict, Optional + +from anthropic import Anthropic, APIError, RateLimitError + +from .base import ProviderBase, ProviderResponse +from ..core.exceptions import ProviderError + + +class AnthropicProvider(ProviderBase): + def __init__( + self, + api_key: Optional[str] = None, + model: str = "claude-3-sonnet-20240229", + temperature: float = 0.7, + **kwargs, + ): + super().__init__(api_key, model, temperature, **kwargs) + self._client: Optional[Anthropic] = None + + @property + def name(self) -> str: + return "anthropic" + + def _get_client(self) -> Anthropic: + if self._client is None: + api_key = self.api_key or self._get_api_key_from_env() + if not api_key: + raise ProviderError( + "Anthropic API key not configured. " + "Set ANTHROPIC_API_KEY env var or pass api_key parameter." + ) + self._client = Anthropic(api_key=api_key) + return self._client + + def _get_api_key_from_env(self) -> Optional[str]: + import os + return os.environ.get("ANTHROPIC_API_KEY") + + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> ProviderResponse: + start_time = time.time() + + try: + client = self._get_client() + + messages = [{"role": "user", "content": prompt}] + + response = client.messages.create( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=max_tokens or 4096, + **kwargs, + ) + + latency_ms = (time.time() - start_time) * 1000 + + return ProviderResponse( + content=response.content[0].text, + model=self.model, + provider=self.name, + usage={ + "input_tokens": response.usage.input_tokens, + "output_tokens": response.usage.output_tokens, + }, + latency_ms=latency_ms, + ) + except APIError as e: + raise ProviderError(f"Anthropic API error: {e}") + except RateLimitError as e: + raise ProviderError(f"Anthropic rate limit exceeded: {e}") + + async def stream_complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> AsyncIterator[str]: + try: + client = self._get_client() + + messages = [{"role": "user", "content": prompt}] + + with client.messages.stream( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=max_tokens or 4096, + **kwargs, + ) as stream: + for text in stream.text_stream: + yield text + except APIError as e: + raise ProviderError(f"Anthropic API error: {e}") + + def validate_api_key(self) -> bool: + try: + import os + api_key = self.api_key or os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + return False + client = Anthropic(api_key=api_key) + return True + except Exception: + return False \ No newline at end of file