fix: resolve CI linting and type errors
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
"""Configuration management for PromptForge."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
@@ -8,6 +10,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
|
|
||||||
class ProviderConfig(BaseModel):
|
class ProviderConfig(BaseModel):
|
||||||
|
"""Configuration for an LLM provider."""
|
||||||
|
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
model: str = "gpt-4"
|
model: str = "gpt-4"
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
@@ -15,21 +19,29 @@ class ProviderConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class RegistryConfig(BaseModel):
|
class RegistryConfig(BaseModel):
|
||||||
|
"""Configuration for the prompt registry."""
|
||||||
|
|
||||||
local_path: str = "~/.promptforge/registry"
|
local_path: str = "~/.promptforge/registry"
|
||||||
remote_url: str = "https://registry.promptforge.io"
|
remote_url: str = "https://registry.promptforge.io"
|
||||||
|
|
||||||
|
|
||||||
class DefaultsConfig(BaseModel):
|
class DefaultsConfig(BaseModel):
|
||||||
|
"""Default settings for PromptForge."""
|
||||||
|
|
||||||
provider: str = "openai"
|
provider: str = "openai"
|
||||||
output_format: str = "text"
|
output_format: str = "text"
|
||||||
|
|
||||||
|
|
||||||
class ValidationConfig(BaseModel):
|
class ValidationConfig(BaseModel):
|
||||||
|
"""Validation settings."""
|
||||||
|
|
||||||
strict_mode: bool = False
|
strict_mode: bool = False
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
"""Main configuration for PromptForge."""
|
||||||
|
|
||||||
providers: Dict[str, ProviderConfig] = Field(default_factory=dict)
|
providers: Dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||||
registry: RegistryConfig = Field(default_factory=RegistryConfig)
|
registry: RegistryConfig = Field(default_factory=RegistryConfig)
|
||||||
defaults: DefaultsConfig = Field(default_factory=DefaultsConfig)
|
defaults: DefaultsConfig = Field(default_factory=DefaultsConfig)
|
||||||
@@ -37,6 +49,7 @@ class Config(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def _expand_env_vars(value: Any) -> Any:
|
def _expand_env_vars(value: Any) -> Any:
|
||||||
|
"""Expand environment variables in a value."""
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if value.startswith("${") and value.endswith("}"):
|
if value.startswith("${") and value.endswith("}"):
|
||||||
env_var = value[2:-1]
|
env_var = value[2:-1]
|
||||||
@@ -45,6 +58,7 @@ def _expand_env_vars(value: Any) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Process configuration dictionary, expanding environment variables."""
|
||||||
processed = {}
|
processed = {}
|
||||||
for key, value in config_dict.items():
|
for key, value in config_dict.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
@@ -55,6 +69,14 @@ def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def load_config(config_path: Optional[Path] = None) -> Config:
|
def load_config(config_path: Optional[Path] = None) -> Config:
|
||||||
|
"""Load configuration from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to configuration file. If None, looks in standard locations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Config object with all settings.
|
||||||
|
"""
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = Path.cwd() / "configs" / "promptforge.yaml"
|
config_path = Path.cwd() / "configs" / "promptforge.yaml"
|
||||||
|
|
||||||
@@ -70,4 +92,5 @@ def load_config(config_path: Optional[Path] = None) -> Config:
|
|||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_config() -> Config:
|
def get_config() -> Config:
|
||||||
return load_config()
|
"""Get cached configuration."""
|
||||||
|
return load_config()
|
||||||
|
|||||||
Reference in New Issue
Block a user