This commit is contained in:
@@ -1,256 +1,127 @@
|
||||
"""Search functionality for API documentation."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""A single search result."""
|
||||
|
||||
endpoint_path: str
|
||||
endpoint_method: str
|
||||
title: str
|
||||
description: str
|
||||
tags: List[str]
|
||||
matched_fields: List[str]
|
||||
path: str
|
||||
method: str
|
||||
operation_id: str | None
|
||||
summary: str | None
|
||||
description: str | None
|
||||
tags: list[str]
|
||||
matched_terms: list[str]
|
||||
score: float
|
||||
snippet: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchIndex:
|
||||
"""Full-text search index for API endpoints."""
|
||||
paths: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
schemas: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
endpoints: List[Dict[str, Any]] = field(default_factory=list)
|
||||
schemas: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
tags: List[Dict[str, str]] = field(default_factory=list)
|
||||
_index: Dict[str, List[int]] = field(default_factory=dict)
|
||||
def add_path(self, path: str, methods: dict[str, Any]) -> None:
|
||||
self.paths[path] = methods
|
||||
|
||||
def build(self, spec_data: Dict[str, Any]) -> None:
|
||||
"""Build the search index from OpenAPI spec data."""
|
||||
self.endpoints = []
|
||||
self.schemas = {}
|
||||
self.tags = spec_data.get("tags", [])
|
||||
self._index = {}
|
||||
def add_schema(self, name: str, schema: dict[str, Any]) -> None:
|
||||
self.schemas[name] = schema
|
||||
|
||||
for path, methods in spec_data.get("paths", {}).items():
|
||||
for method, details in methods.items():
|
||||
if method in ["get", "post", "put", "patch", "delete", "options", "head", "trace"]:
|
||||
if hasattr(details, "model_dump"):
|
||||
details = details.model_dump()
|
||||
endpoint = {
|
||||
"path": path,
|
||||
"method": method.upper(),
|
||||
"operationId": details.get("operationId"),
|
||||
"summary": details.get("summary", ""),
|
||||
"description": details.get("description", ""),
|
||||
"tags": details.get("tags", []),
|
||||
"parameters": details.get("parameters", []),
|
||||
"requestBody": details.get("requestBody"),
|
||||
"responses": details.get("responses", {}),
|
||||
"deprecated": details.get("deprecated", False),
|
||||
}
|
||||
self.endpoints.append(endpoint)
|
||||
self._index_endpoint(len(self.endpoints) - 1, endpoint)
|
||||
def add_tag(self, tag: str) -> None:
|
||||
if tag not in self.tags:
|
||||
self.tags.append(tag)
|
||||
|
||||
components = spec_data.get("components") or {}
|
||||
for schema_name, schema_def in components.get("schemas", {}).items():
|
||||
self.schemas[schema_name] = {
|
||||
"name": schema_name,
|
||||
"description": schema_def.get("description", ""),
|
||||
"properties": schema_def.get("properties", {}),
|
||||
}
|
||||
self._index_schema(schema_name, self.schemas[schema_name])
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Tokenize text into searchable terms."""
|
||||
if not text:
|
||||
return []
|
||||
text = text.lower()
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
tokens = text.split()
|
||||
return [t for t in tokens if len(t) > 1]
|
||||
def create_search_index(spec: dict[str, Any]) -> SearchIndex:
|
||||
index = SearchIndex()
|
||||
for tag in spec.get("tags", []):
|
||||
if isinstance(tag, dict):
|
||||
index.add_tag(tag.get("name", ""))
|
||||
else:
|
||||
index.add_tag(tag)
|
||||
for path, path_item in spec.get("paths", {}).items():
|
||||
if hasattr(path_item, 'model_dump'):
|
||||
path_item = path_item.model_dump()
|
||||
methods = {}
|
||||
for method in ["get", "put", "post", "delete", "options", "head", "patch", "trace"]:
|
||||
if method in path_item and path_item[method]:
|
||||
op = path_item[method]
|
||||
methods[method] = {
|
||||
"summary": op.get("summary"),
|
||||
"description": op.get("description"),
|
||||
"operation_id": op.get("operationId"),
|
||||
"tags": op.get("tags", []),
|
||||
"parameters": op.get("parameters", []),
|
||||
"request_body": op.get("requestBody"),
|
||||
"responses": op.get("responses", {}),
|
||||
}
|
||||
index.add_path(path, methods)
|
||||
components = spec.get("components") or {}
|
||||
for name, schema in components.get("schemas", {}).items():
|
||||
index.add_schema(name, schema)
|
||||
return index
|
||||
|
||||
def _index_endpoint(self, idx: int, endpoint: Dict[str, Any]) -> None:
|
||||
"""Index a single endpoint."""
|
||||
terms = set()
|
||||
|
||||
for token in self._tokenize(endpoint["path"]):
|
||||
terms.add(token)
|
||||
|
||||
for token in self._tokenize(endpoint.get("summary", "")):
|
||||
terms.add(token)
|
||||
|
||||
for token in self._tokenize(endpoint.get("description", "")):
|
||||
terms.add(token)
|
||||
|
||||
for tag in endpoint.get("tags", []):
|
||||
for token in self._tokenize(tag):
|
||||
terms.add(token)
|
||||
|
||||
for token in self._tokenize(endpoint.get("operationId", "")):
|
||||
terms.add(token)
|
||||
|
||||
for term in terms:
|
||||
if term not in self._index:
|
||||
self._index[term] = []
|
||||
self._index[term].append(idx)
|
||||
|
||||
def _index_schema(self, name: str, schema: Dict[str, Any]) -> None:
|
||||
"""Index a schema definition."""
|
||||
terms = {name.lower()}
|
||||
|
||||
for token in self._tokenize(schema.get("description", "")):
|
||||
terms.add(token)
|
||||
|
||||
for prop_name in schema.get("properties", {}):
|
||||
terms.add(prop_name.lower())
|
||||
|
||||
for term in terms:
|
||||
if term not in self._index:
|
||||
self._index[term] = []
|
||||
self._index[term].append(f"schema:{name}")
|
||||
|
||||
def search(self, query: str, limit: int = 20) -> List[SearchResult]:
|
||||
"""
|
||||
Search the index for matching endpoints.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects sorted by relevance
|
||||
"""
|
||||
if not query:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
scores: Dict[int, Tuple[float, List[str]]] = {}
|
||||
|
||||
for token in query_tokens:
|
||||
matching_indices = self._index.get(token, [])
|
||||
for idx in matching_indices:
|
||||
if isinstance(idx, str) and idx.startswith("schema:"):
|
||||
continue
|
||||
|
||||
if idx not in scores:
|
||||
scores[idx] = (0.0, [])
|
||||
current_score, matched_fields = scores[idx]
|
||||
new_score = current_score + 1.0
|
||||
new_matched_fields = matched_fields + [token]
|
||||
scores[idx] = (new_score, new_matched_fields)
|
||||
|
||||
if not scores:
|
||||
return self._fuzzy_search(query_tokens, limit)
|
||||
|
||||
results = []
|
||||
for idx, (score, matched_fields) in sorted(
|
||||
scores.items(), key=lambda x: -x[1][0]
|
||||
)[:limit]:
|
||||
endpoint = self.endpoints[idx]
|
||||
snippet = self._create_snippet(endpoint, query_tokens)
|
||||
|
||||
result = SearchResult(
|
||||
endpoint_path=endpoint["path"],
|
||||
endpoint_method=endpoint["method"],
|
||||
title=endpoint.get("summary", endpoint["path"]),
|
||||
description=endpoint.get("description", ""),
|
||||
tags=endpoint.get("tags", []),
|
||||
matched_fields=list(set(matched_fields)),
|
||||
score=score,
|
||||
snippet=snippet,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _fuzzy_search(self, query_tokens: List[str], limit: int) -> List[SearchResult]:
|
||||
"""Perform fuzzy search when exact match fails."""
|
||||
results = []
|
||||
query = " ".join(query_tokens).lower()
|
||||
|
||||
for idx, endpoint in enumerate(self.endpoints):
|
||||
text = " ".join([
|
||||
endpoint["path"],
|
||||
endpoint.get("summary", ""),
|
||||
endpoint.get("description", ""),
|
||||
]).lower()
|
||||
|
||||
if query in text:
|
||||
def search_index(index: SearchIndex, query: str) -> list[SearchResult]:
|
||||
query_lower = query.lower()
|
||||
query_terms = re.findall(r'\w+', query_lower)
|
||||
results = []
|
||||
for path, methods in index.paths.items():
|
||||
for method, op_data in methods.items():
|
||||
score = 0.0
|
||||
matched_terms = []
|
||||
for term in query_terms:
|
||||
term_score = 0.0
|
||||
if term in path.lower():
|
||||
term_score += 5.0
|
||||
summary = op_data.get("summary", "") or ""
|
||||
if term in summary.lower():
|
||||
term_score += 3.0
|
||||
description = op_data.get("description", "") or ""
|
||||
if term in description.lower():
|
||||
term_score += 2.0
|
||||
operation_id = op_data.get("operation_id", "") or ""
|
||||
if term in operation_id.lower():
|
||||
term_score += 4.0
|
||||
for tag in op_data.get("tags", []):
|
||||
if term in tag.lower():
|
||||
term_score += 2.0
|
||||
if term_score > 0:
|
||||
score += term_score
|
||||
matched_terms.append(term)
|
||||
if score > 0:
|
||||
results.append(SearchResult(
|
||||
endpoint_path=endpoint["path"],
|
||||
endpoint_method=endpoint["method"],
|
||||
title=endpoint.get("summary", endpoint["path"]),
|
||||
description=endpoint.get("description", ""),
|
||||
tags=endpoint.get("tags", []),
|
||||
matched_fields=query_tokens,
|
||||
score=0.5,
|
||||
snippet=self._create_snippet(endpoint, query_tokens),
|
||||
path=path,
|
||||
method=method.upper(),
|
||||
operation_id=op_data.get("operation_id"),
|
||||
summary=op_data.get("summary"),
|
||||
description=op_data.get("description"),
|
||||
tags=op_data.get("tags", []),
|
||||
matched_terms=matched_terms,
|
||||
score=score,
|
||||
))
|
||||
|
||||
return sorted(results, key=lambda x: -x.score)[:limit]
|
||||
|
||||
def _create_snippet(
|
||||
self, endpoint: Dict[str, Any], query_tokens: List[str]
|
||||
) -> str:
|
||||
"""Create a snippet showing matched terms in context."""
|
||||
description = endpoint.get("description", "") or ""
|
||||
snippet = description[:150]
|
||||
if len(description) > 150:
|
||||
snippet += "..."
|
||||
|
||||
for token in query_tokens:
|
||||
pattern = re.compile(re.escape(token), re.IGNORECASE)
|
||||
snippet = pattern.sub(f"**{token.upper()}**", snippet)
|
||||
|
||||
return snippet
|
||||
|
||||
def search_schemas(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""Search for schemas matching the query."""
|
||||
query_tokens = self._tokenize(query)
|
||||
results = []
|
||||
|
||||
for name, schema in self.schemas.items():
|
||||
match_score = 0
|
||||
name_lower = name.lower()
|
||||
|
||||
for token in query_tokens:
|
||||
if token in name_lower:
|
||||
match_score += 2
|
||||
if token in schema.get("description", "").lower():
|
||||
match_score += 1
|
||||
|
||||
if match_score > 0:
|
||||
results.append({
|
||||
"name": name,
|
||||
"description": schema.get("description", ""),
|
||||
"score": match_score,
|
||||
})
|
||||
|
||||
return sorted(results, key=lambda x: -x["score"])[:limit]
|
||||
|
||||
def get_tag_groups(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Group endpoints by tags."""
|
||||
groups: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for idx, endpoint in enumerate(self.endpoints):
|
||||
tags = endpoint.get("tags", [])
|
||||
if not tags:
|
||||
tags = ["untagged"]
|
||||
|
||||
for tag in tags:
|
||||
if tag not in groups:
|
||||
groups[tag] = []
|
||||
groups[tag].append({
|
||||
"path": endpoint["path"],
|
||||
"method": endpoint["method"],
|
||||
"summary": endpoint.get("summary", ""),
|
||||
})
|
||||
|
||||
return groups
|
||||
for schema_name, schema in index.schemas.items():
|
||||
score = 0.0
|
||||
matched_terms = []
|
||||
for term in query_terms:
|
||||
term_score = 0.0
|
||||
if term in schema_name.lower():
|
||||
term_score += 3.0
|
||||
schema_desc = schema.get("description", "") or ""
|
||||
if term in schema_desc.lower():
|
||||
term_score += 2.0
|
||||
if term_score > 0:
|
||||
score += term_score
|
||||
matched_terms.append(term)
|
||||
if score > 0:
|
||||
results.append(SearchResult(
|
||||
path=f"#/components/schemas/{schema_name}",
|
||||
method="SCHEMA",
|
||||
operation_id=None,
|
||||
summary=schema_name,
|
||||
description=schema.get("description"),
|
||||
tags=[],
|
||||
matched_terms=matched_terms,
|
||||
score=score,
|
||||
))
|
||||
return sorted(results, key=lambda x: x.score, reverse=True)
|
||||
|
||||
Reference in New Issue
Block a user