diff --git a/src/rules/security.py b/src/rules/security.py new file mode 100644 index 0000000..6c31ead --- /dev/null +++ b/src/rules/security.py @@ -0,0 +1,289 @@ +"""Security vulnerability detection rules.""" + +import re +from pathlib import Path +from typing import Optional +import tree_sitter + +from src.analyzers.base import ( + Analyzer, + Finding, + FindingCategory, + SeverityLevel, +) + + +class SQLInjectionAnalyzer(Analyzer): + """Detect SQL injection vulnerabilities.""" + + DANGEROUS_FUNCTIONS = { + "execute": "cursor.execute", + "executemany": "cursor.executemany", + "raw_query": "db.raw_query", + "query": "db.query", + } + + SQL_KEYWORDS = { + "select", "insert", "update", "delete", "drop", "create", "alter", + "union", "and", "or", "where", "from", "table", "database", + } + + def rule_id(self) -> str: + return "security.sql_injection" + + def rule_name(self) -> str: + return "SQL Injection Detection" + + def severity(self) -> SeverityLevel: + return SeverityLevel.CRITICAL + + def category(self) -> FindingCategory: + return FindingCategory.SECURITY + + def analyze( + self, source_code: str, file_path: Path, tree: tree_sitter.Tree + ) -> list[Finding]: + findings = [] + calls = self._get_calls(tree.root_node) + + for call in calls: + func_name = self._get_function_name(call) + if func_name in self.DANGEROUS_FUNCTIONS: + args = self._get_arguments(call) + for i, arg in enumerate(args): + if self._is_dangerous_sql(arg, source_code): + line = self._get_line_number(call, source_code) + findings.append( + Finding( + rule_id=self.rule_id(), + rule_name=self.rule_name(), + severity=self.severity(), + category=self.category(), + message="Potential SQL injection detected", + suggestion="Use parameterized queries instead of string formatting", + file_path=file_path, + line_number=line, + column=self._get_column(call), + node=call, + ) + ) + return findings + + def _get_calls(self, node: tree_sitter.Node) -> list[tree_sitter.Node]: + calls = [] + if hasattr(node, "type") and node.type == "call": + calls.append(node) + if hasattr(node, "children"): + for child in node.children: + calls.extend(self._get_calls(child)) + return calls + + def _get_function_name(self, call: tree_sitter.Node) -> str: + if hasattr(call, "children") and len(call.children) > 0: + func = call.children[0] + if hasattr(func, "text"): + return func.text.decode() if isinstance(func.text, bytes) else str(func.text) + return "" + + def _get_arguments(self, call: tree_sitter.Node) -> list[str]: + args = [] + if hasattr(call, "children"): + for child in call.children: + if hasattr(child, "text"): + text = child.text + args.append(text.decode() if isinstance(text, bytes) else str(text)) + return args + + def _is_dangerous_sql(self, arg: str, source_code: str) -> bool: + lower_arg = arg.lower() + has_concatenation = "+" in arg or "%" in arg or "f\"" in arg or "'" in arg + has_sql_keyword = any(kw in lower_arg for kw in self.SQL_KEYWORDS) + return has_concatenation and has_sql_keyword + + def _get_line_number(self, node: tree_sitter.Node, source_code: str) -> int: + lines = source_code.split("\n") + start_byte = node.start_byte if hasattr(node, "start_byte") else 0 + line = 1 + pos = 0 + for line_num, line_text in enumerate(lines, 1): + if pos + len(line_text) >= start_byte: + return line_num + pos += len(line_text) + 1 + return 1 + + def _get_column(self, node: tree_sitter.Node) -> int: + return node.start_column if hasattr(node, "start_column") else 0 + + +class EvalUsageAnalyzer(Analyzer): + """Detect eval() and exec() usage.""" + + DANGEROUS_FUNCTIONS = {"eval", "exec", "execfile", "compile"} + + def rule_id(self) -> str: + return "security.eval_usage" + + def rule_name(self) -> str: + return "Eval/Exec Usage Detection" + + def severity(self) -> SeverityLevel: + return SeverityLevel.CRITICAL + + def category(self) -> FindingCategory: + return FindingCategory.SECURITY + + def analyze( + self, source_code: str, file_path: Path, tree: tree_sitter.Tree + ) -> list[Finding]: + findings = [] + calls = self._get_calls(tree.root_node) + + for call in calls: + func_name = self._get_function_name(call) + if func_name in self.DANGEROUS_FUNCTIONS: + line = self._get_line_number(call, source_code) + findings.append( + Finding( + rule_id=self.rule_id(), + rule_name=self.rule_name(), + severity=self.severity(), + category=self.category(), + message=f"Dangerous {func_name}() call detected", + suggestion="Avoid using eval/exec as they can execute arbitrary code", + file_path=file_path, + line_number=line, + column=self._get_column(call), + node=call, + ) + ) + return findings + + def _get_calls(self, node: tree_sitter.Node) -> list[tree_sitter.Node]: + calls = [] + if hasattr(node, "type") and node.type == "call": + calls.append(node) + if hasattr(node, "children"): + for child in node.children: + calls.extend(self._get_calls(child)) + return calls + + def _get_function_name(self, call: tree_sitter.Node) -> str: + if hasattr(call, "children") and len(call.children) > 0: + func = call.children[0] + if hasattr(func, "text"): + text = func.text + return text.decode() if isinstance(text, bytes) else str(text) + return "" + + def _get_line_number(self, node: tree_sitter.Node, source_code: str) -> int: + lines = source_code.split("\n") + start_byte = node.start_byte if hasattr(node, "start_byte") else 0 + pos = 0 + for line_num, line_text in enumerate(lines, 1): + if pos + len(line_text) >= start_byte: + return line_num + pos += len(line_text) + 1 + return 1 + + def _get_column(self, node: tree_sitter.Node) -> int: + return node.start_column if hasattr(node, "start_column") else 0 + + +class PathTraversalAnalyzer(Analyzer): + """Detect path traversal vulnerabilities.""" + + DANGEROUS_PATTERNS = [ + r"\.\./", + r"\.\.\\", + r"join\s*\(\s*[\"'].*\.\.[\"']", + ] + + VULNERABLE_FUNCTIONS = { + "open": "open()", + "file": "file()", + "os.path.join": "os.path.join()", + "Path": "Path()", + } + + def rule_id(self) -> str: + return "security.path_traversal" + + def rule_name(self) -> str: + return "Path Traversal Detection" + + def severity(self) -> SeverityLevel: + return SeverityLevel.HIGH + + def category(self) -> FindingCategory: + return FindingCategory.SECURITY + + def analyze( + self, source_code: str, file_path: Path, tree: tree_sitter.Tree + ) -> list[Finding]: + findings = [] + calls = self._get_calls(tree.root_node) + + for call in calls: + func_name = self._get_function_name(call) + if func_name in self.VULNERABLE_FUNCTIONS: + args = self._get_arguments(call) + for arg in args: + if self._contains_path_traversal(arg): + line = self._get_line_number(call, source_code) + findings.append( + Finding( + rule_id=self.rule_id(), + rule_name=self.rule_name(), + severity=self.severity(), + category=self.category(), + message="Potential path traversal detected", + suggestion="Validate and sanitize file paths, use os.path.abspath()", + file_path=file_path, + line_number=line, + column=self._get_column(call), + node=call, + ) + ) + return findings + + def _get_calls(self, node: tree_sitter.Node) -> list[tree_sitter.Node]: + calls = [] + if hasattr(node, "type") and node.type == "call": + calls.append(node) + if hasattr(node, "children"): + for child in node.children: + calls.extend(self._get_calls(child)) + return calls + + def _get_function_name(self, call: tree_sitter.Node) -> str: + if hasattr(call, "children") and len(call.children) > 0: + func = call.children[0] + if hasattr(func, "text"): + text = func.text + return text.decode() if isinstance(text, bytes) else str(text) + return "" + + def _get_arguments(self, call: tree_sitter.Node) -> list[str]: + args = [] + if hasattr(call, "children"): + for child in call.children: + if hasattr(child, "text"): + text = child.text + args.append(text.decode() if isinstance(text, bytes) else str(text)) + return args + + def _contains_path_traversal(self, arg: str) -> bool: + return any(re.search(pattern, arg) for pattern in self.DANGEROUS_PATTERNS) + + def _get_line_number(self, node: tree_sitter.Node, source_code: str) -> int: + lines = source_code.split("\n") + start_byte = node.start_byte if hasattr(node, "start_byte") else 0 + pos = 0 + for line_num, line_text in enumerate(lines, 1): + if pos + len(line_text) >= start_byte: + return line_num + pos += len(line_text) + 1 + return 1 + + def _get_column(self, node: tree_sitter.Node) -> int: + return node.start_column if hasattr(node, "start_column") else 0