diff --git a/app/src/promptforge/providers/openai.py b/app/src/promptforge/providers/openai.py new file mode 100644 index 0000000..9d404d5 --- /dev/null +++ b/app/src/promptforge/providers/openai.py @@ -0,0 +1,149 @@ +import time +from typing import AsyncIterator, Optional + +from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError + +from .base import ProviderBase, ProviderResponse +from ..core.exceptions import ProviderError + + +class OpenAIProvider(ProviderBase): + """OpenAI GPT models provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-4", + temperature: float = 0.7, + base_url: Optional[str] = None, + **kwargs, + ): + """Initialize OpenAI provider.""" + super().__init__(api_key, model, temperature, **kwargs) + self.base_url = base_url + self._client: Optional[AsyncOpenAI] = None + + @property + def name(self) -> str: + return "openai" + + def _get_client(self) -> AsyncOpenAI: + """Get or create OpenAI client.""" + if self._client is None: + api_key = self.api_key or self._get_api_key_from_env() + if not api_key: + raise ProviderError( + "OpenAI API key not configured. " + "Set OPENAI_API_KEY env var or pass api_key parameter." + ) + self._client = AsyncOpenAI( + api_key=api_key, + base_url=self.base_url, + ) + return self._client + + def _get_api_key_from_env(self) -> Optional[str]: + import os + return os.environ.get("OPENAI_API_KEY") + + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> ProviderResponse: + """Send completion request to OpenAI.""" + start_time = time.time() + + try: + client = self._get_client() + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = await client.chat.completions.create( # type: ignore[arg-type] + model=self.model, + messages=messages, # type: ignore[union-attr] + temperature=self.temperature, + max_tokens=max_tokens, + **kwargs, + ) + + latency_ms = (time.time() - start_time) * 1000 + + return ProviderResponse( + content=response.choices[0].message.content or "", + model=self.model, + provider=self.name, + usage={ + "prompt_tokens": response.usage.prompt_tokens, # type: ignore[union-attr] + "completion_tokens": response.usage.completion_tokens, # type: ignore[union-attr] + "total_tokens": response.usage.total_tokens, # type: ignore[union-attr] + }, + latency_ms=latency_ms, + metadata={ + "finish_reason": response.choices[0].finish_reason, + }, + ) + except APIError as e: + raise ProviderError(f"OpenAI API error: {e}") + except RateLimitError as e: + raise ProviderError(f"OpenAI rate limit exceeded: {e}") + except APIConnectionError as e: + raise ProviderError(f"OpenAI connection error: {e}") + + async def stream_complete( # type: ignore[override] + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> AsyncIterator[str]: + """Stream completion from OpenAI.""" + try: + client = self._get_client() + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + stream = await client.chat.completions.create( # type: ignore[arg-type] + model=self.model, + messages=messages, # type: ignore[union-attr] + temperature=self.temperature, + max_tokens=max_tokens, + stream=True, + **kwargs, + ) + + async for chunk in stream: # type: ignore[union-attr] + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + except APIError as e: + raise ProviderError(f"OpenAI API error: {e}") + + def validate_api_key(self) -> bool: + """Validate OpenAI API key.""" + try: + import os + api_key = self.api_key or os.environ.get("OPENAI_API_KEY") + if not api_key: + return False + _ = AsyncOpenAI(api_key=api_key) + return True + except Exception: + return False + + def list_models(self) -> list[str]: + """List available OpenAI models.""" + return [ + "gpt-4", + "gpt-4-turbo", + "gpt-4o", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + ]