Add rules module: security and antipattern detection rules
Some checks failed
CI / test (push) Has been cancelled
CI / build (push) Has been cancelled

This commit is contained in:
2026-01-29 23:09:44 +00:00
parent a95d298ad8
commit d32059b3fc

289
src/rules/security.py Normal file
View File

@@ -0,0 +1,289 @@
"""Security vulnerability detection rules."""
import re
from pathlib import Path
from typing import Optional
import tree_sitter
from src.analyzers.base import (
Analyzer,
Finding,
FindingCategory,
SeverityLevel,
)
class SQLInjectionAnalyzer(Analyzer):
"""Detect SQL injection vulnerabilities."""
DANGEROUS_FUNCTIONS = {
"execute": "cursor.execute",
"executemany": "cursor.executemany",
"raw_query": "db.raw_query",
"query": "db.query",
}
SQL_KEYWORDS = {
"select", "insert", "update", "delete", "drop", "create", "alter",
"union", "and", "or", "where", "from", "table", "database",
}
def rule_id(self) -> str:
return "security.sql_injection"
def rule_name(self) -> str:
return "SQL Injection Detection"
def severity(self) -> SeverityLevel:
return SeverityLevel.CRITICAL
def category(self) -> FindingCategory:
return FindingCategory.SECURITY
def analyze(
self, source_code: str, file_path: Path, tree: tree_sitter.Tree
) -> list[Finding]:
findings = []
calls = self._get_calls(tree.root_node)
for call in calls:
func_name = self._get_function_name(call)
if func_name in self.DANGEROUS_FUNCTIONS:
args = self._get_arguments(call)
for i, arg in enumerate(args):
if self._is_dangerous_sql(arg, source_code):
line = self._get_line_number(call, source_code)
findings.append(
Finding(
rule_id=self.rule_id(),
rule_name=self.rule_name(),
severity=self.severity(),
category=self.category(),
message="Potential SQL injection detected",
suggestion="Use parameterized queries instead of string formatting",
file_path=file_path,
line_number=line,
column=self._get_column(call),
node=call,
)
)
return findings
def _get_calls(self, node: tree_sitter.Node) -> list[tree_sitter.Node]:
calls = []
if hasattr(node, "type") and node.type == "call":
calls.append(node)
if hasattr(node, "children"):
for child in node.children:
calls.extend(self._get_calls(child))
return calls
def _get_function_name(self, call: tree_sitter.Node) -> str:
if hasattr(call, "children") and len(call.children) > 0:
func = call.children[0]
if hasattr(func, "text"):
return func.text.decode() if isinstance(func.text, bytes) else str(func.text)
return ""
def _get_arguments(self, call: tree_sitter.Node) -> list[str]:
args = []
if hasattr(call, "children"):
for child in call.children:
if hasattr(child, "text"):
text = child.text
args.append(text.decode() if isinstance(text, bytes) else str(text))
return args
def _is_dangerous_sql(self, arg: str, source_code: str) -> bool:
lower_arg = arg.lower()
has_concatenation = "+" in arg or "%" in arg or "f\"" in arg or "'" in arg
has_sql_keyword = any(kw in lower_arg for kw in self.SQL_KEYWORDS)
return has_concatenation and has_sql_keyword
def _get_line_number(self, node: tree_sitter.Node, source_code: str) -> int:
lines = source_code.split("\n")
start_byte = node.start_byte if hasattr(node, "start_byte") else 0
line = 1
pos = 0
for line_num, line_text in enumerate(lines, 1):
if pos + len(line_text) >= start_byte:
return line_num
pos += len(line_text) + 1
return 1
def _get_column(self, node: tree_sitter.Node) -> int:
return node.start_column if hasattr(node, "start_column") else 0
class EvalUsageAnalyzer(Analyzer):
"""Detect eval() and exec() usage."""
DANGEROUS_FUNCTIONS = {"eval", "exec", "execfile", "compile"}
def rule_id(self) -> str:
return "security.eval_usage"
def rule_name(self) -> str:
return "Eval/Exec Usage Detection"
def severity(self) -> SeverityLevel:
return SeverityLevel.CRITICAL
def category(self) -> FindingCategory:
return FindingCategory.SECURITY
def analyze(
self, source_code: str, file_path: Path, tree: tree_sitter.Tree
) -> list[Finding]:
findings = []
calls = self._get_calls(tree.root_node)
for call in calls:
func_name = self._get_function_name(call)
if func_name in self.DANGEROUS_FUNCTIONS:
line = self._get_line_number(call, source_code)
findings.append(
Finding(
rule_id=self.rule_id(),
rule_name=self.rule_name(),
severity=self.severity(),
category=self.category(),
message=f"Dangerous {func_name}() call detected",
suggestion="Avoid using eval/exec as they can execute arbitrary code",
file_path=file_path,
line_number=line,
column=self._get_column(call),
node=call,
)
)
return findings
def _get_calls(self, node: tree_sitter.Node) -> list[tree_sitter.Node]:
calls = []
if hasattr(node, "type") and node.type == "call":
calls.append(node)
if hasattr(node, "children"):
for child in node.children:
calls.extend(self._get_calls(child))
return calls
def _get_function_name(self, call: tree_sitter.Node) -> str:
if hasattr(call, "children") and len(call.children) > 0:
func = call.children[0]
if hasattr(func, "text"):
text = func.text
return text.decode() if isinstance(text, bytes) else str(text)
return ""
def _get_line_number(self, node: tree_sitter.Node, source_code: str) -> int:
lines = source_code.split("\n")
start_byte = node.start_byte if hasattr(node, "start_byte") else 0
pos = 0
for line_num, line_text in enumerate(lines, 1):
if pos + len(line_text) >= start_byte:
return line_num
pos += len(line_text) + 1
return 1
def _get_column(self, node: tree_sitter.Node) -> int:
return node.start_column if hasattr(node, "start_column") else 0
class PathTraversalAnalyzer(Analyzer):
"""Detect path traversal vulnerabilities."""
DANGEROUS_PATTERNS = [
r"\.\./",
r"\.\.\\",
r"join\s*\(\s*[\"'].*\.\.[\"']",
]
VULNERABLE_FUNCTIONS = {
"open": "open()",
"file": "file()",
"os.path.join": "os.path.join()",
"Path": "Path()",
}
def rule_id(self) -> str:
return "security.path_traversal"
def rule_name(self) -> str:
return "Path Traversal Detection"
def severity(self) -> SeverityLevel:
return SeverityLevel.HIGH
def category(self) -> FindingCategory:
return FindingCategory.SECURITY
def analyze(
self, source_code: str, file_path: Path, tree: tree_sitter.Tree
) -> list[Finding]:
findings = []
calls = self._get_calls(tree.root_node)
for call in calls:
func_name = self._get_function_name(call)
if func_name in self.VULNERABLE_FUNCTIONS:
args = self._get_arguments(call)
for arg in args:
if self._contains_path_traversal(arg):
line = self._get_line_number(call, source_code)
findings.append(
Finding(
rule_id=self.rule_id(),
rule_name=self.rule_name(),
severity=self.severity(),
category=self.category(),
message="Potential path traversal detected",
suggestion="Validate and sanitize file paths, use os.path.abspath()",
file_path=file_path,
line_number=line,
column=self._get_column(call),
node=call,
)
)
return findings
def _get_calls(self, node: tree_sitter.Node) -> list[tree_sitter.Node]:
calls = []
if hasattr(node, "type") and node.type == "call":
calls.append(node)
if hasattr(node, "children"):
for child in node.children:
calls.extend(self._get_calls(child))
return calls
def _get_function_name(self, call: tree_sitter.Node) -> str:
if hasattr(call, "children") and len(call.children) > 0:
func = call.children[0]
if hasattr(func, "text"):
text = func.text
return text.decode() if isinstance(text, bytes) else str(text)
return ""
def _get_arguments(self, call: tree_sitter.Node) -> list[str]:
args = []
if hasattr(call, "children"):
for child in call.children:
if hasattr(child, "text"):
text = child.text
args.append(text.decode() if isinstance(text, bytes) else str(text))
return args
def _contains_path_traversal(self, arg: str) -> bool:
return any(re.search(pattern, arg) for pattern in self.DANGEROUS_PATTERNS)
def _get_line_number(self, node: tree_sitter.Node, source_code: str) -> int:
lines = source_code.split("\n")
start_byte = node.start_byte if hasattr(node, "start_byte") else 0
pos = 0
for line_num, line_text in enumerate(lines, 1):
if pos + len(line_text) >= start_byte:
return line_num
pos += len(line_text) + 1
return 1
def _get_column(self, node: tree_sitter.Node) -> int:
return node.start_column if hasattr(node, "start_column") else 0