From 7043f6071533042b101443f974e76df7bff11717 Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 12:32:10 +0000 Subject: [PATCH] Add provider implementations (OpenAI, Anthropic, Ollama) --- src/promptforge/providers/openai.py | 128 ++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 src/promptforge/providers/openai.py diff --git a/src/promptforge/providers/openai.py b/src/promptforge/providers/openai.py new file mode 100644 index 0000000..837a110 --- /dev/null +++ b/src/promptforge/providers/openai.py @@ -0,0 +1,128 @@ +import asyncio +import time +from typing import Any, AsyncIterator, Dict, Optional + +from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError + +from .base import ProviderBase, ProviderResponse +from ..core.exceptions import ProviderError + + +class OpenAIProvider(ProviderBase): + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-4", + temperature: float = 0.7, + base_url: Optional[str] = None, + **kwargs, + ): + super().__init__(api_key, model, temperature, **kwargs) + self.base_url = base_url + self._client: Optional[AsyncOpenAI] = None + + @property + def name(self) -> str: + return "openai" + + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + api_key = self.api_key or self._get_api_key_from_env() + if not api_key: + raise ProviderError( + "OpenAI API key not configured. " + "Set OPENAI_API_KEY env var or pass api_key parameter." + ) + self._client = AsyncOpenAI(api_key=api_key, base_url=self.base_url) + return self._client + + def _get_api_key_from_env(self) -> Optional[str]: + import os + return os.environ.get("OPENAI_API_KEY") + + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> ProviderResponse: + start_time = time.time() + + try: + client = self._get_client() + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = await client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=max_tokens, + **kwargs, + ) + + latency_ms = (time.time() - start_time) * 1000 + + return ProviderResponse( + content=response.choices[0].message.content or "", + model=self.model, + provider=self.name, + usage={ + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + }, + latency_ms=latency_ms, + metadata={"finish_reason": response.choices[0].finish_reason}, + ) + except APIError as e: + raise ProviderError(f"OpenAI API error: {e}") + except RateLimitError as e: + raise ProviderError(f"OpenAI rate limit exceeded: {e}") + except APIConnectionError as e: + raise ProviderError(f"OpenAI connection error: {e}") + + async def stream_complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + **kwargs, + ) -> AsyncIterator[str]: + try: + client = self._get_client() + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + stream = await client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + max_tokens=max_tokens, + stream=True, + **kwargs, + ) + + async for chunk in stream: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + except APIError as e: + raise ProviderError(f"OpenAI API error: {e}") + + def validate_api_key(self) -> bool: + try: + import os + api_key = self.api_key or os.environ.get("OPENAI_API_KEY") + if not api_key: + return False + client = AsyncOpenAI(api_key=api_key) + return True + except Exception: + return False \ No newline at end of file