diff --git a/src/utils/search.py b/src/utils/search.py index 5e3b8fb..682ab08 100644 --- a/src/utils/search.py +++ b/src/utils/search.py @@ -1,66 +1,256 @@ -from typing import List, Dict, Any -from ..core.parser import load_spec_file +"""Search functionality for API documentation.""" + +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple -def search_endpoints(spec_path: str, query: str, limit: int = 10) -> List[Dict[str, Any]]: - """Search for endpoints matching the query. +@dataclass +class SearchResult: + """A single search result.""" - Args: - spec_path: Path to the OpenAPI spec file - query: Search query string - limit: Maximum number of results to return + endpoint_path: str + endpoint_method: str + title: str + description: str + tags: List[str] + matched_fields: List[str] + score: float + snippet: str - Returns: - List of matching endpoints - """ - spec = load_spec_file(spec_path) - results = [] - query_lower = query.lower() - for path, methods in spec.get('paths', {}).items(): - for method, details in methods.items(): - if method.lower() not in ['get', 'post', 'put', 'delete', 'patch', 'options', 'head']: - continue +@dataclass +class SearchIndex: + """Full-text search index for API endpoints.""" + endpoints: List[Dict[str, Any]] = field(default_factory=list) + schemas: Dict[str, Dict[str, Any]] = field(default_factory=dict) + tags: List[Dict[str, str]] = field(default_factory=list) + _index: Dict[str, List[int]] = field(default_factory=dict) + + def build(self, spec_data: Dict[str, Any]) -> None: + """Build the search index from OpenAPI spec data.""" + self.endpoints = [] + self.schemas = {} + self.tags = spec_data.get("tags", []) + self._index = {} + + for path, methods in spec_data.get("paths", {}).items(): + for method, details in methods.items(): + if method in ["get", "post", "put", "patch", "delete", "options", "head", "trace"]: + if hasattr(details, "model_dump"): + details = details.model_dump() + endpoint = { + "path": path, + "method": method.upper(), + "operationId": details.get("operationId"), + "summary": details.get("summary", ""), + "description": details.get("description", ""), + "tags": details.get("tags", []), + "parameters": details.get("parameters", []), + "requestBody": details.get("requestBody"), + "responses": details.get("responses", {}), + "deprecated": details.get("deprecated", False), + } + self.endpoints.append(endpoint) + self._index_endpoint(len(self.endpoints) - 1, endpoint) + + components = spec_data.get("components") or {} + for schema_name, schema_def in components.get("schemas", {}).items(): + self.schemas[schema_name] = { + "name": schema_name, + "description": schema_def.get("description", ""), + "properties": schema_def.get("properties", {}), + } + self._index_schema(schema_name, self.schemas[schema_name]) + + def _tokenize(self, text: str) -> List[str]: + """Tokenize text into searchable terms.""" + if not text: + return [] + text = text.lower() + text = re.sub(r"[^\w\s]", " ", text) + tokens = text.split() + return [t for t in tokens if len(t) > 1] + + def _index_endpoint(self, idx: int, endpoint: Dict[str, Any]) -> None: + """Index a single endpoint.""" + terms = set() + + for token in self._tokenize(endpoint["path"]): + terms.add(token) + + for token in self._tokenize(endpoint.get("summary", "")): + terms.add(token) + + for token in self._tokenize(endpoint.get("description", "")): + terms.add(token) + + for tag in endpoint.get("tags", []): + for token in self._tokenize(tag): + terms.add(token) + + for token in self._tokenize(endpoint.get("operationId", "")): + terms.add(token) + + for term in terms: + if term not in self._index: + self._index[term] = [] + self._index[term].append(idx) + + def _index_schema(self, name: str, schema: Dict[str, Any]) -> None: + """Index a schema definition.""" + terms = {name.lower()} + + for token in self._tokenize(schema.get("description", "")): + terms.add(token) + + for prop_name in schema.get("properties", {}): + terms.add(prop_name.lower()) + + for term in terms: + if term not in self._index: + self._index[term] = [] + self._index[term].append(f"schema:{name}") + + def search(self, query: str, limit: int = 20) -> List[SearchResult]: + """ + Search the index for matching endpoints. + + Args: + query: Search query string + limit: Maximum number of results to return + + Returns: + List of SearchResult objects sorted by relevance + """ + if not query: + return [] + + query_tokens = self._tokenize(query) + if not query_tokens: + return [] + + scores: Dict[int, Tuple[float, List[str]]] = {} + + for token in query_tokens: + matching_indices = self._index.get(token, []) + for idx in matching_indices: + if isinstance(idx, str) and idx.startswith("schema:"): + continue + + if idx not in scores: + scores[idx] = (0.0, []) + current_score, matched_fields = scores[idx] + new_score = current_score + 1.0 + new_matched_fields = matched_fields + [token] + scores[idx] = (new_score, new_matched_fields) + + if not scores: + return self._fuzzy_search(query_tokens, limit) + + results = [] + for idx, (score, matched_fields) in sorted( + scores.items(), key=lambda x: -x[1][0] + )[:limit]: + endpoint = self.endpoints[idx] + snippet = self._create_snippet(endpoint, query_tokens) + + result = SearchResult( + endpoint_path=endpoint["path"], + endpoint_method=endpoint["method"], + title=endpoint.get("summary", endpoint["path"]), + description=endpoint.get("description", ""), + tags=endpoint.get("tags", []), + matched_fields=list(set(matched_fields)), + score=score, + snippet=snippet, + ) + results.append(result) + + return results + + def _fuzzy_search(self, query_tokens: List[str], limit: int) -> List[SearchResult]: + """Perform fuzzy search when exact match fails.""" + results = [] + query = " ".join(query_tokens).lower() + + for idx, endpoint in enumerate(self.endpoints): + text = " ".join([ + endpoint["path"], + endpoint.get("summary", ""), + endpoint.get("description", ""), + ]).lower() + + if query in text: + results.append(SearchResult( + endpoint_path=endpoint["path"], + endpoint_method=endpoint["method"], + title=endpoint.get("summary", endpoint["path"]), + description=endpoint.get("description", ""), + tags=endpoint.get("tags", []), + matched_fields=query_tokens, + score=0.5, + snippet=self._create_snippet(endpoint, query_tokens), + )) + + return sorted(results, key=lambda x: -x.score)[:limit] + + def _create_snippet( + self, endpoint: Dict[str, Any], query_tokens: List[str] + ) -> str: + """Create a snippet showing matched terms in context.""" + description = endpoint.get("description", "") or "" + snippet = description[:150] + if len(description) > 150: + snippet += "..." + + for token in query_tokens: + pattern = re.compile(re.escape(token), re.IGNORECASE) + snippet = pattern.sub(f"**{token.upper()}**", snippet) + + return snippet + + def search_schemas(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: + """Search for schemas matching the query.""" + query_tokens = self._tokenize(query) + results = [] + + for name, schema in self.schemas.items(): match_score = 0 - matches = [] + name_lower = name.lower() - path_match = query_lower in path.lower() - if path_match: - match_score += 10 - matches.append(f"Path: {path}") - - summary = details.get('summary', '') or '' - if query_lower in summary.lower(): - match_score += 5 - matches.append(f"Summary: {summary}") - - description = details.get('description', '') or '' - if query_lower in description.lower(): - match_score += 3 - matches.append(f"Description: {description[:100]}...") - - tags = details.get('tags', []) - for tag in tags: - if query_lower in tag.lower(): - match_score += 4 - matches.append(f"Tag: {tag}") - - operation_id = details.get('operationId', '') or '' - if query_lower in operation_id.lower(): - match_score += 2 + for token in query_tokens: + if token in name_lower: + match_score += 2 + if token in schema.get("description", "").lower(): + match_score += 1 if match_score > 0: results.append({ - 'path': path, - 'method': method.upper(), - 'summary': details.get('summary'), - 'description': details.get('description'), - 'tags': tags, - 'operation_id': operation_id, - 'score': match_score, - 'matches': matches + "name": name, + "description": schema.get("description", ""), + "score": match_score, }) - results.sort(key=lambda x: x['score'], reverse=True) - return results[:limit] + return sorted(results, key=lambda x: -x["score"])[:limit] + + def get_tag_groups(self) -> Dict[str, List[Dict[str, Any]]]: + """Group endpoints by tags.""" + groups: Dict[str, List[Dict[str, Any]]] = {} + + for idx, endpoint in enumerate(self.endpoints): + tags = endpoint.get("tags", []) + if not tags: + tags = ["untagged"] + + for tag in tags: + if tag not in groups: + groups[tag] = [] + groups[tag].append({ + "path": endpoint["path"], + "method": endpoint["method"], + "summary": endpoint.get("summary", ""), + }) + + return groups