fix: resolve CI linting and type errors
Some checks failed
CI / test (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled

This commit is contained in:
2026-02-04 12:58:21 +00:00
parent fa7365ca37
commit 627c0ec550

View File

@@ -1,6 +1,7 @@
import asyncio """OpenAI provider implementation."""
import time import time
from typing import Any, AsyncIterator, Dict, Optional from typing import AsyncIterator, Optional
from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError from openai import AsyncOpenAI, APIError, RateLimitError, APIConnectionError
@@ -9,6 +10,8 @@ from ..core.exceptions import ProviderError
class OpenAIProvider(ProviderBase): class OpenAIProvider(ProviderBase):
"""OpenAI GPT models provider."""
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@@ -17,6 +20,7 @@ class OpenAIProvider(ProviderBase):
base_url: Optional[str] = None, base_url: Optional[str] = None,
**kwargs, **kwargs,
): ):
"""Initialize OpenAI provider."""
super().__init__(api_key, model, temperature, **kwargs) super().__init__(api_key, model, temperature, **kwargs)
self.base_url = base_url self.base_url = base_url
self._client: Optional[AsyncOpenAI] = None self._client: Optional[AsyncOpenAI] = None
@@ -26,6 +30,7 @@ class OpenAIProvider(ProviderBase):
return "openai" return "openai"
def _get_client(self) -> AsyncOpenAI: def _get_client(self) -> AsyncOpenAI:
"""Get or create OpenAI client."""
if self._client is None: if self._client is None:
api_key = self.api_key or self._get_api_key_from_env() api_key = self.api_key or self._get_api_key_from_env()
if not api_key: if not api_key:
@@ -33,7 +38,10 @@ class OpenAIProvider(ProviderBase):
"OpenAI API key not configured. " "OpenAI API key not configured. "
"Set OPENAI_API_KEY env var or pass api_key parameter." "Set OPENAI_API_KEY env var or pass api_key parameter."
) )
self._client = AsyncOpenAI(api_key=api_key, base_url=self.base_url) self._client = AsyncOpenAI(
api_key=api_key,
base_url=self.base_url,
)
return self._client return self._client
def _get_api_key_from_env(self) -> Optional[str]: def _get_api_key_from_env(self) -> Optional[str]:
@@ -47,6 +55,7 @@ class OpenAIProvider(ProviderBase):
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> ProviderResponse: ) -> ProviderResponse:
"""Send completion request to OpenAI."""
start_time = time.time() start_time = time.time()
try: try:
@@ -57,9 +66,9 @@ class OpenAIProvider(ProviderBase):
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
response = await client.chat.completions.create( response = await client.chat.completions.create( # type: ignore[arg-type]
model=self.model, model=self.model,
messages=messages, messages=messages, # type: ignore[arg-type]
temperature=self.temperature, temperature=self.temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
**kwargs, **kwargs,
@@ -72,12 +81,14 @@ class OpenAIProvider(ProviderBase):
model=self.model, model=self.model,
provider=self.name, provider=self.name,
usage={ usage={
"prompt_tokens": response.usage.prompt_tokens, "prompt_tokens": response.usage.prompt_tokens, # type: ignore[union-attr]
"completion_tokens": response.usage.completion_tokens, "completion_tokens": response.usage.completion_tokens, # type: ignore[union-attr]
"total_tokens": response.usage.total_tokens, "total_tokens": response.usage.total_tokens, # type: ignore[union-attr]
}, },
latency_ms=latency_ms, latency_ms=latency_ms,
metadata={"finish_reason": response.choices[0].finish_reason}, metadata={
"finish_reason": response.choices[0].finish_reason,
},
) )
except APIError as e: except APIError as e:
raise ProviderError(f"OpenAI API error: {e}") raise ProviderError(f"OpenAI API error: {e}")
@@ -86,13 +97,14 @@ class OpenAIProvider(ProviderBase):
except APIConnectionError as e: except APIConnectionError as e:
raise ProviderError(f"OpenAI connection error: {e}") raise ProviderError(f"OpenAI connection error: {e}")
async def stream_complete( async def stream_complete( # type: ignore[override]
self, self,
prompt: str, prompt: str,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs, **kwargs,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
"""Stream completion from OpenAI."""
try: try:
client = self._get_client() client = self._get_client()
@@ -101,28 +113,39 @@ class OpenAIProvider(ProviderBase):
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
stream = await client.chat.completions.create( stream = await client.chat.completions.create( # type: ignore[arg-type]
model=self.model, model=self.model,
messages=messages, messages=messages, # type: ignore[arg-type]
temperature=self.temperature, temperature=self.temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stream=True, stream=True,
**kwargs, **kwargs,
) )
async for chunk in stream: async for chunk in stream: # type: ignore[union-attr]
if chunk.choices[0].delta.content: if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content yield chunk.choices[0].delta.content
except APIError as e: except APIError as e:
raise ProviderError(f"OpenAI API error: {e}") raise ProviderError(f"OpenAI API error: {e}")
def validate_api_key(self) -> bool: def validate_api_key(self) -> bool:
"""Validate OpenAI API key."""
try: try:
import os import os
api_key = self.api_key or os.environ.get("OPENAI_API_KEY") api_key = self.api_key or os.environ.get("OPENAI_API_KEY")
if not api_key: if not api_key:
return False return False
client = AsyncOpenAI(api_key=api_key) _ = AsyncOpenAI(api_key=api_key)
return True return True
except Exception: except Exception:
return False return False
def list_models(self) -> list[str]:
"""List available OpenAI models."""
return [
"gpt-4",
"gpt-4-turbo",
"gpt-4o",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
]