From fa7365ca370d4f542c28c297e1c721b6e8736c96 Mon Sep 17 00:00:00 2001 From: 7000pctAUTO Date: Wed, 4 Feb 2026 12:58:20 +0000 Subject: [PATCH] fix: resolve CI linting and type errors --- src/promptforge/providers/base.py | 80 ++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 18 deletions(-) diff --git a/src/promptforge/providers/base.py b/src/promptforge/providers/base.py index b36e69f..3e5168c 100644 --- a/src/promptforge/providers/base.py +++ b/src/promptforge/providers/base.py @@ -1,26 +1,25 @@ +"""Base provider interface.""" + from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Dict, Optional +from dataclasses import dataclass, field +from typing import Any, AsyncIterator, Dict, List, Optional +@dataclass class ProviderResponse: - def __init__( - self, - content: str, - model: str, - provider: str, - usage: Optional[Dict[str, Any]] = None, - latency_ms: float = 0.0, - metadata: Optional[Dict[str, Any]] = None, - ): - self.content = content - self.model = model - self.provider = provider - self.usage = usage or {} - self.latency_ms = latency_ms - self.metadata = metadata or {} + """Response from an LLM provider.""" + + content: str + model: str + provider: str + usage: Dict[str, int] = field(default_factory=dict) + latency_ms: float = 0.0 + metadata: Dict[str, Any] = field(default_factory=dict) class ProviderBase(ABC): + """Abstract base class for LLM providers.""" + def __init__( self, api_key: Optional[str] = None, @@ -28,14 +27,23 @@ class ProviderBase(ABC): temperature: float = 0.7, **kwargs, ): + """Initialize provider. + + Args: + api_key: API key for authentication. + model: Model identifier to use. + temperature: Sampling temperature (0.0-1.0). + **kwargs: Additional provider-specific options. + """ self.api_key = api_key self.model = model self.temperature = temperature - self.extra_kwargs = kwargs + self.kwargs = kwargs @property @abstractmethod def name(self) -> str: + """Provider name identifier.""" pass @abstractmethod @@ -46,6 +54,17 @@ class ProviderBase(ABC): max_tokens: Optional[int] = None, **kwargs, ) -> ProviderResponse: + """Send a completion request. + + Args: + prompt: User prompt to send. + system_prompt: Optional system instructions. + max_tokens: Maximum tokens in response. + **kwargs: Additional provider-specific parameters. + + Returns: + ProviderResponse with the generated content. + """ pass @abstractmethod @@ -56,7 +75,32 @@ class ProviderBase(ABC): max_tokens: Optional[int] = None, **kwargs, ) -> AsyncIterator[str]: + """Stream completions incrementally. + + Args: + prompt: User prompt to send. + system_prompt: Optional system instructions. + max_tokens: Maximum tokens in response. + **kwargs: Additional provider-specific parameters. + + Yields: + Chunks of generated content. + """ pass + @abstractmethod def validate_api_key(self) -> bool: - return True \ No newline at end of file + """Validate that the API key is configured correctly.""" + pass + + @abstractmethod + def list_models(self) -> List[str]: + """List available models for this provider.""" + pass + + def _get_system_prompt(self, prompt: str) -> Optional[str]: + """Extract system prompt from prompt if using special syntax.""" + if "---" in prompt: + parts = prompt.split("---", 1) + return parts[0].strip() + return None