Initial upload: shell-speak CLI tool with natural language to shell command conversion
This commit is contained in:
123
shell_speak/matcher.py
Normal file
123
shell_speak/matcher.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Pattern matching engine for shell commands."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from shell_speak.library import get_loader
|
||||||
|
from shell_speak.models import CommandPattern, CommandMatch
|
||||||
|
from shell_speak.nlp import normalize_text, extract_keywords, calculate_similarity, tokenize
|
||||||
|
|
||||||
|
|
||||||
|
class PatternMatcher:
|
||||||
|
"""Matches natural language queries to command patterns."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._loader = get_loader()
|
||||||
|
|
||||||
|
def match(self, query: str, tool: Optional[str] = None) -> Optional[CommandMatch]:
|
||||||
|
"""Match a query to the best command pattern."""
|
||||||
|
normalized_query = normalize_text(query)
|
||||||
|
self._loader.load_library(tool)
|
||||||
|
|
||||||
|
corrections = self._loader.get_corrections()
|
||||||
|
correction_key = f"{tool}:{normalized_query}" if tool else normalized_query
|
||||||
|
|
||||||
|
if correction_key in corrections:
|
||||||
|
return CommandMatch(
|
||||||
|
pattern=CommandPattern(
|
||||||
|
name="user_correction",
|
||||||
|
tool=tool or "custom",
|
||||||
|
description="User-provided correction",
|
||||||
|
patterns=[],
|
||||||
|
template=corrections[correction_key],
|
||||||
|
explanation="Custom command from user correction",
|
||||||
|
),
|
||||||
|
confidence=1.0,
|
||||||
|
matched_query=query,
|
||||||
|
command=corrections[correction_key],
|
||||||
|
explanation="This command was learned from your previous correction.",
|
||||||
|
)
|
||||||
|
|
||||||
|
patterns = self._loader.get_patterns()
|
||||||
|
if tool:
|
||||||
|
patterns = [p for p in patterns if p.tool == tool]
|
||||||
|
|
||||||
|
best_match = None
|
||||||
|
best_score = 0.0
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
score = self._calculate_match_score(normalized_query, pattern)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
command = self._substitute_template(normalized_query, pattern)
|
||||||
|
if command:
|
||||||
|
best_match = CommandMatch(
|
||||||
|
pattern=pattern,
|
||||||
|
confidence=score,
|
||||||
|
matched_query=query,
|
||||||
|
command=command,
|
||||||
|
explanation=pattern.explanation or self._generate_explanation(pattern, command),
|
||||||
|
)
|
||||||
|
|
||||||
|
return best_match
|
||||||
|
|
||||||
|
def _calculate_match_score(self, query: str, pattern: CommandPattern) -> float:
|
||||||
|
"""Calculate how well a query matches a pattern."""
|
||||||
|
query_keywords = extract_keywords(query)
|
||||||
|
pattern_keywords = set()
|
||||||
|
|
||||||
|
for ptn in pattern.patterns:
|
||||||
|
pattern_keywords.update(extract_keywords(ptn))
|
||||||
|
|
||||||
|
if not pattern_keywords:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
keyword_overlap = len(query_keywords & pattern_keywords)
|
||||||
|
keyword_score = keyword_overlap / len(pattern_keywords) if pattern_keywords else 0.0
|
||||||
|
|
||||||
|
best_similarity = 0.0
|
||||||
|
for ptn in pattern.patterns:
|
||||||
|
sim = calculate_similarity(query, ptn)
|
||||||
|
if sim > best_similarity:
|
||||||
|
best_similarity = sim
|
||||||
|
|
||||||
|
combined_score = (keyword_score * 0.6) + (best_similarity * 0.4)
|
||||||
|
return min(combined_score, 1.0)
|
||||||
|
|
||||||
|
def _substitute_template(self, query: str, pattern: CommandPattern) -> Optional[str]:
|
||||||
|
"""Substitute variables in the template based on query."""
|
||||||
|
template = pattern.template
|
||||||
|
|
||||||
|
query_tokens = set(tokenize(query))
|
||||||
|
pattern_tokens = set()
|
||||||
|
for ptn in pattern.patterns:
|
||||||
|
pattern_tokens.update(tokenize(ptn))
|
||||||
|
|
||||||
|
diff_tokens = query_tokens - pattern_tokens
|
||||||
|
|
||||||
|
variables = re.findall(r'\{(\w+)\}', template)
|
||||||
|
var_values = {}
|
||||||
|
|
||||||
|
for var in variables:
|
||||||
|
lower_var = var.lower()
|
||||||
|
matching_tokens = [t for t in diff_tokens if lower_var in t.lower() or t.lower() in lower_var]
|
||||||
|
if matching_tokens:
|
||||||
|
var_values[var] = matching_tokens[0]
|
||||||
|
|
||||||
|
result = template
|
||||||
|
for var, value in var_values.items():
|
||||||
|
result = result.replace(f'{{{var}}}', value)
|
||||||
|
|
||||||
|
if re.search(r'\{\w+\}', result):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _generate_explanation(self, pattern: CommandPattern, command: str) -> str:
|
||||||
|
"""Generate an explanation for the command."""
|
||||||
|
return f"{pattern.description}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_matcher() -> PatternMatcher:
|
||||||
|
"""Get the global pattern matcher."""
|
||||||
|
return PatternMatcher()
|
||||||
Reference in New Issue
Block a user