116 lines
4.2 KiB
Python
116 lines
4.2 KiB
Python
import logging
|
|
from collections import Counter
|
|
from datetime import datetime
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
|
|
from .config import Config
|
|
from .models import Command, Suggestion, Pattern, Project
|
|
from .database import Database
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SuggestionEngine:
|
|
def __init__(self, config: Optional[Config] = None, db: Optional[Database] = None):
|
|
self.config = config or Config()
|
|
self.db = db or Database()
|
|
self._markov_chain: Dict[str, List[str]] = {}
|
|
|
|
def train(self, project_id: Optional[int] = None) -> None:
|
|
commands = self.db.get_commands(project_id=project_id, limit=10000)
|
|
self._build_markov_chain(commands)
|
|
|
|
def _build_markov_chain(self, commands: List[Command]) -> None:
|
|
self._markov_chain = {}
|
|
prev_cmd = None
|
|
for cmd in sorted(commands, key=lambda c: c.timestamp):
|
|
cmd_words = cmd.command.split()
|
|
if len(cmd_words) >= 1:
|
|
key = (prev_cmd, cmd_words[0]) if prev_cmd else (None, cmd_words[0])
|
|
if key not in self._markov_chain:
|
|
self._markov_chain[key] = []
|
|
if len(cmd_words) > 1:
|
|
self._markov_chain[key].append(cmd_words[1])
|
|
prev_cmd = cmd_words[0]
|
|
|
|
def get_suggestions(
|
|
self,
|
|
prefix: str,
|
|
project_id: Optional[int] = None,
|
|
limit: int = 10,
|
|
) -> List[Suggestion]:
|
|
prefix_words = prefix.strip().split()
|
|
if not prefix_words:
|
|
return []
|
|
|
|
suggestions = []
|
|
max_suggestions = self.config.get("suggestions.max_suggestions", limit)
|
|
|
|
recent_commands = self.db.get_commands(project_id=project_id, limit=1000)
|
|
frequent = self._get_frequent_commands(limit, project_id)
|
|
|
|
for cmd_str, freq in frequent:
|
|
if cmd_str.startswith(prefix):
|
|
confidence = self._calculate_confidence(cmd_str, freq, recent_commands)
|
|
if confidence >= self.config.get("suggestions.min_confidence", 0.3):
|
|
suggestions.append(
|
|
Suggestion(
|
|
command=cmd_str,
|
|
context=prefix,
|
|
confidence=confidence,
|
|
frequency=freq,
|
|
)
|
|
)
|
|
if len(suggestions) >= max_suggestions:
|
|
break
|
|
|
|
return suggestions
|
|
|
|
def _get_frequent_commands(
|
|
self, limit: int, project_id: Optional[int] = None
|
|
) -> List[Tuple[str, int]]:
|
|
commands = self.db.get_commands(project_id=project_id, limit=5000)
|
|
counter = Counter()
|
|
for cmd in commands:
|
|
counter[cmd.command] += 1
|
|
return counter.most_common(limit)
|
|
|
|
def _calculate_confidence(
|
|
self, command: str, frequency: int, recent_commands: List[Command]
|
|
) -> float:
|
|
recency_weight = self.config.get("suggestions.recency_weight", 0.3)
|
|
frequency_weight = self.config.get("suggestions.frequency_weight", 0.4)
|
|
context_weight = self.config.get("suggestions.context_weight", 0.3)
|
|
|
|
recency_score = 0.0
|
|
if recent_commands:
|
|
recent_count = sum(1 for c in recent_commands if c.command == command)
|
|
recency_score = min(recent_count / max(len(recent_commands), 1), 1.0)
|
|
|
|
max_freq = 100
|
|
frequency_score = min(frequency / max_freq, 1.0)
|
|
|
|
context_score = 1.0
|
|
|
|
return (
|
|
recency_weight * recency_score
|
|
+ frequency_weight * frequency_score
|
|
+ context_weight * context_score
|
|
)
|
|
|
|
def get_autocomplete_candidates(
|
|
self, prefix: str, project_id: Optional[int] = None
|
|
) -> List[str]:
|
|
suggestions = self.get_suggestions(prefix, project_id, limit=5)
|
|
return [s.command for s in suggestions]
|
|
|
|
def get_pattern_suggestions(
|
|
self, project_id: Optional[int] = None
|
|
) -> List[Pattern]:
|
|
return self.db.get_patterns(project_id=project_id)
|
|
|
|
def update_suggestion_usage(self, suggestion: Suggestion) -> None:
|
|
suggestion.frequency += 1
|
|
suggestion.last_used = datetime.utcnow()
|
|
self.db.create_suggestion(suggestion)
|