Initial upload: shell-speak CLI tool with natural language to shell command conversion
This commit is contained in:
123
shell_speak/matcher.py
Normal file
123
shell_speak/matcher.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Pattern matching engine for shell commands."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from shell_speak.library import get_loader
|
||||
from shell_speak.models import CommandPattern, CommandMatch
|
||||
from shell_speak.nlp import normalize_text, extract_keywords, calculate_similarity, tokenize
|
||||
|
||||
|
||||
class PatternMatcher:
|
||||
"""Matches natural language queries to command patterns."""
|
||||
|
||||
def __init__(self):
|
||||
self._loader = get_loader()
|
||||
|
||||
def match(self, query: str, tool: Optional[str] = None) -> Optional[CommandMatch]:
|
||||
"""Match a query to the best command pattern."""
|
||||
normalized_query = normalize_text(query)
|
||||
self._loader.load_library(tool)
|
||||
|
||||
corrections = self._loader.get_corrections()
|
||||
correction_key = f"{tool}:{normalized_query}" if tool else normalized_query
|
||||
|
||||
if correction_key in corrections:
|
||||
return CommandMatch(
|
||||
pattern=CommandPattern(
|
||||
name="user_correction",
|
||||
tool=tool or "custom",
|
||||
description="User-provided correction",
|
||||
patterns=[],
|
||||
template=corrections[correction_key],
|
||||
explanation="Custom command from user correction",
|
||||
),
|
||||
confidence=1.0,
|
||||
matched_query=query,
|
||||
command=corrections[correction_key],
|
||||
explanation="This command was learned from your previous correction.",
|
||||
)
|
||||
|
||||
patterns = self._loader.get_patterns()
|
||||
if tool:
|
||||
patterns = [p for p in patterns if p.tool == tool]
|
||||
|
||||
best_match = None
|
||||
best_score = 0.0
|
||||
|
||||
for pattern in patterns:
|
||||
score = self._calculate_match_score(normalized_query, pattern)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
command = self._substitute_template(normalized_query, pattern)
|
||||
if command:
|
||||
best_match = CommandMatch(
|
||||
pattern=pattern,
|
||||
confidence=score,
|
||||
matched_query=query,
|
||||
command=command,
|
||||
explanation=pattern.explanation or self._generate_explanation(pattern, command),
|
||||
)
|
||||
|
||||
return best_match
|
||||
|
||||
def _calculate_match_score(self, query: str, pattern: CommandPattern) -> float:
|
||||
"""Calculate how well a query matches a pattern."""
|
||||
query_keywords = extract_keywords(query)
|
||||
pattern_keywords = set()
|
||||
|
||||
for ptn in pattern.patterns:
|
||||
pattern_keywords.update(extract_keywords(ptn))
|
||||
|
||||
if not pattern_keywords:
|
||||
return 0.0
|
||||
|
||||
keyword_overlap = len(query_keywords & pattern_keywords)
|
||||
keyword_score = keyword_overlap / len(pattern_keywords) if pattern_keywords else 0.0
|
||||
|
||||
best_similarity = 0.0
|
||||
for ptn in pattern.patterns:
|
||||
sim = calculate_similarity(query, ptn)
|
||||
if sim > best_similarity:
|
||||
best_similarity = sim
|
||||
|
||||
combined_score = (keyword_score * 0.6) + (best_similarity * 0.4)
|
||||
return min(combined_score, 1.0)
|
||||
|
||||
def _substitute_template(self, query: str, pattern: CommandPattern) -> Optional[str]:
|
||||
"""Substitute variables in the template based on query."""
|
||||
template = pattern.template
|
||||
|
||||
query_tokens = set(tokenize(query))
|
||||
pattern_tokens = set()
|
||||
for ptn in pattern.patterns:
|
||||
pattern_tokens.update(tokenize(ptn))
|
||||
|
||||
diff_tokens = query_tokens - pattern_tokens
|
||||
|
||||
variables = re.findall(r'\{(\w+)\}', template)
|
||||
var_values = {}
|
||||
|
||||
for var in variables:
|
||||
lower_var = var.lower()
|
||||
matching_tokens = [t for t in diff_tokens if lower_var in t.lower() or t.lower() in lower_var]
|
||||
if matching_tokens:
|
||||
var_values[var] = matching_tokens[0]
|
||||
|
||||
result = template
|
||||
for var, value in var_values.items():
|
||||
result = result.replace(f'{{{var}}}', value)
|
||||
|
||||
if re.search(r'\{\w+\}', result):
|
||||
return None
|
||||
|
||||
return result
|
||||
|
||||
def _generate_explanation(self, pattern: CommandPattern, command: str) -> str:
|
||||
"""Generate an explanation for the command."""
|
||||
return f"{pattern.description}"
|
||||
|
||||
|
||||
def get_matcher() -> PatternMatcher:
|
||||
"""Get the global pattern matcher."""
|
||||
return PatternMatcher()
|
||||
Reference in New Issue
Block a user