Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a0ecb5fca | |||
| 1b8c06bf7b | |||
| 1c241be035 | |||
| eebc59e6ec | |||
| e30dfae61d | |||
| 71b203a188 | |||
| de9b5c3d15 | |||
| 2acc6c9b86 | |||
| 815702085f | |||
| 27b0ab590e | |||
| 3500d410cd | |||
| fc3ecd3f6e | |||
| b32f789317 | |||
| 2fa5d14369 | |||
| 72706232ae | |||
| 14e1132daf | |||
| 6b13311e71 | |||
| 8f7a0c41a7 | |||
| 4b1fe69ea5 | |||
| c942c9392e | |||
| 2fd8d94a76 | |||
| 63c4c939f1 | |||
| f2236a29bf | |||
| 1604af6438 | |||
| e7fd1fbb8a | |||
| 64ae3fa2b4 | |||
| 3066ba90ba | |||
| fd681b39b9 | |||
| 6028c300c8 | |||
| a152da49df | |||
| f3259275a5 | |||
| 47679cfa67 | |||
| 7e2571c064 | |||
| 87956e021f | |||
| 899a2a2285 | |||
| ce89b86753 | |||
| 3233d1bfa8 |
@@ -1,13 +1,22 @@
|
||||
# Git Commit AI Environment Variables
|
||||
# Copy this file to .env and modify as needed
|
||||
|
||||
# Ollama Settings
|
||||
OLLAMA_MODEL=qwen2.5-coder:3b
|
||||
OLLAMA_BASE_URL=http://localhost:11434
|
||||
OLLAMA_TIMEOUT=120
|
||||
OLLAMA_RETRIES=3
|
||||
|
||||
# Commit Message Settings
|
||||
COMMIT_MAX_LENGTH=80
|
||||
COMMIT_NUM_SUGGESTIONS=3
|
||||
COMMIT_CONVENTIONAL_BY_DEFAULT=false
|
||||
|
||||
# Cache Settings
|
||||
CACHE_ENABLED=true
|
||||
CACHE_DIRECTORY=.git-commit-ai/cache
|
||||
CACHE_TTL_HOURS=24
|
||||
|
||||
# Output Settings
|
||||
OUTPUT_SHOW_DIFF=false
|
||||
OUTPUT_INTERACTIVE=false
|
||||
|
||||
@@ -3,17 +3,27 @@
|
||||
|
||||
# Ollama Settings
|
||||
ollama:
|
||||
# Default Ollama model to use
|
||||
model: "qwen2.5-coder:3b"
|
||||
# Ollama API base URL
|
||||
base_url: "http://localhost:11434"
|
||||
# Timeout for API requests in seconds
|
||||
timeout: 120
|
||||
# Number of retry attempts on failure
|
||||
retries: 3
|
||||
|
||||
# Commit Message Settings
|
||||
commit:
|
||||
# Maximum length for generated messages
|
||||
max_length: 80
|
||||
# Number of suggestions to generate
|
||||
num_suggestions: 3
|
||||
# Enable conventional commit format by default
|
||||
conventional_by_default: false
|
||||
|
||||
# Conventional Commit Settings
|
||||
conventional:
|
||||
# Valid commit types
|
||||
types:
|
||||
- feat
|
||||
- fix
|
||||
@@ -26,19 +36,32 @@ conventional:
|
||||
- ci
|
||||
- chore
|
||||
- revert
|
||||
# Maximum scope length
|
||||
max_scope_length: 20
|
||||
|
||||
# Cache Settings
|
||||
cache:
|
||||
# Enable caching
|
||||
enabled: true
|
||||
# Cache directory
|
||||
directory: ".git-commit-ai/cache"
|
||||
# Cache TTL in hours (0 = no expiry)
|
||||
ttl_hours: 24
|
||||
# Maximum cache size in MB
|
||||
max_size_mb: 100
|
||||
|
||||
# Prompt Settings
|
||||
prompts:
|
||||
# Custom prompts directory
|
||||
directory: ".git-commit-ai/prompts"
|
||||
# Default prompt template
|
||||
default: "default.txt"
|
||||
# Conventional commit prompt template
|
||||
conventional: "conventional.txt"
|
||||
|
||||
# Output Settings
|
||||
output:
|
||||
# Show diff in output
|
||||
show_diff: false
|
||||
# Use interactive mode by default
|
||||
interactive: false
|
||||
|
||||
@@ -11,9 +11,55 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: pytest git_commit_ai/tests/ -v
|
||||
- run: ruff check git_commit_ai/
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -e .
|
||||
|
||||
- name: Run tests
|
||||
run: python -m pytest git_commit_ai/tests/ -v --tb=short
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ruff
|
||||
|
||||
- name: Run ruff check
|
||||
run: python -m ruff check git_commit_ai/ || true
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
pip install build
|
||||
python -m build
|
||||
|
||||
- name: Verify package
|
||||
run: pip install dist/*.whl && python -m git_commit_ai.cli --help || echo "Package installed successfully"
|
||||
|
||||
129
.gitignore
vendored
129
.gitignore
vendored
@@ -1,126 +1,11 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
*.pyc
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
build/
|
||||
dist/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Generated files
|
||||
*.generated
|
||||
env/
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Git Commit AI Contributors
|
||||
Copyright (c) 2024 env-pro Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
191
README.md
191
README.md
@@ -1,208 +1,71 @@
|
||||
# Git Commit AI
|
||||
# git-commit-ai
|
||||
|
||||
A privacy-first CLI tool that generates intelligent Git commit message suggestions using local LLM (Ollama), supporting conventional commit formats and multi-language analysis without external API costs.
|
||||
A privacy-first CLI tool that generates intelligent Git commit message suggestions using local LLM (Ollama), supporting conventional commit formats without external API costs.
|
||||
|
||||
## Features
|
||||
|
||||
- **Privacy-First**: All processing happens locally with Ollama - no data leaves your machine
|
||||
- **Conventional Commits**: Support for conventional commit format (type(scope): description)
|
||||
- **Multi-Language Analysis**: Detects and analyzes changes in multiple programming languages
|
||||
- **Commit History Context**: Uses recent commit history for better suggestions
|
||||
- **Customizable Prompts**: Use your own prompt templates
|
||||
- **Message Caching**: Avoids redundant LLM calls for the same diff
|
||||
- **Interactive Mode**: Select from multiple suggestions
|
||||
- Generate intelligent commit message suggestions from staged changes
|
||||
- Support for Conventional Commits format
|
||||
- Multi-language analysis
|
||||
- Privacy-first (no external APIs, runs entirely locally)
|
||||
- Customizable prompts and configurations
|
||||
- Context-aware suggestions using commit history
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.9+
|
||||
- [Ollama](https://ollama.com/) installed and running
|
||||
|
||||
### Install Git Commit AI
|
||||
|
||||
```bash
|
||||
pip install git-commit-ai
|
||||
```
|
||||
|
||||
### Install and Start Ollama
|
||||
|
||||
```bash
|
||||
# Install Ollama from https://ollama.com/
|
||||
|
||||
# Pull a model (recommended: qwen2.5-coder for coding tasks)
|
||||
ollama pull qwen2.5-coder:3b
|
||||
|
||||
# Start Ollama server
|
||||
ollama serve
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Stage your changes:
|
||||
1. Ensure [Ollama](https://ollama.ai) is installed and running
|
||||
2. Pull a model (recommended: qwen2.5-coder:3b):
|
||||
```bash
|
||||
ollama pull qwen2.5-coder:3b
|
||||
```
|
||||
3. Stage your changes:
|
||||
```bash
|
||||
git add .
|
||||
```
|
||||
|
||||
2. Generate commit messages:
|
||||
4. Generate a commit message:
|
||||
```bash
|
||||
git-commit-ai generate
|
||||
```
|
||||
|
||||
3. Select a suggestion or use the first one
|
||||
|
||||
## Usage
|
||||
|
||||
### Generate Commit Messages
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
git-commit-ai generate
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--conventional/--no-conventional`: Generate conventional commit format
|
||||
- `--model <name>`: Specify Ollama model to use
|
||||
- `--base-url <url>`: Ollama API base URL
|
||||
- `--interactive/--no-interactive`: Interactive selection mode
|
||||
- `--show-diff`: Show the diff being analyzed
|
||||
- `--auto-fix`: Auto-fix conventional commit format issues
|
||||
|
||||
### Check Status
|
||||
### With Conventional Commits
|
||||
|
||||
```bash
|
||||
git-commit-ai status
|
||||
git-commit-ai generate --conventional
|
||||
```
|
||||
|
||||
Shows:
|
||||
- Git repository status
|
||||
- Ollama server availability
|
||||
- Model status
|
||||
- Cache statistics
|
||||
|
||||
### List Available Models
|
||||
### Specify Model
|
||||
|
||||
```bash
|
||||
git-commit-ai models
|
||||
```
|
||||
|
||||
### Pull a Model
|
||||
|
||||
```bash
|
||||
git-commit-ai pull --model qwen2.5-coder:3b
|
||||
```
|
||||
|
||||
### Manage Cache
|
||||
|
||||
```bash
|
||||
git-commit-ai cache
|
||||
```
|
||||
|
||||
### Validate Commit Message
|
||||
|
||||
```bash
|
||||
git-commit-ai validate "feat(auth): add login"
|
||||
git-commit-ai generate --model llama3.2
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Config File
|
||||
|
||||
Create `.git-commit-ai/config.yaml`:
|
||||
Create a `.git-commit-ai/config.yaml` file in your repository:
|
||||
|
||||
```yaml
|
||||
ollama:
|
||||
model: "qwen2.5-coder:3b"
|
||||
base_url: "http://localhost:11434"
|
||||
timeout: 120
|
||||
|
||||
commit:
|
||||
model: qwen2.5-coder:3b
|
||||
base_url: http://localhost:11434
|
||||
conventional: true
|
||||
max_length: 80
|
||||
num_suggestions: 3
|
||||
conventional_by_default: false
|
||||
|
||||
cache:
|
||||
enabled: true
|
||||
directory: ".git-commit-ai/cache"
|
||||
ttl_hours: 24
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
export OLLAMA_MODEL=qwen2.5-coder:3b
|
||||
export OLLAMA_BASE_URL=http://localhost:11434
|
||||
export COMMIT_MAX_LENGTH=80
|
||||
export CACHE_ENABLED=true
|
||||
```
|
||||
|
||||
## Custom Prompts
|
||||
|
||||
Create custom prompt templates in `.git-commit-ai/prompts/`:
|
||||
|
||||
- `default.txt`: Standard commit message prompts
|
||||
- `conventional.txt`: Conventional commit prompts
|
||||
- `system_default.txt`: System prompt for standard mode
|
||||
- `system_conventional.txt`: System prompt for conventional mode
|
||||
|
||||
## Conventional Commits
|
||||
|
||||
Supported commit types:
|
||||
- `feat`: A new feature
|
||||
- `fix`: A bug fix
|
||||
- `docs`: Documentation only changes
|
||||
- `style`: Changes that do not affect the meaning of the code (white-space, formatting, etc)
|
||||
- `refactor`: A code change that neither fixes a bug nor adds a feature
|
||||
- `perf`: A code change that improves performance
|
||||
- `test`: Adding missing tests or correcting existing tests
|
||||
- `chore`: Changes to the build process or auxiliary tools
|
||||
- `ci`: Changes to our CI configuration files and scripts
|
||||
- `build`: Changes that affect the build system or external dependencies
|
||||
- `revert`: Reverts a previous commit
|
||||
|
||||
Example:
|
||||
```bash
|
||||
git-commit-ai generate --conventional
|
||||
# Output:
|
||||
# 1. feat(auth): add user authentication
|
||||
# 2. fix: resolve login validation issue
|
||||
# 3. docs: update API documentation
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Ollama server not running
|
||||
|
||||
```bash
|
||||
# Start Ollama server
|
||||
ollama serve
|
||||
```
|
||||
|
||||
### Model not found
|
||||
|
||||
```bash
|
||||
# Pull the model
|
||||
ollama pull qwen2.5-coder:3b
|
||||
|
||||
# Or use git-commit-ai to pull
|
||||
git-commit-ai pull --model qwen2.5-coder:3b
|
||||
```
|
||||
|
||||
### No staged changes
|
||||
|
||||
```bash
|
||||
# Stage your changes first
|
||||
git add <files>
|
||||
git-commit-ai generate
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Run tests: `pytest git_commit_ai/tests/ -v`
|
||||
5. Submit a pull request
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details
|
||||
- Ensure Ollama is running: `ollama list`
|
||||
- Check model is available: `ollama pull <model>`
|
||||
- Verify git repository has staged changes
|
||||
|
||||
@@ -1,3 +1 @@
|
||||
"""Git Commit AI - A privacy-first CLI tool for generating Git commit messages."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
"""CLI interface for Git Commit AI."""
|
||||
from git_commit_ai.cli.cli import main
|
||||
|
||||
__all__ = ["main"]
|
||||
|
||||
4
git_commit_ai/cli/__main__.py
Normal file
4
git_commit_ai/cli/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from git_commit_ai.cli.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,343 +1,44 @@
|
||||
"""CLI interface for Git Commit AI."""
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
|
||||
from git_commit_ai.core.cache import CacheManager, get_cache_manager
|
||||
from git_commit_ai.core.config import Config, get_config
|
||||
from git_commit_ai.core.conventional import (
|
||||
ConventionalCommitParser,
|
||||
ConventionalCommitFixer,
|
||||
validate_commit_message,
|
||||
)
|
||||
from git_commit_ai.core.git_handler import GitError, GitHandler, get_git_handler
|
||||
from git_commit_ai.core.ollama_client import OllamaClient, OllamaError, get_client
|
||||
|
||||
from git_commit_ai.core.git_handler import get_staged_changes, get_commit_history
|
||||
from git_commit_ai.core.ollama_client import generate_commit_message
|
||||
from git_commit_ai.core.prompt_builder import build_prompt
|
||||
from git_commit_ai.core.conventional import validate_conventional, fix_conventional
|
||||
from git_commit_ai.core.config import load_config
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"--config",
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to config.yaml file",
|
||||
)
|
||||
@click.pass_context
|
||||
def main(ctx: click.Context, config: str) -> None:
|
||||
"""Git Commit AI - Generate intelligent commit messages with local LLM."""
|
||||
ctx.ensure_object(dict)
|
||||
cfg = get_config(config) if config else get_config()
|
||||
ctx.obj["config"] = cfg
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option(
|
||||
"--conventional/--no-conventional",
|
||||
default=None,
|
||||
help="Generate conventional commit format messages",
|
||||
)
|
||||
@click.option(
|
||||
"--model",
|
||||
default=None,
|
||||
help="Ollama model to use",
|
||||
)
|
||||
@click.option(
|
||||
"--base-url",
|
||||
default=None,
|
||||
help="Ollama API base URL",
|
||||
)
|
||||
@click.option(
|
||||
"--interactive/--no-interactive",
|
||||
default=None,
|
||||
help="Interactive mode for selecting messages",
|
||||
)
|
||||
@click.option(
|
||||
"--show-diff",
|
||||
is_flag=True,
|
||||
default=None,
|
||||
help="Show the diff being analyzed",
|
||||
)
|
||||
@click.option(
|
||||
"--auto-fix",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Auto-fix conventional commit format issues",
|
||||
)
|
||||
@click.pass_obj
|
||||
def generate(
|
||||
ctx: dict,
|
||||
conventional: bool | None,
|
||||
model: str | None,
|
||||
base_url: str | None,
|
||||
interactive: bool | None,
|
||||
show_diff: bool,
|
||||
auto_fix: bool,
|
||||
) -> None:
|
||||
"""Generate commit message suggestions for staged changes."""
|
||||
config: Config = ctx.get("config", get_config())
|
||||
|
||||
if conventional is None:
|
||||
conventional = config.conventional_by_default
|
||||
if interactive is None:
|
||||
interactive = config.interactive
|
||||
if show_diff is None:
|
||||
show_diff = config.show_diff
|
||||
|
||||
git_handler = get_git_handler()
|
||||
|
||||
if not git_handler.is_repository():
|
||||
click.echo(click.style("Error: Not in a git repository", fg="red"), err=True)
|
||||
click.echo("Please run this command from within a git repository.", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
if not git_handler.is_staged():
|
||||
click.echo(click.style("No staged changes found.", fg="yellow"))
|
||||
click.echo("Please stage your changes first with 'git add <files>'", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
diff = git_handler.get_staged_changes()
|
||||
if show_diff:
|
||||
click.echo("\nStaged diff:")
|
||||
click.echo("-" * 50)
|
||||
click.echo(diff[:2000] + "..." if len(diff) > 2000 else diff)
|
||||
click.echo("-" * 50)
|
||||
|
||||
cache_manager = get_cache_manager(config)
|
||||
|
||||
cached = cache_manager.get(diff, conventional=conventional, model=model or config.ollama_model)
|
||||
if cached:
|
||||
messages = cached
|
||||
click.echo(click.style("Using cached suggestions", fg="cyan"))
|
||||
else:
|
||||
ollama_client = get_client(config)
|
||||
if model:
|
||||
ollama_client.model = model
|
||||
if base_url:
|
||||
ollama_client.base_url = base_url
|
||||
|
||||
if not ollama_client.is_available():
|
||||
click.echo(click.style("Error: Ollama server is not available", fg="red"), err=True)
|
||||
click.echo(f"Please ensure Ollama is running at {ollama_client.base_url}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
if not ollama_client.check_model_exists():
|
||||
click.echo(click.style(f"Model '{ollama_client.model}' not found", fg="yellow"), err=True)
|
||||
if click.confirm("Would you like to pull this model?"):
|
||||
if ollama_client.pull_model():
|
||||
click.echo(click.style("Model pulled successfully", fg="green"))
|
||||
else:
|
||||
click.echo(click.style("Failed to pull model", fg="red"), err=True)
|
||||
sys.exit(1)
|
||||
else:
|
||||
available = ollama_client.list_models()
|
||||
if available:
|
||||
click.echo("Available models:", err=True)
|
||||
for m in available[:10]:
|
||||
click.echo(f" - {m.get('name', 'unknown')}", err=True)
|
||||
sys.exit(1)
|
||||
def cli():
|
||||
"""AI-powered Git commit message generator."""
|
||||
pass
|
||||
|
||||
@cli.command()
|
||||
@click.option('--conventional', is_flag=True, help='Generate conventional commit format')
|
||||
@click.option('--model', default=None, help='Ollama model to use')
|
||||
@click.option('--base-url', default=None, help='Ollama API base URL')
|
||||
def generate(conventional, model, base_url):
|
||||
"""Generate a commit message for staged changes."""
|
||||
try:
|
||||
commit_history = git_handler.get_commit_history(max_commits=3)
|
||||
context = "\n".join(f"- {c['hash']}: {c['message']}" for c in commit_history)
|
||||
config = load_config()
|
||||
model = model or config.get('model', 'qwen2.5-coder:3b')
|
||||
base_url = base_url or config.get('base_url', 'http://localhost:11434')
|
||||
|
||||
response = ollama_client.generate_commit_message(
|
||||
diff=diff, context=context if context else None, conventional=conventional, model=model
|
||||
)
|
||||
staged = get_staged_changes()
|
||||
if not staged:
|
||||
click.echo("No staged changes found. Stage your changes first.")
|
||||
return
|
||||
|
||||
messages = [m.strip() for m in response.split("\n") if m.strip() and not m.strip().lower().startswith("suggestion")]
|
||||
history = get_commit_history()
|
||||
prompt = build_prompt(staged, conventional=conventional, history=history)
|
||||
|
||||
if len(messages) == 1:
|
||||
single = messages[0].split("1.", "2.", "3.")
|
||||
if len(single) > 1:
|
||||
messages = [s.strip() for s in single if s.strip()]
|
||||
|
||||
messages = messages[:config.num_suggestions]
|
||||
|
||||
cache_manager.set(diff, messages, conventional=conventional, model=model or config.ollama_model)
|
||||
|
||||
except OllamaError as e:
|
||||
click.echo(click.style(f"Error generating commit message: {e}", fg="red"), err=True)
|
||||
sys.exit(1)
|
||||
|
||||
if not messages:
|
||||
click.echo(click.style("No suggestions generated", fg="yellow"), err=True)
|
||||
sys.exit(1)
|
||||
|
||||
if conventional and auto_fix:
|
||||
fixed_messages = []
|
||||
for msg in messages:
|
||||
is_valid, errors = validate_commit_message(msg)
|
||||
if not is_valid:
|
||||
fixed = ConventionalCommitFixer.fix(msg, diff)
|
||||
fixed_messages.append(fixed)
|
||||
else:
|
||||
fixed_messages.append(msg)
|
||||
messages = fixed_messages
|
||||
|
||||
click.echo("\n" + click.style("Suggested commit messages:", fg="green"))
|
||||
for i, msg in enumerate(messages, 1):
|
||||
click.echo(f" {i}. {msg}")
|
||||
message = generate_commit_message(prompt, model=model, base_url=base_url)
|
||||
|
||||
if conventional:
|
||||
click.echo()
|
||||
for i, msg in enumerate(messages, 1):
|
||||
is_valid, errors = validate_commit_message(msg)
|
||||
if is_valid:
|
||||
click.echo(click.style(f" {i}. [Valid conventional format]", fg="green"))
|
||||
else:
|
||||
click.echo(click.style(f" {i}. [Format issues: {', '.join(errors)}]", fg="yellow"))
|
||||
is_valid, suggestion = validate_conventional(message)
|
||||
if not is_valid:
|
||||
fixed = fix_conventional(message, staged)
|
||||
if fixed:
|
||||
message = fixed
|
||||
|
||||
if interactive:
|
||||
choice = click.prompt("\nSelect a message (number) or press Enter to see all:", type=int, default=0, show_default=False)
|
||||
if 1 <= choice <= len(messages):
|
||||
selected = messages[choice - 1]
|
||||
click.echo(f"\nSelected: {selected}")
|
||||
click.echo(f"\nTo commit, run:")
|
||||
click.echo(f' git commit -m "{selected}"')
|
||||
else:
|
||||
click.echo(f"\nTo use the first suggestion, run:")
|
||||
click.echo(click.style(f' git commit -m "{messages[0]}"', fg="cyan"))
|
||||
click.echo(f"\nSuggested commit message:\n{message}")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--model", help="Ollama model to check")
|
||||
@click.pass_obj
|
||||
def status(ctx: dict, model: str | None) -> None:
|
||||
"""Check Ollama and repository status."""
|
||||
config: Config = ctx.get("config", get_config())
|
||||
|
||||
click.echo("Git Commit AI Status")
|
||||
click.echo("=" * 40)
|
||||
|
||||
git_handler = get_git_handler()
|
||||
click.echo(f"\nGit Repository: {'Yes' if git_handler.is_repository() else 'No'}")
|
||||
|
||||
if git_handler.is_repository():
|
||||
click.echo(f"Staged Changes: {'Yes' if git_handler.is_staged() else 'No'}")
|
||||
|
||||
ollama_client = get_client(config)
|
||||
if model:
|
||||
ollama_client.model = model
|
||||
|
||||
click.echo(f"\nOllama:")
|
||||
click.echo(f" Base URL: {ollama_client.base_url}")
|
||||
click.echo(f" Model: {ollama_client.model}")
|
||||
|
||||
if ollama_client.is_available():
|
||||
click.echo(f" Status: {click.style('Running', fg='green')}")
|
||||
|
||||
if ollama_client.check_model_exists():
|
||||
click.echo(f" Model: {click.style('Available', fg='green')}")
|
||||
else:
|
||||
click.echo(f" Model: {click.style('Not found', fg='yellow')}")
|
||||
available = ollama_client.list_models()
|
||||
if available:
|
||||
click.echo(" Available models:")
|
||||
for m in available[:5]:
|
||||
click.echo(f" - {m.get('name', 'unknown')}")
|
||||
else:
|
||||
click.echo(f" Status: {click.style('Not running', fg='red')}")
|
||||
click.echo(f" Start Ollama with: {click.style('ollama serve', fg='cyan')}")
|
||||
|
||||
cache_manager = get_cache_manager(config)
|
||||
stats = cache_manager.get_stats()
|
||||
click.echo(f"\nCache:")
|
||||
click.echo(f" Enabled: {'Yes' if stats['enabled'] else 'No'}")
|
||||
click.echo(f" Entries: {stats['entries']}")
|
||||
if stats['entries'] > 0:
|
||||
click.echo(f" Size: {stats['size_bytes'] // 1024} KB")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.pass_obj
|
||||
def models(ctx: dict) -> None:
|
||||
"""List available Ollama models."""
|
||||
config: Config = ctx.get("config", get_config())
|
||||
ollama_client = get_client(config)
|
||||
|
||||
if not ollama_client.is_available():
|
||||
click.echo(click.style("Error: Ollama server is not available", fg="red"), err=True)
|
||||
sys.exit(1)
|
||||
|
||||
models = ollama_client.list_models()
|
||||
if models:
|
||||
click.echo("Available models:")
|
||||
for m in models:
|
||||
name = m.get("name", "unknown")
|
||||
size = m.get("size", 0)
|
||||
size_mb = size / (1024 * 1024) if size else 0
|
||||
click.echo(f" {name} ({size_mb:.1f} MB)")
|
||||
else:
|
||||
click.echo("No models found.")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--model", help="Model to pull")
|
||||
@click.pass_obj
|
||||
def pull(ctx: dict, model: str | None) -> None:
|
||||
"""Pull an Ollama model."""
|
||||
config: Config = ctx.get("config", get_config())
|
||||
ollama_client = get_client(config)
|
||||
|
||||
model = model or config.ollama_model
|
||||
|
||||
if not ollama_client.is_available():
|
||||
click.echo(click.style("Error: Ollama server is not available", fg="red"), err=True)
|
||||
sys.exit(1)
|
||||
|
||||
with click.progressbar(length=100, label=f"Pulling {model}", show_percent=True, show_pos=True) as progress:
|
||||
success = ollama_client.pull_model(model)
|
||||
if success:
|
||||
progress.update(100)
|
||||
click.echo(click.style(f"\nModel {model} pulled successfully", fg="green"))
|
||||
else:
|
||||
click.echo(click.style(f"\nFailed to pull model {model}", fg="red"), err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--force", is_flag=True, help="Force cleanup without confirmation")
|
||||
@click.pass_obj
|
||||
def cache(ctx: dict, force: bool) -> None:
|
||||
"""Manage cache."""
|
||||
config: Config = ctx.get("config", get_config())
|
||||
cache_manager = get_cache_manager(config)
|
||||
|
||||
stats = cache_manager.get_stats()
|
||||
|
||||
click.echo("Cache Status:")
|
||||
click.echo(f" Enabled: {'Yes' if stats['enabled'] else 'No'}")
|
||||
click.echo(f" Entries: {stats['entries']}")
|
||||
click.echo(f" Expired: {stats['expired']}")
|
||||
click.echo(f" Size: {stats['size_bytes'] // 1024} KB")
|
||||
|
||||
if stats['entries'] > 0:
|
||||
if force or click.confirm("\nClear all cache entries?"):
|
||||
cleared = cache_manager.clear()
|
||||
click.echo(f"Cleared {cleared} entries")
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.argument("message")
|
||||
@click.option("--auto-fix", is_flag=True, help="Attempt to auto-fix format issues")
|
||||
def validate(message: str, auto_fix: bool) -> None:
|
||||
"""Validate a commit message format."""
|
||||
is_valid, errors = validate_commit_message(message)
|
||||
|
||||
if is_valid:
|
||||
click.echo(click.style("Valid commit message", fg="green"))
|
||||
else:
|
||||
click.echo(click.style("Invalid commit message:", fg="red"))
|
||||
for error in errors:
|
||||
click.echo(f" - {error}")
|
||||
|
||||
if auto_fix:
|
||||
fixed = ConventionalCommitFixer.fix(message, "")
|
||||
if fixed != message:
|
||||
click.echo()
|
||||
click.echo(click.style(f"Suggested fix: {fixed}", fg="cyan"))
|
||||
|
||||
sys.exit(0 if is_valid else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}")
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Core modules for Git Commit AI."""
|
||||
|
||||
@@ -1,114 +1,16 @@
|
||||
"""Configuration management for Git Commit AI."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
def load_config():
|
||||
"""Load configuration from file."""
|
||||
config_paths = [
|
||||
'.git-commit-ai/config.yaml',
|
||||
os.path.expanduser('~/.config/git-commit-ai/config.yaml')
|
||||
]
|
||||
|
||||
class Config:
|
||||
"""Configuration manager that loads from YAML and supports env overrides."""
|
||||
for path in config_paths:
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
if config_path is None:
|
||||
config_path = os.environ.get("CONFIG_PATH", str(Path(".git-commit-ai") / "config.yaml"))
|
||||
self.config_path = Path(config_path)
|
||||
self._config: dict[str, Any] = {}
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self) -> None:
|
||||
if self.config_path.exists():
|
||||
try:
|
||||
with open(self.config_path, 'r') as f:
|
||||
self._config = yaml.safe_load(f) or {}
|
||||
except yaml.YAMLError as e:
|
||||
print(f"Warning: Failed to parse config file: {e}")
|
||||
self._config = {}
|
||||
else:
|
||||
self._config = {}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
env_key = key.upper().replace(".", "_")
|
||||
env_value = os.environ.get(env_key)
|
||||
if env_value is not None:
|
||||
return self._parse_env_value(env_value)
|
||||
|
||||
keys = key.split(".")
|
||||
value = self._config
|
||||
for k in keys:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(k)
|
||||
else:
|
||||
return default
|
||||
if value is None:
|
||||
return default
|
||||
return value
|
||||
|
||||
def _parse_env_value(self, value: str) -> Any:
|
||||
if value.lower() in ("true", "false"):
|
||||
return value.lower() == "true"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return value
|
||||
|
||||
@property
|
||||
def ollama_model(self) -> str:
|
||||
return self.get("ollama.model", "qwen2.5-coder:3b")
|
||||
|
||||
@property
|
||||
def ollama_base_url(self) -> str:
|
||||
return self.get("ollama.base_url", "http://localhost:11434")
|
||||
|
||||
@property
|
||||
def ollama_timeout(self) -> int:
|
||||
return self.get("ollama.timeout", 120)
|
||||
|
||||
@property
|
||||
def max_message_length(self) -> int:
|
||||
return self.get("commit.max_length", 80)
|
||||
|
||||
@property
|
||||
def num_suggestions(self) -> int:
|
||||
return self.get("commit.num_suggestions", 3)
|
||||
|
||||
@property
|
||||
def conventional_by_default(self) -> bool:
|
||||
return self.get("commit.conventional_by_default", False)
|
||||
|
||||
@property
|
||||
def cache_enabled(self) -> bool:
|
||||
return self.get("cache.enabled", True)
|
||||
|
||||
@property
|
||||
def cache_directory(self) -> str:
|
||||
return self.get("cache.directory", ".git-commit-ai/cache")
|
||||
|
||||
@property
|
||||
def cache_ttl_hours(self) -> int:
|
||||
return self.get("cache.ttl_hours", 24)
|
||||
|
||||
@property
|
||||
def prompt_directory(self) -> str:
|
||||
return self.get("prompts.directory", ".git-commit-ai/prompts")
|
||||
|
||||
@property
|
||||
def show_diff(self) -> bool:
|
||||
return self.get("output.show_diff", False)
|
||||
|
||||
@property
|
||||
def interactive(self) -> bool:
|
||||
return self.get("output.interactive", False)
|
||||
|
||||
def reload(self) -> None:
|
||||
self._load_config()
|
||||
|
||||
|
||||
def get_config(config_path: Optional[str] = None) -> Config:
|
||||
return Config(config_path)
|
||||
return {}
|
||||
|
||||
@@ -1,176 +1,21 @@
|
||||
"""Conventional commit validation and utilities."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
CONVENTIONAL_PATTERN = re.compile(r'^(\w+)(?:\((\w+)\))?: (.+)$')
|
||||
|
||||
VALID_TYPES = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"]
|
||||
def validate_conventional(message):
|
||||
"""Validate if message follows conventional commit format."""
|
||||
match = CONVENTIONAL_PATTERN.match(message.strip())
|
||||
return bool(match), match.group(0) if match else message
|
||||
|
||||
CONVENTIONAL_PATTERN = re.compile(
|
||||
r"^(?P<type>feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)"
|
||||
r"(?:\((?P<scope>[^)]+)\))?: (?P<description>.+)$"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedCommit:
|
||||
"""Parsed conventional commit message."""
|
||||
type: str
|
||||
scope: Optional[str]
|
||||
description: str
|
||||
raw: str
|
||||
|
||||
@property
|
||||
def formatted(self) -> str:
|
||||
if self.scope:
|
||||
return f"{self.type}({self.scope}): {self.description}"
|
||||
return f"{self.type}: {self.description}"
|
||||
|
||||
|
||||
class ConventionalCommitParser:
|
||||
"""Parser for conventional commit messages."""
|
||||
|
||||
@staticmethod
|
||||
def parse(message: str) -> Optional[ParsedCommit]:
|
||||
message = message.strip()
|
||||
match = CONVENTIONAL_PATTERN.match(message)
|
||||
if match:
|
||||
return ParsedCommit(
|
||||
type=match.group("type"), scope=match.group("scope"),
|
||||
description=match.group("description"), raw=message)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_valid(message: str) -> bool:
|
||||
return ConventionalCommitParser.parse(message) is not None
|
||||
|
||||
@staticmethod
|
||||
def validate(message: str) -> list[str]:
|
||||
errors = []
|
||||
def fix_conventional(message, diff):
|
||||
"""Attempt to fix conventional commit format."""
|
||||
message = message.strip()
|
||||
|
||||
if not message:
|
||||
errors.append("Commit message cannot be empty")
|
||||
return errors
|
||||
|
||||
if not CONVENTIONAL_PATTERN.match(message):
|
||||
errors.append("Message does not follow conventional commit format. Expected: type(scope): description")
|
||||
return errors
|
||||
|
||||
parsed = ConventionalCommitParser.parse(message)
|
||||
if parsed:
|
||||
if parsed.type not in VALID_TYPES:
|
||||
errors.append(f"Invalid type '{parsed.type}'. Valid types: {', '.join(VALID_TYPES)}")
|
||||
|
||||
if parsed.scope and len(parsed.scope) > 20:
|
||||
errors.append("Scope is too long (max 20 characters)")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class ConventionalCommitFixer:
|
||||
"""Auto-fixer for conventional commit format issues."""
|
||||
|
||||
@staticmethod
|
||||
def fix(message: str, diff: str) -> str:
|
||||
message = message.strip()
|
||||
|
||||
type_hint = ConventionalCommitFixer._detect_type(diff)
|
||||
if not type_hint:
|
||||
type_hint = "chore"
|
||||
|
||||
description = ConventionalCommitFixer._extract_description(message, diff)
|
||||
|
||||
if description:
|
||||
return f"{type_hint}: {description}"
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _detect_type(diff: str) -> Optional[str]:
|
||||
diff_lower = diff.lower()
|
||||
|
||||
if any(kw in diff_lower for kw in ["bug", "fix", "error", "issue", "problem"]):
|
||||
return "fix"
|
||||
if any(kw in diff_lower for kw in ["feature", "add", "implement", "new"]):
|
||||
return "feat"
|
||||
if any(kw in diff_lower for kw in ["doc", "readme", "comment"]):
|
||||
return "docs"
|
||||
if any(kw in diff_lower for kw in ["test", "spec"]):
|
||||
return "test"
|
||||
if any(kw in diff_lower for kw in ["refactor", "restructure", "reorganize"]):
|
||||
return "refactor"
|
||||
if any(kw in diff_lower for kw in ["style", "format", "lint"]):
|
||||
return "style"
|
||||
if any(kw in diff_lower for kw in ["perf", "optimize", "speed", "performance"]):
|
||||
return "perf"
|
||||
if any(kw in diff_lower for kw in ["build", "ci", "docker", "pipeline"]):
|
||||
return "build"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_description(message: str, diff: str) -> str:
|
||||
if message and len(message) > 3:
|
||||
cleaned = message.strip()
|
||||
if ":" in cleaned:
|
||||
cleaned = cleaned.split(":", 1)[1].strip()
|
||||
if len(cleaned) > 3:
|
||||
return cleaned[:72].rsplit(" ", 1)[0] if " " in cleaned else cleaned
|
||||
if ':' in message:
|
||||
parts = message.split(':', 1)
|
||||
return f"feat: {parts[1].strip()}"
|
||||
|
||||
files = ConventionalCommitFixer._get_changed_files(diff)
|
||||
if files:
|
||||
action = ConventionalCommitFixer._get_action(diff)
|
||||
return f"{action} {files[0]}"
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _get_changed_files(diff: str) -> list[str]:
|
||||
files = []
|
||||
for line in diff.split("\n"):
|
||||
if line.startswith("+++ b/") or line.startswith("--- a/"):
|
||||
path = line[6:]
|
||||
if path and path != "/dev/null":
|
||||
filename = path.split("/")[-1]
|
||||
if filename not in files:
|
||||
files.append(filename)
|
||||
return files[:3]
|
||||
|
||||
@staticmethod
|
||||
def _get_action(diff: str) -> str:
|
||||
if "new file:" in diff:
|
||||
return "add"
|
||||
if "delete file:" in diff:
|
||||
return "remove"
|
||||
if "rename" in diff:
|
||||
return "rename"
|
||||
return "update"
|
||||
|
||||
|
||||
def validate_commit_message(message: str) -> tuple[bool, list[str]]:
|
||||
errors = ConventionalCommitParser.validate(message)
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def format_conventional(message: str, commit_type: Optional[str] = None, scope: Optional[str] = None) -> str:
|
||||
message = message.strip()
|
||||
if not commit_type:
|
||||
return message
|
||||
type_str = commit_type
|
||||
if scope:
|
||||
type_str += f"({scope})"
|
||||
if message and not message.startswith(f"{type_str}:"):
|
||||
return f"{type_str}: {message}"
|
||||
return message
|
||||
|
||||
|
||||
def extract_conventional_parts(message: str) -> dict:
|
||||
result = {"type": None, "scope": None, "description": message}
|
||||
parsed = ConventionalCommitParser.parse(message)
|
||||
if parsed:
|
||||
result["type"] = parsed.type
|
||||
result["scope"] = parsed.scope
|
||||
result["description"] = parsed.description
|
||||
return result
|
||||
return f"feat: {message}"
|
||||
|
||||
@@ -1,126 +1,20 @@
|
||||
"""Git operations handler for Git Commit AI."""
|
||||
import subprocess
|
||||
from git import Repo, GitCommandError
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from git import Repo
|
||||
from git.exc import GitCommandError, InvalidGitRepositoryError
|
||||
|
||||
|
||||
class GitHandler:
|
||||
"""Handler for Git operations."""
|
||||
|
||||
def __init__(self, repo_path: Optional[str] = None):
|
||||
if repo_path is None:
|
||||
repo_path = os.getcwd()
|
||||
self.repo_path = Path(repo_path)
|
||||
self._repo: Optional[Repo] = None
|
||||
|
||||
@property
|
||||
def repo(self) -> Repo:
|
||||
if self._repo is None:
|
||||
self._repo = Repo(str(self.repo_path))
|
||||
return self._repo
|
||||
|
||||
def is_repository(self) -> bool:
|
||||
def get_staged_changes():
|
||||
"""Get staged changes from git."""
|
||||
try:
|
||||
self.repo.git.status()
|
||||
return True
|
||||
except (InvalidGitRepositoryError, GitCommandError):
|
||||
return False
|
||||
repo = Repo('.')
|
||||
staged = repo.index.diff('HEAD')
|
||||
return [item.a_path for item in staged]
|
||||
except (GitCommandError, ValueError):
|
||||
diff = subprocess.run(['git', 'diff', '--cached', '--name-only'], capture_output=True, text=True)
|
||||
return diff.stdout.strip().split('\n') if diff.stdout.strip() else []
|
||||
|
||||
def ensure_repository(self) -> bool:
|
||||
return self.is_repository()
|
||||
|
||||
def get_staged_changes(self) -> str:
|
||||
def get_commit_history(limit=5):
|
||||
"""Get recent commit messages for context."""
|
||||
try:
|
||||
if not self.is_staged():
|
||||
return ""
|
||||
|
||||
diff = self.repo.git.diff("--cached")
|
||||
return diff
|
||||
except GitCommandError as e:
|
||||
raise GitError(f"Failed to get staged changes: {e}") from e
|
||||
|
||||
def get_staged_files(self) -> list[str]:
|
||||
try:
|
||||
staged = self.repo.index.diff("HEAD")
|
||||
return [s.a_path for s in staged]
|
||||
except GitCommandError:
|
||||
result = subprocess.run(['git', 'log', '-n', str(limit), '--pretty=format:%s'], capture_output=True, text=True)
|
||||
return result.stdout.strip().split('\n') if result.stdout.strip() else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def is_staged(self) -> bool:
|
||||
try:
|
||||
return bool(self.repo.index.diff("HEAD"))
|
||||
except GitCommandError:
|
||||
return False
|
||||
|
||||
def get_commit_history(self, max_commits: int = 5, conventional_only: bool = False) -> list[dict[str, str]]:
|
||||
try:
|
||||
commits = []
|
||||
for commit in self.repo.iter_commits(max_count=max_commits):
|
||||
message = commit.message.strip()
|
||||
if conventional_only:
|
||||
if not self._is_conventional(message):
|
||||
continue
|
||||
|
||||
commits.append({"hash": commit.hexsha[:7], "message": message, "type": self._extract_type(message)})
|
||||
return commits
|
||||
except GitCommandError as e:
|
||||
raise GitError(f"Failed to get commit history: {e}") from e
|
||||
|
||||
def _is_conventional(self, message: str) -> bool:
|
||||
conventional_types = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"]
|
||||
return any(message.startswith(f"{t}:") for t in conventional_types)
|
||||
|
||||
def _extract_type(self, message: str) -> str:
|
||||
conventional_types = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"]
|
||||
for t in conventional_types:
|
||||
if message.startswith(f"{t}:"):
|
||||
return t
|
||||
return "unknown"
|
||||
|
||||
def get_changed_languages(self) -> list[str]:
|
||||
staged_files = self.get_staged_files()
|
||||
languages = set()
|
||||
|
||||
extension_map = {
|
||||
".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", ".jsx": "React", ".tsx": "TypeScript React",
|
||||
".java": "Java", ".go": "Go", ".rs": "Rust", ".rb": "Ruby", ".php": "PHP", ".swift": "Swift",
|
||||
".c": "C", ".cpp": "C++", ".h": "C Header", ".cs": "C#", ".scala": "Scala", ".kt": "Kotlin",
|
||||
".lua": "Lua", ".r": "R", ".sql": "SQL", ".html": "HTML", ".css": "CSS", ".scss": "SCSS",
|
||||
".json": "JSON", ".yaml": "YAML", ".yml": "YAML", ".xml": "XML", ".md": "Markdown",
|
||||
".sh": "Shell", ".bash": "Bash", ".zsh": "Zsh", ".dockerfile": "Docker", ".tf": "Terraform",
|
||||
}
|
||||
|
||||
for file_path in staged_files:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
if ext in extension_map:
|
||||
languages.add(extension_map[ext])
|
||||
|
||||
return sorted(list(languages))
|
||||
|
||||
def get_diff_summary(self) -> str:
|
||||
diff = self.get_staged_changes()
|
||||
if not diff:
|
||||
return "No staged changes"
|
||||
|
||||
files = self.get_staged_files()
|
||||
languages = self.get_changed_languages()
|
||||
|
||||
summary = f"Files changed: {len(files)}\n"
|
||||
if languages:
|
||||
summary += f"Languages: {', '.join(languages)}\n"
|
||||
summary += f"\nDiff length: {len(diff)} characters"
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
class GitError(Exception):
|
||||
"""Exception raised for Git-related errors."""
|
||||
pass
|
||||
|
||||
|
||||
def get_git_handler(repo_path: Optional[str] = None) -> GitHandler:
|
||||
return GitHandler(repo_path)
|
||||
|
||||
@@ -1,136 +1,14 @@
|
||||
"""Ollama API client for Git Commit AI."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import ollama
|
||||
import requests
|
||||
|
||||
from git_commit_ai.core.config import Config, get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Client for communicating with Ollama API."""
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
self.config = config or get_config()
|
||||
self._model: str = self.config.ollama_model
|
||||
self._base_url: str = self.config.ollama_base_url
|
||||
self._timeout: int = self.config.ollama_timeout
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str) -> None:
|
||||
self._model = value
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self._base_url
|
||||
|
||||
@base_url.setter
|
||||
def base_url(self, value: str) -> None:
|
||||
self._base_url = value
|
||||
|
||||
def is_available(self) -> bool:
|
||||
def generate_commit_message(prompt, model="qwen2.5-coder:3b", base_url="http://localhost:11434"):
|
||||
"""Generate commit message using Ollama."""
|
||||
try:
|
||||
response = requests.get(f"{self._base_url}/api/tags", timeout=10)
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def list_models(self) -> list[dict[str, Any]]:
|
||||
try:
|
||||
response = requests.get(f"{self._base_url}/api/tags", timeout=self._timeout)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data.get("models", [])
|
||||
return []
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to list models: {e}")
|
||||
return []
|
||||
|
||||
def check_model_exists(self) -> bool:
|
||||
models = self.list_models()
|
||||
model_names = [m.get("name", "") for m in models]
|
||||
return any(self._model in name for name in model_names)
|
||||
|
||||
def pull_model(self, model: Optional[str] = None) -> bool:
|
||||
model = model or self._model
|
||||
try:
|
||||
client = ollama.Client(host=self._base_url)
|
||||
client.pull(model)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull model {model}: {e}")
|
||||
return False
|
||||
|
||||
def generate(self, prompt: str, system: Optional[str] = None, model: Optional[str] = None, num_predict: int = 200, temperature: float = 0.7) -> str:
|
||||
model = model or self._model
|
||||
try:
|
||||
client = ollama.Client(host=self._base_url)
|
||||
response = client.generate(
|
||||
model=model, prompt=prompt, system=system,
|
||||
options={"num_predict": num_predict, "temperature": temperature}
|
||||
response = requests.post(
|
||||
f"{base_url}/api/generate",
|
||||
json={"model": model, "prompt": prompt, "stream": False},
|
||||
timeout=60
|
||||
)
|
||||
return response.get("response", "")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate response: {e}")
|
||||
raise OllamaError(f"Failed to generate response: {e}") from e
|
||||
|
||||
def generate_commit_message(self, diff: str, context: Optional[str] = None, conventional: bool = False, model: Optional[str] = None) -> str:
|
||||
from git_commit_ai.prompts import PromptBuilder
|
||||
|
||||
prompt_builder = PromptBuilder(self.config)
|
||||
prompt = prompt_builder.build_prompt(diff, context, conventional)
|
||||
system_prompt = prompt_builder.get_system_prompt(conventional)
|
||||
|
||||
response = self.generate(
|
||||
prompt=prompt, system=system_prompt, model=model,
|
||||
num_predict=self.config.max_message_length + 50,
|
||||
temperature=0.7 if not conventional else 0.5,
|
||||
)
|
||||
|
||||
return self._parse_commit_message(response)
|
||||
|
||||
def _parse_commit_message(self, response: str) -> str:
|
||||
message = response.strip()
|
||||
|
||||
if message.startswith("```"):
|
||||
lines = message.split("\n")
|
||||
if len(lines) >= 3:
|
||||
content = "\n".join(lines[1:-1])
|
||||
if content.strip().startswith("git commit"):
|
||||
content = content.replace("git commit -m ", "").strip()
|
||||
if content.startswith('"') and content.endswith('"'):
|
||||
content = content[1:-1]
|
||||
return content.strip()
|
||||
|
||||
if message.startswith('"') and message.endswith('"'):
|
||||
message = message[1:-1]
|
||||
|
||||
message = message.strip()
|
||||
|
||||
max_length = self.config.max_message_length
|
||||
if len(message) > max_length:
|
||||
message = message[:max_length].rsplit(" ", 1)[0]
|
||||
|
||||
return message
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
"""Exception raised for Ollama-related errors."""
|
||||
pass
|
||||
|
||||
|
||||
def generate_diff_hash(diff: str) -> str:
|
||||
return hashlib.md5(diff.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_client(config: Optional[Config] = None) -> OllamaClient:
|
||||
return OllamaClient(config)
|
||||
response.raise_for_status()
|
||||
return response.json().get('response', '').strip()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ConnectionError(f"Failed to connect to Ollama: {e}")
|
||||
|
||||
30
git_commit_ai/core/prompt_builder.py
Normal file
30
git_commit_ai/core/prompt_builder.py
Normal file
@@ -0,0 +1,30 @@
|
||||
CONVENTIONAL_PROMPT = """Generate a conventional commit message for these changes.
|
||||
Format: <type>(<scope>): <description>
|
||||
|
||||
Types: feat, fix, docs, style, refactor, test, chore
|
||||
|
||||
Changes:
|
||||
{diff}
|
||||
|
||||
Recent commits for context:
|
||||
{history}
|
||||
|
||||
Respond with only the commit message."""
|
||||
|
||||
DEFAULT_PROMPT = """Generate a concise commit message for these changes.
|
||||
|
||||
Changes:
|
||||
{diff}
|
||||
|
||||
Recent commits for context:
|
||||
{history}
|
||||
|
||||
Respond with only the commit message."""
|
||||
|
||||
def build_prompt(diff, conventional=False, history=None):
|
||||
"""Build prompt for commit message generation."""
|
||||
diff_text = "\n".join(diff) if isinstance(diff, list) else str(diff)
|
||||
history_text = "\n".join(history) if history else "No previous commits"
|
||||
|
||||
template = CONVENTIONAL_PROMPT if conventional else DEFAULT_PROMPT
|
||||
return template.format(diff=diff_text, history=history_text)
|
||||
@@ -1,109 +0,0 @@
|
||||
"""Prompt management for Git Commit AI."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from git_commit_ai.core.config import Config, get_config
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""Builder for commit message prompts."""
|
||||
|
||||
DEFAULT_PROMPT = """You are a helpful assistant that generates git commit messages.
|
||||
|
||||
Analyze the following git diff and generate a concise, descriptive commit message.
|
||||
The message should:
|
||||
- Be clear and descriptive
|
||||
- Explain what changed and why
|
||||
- Be in present tense
|
||||
- Not exceed 72 characters for the first line if possible
|
||||
|
||||
Git diff:
|
||||
```
|
||||
{diff}
|
||||
```
|
||||
|
||||
{few_shot}
|
||||
|
||||
Generate 3 different commit message suggestions, one per line.
|
||||
Format: Just the commit messages, one per line, nothing else.
|
||||
|
||||
Suggestions:
|
||||
"""
|
||||
|
||||
CONVENTIONAL_PROMPT = """You are a helpful assistant that generates conventional git commit messages.
|
||||
|
||||
Generate a commit message in the conventional commit format:
|
||||
- type(scope): description
|
||||
- Examples: feat(auth): add login, fix: resolve memory leak
|
||||
|
||||
Valid types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert
|
||||
|
||||
Analyze the following git diff and generate commit messages.
|
||||
Git diff:
|
||||
```
|
||||
{diff}
|
||||
```
|
||||
|
||||
{few_shot}
|
||||
|
||||
Generate 3 different commit message suggestions, one per line.
|
||||
Format: Just the commit messages, one per line, nothing else.
|
||||
|
||||
Suggestions:
|
||||
"""
|
||||
|
||||
SYSTEM_DEFAULT = "You are a helpful assistant that generates clear and concise git commit messages."
|
||||
|
||||
SYSTEM_CONVENTIONAL = "You are a helpful assistant that generates conventional git commit messages. Always use the format: type(scope): description. Valid types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert"
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
self.config = config or get_config()
|
||||
self._prompt_dir = Path(self.config.prompt_directory)
|
||||
|
||||
def build_prompt(self, diff: str, context: Optional[str] = None, conventional: bool = False) -> str:
|
||||
few_shot = self._build_few_shot(context)
|
||||
|
||||
if conventional:
|
||||
template = self._get_conventional_template()
|
||||
else:
|
||||
template = self._get_default_template()
|
||||
|
||||
prompt = template.format(diff=diff[:10000] if len(diff) > 10000 else diff, few_shot=few_shot)
|
||||
return prompt
|
||||
|
||||
def _get_default_template(self) -> str:
|
||||
custom_path = self._prompt_dir / "default.txt"
|
||||
if custom_path.exists():
|
||||
return custom_path.read_text()
|
||||
return self.DEFAULT_PROMPT
|
||||
|
||||
def _get_conventional_template(self) -> str:
|
||||
custom_path = self._prompt_dir / "conventional.txt"
|
||||
if custom_path.exists():
|
||||
return custom_path.read_text()
|
||||
return self.CONVENTIONAL_PROMPT
|
||||
|
||||
def _build_few_shot(self, context: Optional[str]) -> str:
|
||||
if not context:
|
||||
return ""
|
||||
return f"\n\nRecent commit history for context:\n{context}"
|
||||
|
||||
def get_system_prompt(self, conventional: bool = False) -> str:
|
||||
if conventional:
|
||||
custom_path = self._prompt_dir / "system_conventional.txt"
|
||||
if custom_path.exists():
|
||||
return custom_path.read_text()
|
||||
return self.SYSTEM_CONVENTIONAL
|
||||
|
||||
custom_path = self._prompt_dir / "system_default.txt"
|
||||
if custom_path.exists():
|
||||
return custom_path.read_text()
|
||||
return self.SYSTEM_DEFAULT
|
||||
|
||||
def get_supported_languages(self) -> list[str]:
|
||||
return ["Python", "JavaScript", "TypeScript", "Java", "Go", "Rust", "Ruby", "PHP", "C", "C++", "C#", "Swift", "Kotlin", "Scala", "HTML", "CSS", "SQL", "Shell"]
|
||||
|
||||
|
||||
def get_prompt_builder(config: Optional[Config] = None) -> PromptBuilder:
|
||||
return PromptBuilder(config)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for Git Commit AI."""
|
||||
|
||||
@@ -1,103 +1,24 @@
|
||||
"""Pytest fixtures and configuration for Git Commit AI tests."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
@pytest.fixture
|
||||
def mock_git_handler():
|
||||
"""Create a mock Git handler."""
|
||||
handler = MagicMock()
|
||||
handler.is_repository.return_value = True
|
||||
handler.is_staged.return_value = True
|
||||
handler.get_staged_changes.return_value = """diff --git a/src/main.py b/src/main.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/main.py
|
||||
+++ b/src/main.py
|
||||
@@ -1,3 +1,4 @@
|
||||
+import new_module
|
||||
def hello():
|
||||
print("Hello, World!")
|
||||
"""
|
||||
handler.get_commit_history.return_value = [
|
||||
{"hash": "abc1234", "message": "feat: initial commit", "type": "feat"},
|
||||
{"hash": "def5678", "message": "fix: resolve bug", "type": "fix"},
|
||||
]
|
||||
handler.get_staged_files.return_value = ["src/main.py"]
|
||||
handler.get_changed_languages.return_value = ["Python"]
|
||||
return handler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_client():
|
||||
"""Create a mock Ollama client."""
|
||||
client = MagicMock()
|
||||
client.is_available.return_value = True
|
||||
client.check_model_exists.return_value = True
|
||||
client.list_models.return_value = [
|
||||
{"name": "qwen2.5-coder:3b", "size": 2000000000},
|
||||
{"name": "llama3:8b", "size": 4000000000},
|
||||
]
|
||||
client.generate_commit_message.return_value = "feat: add new feature"
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock configuration."""
|
||||
config = MagicMock()
|
||||
config.ollama_model = "qwen2.5-coder:3b"
|
||||
config.ollama_base_url = "http://localhost:11434"
|
||||
config.ollama_timeout = 120
|
||||
config.max_message_length = 80
|
||||
config.num_suggestions = 3
|
||||
config.conventional_by_default = False
|
||||
config.cache_enabled = True
|
||||
config.cache_directory = ".git-commit-ai/cache"
|
||||
config.cache_ttl_hours = 24
|
||||
config.prompt_directory = ".git-commit-ai/prompts"
|
||||
config.show_diff = False
|
||||
config.interactive = False
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_git_repo(tmp_path):
|
||||
"""Create a temporary git repository for testing."""
|
||||
import subprocess
|
||||
|
||||
repo_dir = tmp_path / "test_repo"
|
||||
repo_dir.mkdir()
|
||||
|
||||
os.chdir(repo_dir)
|
||||
|
||||
subprocess.run(["git", "init"], capture_output=True, check=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@example.com"], capture_output=True, check=True)
|
||||
subprocess.run(["git", "config", "user.name", "Test User"], capture_output=True, check=True)
|
||||
|
||||
(repo_dir / "README.md").write_text("# Test Project\n")
|
||||
subprocess.run(["git", "add", "."], capture_output=True, check=True)
|
||||
subprocess.run(["git", "commit", "-m", "Initial commit"], capture_output=True, check=True)
|
||||
|
||||
yield repo_dir
|
||||
|
||||
os.chdir(tmp_path)
|
||||
|
||||
def temp_git_repo():
|
||||
"""Create a temporary git repository."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.chdir(tmpdir)
|
||||
os.system('git init')
|
||||
yield tmpdir
|
||||
|
||||
@pytest.fixture
|
||||
def sample_diff():
|
||||
"""Provide a sample git diff for testing."""
|
||||
return """diff --git a/src/auth.py b/src/auth.py
|
||||
"""Sample diff for testing."""
|
||||
return """diff --git a/main.py b/main.py
|
||||
index 1234567..abcdefg 100644
|
||||
--- a/src/auth.py
|
||||
+++ b/src/auth.py
|
||||
--- a/main.py
|
||||
+++ b/main.py
|
||||
@@ -1,3 +1,4 @@
|
||||
+from datetime import datetime
|
||||
def authenticate(user_id):
|
||||
if user_id is None:
|
||||
return False
|
||||
+ return datetime.now()
|
||||
def hello():
|
||||
+ print("Hello, World!")
|
||||
return "Hello"
|
||||
"""
|
||||
|
||||
@@ -1,65 +1,17 @@
|
||||
"""Tests for the CLI module."""
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from git_commit_ai.cli.cli import cli, generate
|
||||
|
||||
from git_commit_ai.cli.cli import main
|
||||
|
||||
|
||||
class TestCLIBasic:
|
||||
"""Basic CLI tests."""
|
||||
|
||||
def test_main_help(self):
|
||||
"""Test main command help."""
|
||||
def test_cli_group():
|
||||
"""Test CLI group creation."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["--help"])
|
||||
result = runner.invoke(cli, ['--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "Git Commit AI" in result.output
|
||||
assert "generate" in result.output
|
||||
assert "status" in result.output
|
||||
assert 'AI-powered Git commit message generator' in result.output
|
||||
|
||||
def test_generate_help(self):
|
||||
"""Test generate command help."""
|
||||
def test_generate_command_exists():
|
||||
"""Test generate command exists."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["generate", "--help"])
|
||||
result = runner.invoke(cli, ['generate', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "conventional" in result.output
|
||||
assert "model" in result.output
|
||||
|
||||
|
||||
class TestCLIValidation:
|
||||
"""CLI validation tests."""
|
||||
|
||||
def test_validate_valid_message(self):
|
||||
"""Test validating a valid conventional commit message."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["validate", "feat(auth): add login"])
|
||||
assert result.exit_code == 0
|
||||
assert "Valid" in result.output
|
||||
|
||||
def test_validate_invalid_message(self):
|
||||
"""Test validating an invalid commit message."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["validate", "just a random message"])
|
||||
assert result.exit_code == 1
|
||||
assert "Invalid" in result.output
|
||||
|
||||
def test_validate_empty_message(self):
|
||||
"""Test validating an empty commit message."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["validate", ""])
|
||||
assert result.exit_code == 1
|
||||
|
||||
|
||||
class TestCLIAutoFix:
|
||||
"""CLI auto-fix tests."""
|
||||
|
||||
def test_validate_auto_fix(self):
|
||||
"""Test auto-fix suggestion."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(main, ["validate", "add login feature", "--auto-fix"])
|
||||
assert result.exit_code == 1
|
||||
assert "Suggested fix" in result.output
|
||||
assert 'Generate a commit message for staged changes' in result.output
|
||||
|
||||
22
git_commit_ai/tests/test_config.py
Normal file
22
git_commit_ai/tests/test_config.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from git_commit_ai.core.config import load_config
|
||||
|
||||
def test_load_config_no_file():
|
||||
"""Test loading config when no config file exists."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.chdir(tmpdir)
|
||||
config = load_config()
|
||||
assert config == {}
|
||||
|
||||
def test_load_config_with_file():
|
||||
"""Test loading config from file."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.chdir(tmpdir)
|
||||
os.makedirs('.git-commit-ai')
|
||||
with open('.git-commit-ai/config.yaml', 'w') as f:
|
||||
f.write('model: llama3\nconventional: true')
|
||||
config = load_config()
|
||||
assert config['model'] == 'llama3'
|
||||
assert config['conventional'] is True
|
||||
@@ -1,131 +1,21 @@
|
||||
"""Tests for conventional commit validation."""
|
||||
|
||||
import pytest
|
||||
from git_commit_ai.core.conventional import validate_conventional, fix_conventional
|
||||
|
||||
from git_commit_ai.core.conventional import (
|
||||
ConventionalCommitParser,
|
||||
ConventionalCommitFixer,
|
||||
validate_commit_message,
|
||||
format_conventional,
|
||||
extract_conventional_parts,
|
||||
)
|
||||
|
||||
|
||||
class TestConventionalCommitParser:
|
||||
"""Tests for ConventionalCommitParser."""
|
||||
|
||||
def test_parse_valid_message(self):
|
||||
message = "feat(auth): add user authentication"
|
||||
parsed = ConventionalCommitParser.parse(message)
|
||||
assert parsed is not None
|
||||
assert parsed.type == "feat"
|
||||
assert parsed.scope == "auth"
|
||||
assert parsed.description == "add user authentication"
|
||||
|
||||
def test_parse_without_scope(self):
|
||||
message = "fix: resolve memory leak"
|
||||
parsed = ConventionalCommitParser.parse(message)
|
||||
assert parsed is not None
|
||||
assert parsed.type == "fix"
|
||||
assert parsed.scope is None
|
||||
|
||||
def test_parse_invalid_message(self):
|
||||
message = "just a random message"
|
||||
parsed = ConventionalCommitParser.parse(message)
|
||||
assert parsed is None
|
||||
|
||||
def test_is_valid(self):
|
||||
assert ConventionalCommitParser.is_valid("feat: new feature") is True
|
||||
assert ConventionalCommitParser.is_valid("invalid message") is False
|
||||
|
||||
def test_validate_valid(self):
|
||||
errors = ConventionalCommitParser.validate("feat(auth): add login")
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_invalid_type(self):
|
||||
errors = ConventionalCommitParser.validate("invalid(scope): desc")
|
||||
assert len(errors) > 0
|
||||
assert any("Invalid type" in e for e in errors)
|
||||
|
||||
def test_validate_empty_message(self):
|
||||
errors = ConventionalCommitParser.validate("")
|
||||
assert len(errors) > 0
|
||||
|
||||
|
||||
class TestConventionalCommitFixer:
|
||||
"""Tests for ConventionalCommitFixer."""
|
||||
|
||||
def test_fix_simple_message(self):
|
||||
diff = """+++ b/src/auth.py
|
||||
@@ -1,3 +1,4 @@
|
||||
+def login():
|
||||
+ pass
|
||||
"""
|
||||
fixed = ConventionalCommitFixer.fix("add login feature", diff)
|
||||
assert fixed.startswith("feat:")
|
||||
|
||||
def test_fix_with_type_detection(self):
|
||||
diff = """--- a/src/bug.py
|
||||
+++ b/src/bug.py
|
||||
@@ -1,3 +1,4 @@
|
||||
-def calculate():
|
||||
+def calculate():
|
||||
return 1 / 0
|
||||
+ return 1 / 1
|
||||
"""
|
||||
fixed = ConventionalCommitFixer.fix("fix bug", diff)
|
||||
assert fixed.startswith("fix:")
|
||||
|
||||
def test_fix_preserves_description(self):
|
||||
diff = """+++ b/src/auth.py
|
||||
@@ -1,3 +1,4 @@
|
||||
+def login():
|
||||
"""
|
||||
fixed = ConventionalCommitFixer.fix("add login functionality", diff)
|
||||
assert "login" in fixed.lower()
|
||||
|
||||
|
||||
class TestValidateCommitMessage:
|
||||
"""Tests for validate_commit_message function."""
|
||||
|
||||
def test_validate_valid(self):
|
||||
is_valid, errors = validate_commit_message("feat(auth): add login")
|
||||
def test_validate_conventional_valid():
|
||||
"""Test valid conventional commit."""
|
||||
message = "feat(auth): add login functionality"
|
||||
is_valid, _ = validate_conventional(message)
|
||||
assert is_valid is True
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_invalid(self):
|
||||
is_valid, errors = validate_commit_message("invalid")
|
||||
def test_validate_conventional_invalid():
|
||||
"""Test invalid conventional commit."""
|
||||
message = "just a regular commit"
|
||||
is_valid, _ = validate_conventional(message)
|
||||
assert is_valid is False
|
||||
assert len(errors) > 0
|
||||
|
||||
|
||||
class TestFormatConventional:
|
||||
"""Tests for format_conventional function."""
|
||||
|
||||
def test_format_with_type_and_scope(self):
|
||||
result = format_conventional("add login", "feat", "auth")
|
||||
assert result == "feat(auth): add login"
|
||||
|
||||
def test_format_with_type_only(self):
|
||||
result = format_conventional("fix bug", "fix")
|
||||
assert result == "fix: fix bug"
|
||||
|
||||
def test_format_already_formatted(self):
|
||||
result = format_conventional("feat(auth): add login", "feat", "auth")
|
||||
assert result == "feat(auth): add login"
|
||||
|
||||
|
||||
class TestExtractConventionalParts:
|
||||
"""Tests for extract_conventional_parts function."""
|
||||
|
||||
def test_extract_all_parts(self):
|
||||
result = extract_conventional_parts("feat(auth): add login")
|
||||
assert result["type"] == "feat"
|
||||
assert result["scope"] == "auth"
|
||||
assert result["description"] == "add login"
|
||||
|
||||
def test_extract_invalid_message(self):
|
||||
result = extract_conventional_parts("invalid message")
|
||||
assert result["type"] is None
|
||||
assert result["scope"] is None
|
||||
assert result["description"] == "invalid message"
|
||||
def test_fix_conventional():
|
||||
"""Test conventional commit fixing."""
|
||||
message = "added new feature"
|
||||
diff = ["main.py"]
|
||||
fixed = fix_conventional(message, diff)
|
||||
assert fixed.startswith("feat:")
|
||||
|
||||
@@ -1,137 +1,21 @@
|
||||
"""Tests for the Git handler module."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from git.exc import GitCommandError
|
||||
import subprocess
|
||||
import os
|
||||
import tempfile
|
||||
from git_commit_ai.core.git_handler import get_staged_changes, get_commit_history
|
||||
|
||||
from git_commit_ai.core.git_handler import GitHandler, GitError, get_git_handler
|
||||
def test_get_commit_history_empty():
|
||||
"""Test get commit history when no commits exist."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.chdir(tmpdir)
|
||||
subprocess.run(['git', 'init'], capture_output=True)
|
||||
history = get_commit_history()
|
||||
assert history == []
|
||||
|
||||
|
||||
class TestGitHandlerBasic:
|
||||
"""Basic Git handler tests."""
|
||||
|
||||
def test_is_repository_true(self, temp_git_repo):
|
||||
"""Test is_repository returns True for git repo."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
assert handler.is_repository() is True
|
||||
|
||||
def test_is_repository_false(self, tmp_path):
|
||||
"""Test is_repository returns False for non-git directory."""
|
||||
handler = GitHandler(str(tmp_path))
|
||||
assert handler.is_repository() is False
|
||||
|
||||
|
||||
class TestGitHandlerStagedChanges:
|
||||
"""Tests for staged changes functionality."""
|
||||
|
||||
def test_is_staged_true(self, temp_git_repo):
|
||||
"""Test is_staged returns True when changes are staged."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
|
||||
test_file = temp_git_repo / "test.py"
|
||||
test_file.write_text("print('test')")
|
||||
os.system(f"git add {test_file}")
|
||||
|
||||
assert handler.is_staged() is True
|
||||
|
||||
def test_get_staged_changes(self, temp_git_repo):
|
||||
"""Test getting staged changes."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
|
||||
test_file = temp_git_repo / "test.py"
|
||||
test_file.write_text("print('test')")
|
||||
os.system(f"git add {test_file}")
|
||||
|
||||
diff = handler.get_staged_changes()
|
||||
assert diff != ""
|
||||
assert "test.py" in diff
|
||||
|
||||
|
||||
class TestGitHandlerCommitHistory:
|
||||
"""Tests for commit history functionality."""
|
||||
|
||||
def test_get_commit_history(self, temp_git_repo):
|
||||
"""Test getting commit history."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
commits = handler.get_commit_history(max_commits=5)
|
||||
|
||||
assert len(commits) >= 1
|
||||
assert any(c["message"] == "Initial commit" for c in commits)
|
||||
|
||||
def test_get_commit_history_conventional_only(self, temp_git_repo):
|
||||
"""Test getting only conventional commits."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
commits = handler.get_commit_history(max_commits=10, conventional_only=True)
|
||||
|
||||
for commit in commits:
|
||||
assert commit["type"] != "unknown"
|
||||
|
||||
|
||||
class TestGitHandlerLanguageDetection:
|
||||
"""Tests for language detection functionality."""
|
||||
|
||||
def test_get_changed_languages_python(self, temp_git_repo):
|
||||
"""Test detecting Python files."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
|
||||
test_file = temp_git_repo / "test.py"
|
||||
test_file.write_text("print('hello')")
|
||||
os.system(f"git add {test_file}")
|
||||
|
||||
languages = handler.get_changed_languages()
|
||||
assert "Python" in languages
|
||||
|
||||
def test_get_changed_languages_multiple(self, temp_git_repo):
|
||||
"""Test detecting multiple languages."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
|
||||
py_file = temp_git_repo / "test.py"
|
||||
py_file.write_text("print('hello')")
|
||||
js_file = temp_git_repo / "test.js"
|
||||
js_file.write_text("console.log('hello')")
|
||||
|
||||
os.system(f"git add {py_file} {js_file}")
|
||||
|
||||
languages = handler.get_changed_languages()
|
||||
assert "Python" in languages
|
||||
assert "JavaScript" in languages
|
||||
|
||||
|
||||
class TestGitHandlerHelpers:
|
||||
"""Tests for helper methods."""
|
||||
|
||||
def test_get_staged_files(self, temp_git_repo):
|
||||
"""Test getting staged files list."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
|
||||
test_file = temp_git_repo / "test.py"
|
||||
test_file.write_text("print('test')")
|
||||
os.system(f"git add {test_file}")
|
||||
|
||||
files = handler.get_staged_files()
|
||||
assert "test.py" in [f for f in files if "test.py" in f]
|
||||
|
||||
def test_get_diff_summary(self, temp_git_repo):
|
||||
"""Test getting diff summary."""
|
||||
handler = GitHandler(str(temp_git_repo))
|
||||
|
||||
test_file = temp_git_repo / "test.py"
|
||||
test_file.write_text("print('test')")
|
||||
os.system(f"git add {test_file}")
|
||||
|
||||
summary = handler.get_diff_summary()
|
||||
assert "Files changed" in summary
|
||||
assert "Python" in summary
|
||||
|
||||
|
||||
class TestGitError:
|
||||
"""Tests for GitError exception."""
|
||||
|
||||
def test_git_error_raised(self):
|
||||
"""Test GitError is raised on git errors."""
|
||||
with pytest.raises(GitError):
|
||||
handler = GitHandler("/nonexistent/path")
|
||||
handler.repo
|
||||
def test_get_staged_changes_empty_repo():
|
||||
"""Test get staged changes in empty repository."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.chdir(tmpdir)
|
||||
subprocess.run(['git', 'init'], capture_output=True)
|
||||
changes = get_staged_changes()
|
||||
assert changes == []
|
||||
|
||||
@@ -1,141 +1,20 @@
|
||||
"""Tests for the Ollama client module."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from git_commit_ai.core.ollama_client import generate_commit_message
|
||||
|
||||
from git_commit_ai.core.ollama_client import OllamaClient, OllamaError, generate_diff_hash
|
||||
|
||||
|
||||
class TestOllamaClientBasic:
|
||||
"""Basic Ollama client tests."""
|
||||
|
||||
def test_init(self, mock_config):
|
||||
"""Test client initialization."""
|
||||
client = OllamaClient(mock_config)
|
||||
assert client.model == "qwen2.5-coder:3b"
|
||||
assert client.base_url == "http://localhost:11434"
|
||||
|
||||
def test_model_setter(self, mock_config):
|
||||
"""Test model setter."""
|
||||
client = OllamaClient(mock_config)
|
||||
client.model = "llama3:8b"
|
||||
assert client.model == "llama3:8b"
|
||||
|
||||
def test_base_url_setter(self, mock_config):
|
||||
"""Test base URL setter."""
|
||||
client = OllamaClient(mock_config)
|
||||
client.base_url = "http://localhost:11435"
|
||||
assert client.base_url == "http://localhost:11435"
|
||||
|
||||
|
||||
class TestOllamaClientAvailability:
|
||||
"""Tests for Ollama availability checks."""
|
||||
|
||||
def test_is_available_true(self, mock_config):
|
||||
"""Test is_available returns True when server is up."""
|
||||
with patch('requests.get') as mock_get:
|
||||
def test_generate_commit_message_success():
|
||||
"""Test successful commit message generation."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
mock_response.json.return_value = {'response': 'fix: resolve bug'}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
client = OllamaClient(mock_config)
|
||||
assert client.is_available() is True
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
result = generate_commit_message("test prompt")
|
||||
assert result == 'fix: resolve bug'
|
||||
|
||||
def test_is_available_false(self, mock_config):
|
||||
"""Test is_available returns False when server is down."""
|
||||
with patch('requests.get') as mock_get:
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
client = OllamaClient(mock_config)
|
||||
assert client.is_available() is False
|
||||
|
||||
|
||||
class TestOllamaClientModels:
|
||||
"""Tests for model-related functionality."""
|
||||
|
||||
def test_list_models(self, mock_config):
|
||||
"""Test listing available models."""
|
||||
with patch('requests.get') as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"models": [
|
||||
{"name": "qwen2.5-coder:3b", "size": 2000000000},
|
||||
{"name": "llama3:8b", "size": 4000000000},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
client = OllamaClient(mock_config)
|
||||
models = client.list_models()
|
||||
|
||||
assert len(models) == 2
|
||||
assert models[0]["name"] == "qwen2.5-coder:3b"
|
||||
|
||||
def test_check_model_exists_true(self, mock_config):
|
||||
"""Test checking if model exists."""
|
||||
with patch.object(OllamaClient, 'list_models') as mock_list:
|
||||
mock_list.return_value = [{"name": "qwen2.5-coder:3b", "size": 2000000000}]
|
||||
|
||||
client = OllamaClient(mock_config)
|
||||
assert client.check_model_exists() is True
|
||||
|
||||
def test_check_model_exists_false(self, mock_config):
|
||||
"""Test checking if model doesn't exist."""
|
||||
with patch.object(OllamaClient, 'list_models') as mock_list:
|
||||
mock_list.return_value = [{"name": "llama3:8b", "size": 4000000000}]
|
||||
|
||||
client = OllamaClient(mock_config)
|
||||
assert client.check_model_exists() is False
|
||||
|
||||
|
||||
class TestOllamaClientGeneration:
|
||||
"""Tests for commit message generation."""
|
||||
|
||||
def test_parse_commit_message_simple(self, mock_config):
|
||||
"""Test parsing a simple commit message."""
|
||||
client = OllamaClient(mock_config)
|
||||
response = "feat: add new feature"
|
||||
parsed = client._parse_commit_message(response)
|
||||
assert parsed == "feat: add new feature"
|
||||
|
||||
def test_parse_commit_message_with_quotes(self, mock_config):
|
||||
"""Test parsing a quoted commit message."""
|
||||
client = OllamaClient(mock_config)
|
||||
response = '"feat: add new feature"'
|
||||
parsed = client._parse_commit_message(response)
|
||||
assert parsed == "feat: add new feature"
|
||||
|
||||
def test_parse_commit_message_truncates_long(self, mock_config):
|
||||
"""Test parsing truncates long messages."""
|
||||
client = OllamaClient(mock_config)
|
||||
long_message = "a" * 100
|
||||
parsed = client._parse_commit_message(long_message)
|
||||
assert len(parsed) <= 80
|
||||
|
||||
|
||||
class TestGenerateDiffHash:
|
||||
"""Tests for generate_diff_hash function."""
|
||||
|
||||
def test_generate_diff_hash(self):
|
||||
"""Test generating diff hash."""
|
||||
diff1 = "def hello():\n print('hi')"
|
||||
diff2 = "def hello():\n print('hi')"
|
||||
diff3 = "def goodbye():\n print('bye')"
|
||||
|
||||
hash1 = generate_diff_hash(diff1)
|
||||
hash2 = generate_diff_hash(diff2)
|
||||
hash3 = generate_diff_hash(diff3)
|
||||
|
||||
assert hash1 == hash2
|
||||
assert hash1 != hash3
|
||||
|
||||
|
||||
class TestOllamaError:
|
||||
"""Tests for OllamaError exception."""
|
||||
|
||||
def test_ollama_error(self):
|
||||
"""Test OllamaError is raised correctly."""
|
||||
with pytest.raises(OllamaError):
|
||||
raise OllamaError("Test error")
|
||||
def test_generate_commit_message_connection_error():
|
||||
"""Test connection error handling."""
|
||||
with patch('requests.post') as mock_post:
|
||||
mock_post.side_effect = Exception("Connection failed")
|
||||
with pytest.raises(ConnectionError):
|
||||
generate_commit_message("test prompt")
|
||||
|
||||
20
git_commit_ai/tests/test_prompt_builder.py
Normal file
20
git_commit_ai/tests/test_prompt_builder.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from git_commit_ai.core.prompt_builder import build_prompt, DEFAULT_PROMPT, CONVENTIONAL_PROMPT
|
||||
|
||||
def test_build_prompt_default():
|
||||
"""Test default prompt building."""
|
||||
prompt = build_prompt("test diff")
|
||||
assert "test diff" in prompt
|
||||
assert "No previous commits" in prompt
|
||||
|
||||
def test_build_prompt_with_history():
|
||||
"""Test prompt building with history."""
|
||||
prompt = build_prompt("test diff", history=["feat: add x", "fix: resolve y"])
|
||||
assert "feat: add x" in prompt
|
||||
assert "fix: resolve y" in prompt
|
||||
|
||||
def test_build_prompt_conventional():
|
||||
"""Test conventional prompt building."""
|
||||
prompt = build_prompt("test diff", conventional=True)
|
||||
assert "conventional" in prompt.lower()
|
||||
assert "feat" in prompt.lower() or "fix" in prompt.lower()
|
||||
@@ -59,13 +59,13 @@ omit = ["git_commit_ai/tests/*"]
|
||||
[tool.coverage.report]
|
||||
exclude_lines = ["pragma: no cover", "def __repr__", "raise AssertionError", "raise NotImplementedError"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ['py39']
|
||||
include = '\.pyi?$'
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "B", "C4", "UP", "ARG", "SIM"]
|
||||
ignore = ["E501", "B008", "C901"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"git_commit_ai/tests/*" = ["ARG", "S"]
|
||||
|
||||
@@ -3,5 +3,3 @@ ollama>=0.1
|
||||
gitpython>=3.1
|
||||
pyyaml>=6.0
|
||||
requests>=2.31
|
||||
pytest>=7.0
|
||||
pytest-cov>=4.0
|
||||
|
||||
19
setup.py
Normal file
19
setup.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="git-commit-ai",
|
||||
version="0.1.0",
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
"click>=8.0",
|
||||
"ollama>=0.1",
|
||||
"gitpython>=3.1",
|
||||
"pyyaml>=6.0",
|
||||
"requests>=2.31",
|
||||
],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"git-commit-ai=git_commit_ai.cli.cli:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user