381 lines
13 KiB
Python
381 lines
13 KiB
Python
import random
|
|
import string
|
|
from typing import Optional
|
|
from .parser import parse_regex, RegexNode, NodeType
|
|
|
|
|
|
class TestCaseGenerator:
|
|
"""Generates matching and non-matching test cases for regex patterns."""
|
|
|
|
def __init__(self, flavor: str = "pcre"):
|
|
self.flavor = flavor
|
|
|
|
def generate_matching(
|
|
self,
|
|
pattern: str,
|
|
count: int = 5,
|
|
max_length: int = 50
|
|
) -> list[str]:
|
|
"""Generate strings that match the pattern."""
|
|
try:
|
|
ast = parse_regex(pattern, self.flavor)
|
|
return self._generate_matching_from_ast(ast, count, max_length)
|
|
except Exception:
|
|
return self._generate_fallback_matching(pattern, count)
|
|
|
|
def _generate_matching_from_ast(
|
|
self,
|
|
node: RegexNode,
|
|
count: int,
|
|
max_length: int
|
|
) -> list[str]:
|
|
"""Generate matching strings from AST."""
|
|
if node.node_type == NodeType.SEQUENCE:
|
|
return self._generate_sequence(node.children, count, max_length)
|
|
return [pattern_to_string(node, max_length) for _ in range(count)]
|
|
|
|
def _generate_sequence(
|
|
self,
|
|
children: list[RegexNode],
|
|
count: int,
|
|
max_length: int
|
|
) -> list[str]:
|
|
"""Generate strings for a sequence of nodes."""
|
|
results = []
|
|
for _ in range(count):
|
|
parts = []
|
|
for child in children:
|
|
if len("".join(parts)) >= max_length:
|
|
break
|
|
part = generate_from_node(child, max_length - len("".join(parts)))
|
|
if part is None:
|
|
part = ""
|
|
parts.append(part)
|
|
results.append("".join(parts))
|
|
return results
|
|
|
|
def _generate_fallback_matching(
|
|
self,
|
|
pattern: str,
|
|
count: int
|
|
) -> list[str]:
|
|
"""Fallback matching generation using simple heuristics."""
|
|
results = []
|
|
for _ in range(count):
|
|
result = ""
|
|
in_class = False
|
|
class_chars = []
|
|
|
|
for char in pattern:
|
|
if char == '\\' and len(pattern) > 1:
|
|
next_char = pattern[pattern.index(char) + 1]
|
|
if next_char in 'dDsSwWbB':
|
|
if next_char == 'd':
|
|
result += random.choice(string.digits)
|
|
elif next_char == 'D':
|
|
result += random.choice(string.ascii_letters)
|
|
elif next_char == 'w':
|
|
result += random.choice(string.ascii_letters)
|
|
elif next_char == 'W':
|
|
result += random.choice(' !@#$%^&*()')
|
|
elif next_char == 's':
|
|
result += " "
|
|
elif next_char == 'b':
|
|
result += random.choice(string.ascii_letters)
|
|
else:
|
|
result += next_char
|
|
elif char == '.':
|
|
result += random.choice(string.ascii_letters)
|
|
elif char in '*+?':
|
|
continue
|
|
elif char == '[':
|
|
in_class = True
|
|
class_chars = []
|
|
elif char == ']':
|
|
in_class = False
|
|
if class_chars:
|
|
result += random.choice(class_chars)
|
|
elif in_class:
|
|
if char == '-' and class_chars:
|
|
pass
|
|
else:
|
|
class_chars.append(char)
|
|
elif char not in '()|^$\\{}':
|
|
result += char
|
|
|
|
if not result:
|
|
result = "test"
|
|
results.append(result[:20])
|
|
|
|
return results[:count]
|
|
|
|
def generate_non_matching(
|
|
self,
|
|
pattern: str,
|
|
count: int = 5,
|
|
max_length: int = 50
|
|
) -> list[str]:
|
|
"""Generate strings that do NOT match the pattern."""
|
|
try:
|
|
ast = parse_regex(pattern, self.flavor)
|
|
return self._generate_non_matching_from_ast(pattern, ast, count, max_length)
|
|
except Exception:
|
|
return self._generate_fallback_non_matching(pattern, count)
|
|
|
|
def _generate_non_matching_from_ast(
|
|
self,
|
|
pattern: str,
|
|
node: RegexNode,
|
|
count: int,
|
|
max_length: int
|
|
) -> list[str]:
|
|
"""Generate non-matching strings from AST."""
|
|
results = set()
|
|
|
|
if node.node_type == NodeType.ANCHOR_START:
|
|
return [s + "prefix" for s in results] or ["prefix_test"]
|
|
|
|
if node.node_type == NodeType.ANCHOR_END:
|
|
return ["suffix" + s for s in results] or ["test_suffix"]
|
|
|
|
if node.node_type == NodeType.START_OF_STRING:
|
|
return ["prefix" + s for s in results] or ["prefix_test"]
|
|
|
|
if node.node_type == NodeType.END_OF_STRING:
|
|
return [s + "suffix" for s in results] or ["test_suffix"]
|
|
|
|
base_matching = self._generate_matching_from_ast(node, 10, max_length)
|
|
|
|
for matching in base_matching:
|
|
if len(results) >= count:
|
|
break
|
|
|
|
if len(matching) > 0:
|
|
pos = random.randint(0, len(matching) - 1)
|
|
original = matching[pos]
|
|
replacement = get_replacement_char(original)
|
|
if replacement != original:
|
|
non_match = matching[:pos] + replacement + matching[pos + 1:]
|
|
if not matches_pattern(pattern, non_match, self.flavor):
|
|
results.add(non_match)
|
|
|
|
if len(results) < count and matching:
|
|
pos = random.randint(0, len(matching))
|
|
char_to_add = get_opposite_char_class(matching[pos - 1] if pos > 0 else 'a')
|
|
non_match = matching[:pos] + char_to_add + matching[pos:]
|
|
if not matches_pattern(pattern, non_match, self.flavor):
|
|
results.add(non_match)
|
|
|
|
if len(results) < count:
|
|
for _ in range(count - len(results)):
|
|
base = self._generate_fallback_non_matching(pattern, 1)[0] if self._generate_fallback_non_matching(pattern, 1) else "does_not_match_123"
|
|
results.add(base + str(random.randint(100, 999)))
|
|
|
|
return list(results)[:count]
|
|
|
|
def _generate_fallback_non_matching(
|
|
self,
|
|
pattern: str,
|
|
count: int
|
|
) -> list[str]:
|
|
"""Fallback non-matching generation."""
|
|
results = ["does_not_match", "completely_different", "!@#$%^&*()", "", "xyz123"]
|
|
|
|
if pattern.startswith('^'):
|
|
results.append("prefix_" + results[0])
|
|
|
|
if pattern.endswith('$'):
|
|
results.append(results[0] + "_suffix")
|
|
|
|
if '\\d' in pattern or '[0-9]' in pattern:
|
|
results.append("abc_def")
|
|
|
|
if '\\w' in pattern:
|
|
results.append("!@#$%^&*")
|
|
|
|
if '\\s' in pattern:
|
|
results.append("nospacehere")
|
|
|
|
dot_count = pattern.count('.')
|
|
if dot_count > 0:
|
|
results.append("x" * (dot_count + 1))
|
|
|
|
import re
|
|
try:
|
|
compiled = re.compile(pattern)
|
|
filtered_results = []
|
|
for r in results:
|
|
if compiled.search(r) is None:
|
|
filtered_results.append(r)
|
|
if filtered_results:
|
|
return filtered_results[:count]
|
|
except re.error:
|
|
pass
|
|
|
|
return results[:count]
|
|
|
|
|
|
def generate_from_node(node: RegexNode, max_length: int) -> Optional[str]:
|
|
"""Generate a string from a single node."""
|
|
if node.node_type == NodeType.LITERAL:
|
|
return node.value[:max_length] if node.value else None
|
|
|
|
if node.node_type == NodeType.ESCAPED_CHAR:
|
|
return node.value if node.value else None
|
|
|
|
if node.node_type == NodeType.DOT:
|
|
return random.choice(string.ascii_letters)
|
|
|
|
if node.node_type in (NodeType.POSITIVE_SET, NodeType.NEGATIVE_SET):
|
|
if node.node_type == NodeType.NEGATIVE_SET:
|
|
all_chars = []
|
|
for start, end in node.ranges:
|
|
all_chars.extend([chr(i) for i in range(ord(start), ord(end) + 1)])
|
|
all_chars.extend(node.characters)
|
|
available = [c for c in string.ascii_letters if c not in all_chars]
|
|
if available:
|
|
return random.choice(available)
|
|
return "!"
|
|
if node.ranges:
|
|
start, end = node.ranges[0]
|
|
return chr(random.randint(ord(start), ord(end)))
|
|
if node.characters:
|
|
return random.choice(node.characters)
|
|
return "a"
|
|
|
|
if node.node_type in (NodeType.DIGIT, NodeType.NON_DIGIT):
|
|
return random.choice(string.digits)
|
|
|
|
if node.node_type in (NodeType.WORD_CHAR, NodeType.NON_WORD_CHAR):
|
|
return random.choice(string.ascii_letters)
|
|
|
|
if node.node_type in (NodeType.WHITESPACE, NodeType.NON_WHITESPACE):
|
|
return " "
|
|
|
|
if node.node_type == NodeType.QUANTIFIER:
|
|
if node.children:
|
|
child_str = generate_from_node(node.children[0], max_length)
|
|
if child_str is None:
|
|
child_str = "x"
|
|
|
|
min_count = node.min_count if node.min_count else 0
|
|
max_count = min(node.max_count, 3) if node.max_count and node.max_count != float('inf') else 3
|
|
max_count = max(min_count, max_count)
|
|
|
|
if min_count == 0 and max_count == 0:
|
|
repeat = 0
|
|
elif min_count == 0:
|
|
repeat = random.randint(1, max_count)
|
|
else:
|
|
repeat = random.randint(min_count, max_count)
|
|
|
|
return (child_str * repeat)[:max_length]
|
|
return None
|
|
|
|
if node.node_type == NodeType.CAPTURING_GROUP:
|
|
if node.children:
|
|
return generate_from_node(node.children[0], max_length)
|
|
return None
|
|
|
|
if node.node_type == NodeType.NON_CAPTURING_GROUP:
|
|
if node.children:
|
|
return generate_from_node(node.children[0], max_length)
|
|
return None
|
|
|
|
if node.node_type == NodeType.NAMED_GROUP:
|
|
if node.children:
|
|
return generate_from_node(node.children[0], max_length)
|
|
return None
|
|
|
|
if node.node_type in (NodeType.LOOKAHEAD, NodeType.NEGATIVE_LOOKAHEAD):
|
|
return ""
|
|
|
|
if node.node_type in (NodeType.LOOKBEHIND, NodeType.NEGATIVE_LOOKBEHIND):
|
|
return ""
|
|
|
|
if node.node_type == NodeType.SEQUENCE:
|
|
result = ""
|
|
for child in node.children:
|
|
if len(result) >= max_length:
|
|
break
|
|
part = generate_from_node(child, max_length - len(result))
|
|
if part:
|
|
result += part
|
|
return result if result else None
|
|
|
|
if node.node_type == NodeType.BRANCH:
|
|
if node.children:
|
|
choices = []
|
|
for child in node.children:
|
|
part = generate_from_node(child, max_length)
|
|
if part:
|
|
choices.append(part)
|
|
if choices:
|
|
return random.choice(choices)
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
def pattern_to_string(node: RegexNode, max_length: int) -> str:
|
|
"""Convert a node to a representative string."""
|
|
result = generate_from_node(node, max_length)
|
|
return result if result else "test"
|
|
|
|
|
|
def get_replacement_char(original: str) -> str:
|
|
"""Get a replacement character different from the original."""
|
|
if original.isdigit():
|
|
return random.choice([c for c in string.digits if c != original])
|
|
if original.isalpha():
|
|
return random.choice([c for c in string.ascii_letters if c.lower() != original.lower()])
|
|
if original == ' ':
|
|
return random.choice(['\\t', '\\n'])
|
|
return 'x'
|
|
|
|
|
|
def get_opposite_char_class(char: str) -> str:
|
|
"""Get a character from a different class."""
|
|
if char.isdigit():
|
|
return random.choice(string.ascii_letters)
|
|
if char.isalpha():
|
|
return random.choice(string.digits)
|
|
if char == ' ':
|
|
return 'x'
|
|
return '1'
|
|
|
|
|
|
def matches_pattern(pattern: str, text: str, flavor: str) -> bool:
|
|
"""Check if text matches pattern."""
|
|
import re
|
|
try:
|
|
flags = 0
|
|
if flavor == "python":
|
|
pass
|
|
elif flavor == "javascript":
|
|
flags = re.MULTILINE
|
|
elif flavor == "pcre":
|
|
flags = re.MULTILINE
|
|
|
|
compiled = re.compile(pattern, flags)
|
|
return compiled.search(text) is not None
|
|
except re.error:
|
|
return False
|
|
|
|
|
|
def generate_test_cases(
|
|
pattern: str,
|
|
flavor: str = "pcre",
|
|
matching_count: int = 5,
|
|
non_matching_count: int = 5
|
|
) -> dict:
|
|
"""Generate all test cases for a pattern."""
|
|
generator = TestCaseGenerator(flavor)
|
|
|
|
return {
|
|
"pattern": pattern,
|
|
"flavor": flavor,
|
|
"matching": generator.generate_matching(pattern, matching_count),
|
|
"non_matching": generator.generate_non_matching(pattern, non_matching_count)
|
|
}
|