diff --git a/tests/test_scanners.py b/tests/test_scanners.py new file mode 100644 index 0000000..86ca7d9 --- /dev/null +++ b/tests/test_scanners.py @@ -0,0 +1,181 @@ +"""Tests for scanner modules.""" + +import pytest +from pathlib import Path + +from src.core.models import Issue, IssueCategory, SeverityLevel +from src.scanners.bandit_scanner import BanditScanner +from src.scanners.ruff_scanner import RuffScanner +from src.scanners.tree_sitter_scanner import TreeSitterScanner + + +class TestBanditScanner: + """Tests for BanditScanner.""" + + def test_scan_content_returns_list(self, sample_python_code): + """Test that scan_content returns a list of issues.""" + scanner = BanditScanner() + issues = scanner.scan_content(sample_python_code, "/test/file.py") + assert isinstance(issues, list) + + def test_scan_content_handles_empty_code(self): + """Test that scan_content handles empty code.""" + scanner = BanditScanner() + issues = scanner.scan_content("", "/test/file.py") + assert isinstance(issues, list) + + def test_scan_content_handles_invalid_code(self): + """Test that scan_content handles invalid code gracefully.""" + scanner = BanditScanner() + issues = scanner.scan_content("not valid python", "/test/file.py") + assert isinstance(issues, list) + + def test_scan_file_nonexistent(self): + """Test that scan_file handles nonexistent files.""" + scanner = BanditScanner() + issues = scanner.scan_file("/nonexistent/file.py") + assert issues == [] + + def test_get_plugin_info(self): + """Test that get_plugin_info returns valid info.""" + scanner = BanditScanner() + info = scanner.get_plugin_info() + assert "name" in info + assert "version" in info + assert info["name"] == "bandit" + + +class TestRuffScanner: + """Tests for RuffScanner.""" + + def test_scan_content_returns_list(self, clean_python_code): + """Test that scan_content returns a list of issues.""" + scanner = RuffScanner() + issues = scanner.scan_content(clean_python_code, "/test/file.py", "python") + assert isinstance(issues, list) + + def test_scan_content_handles_empty_code(self): + """Test that scan_content handles empty code.""" + scanner = RuffScanner() + issues = scanner.scan_content("", "/test/file.py", "python") + assert isinstance(issues, list) + + def test_scan_content_handles_unparseable_code(self): + """Test that scan_content handles unparseable code.""" + scanner = RuffScanner() + issues = scanner.scan_content("def (", "/test/file.py", "python") + assert isinstance(issues, list) + + def test_detect_language_python(self): + """Test language detection for Python files.""" + scanner = RuffScanner() + assert scanner._detect_language("/test/file.py") == "python" + + def test_detect_language_javascript(self): + """Test language detection for JavaScript files.""" + scanner = RuffScanner() + assert scanner._detect_language("/test/file.js") == "javascript" + + def test_detect_language_typescript(self): + """Test language detection for TypeScript files.""" + scanner = RuffScanner() + assert scanner._detect_language("/test/file.ts") == "typescript" + + def test_get_plugin_info(self): + """Test that get_plugin_info returns valid info.""" + scanner = RuffScanner() + info = scanner.get_plugin_info() + assert "name" in info + assert info["name"] == "ruff" + + +class TestTreeSitterScanner: + """Tests for TreeSitterScanner.""" + + def test_scan_content_returns_list(self, sample_python_code): + """Test that scan_content returns a list of issues.""" + scanner = TreeSitterScanner() + issues = scanner.scan_content(sample_python_code, "/test/file.py", "python") + assert isinstance(issues, list) + + def test_scan_python_finds_credentials(self): + """Test that scanner finds hardcoded credentials in Python.""" + code = 'api_key = "sk-1234567890abcdefghijklmnop"' + scanner = TreeSitterScanner() + issues = scanner.scan_content(code, "/test/file.py", "python") + + credential_issues = [i for i in issues if i.category == IssueCategory.SECURITY] + assert len(credential_issues) > 0 + + def test_scan_python_finds_mutable_defaults(self): + """Test that scanner finds mutable default arguments.""" + code = "def func(items=[]):\n pass" + scanner = TreeSitterScanner() + issues = scanner.scan_content(code, "/test/file.py", "python") + + anti_pattern_issues = [i for i in issues if i.category == IssueCategory.ANTI_PATTERN] + assert len(anti_pattern_issues) > 0 + + def test_scan_python_finds_bare_except(self): + """Test that scanner finds bare except clauses.""" + code = "try:\n pass\nexcept:\n pass" + scanner = TreeSitterScanner() + issues = scanner.scan_content(code, "/test/file.py", "python") + + error_handling_issues = [i for i in issues if i.category == IssueCategory.ERROR_HANDLING] + assert len(error_handling_issues) > 0 + + def test_scan_js_finds_credentials(self): + """Test that scanner finds credentials in JavaScript.""" + code = 'const apiKey = "sk_live_1234567890abcdefghijklmnopqrstuvwxyz";' + scanner = TreeSitterScanner() + issues = scanner.scan_content(code, "/test/file.js", "javascript") + + security_issues = [i for i in issues if i.category == IssueCategory.SECURITY] + assert len(security_issues) > 0 + + def test_scan_sql_injection_patterns(self): + """Test that scanner finds SQL injection patterns.""" + code = 'result = db.execute(f"SELECT * FROM users WHERE name = {user_input}")' + scanner = TreeSitterScanner() + issues = scanner.scan_content(code, "/test/file.py", "python") + + security_issues = [i for i in issues if i.category == IssueCategory.SECURITY] + assert len(security_issues) > 0 + + def test_scan_command_injection_patterns(self): + """Test that scanner finds command injection patterns.""" + code = 'os.system(f"echo {user_input}")' + scanner = TreeSitterScanner() + issues = scanner.scan_content(code, "/test/file.py", "python") + + security_issues = [i for i in issues if i.category == IssueCategory.SECURITY] + assert len(security_issues) > 0 + + def test_false_positive_filtering(self): + """Test that scanner filters false positives.""" + code = '# TODO: Replace "sk-test-EXAMPLE" with real key' + scanner = TreeSitterScanner() + issues = scanner.scan_content(code, "/test/file.py", "python") + + assert len(issues) == 0 + + def test_get_plugin_info(self): + """Test that get_plugin_info returns valid info.""" + scanner = TreeSitterScanner() + info = scanner.get_plugin_info() + assert "name" in info + assert info["name"] == "tree-sitter" + + def test_is_comment_detection(self): + """Test comment detection.""" + scanner = TreeSitterScanner() + assert scanner._is_comment("# This is a comment") is True + assert scanner._is_comment("code = 1") is False + + def test_language_detection(self): + """Test language detection from file path.""" + scanner = TreeSitterScanner() + assert scanner._detect_language("/test/file.py") == "python" + assert scanner._detect_language("/test/file.js") == "javascript" + assert scanner._detect_language("/test/file.ts") == "typescript"