Files
errorfix-cli/errorfix/rules/loader.py
2026-02-01 03:56:22 +00:00

110 lines
3.5 KiB
Python

import os
import json
from pathlib import Path
from typing import List, Optional, Union
import yaml
from .rule import Rule
from .validator import RuleValidator
class RuleLoader:
def __init__(self, validator: Optional[RuleValidator] = None):
self.validator = validator or RuleValidator()
self._cache: dict = {}
def load_yaml(self, file_path: str) -> List[Rule]:
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"Rule file not found: {file_path}")
with open(path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
if not data:
return []
rules = []
items = data if isinstance(data, list) else [data]
for item in items:
self.validator.validate(item)
rule = Rule.from_dict(item)
rules.append(rule)
return rules
def load_json(self, file_path: str) -> List[Rule]:
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"Rule file not found: {file_path}")
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
if not data:
return []
rules = []
items = data if isinstance(data, list) else [data]
for item in items:
self.validator.validate(item)
rule = Rule.from_dict(item)
rules.append(rule)
return rules
def load_directory(self, directory: str, recursive: bool = True) -> List[Rule]:
path = Path(directory)
if not path.exists() or not path.is_dir():
raise NotADirectoryError(f"Directory not found: {directory}")
rules = []
pattern = '**/*.yaml' if recursive else '*.yaml'
yaml_files = list(path.glob(pattern))
pattern = '**/*.yml' if recursive else '*.yml'
yml_files = list(path.glob(pattern))
pattern = '**/*.json' if recursive else '*.json'
json_files = list(path.glob(pattern))
for file_path in yaml_files + yml_files + json_files:
try:
if file_path.suffix in ['.yaml', '.yml']:
file_rules = self.load_yaml(str(file_path))
else:
file_rules = self.load_json(str(file_path))
rules.extend(file_rules)
except Exception as e:
print(f"Warning: Failed to load rules from {file_path}: {e}")
return rules
def load_multiple(self, sources: List[str]) -> List[Rule]:
all_rules = []
for source in sources:
if os.path.isfile(source):
if source.endswith(('.yaml', '.yml')):
all_rules.extend(self.load_yaml(source))
elif source.endswith('.json'):
all_rules.extend(self.load_json(source))
elif os.path.isdir(source):
all_rules.extend(self.load_directory(source))
return all_rules
def filter_rules(
self,
rules: List[Rule],
language: Optional[str] = None,
tool: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> List[Rule]:
filtered = rules
if language:
filtered = [r for r in filtered if r.language is None or r.language == language]
if tool:
filtered = [r for r in filtered if r.tool is None or r.tool == tool]
if tags:
filtered = [r for r in filtered if any(t in r.tags for t in tags)]
return sorted(filtered, key=lambda r: r.priority, reverse=True)