diff --git a/src/promptforge/providers/openai.py b/src/promptforge/providers/openai.py index 837a110..1686f86 100644 --- a/src/promptforge/providers/openai.py +++ b/src/promptforge/providers/openai.py @@ -1,6 +1,7 @@ -import asyncio +"""OpenAI provider implementation.""" + import time -from typing import Any, AsyncIterator, Dict, Optional +from typing import AsyncIterator, Optional from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError @@ -9,6 +10,8 @@ from ..core.exceptions import ProviderError class OpenAIProvider(ProviderBase): + """OpenAI GPT models provider.""" + def __init__( self, api_key: Optional[str] = None, @@ -17,6 +20,7 @@ class OpenAIProvider(ProviderBase): base_url: Optional[str] = None, **kwargs, ): + """Initialize OpenAI provider.""" super().__init__(api_key, model, temperature, **kwargs) self.base_url = base_url self._client: Optional[AsyncOpenAI] = None @@ -26,6 +30,7 @@ class OpenAIProvider(ProviderBase): return "openai" def _get_client(self) -> AsyncOpenAI: + """Get or create OpenAI client.""" if self._client is None: api_key = self.api_key or self._get_api_key_from_env() if not api_key: @@ -33,7 +38,10 @@ class OpenAIProvider(ProviderBase): "OpenAI API key not configured. " "Set OPENAI_API_KEY env var or pass api_key parameter." ) - self._client = AsyncOpenAI(api_key=api_key, base_url=self.base_url) + self._client = AsyncOpenAI( + api_key=api_key, + base_url=self.base_url, + ) return self._client def _get_api_key_from_env(self) -> Optional[str]: @@ -47,6 +55,7 @@ class OpenAIProvider(ProviderBase): max_tokens: Optional[int] = None, **kwargs, ) -> ProviderResponse: + """Send completion request to OpenAI.""" start_time = time.time() try: @@ -57,9 +66,9 @@ class OpenAIProvider(ProviderBase): messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) - response = await client.chat.completions.create( + response = await client.chat.completions.create( # type: ignore[arg-type] model=self.model, - messages=messages, + messages=messages, # type: ignore[arg-type] temperature=self.temperature, max_tokens=max_tokens, **kwargs, @@ -72,12 +81,14 @@ class OpenAIProvider(ProviderBase): model=self.model, provider=self.name, usage={ - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, + "prompt_tokens": response.usage.prompt_tokens, # type: ignore[union-attr] + "completion_tokens": response.usage.completion_tokens, # type: ignore[union-attr] + "total_tokens": response.usage.total_tokens, # type: ignore[union-attr] }, latency_ms=latency_ms, - metadata={"finish_reason": response.choices[0].finish_reason}, + metadata={ + "finish_reason": response.choices[0].finish_reason, + }, ) except APIError as e: raise ProviderError(f"OpenAI API error: {e}") @@ -86,13 +97,14 @@ class OpenAIProvider(ProviderBase): except APIConnectionError as e: raise ProviderError(f"OpenAI connection error: {e}") - async def stream_complete( + async def stream_complete( # type: ignore[override] self, prompt: str, system_prompt: Optional[str] = None, max_tokens: Optional[int] = None, **kwargs, ) -> AsyncIterator[str]: + """Stream completion from OpenAI.""" try: client = self._get_client() @@ -101,28 +113,39 @@ class OpenAIProvider(ProviderBase): messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) - stream = await client.chat.completions.create( + stream = await client.chat.completions.create( # type: ignore[arg-type] model=self.model, - messages=messages, + messages=messages, # type: ignore[arg-type] temperature=self.temperature, max_tokens=max_tokens, stream=True, **kwargs, ) - async for chunk in stream: + async for chunk in stream: # type: ignore[union-attr] if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content except APIError as e: raise ProviderError(f"OpenAI API error: {e}") def validate_api_key(self) -> bool: + """Validate OpenAI API key.""" try: import os api_key = self.api_key or os.environ.get("OPENAI_API_KEY") if not api_key: return False - client = AsyncOpenAI(api_key=api_key) + _ = AsyncOpenAI(api_key=api_key) return True except Exception: - return False \ No newline at end of file + return False + + def list_models(self) -> list[str]: + """List available OpenAI models.""" + return [ + "gpt-4", + "gpt-4-turbo", + "gpt-4o", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + ]