feat: add suggestion engine for smart autocomplete
This commit is contained in:
115
cli_memory/suggestions.py
Normal file
115
cli_memory/suggestions.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
|
||||
from .config import Config
|
||||
from .models import Command, Suggestion, Pattern, Project
|
||||
from .database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestionEngine:
|
||||
def __init__(self, config: Optional[Config] = None, db: Optional[Database] = None):
|
||||
self.config = config or Config()
|
||||
self.db = db or Database()
|
||||
self._markov_chain: Dict[str, List[str]] = {}
|
||||
|
||||
def train(self, project_id: Optional[int] = None) -> None:
|
||||
commands = self.db.get_commands(project_id=project_id, limit=10000)
|
||||
self._build_markov_chain(commands)
|
||||
|
||||
def _build_markov_chain(self, commands: List[Command]) -> None:
|
||||
self._markov_chain = {}
|
||||
prev_cmd = None
|
||||
for cmd in sorted(commands, key=lambda c: c.timestamp):
|
||||
cmd_words = cmd.command.split()
|
||||
if len(cmd_words) >= 1:
|
||||
key = (prev_cmd, cmd_words[0]) if prev_cmd else (None, cmd_words[0])
|
||||
if key not in self._markov_chain:
|
||||
self._markov_chain[key] = []
|
||||
if len(cmd_words) > 1:
|
||||
self._markov_chain[key].append(cmd_words[1])
|
||||
prev_cmd = cmd_words[0]
|
||||
|
||||
def get_suggestions(
|
||||
self,
|
||||
prefix: str,
|
||||
project_id: Optional[int] = None,
|
||||
limit: int = 10,
|
||||
) -> List[Suggestion]:
|
||||
prefix_words = prefix.strip().split()
|
||||
if not prefix_words:
|
||||
return []
|
||||
|
||||
suggestions = []
|
||||
max_suggestions = self.config.get("suggestions.max_suggestions", limit)
|
||||
|
||||
recent_commands = self.db.get_commands(project_id=project_id, limit=1000)
|
||||
frequent = self._get_frequent_commands(limit, project_id)
|
||||
|
||||
for cmd_str, freq in frequent:
|
||||
if cmd_str.startswith(prefix):
|
||||
confidence = self._calculate_confidence(cmd_str, freq, recent_commands)
|
||||
if confidence >= self.config.get("suggestions.min_confidence", 0.3):
|
||||
suggestions.append(
|
||||
Suggestion(
|
||||
command=cmd_str,
|
||||
context=prefix,
|
||||
confidence=confidence,
|
||||
frequency=freq,
|
||||
)
|
||||
)
|
||||
if len(suggestions) >= max_suggestions:
|
||||
break
|
||||
|
||||
return suggestions
|
||||
|
||||
def _get_frequent_commands(
|
||||
self, limit: int, project_id: Optional[int] = None
|
||||
) -> List[Tuple[str, int]]:
|
||||
commands = self.db.get_commands(project_id=project_id, limit=5000)
|
||||
counter = Counter()
|
||||
for cmd in commands:
|
||||
counter[cmd.command] += 1
|
||||
return counter.most_common(limit)
|
||||
|
||||
def _calculate_confidence(
|
||||
self, command: str, frequency: int, recent_commands: List[Command]
|
||||
) -> float:
|
||||
recency_weight = self.config.get("suggestions.recency_weight", 0.3)
|
||||
frequency_weight = self.config.get("suggestions.frequency_weight", 0.4)
|
||||
context_weight = self.config.get("suggestions.context_weight", 0.3)
|
||||
|
||||
recency_score = 0.0
|
||||
if recent_commands:
|
||||
recent_count = sum(1 for c in recent_commands if c.command == command)
|
||||
recency_score = min(recent_count / max(len(recent_commands), 1), 1.0)
|
||||
|
||||
max_freq = 100
|
||||
frequency_score = min(frequency / max_freq, 1.0)
|
||||
|
||||
context_score = 1.0
|
||||
|
||||
return (
|
||||
recency_weight * recency_score
|
||||
+ frequency_weight * frequency_score
|
||||
+ context_weight * context_score
|
||||
)
|
||||
|
||||
def get_autocomplete_candidates(
|
||||
self, prefix: str, project_id: Optional[int] = None
|
||||
) -> List[str]:
|
||||
suggestions = self.get_suggestions(prefix, project_id, limit=5)
|
||||
return [s.command for s in suggestions]
|
||||
|
||||
def get_pattern_suggestions(
|
||||
self, project_id: Optional[int] = None
|
||||
) -> List[Pattern]:
|
||||
return self.db.get_patterns(project_id=project_id)
|
||||
|
||||
def update_suggestion_usage(self, suggestion: Suggestion) -> None:
|
||||
suggestion.frequency += 1
|
||||
suggestion.last_used = datetime.utcnow()
|
||||
self.db.create_suggestion(suggestion)
|
||||
Reference in New Issue
Block a user