diff --git a/git_commit_ai/core/cache.py b/git_commit_ai/core/cache.py new file mode 100644 index 0000000..83065a7 --- /dev/null +++ b/git_commit_ai/core/cache.py @@ -0,0 +1,142 @@ +"""Caching mechanism for Git Commit AI.""" + +import hashlib +import json +import logging +from pathlib import Path +from typing import Any, Optional + +from git_commit_ai.core.config import Config, get_config + +logger = logging.getLogger(__name__) + + +class CacheManager: + """Manager for caching generated commit messages.""" + + def __init__(self, config: Optional[Config] = None): + self.config = config or get_config() + self._cache_dir = Path(self.config.cache_directory) + self._enabled = self.config.cache_enabled + self._ttl_hours = self.config.cache_ttl_hours + self._ensure_cache_dir() + + def _ensure_cache_dir(self) -> None: + self._cache_dir.mkdir(parents=True, exist_ok=True) + + def _generate_cache_key(self, diff: str, **kwargs: Any) -> str: + key_parts = {"diff": diff, "conventional": kwargs.get("conventional", False), "model": kwargs.get("model", "")} + key_str = json.dumps(key_parts, sort_keys=True) + return hashlib.md5(key_str.encode()).hexdigest() + + def _get_cache_path(self, key: str) -> Path: + return self._cache_dir / f"{key}.json" + + def get(self, diff: str, **kwargs: Any) -> Optional[list[str]]: + if not self._enabled: + return None + + key = self._generate_cache_key(diff, **kwargs) + cache_path = self._get_cache_path(key) + + if not cache_path.exists(): + return None + + try: + with open(cache_path, 'r') as f: + cache_data = json.load(f) + + if self._is_expired(cache_data): + cache_path.unlink() + return None + + return cache_data.get("messages", []) + except (json.JSONDecodeError, IOError) as e: + logger.warning(f"Failed to read cache: {e}") + return None + + def set(self, diff: str, messages: list[str], **kwargs: Any) -> bool: + if not self._enabled: + return False + + key = self._generate_cache_key(diff, **kwargs) + cache_path = self._get_cache_path(key) + + cache_data = {"key": key, "messages": messages, "created_at": self._get_timestamp(), "expires_at": self._get_expiration()} + + try: + with open(cache_path, 'w') as f: + json.dump(cache_data, f, indent=2) + return True + except IOError as e: + logger.warning(f"Failed to write cache: {e}") + return False + + def _is_expired(self, cache_data: dict) -> bool: + if self._ttl_hours <= 0: + return False + + expires_at = cache_data.get("expires_at") + if not expires_at: + return True + + return self._get_timestamp() > expires_at + + def _get_timestamp(self) -> int: + return int(__import__('time').time()) + + def _get_expiration(self) -> int: + return self._get_timestamp() + (self._ttl_hours * 3600) + + def cleanup(self) -> int: + if not self._enabled: + return 0 + + cleaned = 0 + for cache_file in self._cache_dir.glob("*.json"): + try: + with open(cache_file, 'r') as f: + cache_data = json.load(f) + if self._is_expired(cache_data): + cache_file.unlink() + cleaned += 1 + except (json.JSONDecodeError, IOError): + cache_file.unlink() + cleaned += 1 + + return cleaned + + def clear(self) -> int: + if not self._enabled: + return 0 + + cleared = 0 + for cache_file in self._cache_dir.glob("*.json"): + cache_file.unlink() + cleared += 1 + + return cleared + + def get_stats(self) -> dict: + stats = {"enabled": self._enabled, "directory": str(self._cache_dir), "entries": 0, "size_bytes": 0, "expired": 0} + + if not self._cache_dir.exists(): + return stats + + for cache_file in self._cache_dir.glob("*.json"): + stats["entries"] += 1 + stats["size_bytes"] += cache_file.stat().st_size + + try: + with open(cache_file, 'r') as f: + cache_data = json.load(f) + if self._is_expired(cache_data): + stats["expired"] += 1 + except (json.JSONDecodeError, IOError): + stats["expired"] += 1 + + return stats + + +def get_cache_manager(config: Optional[Config] = None) -> CacheManager: + return CacheManager(config)