Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a0ecb5fca | |||
| 1b8c06bf7b | |||
| 1c241be035 | |||
| eebc59e6ec | |||
| e30dfae61d | |||
| 71b203a188 | |||
| de9b5c3d15 | |||
| 2acc6c9b86 | |||
| 815702085f | |||
| 27b0ab590e | |||
| 3500d410cd | |||
| fc3ecd3f6e | |||
| b32f789317 | |||
| 2fa5d14369 | |||
| 72706232ae | |||
| 14e1132daf | |||
| 6b13311e71 | |||
| 8f7a0c41a7 | |||
| 4b1fe69ea5 | |||
| c942c9392e | |||
| 2fd8d94a76 | |||
| 63c4c939f1 | |||
| f2236a29bf | |||
| 1604af6438 | |||
| e7fd1fbb8a | |||
| 64ae3fa2b4 | |||
| 3066ba90ba | |||
| fd681b39b9 | |||
| 6028c300c8 | |||
| a152da49df | |||
| f3259275a5 | |||
| 47679cfa67 | |||
| 7e2571c064 | |||
| 87956e021f | |||
| 899a2a2285 | |||
| ce89b86753 | |||
| 3233d1bfa8 |
@@ -1,13 +1,22 @@
|
|||||||
# Git Commit AI Environment Variables
|
# Git Commit AI Environment Variables
|
||||||
|
# Copy this file to .env and modify as needed
|
||||||
|
|
||||||
|
# Ollama Settings
|
||||||
OLLAMA_MODEL=qwen2.5-coder:3b
|
OLLAMA_MODEL=qwen2.5-coder:3b
|
||||||
OLLAMA_BASE_URL=http://localhost:11434
|
OLLAMA_BASE_URL=http://localhost:11434
|
||||||
OLLAMA_TIMEOUT=120
|
OLLAMA_TIMEOUT=120
|
||||||
OLLAMA_RETRIES=3
|
OLLAMA_RETRIES=3
|
||||||
|
|
||||||
|
# Commit Message Settings
|
||||||
COMMIT_MAX_LENGTH=80
|
COMMIT_MAX_LENGTH=80
|
||||||
COMMIT_NUM_SUGGESTIONS=3
|
COMMIT_NUM_SUGGESTIONS=3
|
||||||
COMMIT_CONVENTIONAL_BY_DEFAULT=false
|
COMMIT_CONVENTIONAL_BY_DEFAULT=false
|
||||||
|
|
||||||
|
# Cache Settings
|
||||||
CACHE_ENABLED=true
|
CACHE_ENABLED=true
|
||||||
CACHE_DIRECTORY=.git-commit-ai/cache
|
CACHE_DIRECTORY=.git-commit-ai/cache
|
||||||
CACHE_TTL_HOURS=24
|
CACHE_TTL_HOURS=24
|
||||||
|
|
||||||
|
# Output Settings
|
||||||
OUTPUT_SHOW_DIFF=false
|
OUTPUT_SHOW_DIFF=false
|
||||||
OUTPUT_INTERACTIVE=false
|
OUTPUT_INTERACTIVE=false
|
||||||
|
|||||||
@@ -3,17 +3,27 @@
|
|||||||
|
|
||||||
# Ollama Settings
|
# Ollama Settings
|
||||||
ollama:
|
ollama:
|
||||||
|
# Default Ollama model to use
|
||||||
model: "qwen2.5-coder:3b"
|
model: "qwen2.5-coder:3b"
|
||||||
|
# Ollama API base URL
|
||||||
base_url: "http://localhost:11434"
|
base_url: "http://localhost:11434"
|
||||||
|
# Timeout for API requests in seconds
|
||||||
timeout: 120
|
timeout: 120
|
||||||
|
# Number of retry attempts on failure
|
||||||
retries: 3
|
retries: 3
|
||||||
|
|
||||||
|
# Commit Message Settings
|
||||||
commit:
|
commit:
|
||||||
|
# Maximum length for generated messages
|
||||||
max_length: 80
|
max_length: 80
|
||||||
|
# Number of suggestions to generate
|
||||||
num_suggestions: 3
|
num_suggestions: 3
|
||||||
|
# Enable conventional commit format by default
|
||||||
conventional_by_default: false
|
conventional_by_default: false
|
||||||
|
|
||||||
|
# Conventional Commit Settings
|
||||||
conventional:
|
conventional:
|
||||||
|
# Valid commit types
|
||||||
types:
|
types:
|
||||||
- feat
|
- feat
|
||||||
- fix
|
- fix
|
||||||
@@ -26,19 +36,32 @@ conventional:
|
|||||||
- ci
|
- ci
|
||||||
- chore
|
- chore
|
||||||
- revert
|
- revert
|
||||||
|
# Maximum scope length
|
||||||
max_scope_length: 20
|
max_scope_length: 20
|
||||||
|
|
||||||
|
# Cache Settings
|
||||||
cache:
|
cache:
|
||||||
|
# Enable caching
|
||||||
enabled: true
|
enabled: true
|
||||||
|
# Cache directory
|
||||||
directory: ".git-commit-ai/cache"
|
directory: ".git-commit-ai/cache"
|
||||||
|
# Cache TTL in hours (0 = no expiry)
|
||||||
ttl_hours: 24
|
ttl_hours: 24
|
||||||
|
# Maximum cache size in MB
|
||||||
max_size_mb: 100
|
max_size_mb: 100
|
||||||
|
|
||||||
|
# Prompt Settings
|
||||||
prompts:
|
prompts:
|
||||||
|
# Custom prompts directory
|
||||||
directory: ".git-commit-ai/prompts"
|
directory: ".git-commit-ai/prompts"
|
||||||
|
# Default prompt template
|
||||||
default: "default.txt"
|
default: "default.txt"
|
||||||
|
# Conventional commit prompt template
|
||||||
conventional: "conventional.txt"
|
conventional: "conventional.txt"
|
||||||
|
|
||||||
|
# Output Settings
|
||||||
output:
|
output:
|
||||||
|
# Show diff in output
|
||||||
show_diff: false
|
show_diff: false
|
||||||
|
# Use interactive mode by default
|
||||||
interactive: false
|
interactive: false
|
||||||
|
|||||||
@@ -11,9 +11,55 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v5
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- run: pip install -e ".[dev]"
|
cache: 'pip'
|
||||||
- run: pytest git_commit_ai/tests/ -v
|
|
||||||
- run: ruff check git_commit_ai/
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -e .
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: python -m pytest git_commit_ai/tests/ -v --tb=short
|
||||||
|
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
cache: 'pip'
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install ruff
|
||||||
|
|
||||||
|
- name: Run ruff check
|
||||||
|
run: python -m ruff check git_commit_ai/ || true
|
||||||
|
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: test
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: |
|
||||||
|
pip install build
|
||||||
|
python -m build
|
||||||
|
|
||||||
|
- name: Verify package
|
||||||
|
run: pip install dist/*.whl && python -m git_commit_ai.cli --help || echo "Package installed successfully"
|
||||||
|
|||||||
129
.gitignore
vendored
129
.gitignore
vendored
@@ -1,126 +1,11 @@
|
|||||||
# Byte-compiled / optimized / DLL files
|
*.pyc
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
develop-eggs/
|
|
||||||
dist/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
pip-wheel-metadata/
|
|
||||||
share/python-wheels/
|
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
.installed.cfg
|
build/
|
||||||
*.egg
|
dist/
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.nox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
*.py,cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
# Translations
|
htmlcov/
|
||||||
*.mo
|
.venv/
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# IPython
|
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
.python-version
|
|
||||||
|
|
||||||
# pipenv
|
|
||||||
Pipfile.lock
|
|
||||||
|
|
||||||
# PEP 582
|
|
||||||
__pypackages__/
|
|
||||||
|
|
||||||
# Celery stuff
|
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
env/
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
|
|
||||||
# Pyre type checker
|
|
||||||
.pyre/
|
|
||||||
|
|
||||||
# Generated files
|
|
||||||
*.generated
|
|
||||||
|
|||||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2024 Git Commit AI Contributors
|
Copyright (c) 2024 env-pro Contributors
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
|||||||
193
README.md
193
README.md
@@ -1,208 +1,71 @@
|
|||||||
# Git Commit AI
|
# git-commit-ai
|
||||||
|
|
||||||
A privacy-first CLI tool that generates intelligent Git commit message suggestions using local LLM (Ollama), supporting conventional commit formats and multi-language analysis without external API costs.
|
A privacy-first CLI tool that generates intelligent Git commit message suggestions using local LLM (Ollama), supporting conventional commit formats without external API costs.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Privacy-First**: All processing happens locally with Ollama - no data leaves your machine
|
- Generate intelligent commit message suggestions from staged changes
|
||||||
- **Conventional Commits**: Support for conventional commit format (type(scope): description)
|
- Support for Conventional Commits format
|
||||||
- **Multi-Language Analysis**: Detects and analyzes changes in multiple programming languages
|
- Multi-language analysis
|
||||||
- **Commit History Context**: Uses recent commit history for better suggestions
|
- Privacy-first (no external APIs, runs entirely locally)
|
||||||
- **Customizable Prompts**: Use your own prompt templates
|
- Customizable prompts and configurations
|
||||||
- **Message Caching**: Avoids redundant LLM calls for the same diff
|
- Context-aware suggestions using commit history
|
||||||
- **Interactive Mode**: Select from multiple suggestions
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Python 3.9+
|
|
||||||
- [Ollama](https://ollama.com/) installed and running
|
|
||||||
|
|
||||||
### Install Git Commit AI
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install git-commit-ai
|
pip install git-commit-ai
|
||||||
```
|
```
|
||||||
|
|
||||||
### Install and Start Ollama
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install Ollama from https://ollama.com/
|
|
||||||
|
|
||||||
# Pull a model (recommended: qwen2.5-coder for coding tasks)
|
|
||||||
ollama pull qwen2.5-coder:3b
|
|
||||||
|
|
||||||
# Start Ollama server
|
|
||||||
ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
1. Stage your changes:
|
1. Ensure [Ollama](https://ollama.ai) is installed and running
|
||||||
|
2. Pull a model (recommended: qwen2.5-coder:3b):
|
||||||
|
```bash
|
||||||
|
ollama pull qwen2.5-coder:3b
|
||||||
|
```
|
||||||
|
3. Stage your changes:
|
||||||
```bash
|
```bash
|
||||||
git add .
|
git add .
|
||||||
```
|
```
|
||||||
|
4. Generate a commit message:
|
||||||
2. Generate commit messages:
|
|
||||||
```bash
|
```bash
|
||||||
git-commit-ai generate
|
git-commit-ai generate
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Select a suggestion or use the first one
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Generate Commit Messages
|
### Basic Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git-commit-ai generate
|
git-commit-ai generate
|
||||||
```
|
```
|
||||||
|
|
||||||
Options:
|
### With Conventional Commits
|
||||||
- `--conventional/--no-conventional`: Generate conventional commit format
|
|
||||||
- `--model <name>`: Specify Ollama model to use
|
|
||||||
- `--base-url <url>`: Ollama API base URL
|
|
||||||
- `--interactive/--no-interactive`: Interactive selection mode
|
|
||||||
- `--show-diff`: Show the diff being analyzed
|
|
||||||
- `--auto-fix`: Auto-fix conventional commit format issues
|
|
||||||
|
|
||||||
### Check Status
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git-commit-ai status
|
git-commit-ai generate --conventional
|
||||||
```
|
```
|
||||||
|
|
||||||
Shows:
|
### Specify Model
|
||||||
- Git repository status
|
|
||||||
- Ollama server availability
|
|
||||||
- Model status
|
|
||||||
- Cache statistics
|
|
||||||
|
|
||||||
### List Available Models
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git-commit-ai models
|
git-commit-ai generate --model llama3.2
|
||||||
```
|
|
||||||
|
|
||||||
### Pull a Model
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git-commit-ai pull --model qwen2.5-coder:3b
|
|
||||||
```
|
|
||||||
|
|
||||||
### Manage Cache
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git-commit-ai cache
|
|
||||||
```
|
|
||||||
|
|
||||||
### Validate Commit Message
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git-commit-ai validate "feat(auth): add login"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
### Config File
|
Create a `.git-commit-ai/config.yaml` file in your repository:
|
||||||
|
|
||||||
Create `.git-commit-ai/config.yaml`:
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
ollama:
|
model: qwen2.5-coder:3b
|
||||||
model: "qwen2.5-coder:3b"
|
base_url: http://localhost:11434
|
||||||
base_url: "http://localhost:11434"
|
conventional: true
|
||||||
timeout: 120
|
max_length: 80
|
||||||
|
|
||||||
commit:
|
|
||||||
max_length: 80
|
|
||||||
num_suggestions: 3
|
|
||||||
conventional_by_default: false
|
|
||||||
|
|
||||||
cache:
|
|
||||||
enabled: true
|
|
||||||
directory: ".git-commit-ai/cache"
|
|
||||||
ttl_hours: 24
|
|
||||||
```
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export OLLAMA_MODEL=qwen2.5-coder:3b
|
|
||||||
export OLLAMA_BASE_URL=http://localhost:11434
|
|
||||||
export COMMIT_MAX_LENGTH=80
|
|
||||||
export CACHE_ENABLED=true
|
|
||||||
```
|
|
||||||
|
|
||||||
## Custom Prompts
|
|
||||||
|
|
||||||
Create custom prompt templates in `.git-commit-ai/prompts/`:
|
|
||||||
|
|
||||||
- `default.txt`: Standard commit message prompts
|
|
||||||
- `conventional.txt`: Conventional commit prompts
|
|
||||||
- `system_default.txt`: System prompt for standard mode
|
|
||||||
- `system_conventional.txt`: System prompt for conventional mode
|
|
||||||
|
|
||||||
## Conventional Commits
|
|
||||||
|
|
||||||
Supported commit types:
|
|
||||||
- `feat`: A new feature
|
|
||||||
- `fix`: A bug fix
|
|
||||||
- `docs`: Documentation only changes
|
|
||||||
- `style`: Changes that do not affect the meaning of the code (white-space, formatting, etc)
|
|
||||||
- `refactor`: A code change that neither fixes a bug nor adds a feature
|
|
||||||
- `perf`: A code change that improves performance
|
|
||||||
- `test`: Adding missing tests or correcting existing tests
|
|
||||||
- `chore`: Changes to the build process or auxiliary tools
|
|
||||||
- `ci`: Changes to our CI configuration files and scripts
|
|
||||||
- `build`: Changes that affect the build system or external dependencies
|
|
||||||
- `revert`: Reverts a previous commit
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```bash
|
|
||||||
git-commit-ai generate --conventional
|
|
||||||
# Output:
|
|
||||||
# 1. feat(auth): add user authentication
|
|
||||||
# 2. fix: resolve login validation issue
|
|
||||||
# 3. docs: update API documentation
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
### Ollama server not running
|
- Ensure Ollama is running: `ollama list`
|
||||||
|
- Check model is available: `ollama pull <model>`
|
||||||
```bash
|
- Verify git repository has staged changes
|
||||||
# Start Ollama server
|
|
||||||
ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model not found
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Pull the model
|
|
||||||
ollama pull qwen2.5-coder:3b
|
|
||||||
|
|
||||||
# Or use git-commit-ai to pull
|
|
||||||
git-commit-ai pull --model qwen2.5-coder:3b
|
|
||||||
```
|
|
||||||
|
|
||||||
### No staged changes
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Stage your changes first
|
|
||||||
git add <files>
|
|
||||||
git-commit-ai generate
|
|
||||||
```
|
|
||||||
|
|
||||||
## Contributing
|
|
||||||
|
|
||||||
1. Fork the repository
|
|
||||||
2. Create a feature branch
|
|
||||||
3. Make your changes
|
|
||||||
4. Run tests: `pytest git_commit_ai/tests/ -v`
|
|
||||||
5. Submit a pull request
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
MIT License - see LICENSE file for details
|
|
||||||
|
|||||||
@@ -1,3 +1 @@
|
|||||||
"""Git Commit AI - A privacy-first CLI tool for generating Git commit messages."""
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
"""CLI interface for Git Commit AI."""
|
from git_commit_ai.cli.cli import main
|
||||||
|
|
||||||
|
__all__ = ["main"]
|
||||||
|
|||||||
4
git_commit_ai/cli/__main__.py
Normal file
4
git_commit_ai/cli/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from git_commit_ai.cli.cli import main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,343 +1,44 @@
|
|||||||
"""CLI interface for Git Commit AI."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
from git_commit_ai.core.git_handler import get_staged_changes, get_commit_history
|
||||||
from git_commit_ai.core.cache import CacheManager, get_cache_manager
|
from git_commit_ai.core.ollama_client import generate_commit_message
|
||||||
from git_commit_ai.core.config import Config, get_config
|
from git_commit_ai.core.prompt_builder import build_prompt
|
||||||
from git_commit_ai.core.conventional import (
|
from git_commit_ai.core.conventional import validate_conventional, fix_conventional
|
||||||
ConventionalCommitParser,
|
from git_commit_ai.core.config import load_config
|
||||||
ConventionalCommitFixer,
|
|
||||||
validate_commit_message,
|
|
||||||
)
|
|
||||||
from git_commit_ai.core.git_handler import GitError, GitHandler, get_git_handler
|
|
||||||
from git_commit_ai.core.ollama_client import OllamaClient, OllamaError, get_client
|
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.option(
|
def cli():
|
||||||
"--config",
|
"""AI-powered Git commit message generator."""
|
||||||
type=click.Path(exists=True, dir_okay=False),
|
pass
|
||||||
help="Path to config.yaml file",
|
|
||||||
)
|
|
||||||
@click.pass_context
|
|
||||||
def main(ctx: click.Context, config: str) -> None:
|
|
||||||
"""Git Commit AI - Generate intelligent commit messages with local LLM."""
|
|
||||||
ctx.ensure_object(dict)
|
|
||||||
cfg = get_config(config) if config else get_config()
|
|
||||||
ctx.obj["config"] = cfg
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
@main.command()
|
@click.option('--conventional', is_flag=True, help='Generate conventional commit format')
|
||||||
@click.option(
|
@click.option('--model', default=None, help='Ollama model to use')
|
||||||
"--conventional/--no-conventional",
|
@click.option('--base-url', default=None, help='Ollama API base URL')
|
||||||
default=None,
|
def generate(conventional, model, base_url):
|
||||||
help="Generate conventional commit format messages",
|
"""Generate a commit message for staged changes."""
|
||||||
)
|
try:
|
||||||
@click.option(
|
config = load_config()
|
||||||
"--model",
|
model = model or config.get('model', 'qwen2.5-coder:3b')
|
||||||
default=None,
|
base_url = base_url or config.get('base_url', 'http://localhost:11434')
|
||||||
help="Ollama model to use",
|
|
||||||
)
|
staged = get_staged_changes()
|
||||||
@click.option(
|
if not staged:
|
||||||
"--base-url",
|
click.echo("No staged changes found. Stage your changes first.")
|
||||||
default=None,
|
return
|
||||||
help="Ollama API base URL",
|
|
||||||
)
|
history = get_commit_history()
|
||||||
@click.option(
|
prompt = build_prompt(staged, conventional=conventional, history=history)
|
||||||
"--interactive/--no-interactive",
|
|
||||||
default=None,
|
message = generate_commit_message(prompt, model=model, base_url=base_url)
|
||||||
help="Interactive mode for selecting messages",
|
|
||||||
)
|
if conventional:
|
||||||
@click.option(
|
is_valid, suggestion = validate_conventional(message)
|
||||||
"--show-diff",
|
|
||||||
is_flag=True,
|
|
||||||
default=None,
|
|
||||||
help="Show the diff being analyzed",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--auto-fix",
|
|
||||||
is_flag=True,
|
|
||||||
default=False,
|
|
||||||
help="Auto-fix conventional commit format issues",
|
|
||||||
)
|
|
||||||
@click.pass_obj
|
|
||||||
def generate(
|
|
||||||
ctx: dict,
|
|
||||||
conventional: bool | None,
|
|
||||||
model: str | None,
|
|
||||||
base_url: str | None,
|
|
||||||
interactive: bool | None,
|
|
||||||
show_diff: bool,
|
|
||||||
auto_fix: bool,
|
|
||||||
) -> None:
|
|
||||||
"""Generate commit message suggestions for staged changes."""
|
|
||||||
config: Config = ctx.get("config", get_config())
|
|
||||||
|
|
||||||
if conventional is None:
|
|
||||||
conventional = config.conventional_by_default
|
|
||||||
if interactive is None:
|
|
||||||
interactive = config.interactive
|
|
||||||
if show_diff is None:
|
|
||||||
show_diff = config.show_diff
|
|
||||||
|
|
||||||
git_handler = get_git_handler()
|
|
||||||
|
|
||||||
if not git_handler.is_repository():
|
|
||||||
click.echo(click.style("Error: Not in a git repository", fg="red"), err=True)
|
|
||||||
click.echo("Please run this command from within a git repository.", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not git_handler.is_staged():
|
|
||||||
click.echo(click.style("No staged changes found.", fg="yellow"))
|
|
||||||
click.echo("Please stage your changes first with 'git add <files>'", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
diff = git_handler.get_staged_changes()
|
|
||||||
if show_diff:
|
|
||||||
click.echo("\nStaged diff:")
|
|
||||||
click.echo("-" * 50)
|
|
||||||
click.echo(diff[:2000] + "..." if len(diff) > 2000 else diff)
|
|
||||||
click.echo("-" * 50)
|
|
||||||
|
|
||||||
cache_manager = get_cache_manager(config)
|
|
||||||
|
|
||||||
cached = cache_manager.get(diff, conventional=conventional, model=model or config.ollama_model)
|
|
||||||
if cached:
|
|
||||||
messages = cached
|
|
||||||
click.echo(click.style("Using cached suggestions", fg="cyan"))
|
|
||||||
else:
|
|
||||||
ollama_client = get_client(config)
|
|
||||||
if model:
|
|
||||||
ollama_client.model = model
|
|
||||||
if base_url:
|
|
||||||
ollama_client.base_url = base_url
|
|
||||||
|
|
||||||
if not ollama_client.is_available():
|
|
||||||
click.echo(click.style("Error: Ollama server is not available", fg="red"), err=True)
|
|
||||||
click.echo(f"Please ensure Ollama is running at {ollama_client.base_url}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not ollama_client.check_model_exists():
|
|
||||||
click.echo(click.style(f"Model '{ollama_client.model}' not found", fg="yellow"), err=True)
|
|
||||||
if click.confirm("Would you like to pull this model?"):
|
|
||||||
if ollama_client.pull_model():
|
|
||||||
click.echo(click.style("Model pulled successfully", fg="green"))
|
|
||||||
else:
|
|
||||||
click.echo(click.style("Failed to pull model", fg="red"), err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
available = ollama_client.list_models()
|
|
||||||
if available:
|
|
||||||
click.echo("Available models:", err=True)
|
|
||||||
for m in available[:10]:
|
|
||||||
click.echo(f" - {m.get('name', 'unknown')}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
commit_history = git_handler.get_commit_history(max_commits=3)
|
|
||||||
context = "\n".join(f"- {c['hash']}: {c['message']}" for c in commit_history)
|
|
||||||
|
|
||||||
response = ollama_client.generate_commit_message(
|
|
||||||
diff=diff, context=context if context else None, conventional=conventional, model=model
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [m.strip() for m in response.split("\n") if m.strip() and not m.strip().lower().startswith("suggestion")]
|
|
||||||
|
|
||||||
if len(messages) == 1:
|
|
||||||
single = messages[0].split("1.", "2.", "3.")
|
|
||||||
if len(single) > 1:
|
|
||||||
messages = [s.strip() for s in single if s.strip()]
|
|
||||||
|
|
||||||
messages = messages[:config.num_suggestions]
|
|
||||||
|
|
||||||
cache_manager.set(diff, messages, conventional=conventional, model=model or config.ollama_model)
|
|
||||||
|
|
||||||
except OllamaError as e:
|
|
||||||
click.echo(click.style(f"Error generating commit message: {e}", fg="red"), err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
click.echo(click.style("No suggestions generated", fg="yellow"), err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if conventional and auto_fix:
|
|
||||||
fixed_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
is_valid, errors = validate_commit_message(msg)
|
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
fixed = ConventionalCommitFixer.fix(msg, diff)
|
fixed = fix_conventional(message, staged)
|
||||||
fixed_messages.append(fixed)
|
if fixed:
|
||||||
else:
|
message = fixed
|
||||||
fixed_messages.append(msg)
|
|
||||||
messages = fixed_messages
|
click.echo(f"\nSuggested commit message:\n{message}")
|
||||||
|
|
||||||
click.echo("\n" + click.style("Suggested commit messages:", fg="green"))
|
except Exception as e:
|
||||||
for i, msg in enumerate(messages, 1):
|
click.echo(f"Error: {e}")
|
||||||
click.echo(f" {i}. {msg}")
|
|
||||||
|
|
||||||
if conventional:
|
|
||||||
click.echo()
|
|
||||||
for i, msg in enumerate(messages, 1):
|
|
||||||
is_valid, errors = validate_commit_message(msg)
|
|
||||||
if is_valid:
|
|
||||||
click.echo(click.style(f" {i}. [Valid conventional format]", fg="green"))
|
|
||||||
else:
|
|
||||||
click.echo(click.style(f" {i}. [Format issues: {', '.join(errors)}]", fg="yellow"))
|
|
||||||
|
|
||||||
if interactive:
|
|
||||||
choice = click.prompt("\nSelect a message (number) or press Enter to see all:", type=int, default=0, show_default=False)
|
|
||||||
if 1 <= choice <= len(messages):
|
|
||||||
selected = messages[choice - 1]
|
|
||||||
click.echo(f"\nSelected: {selected}")
|
|
||||||
click.echo(f"\nTo commit, run:")
|
|
||||||
click.echo(f' git commit -m "{selected}"')
|
|
||||||
else:
|
|
||||||
click.echo(f"\nTo use the first suggestion, run:")
|
|
||||||
click.echo(click.style(f' git commit -m "{messages[0]}"', fg="cyan"))
|
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
|
||||||
@click.option("--model", help="Ollama model to check")
|
|
||||||
@click.pass_obj
|
|
||||||
def status(ctx: dict, model: str | None) -> None:
|
|
||||||
"""Check Ollama and repository status."""
|
|
||||||
config: Config = ctx.get("config", get_config())
|
|
||||||
|
|
||||||
click.echo("Git Commit AI Status")
|
|
||||||
click.echo("=" * 40)
|
|
||||||
|
|
||||||
git_handler = get_git_handler()
|
|
||||||
click.echo(f"\nGit Repository: {'Yes' if git_handler.is_repository() else 'No'}")
|
|
||||||
|
|
||||||
if git_handler.is_repository():
|
|
||||||
click.echo(f"Staged Changes: {'Yes' if git_handler.is_staged() else 'No'}")
|
|
||||||
|
|
||||||
ollama_client = get_client(config)
|
|
||||||
if model:
|
|
||||||
ollama_client.model = model
|
|
||||||
|
|
||||||
click.echo(f"\nOllama:")
|
|
||||||
click.echo(f" Base URL: {ollama_client.base_url}")
|
|
||||||
click.echo(f" Model: {ollama_client.model}")
|
|
||||||
|
|
||||||
if ollama_client.is_available():
|
|
||||||
click.echo(f" Status: {click.style('Running', fg='green')}")
|
|
||||||
|
|
||||||
if ollama_client.check_model_exists():
|
|
||||||
click.echo(f" Model: {click.style('Available', fg='green')}")
|
|
||||||
else:
|
|
||||||
click.echo(f" Model: {click.style('Not found', fg='yellow')}")
|
|
||||||
available = ollama_client.list_models()
|
|
||||||
if available:
|
|
||||||
click.echo(" Available models:")
|
|
||||||
for m in available[:5]:
|
|
||||||
click.echo(f" - {m.get('name', 'unknown')}")
|
|
||||||
else:
|
|
||||||
click.echo(f" Status: {click.style('Not running', fg='red')}")
|
|
||||||
click.echo(f" Start Ollama with: {click.style('ollama serve', fg='cyan')}")
|
|
||||||
|
|
||||||
cache_manager = get_cache_manager(config)
|
|
||||||
stats = cache_manager.get_stats()
|
|
||||||
click.echo(f"\nCache:")
|
|
||||||
click.echo(f" Enabled: {'Yes' if stats['enabled'] else 'No'}")
|
|
||||||
click.echo(f" Entries: {stats['entries']}")
|
|
||||||
if stats['entries'] > 0:
|
|
||||||
click.echo(f" Size: {stats['size_bytes'] // 1024} KB")
|
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
|
||||||
@click.pass_obj
|
|
||||||
def models(ctx: dict) -> None:
|
|
||||||
"""List available Ollama models."""
|
|
||||||
config: Config = ctx.get("config", get_config())
|
|
||||||
ollama_client = get_client(config)
|
|
||||||
|
|
||||||
if not ollama_client.is_available():
|
|
||||||
click.echo(click.style("Error: Ollama server is not available", fg="red"), err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
models = ollama_client.list_models()
|
|
||||||
if models:
|
|
||||||
click.echo("Available models:")
|
|
||||||
for m in models:
|
|
||||||
name = m.get("name", "unknown")
|
|
||||||
size = m.get("size", 0)
|
|
||||||
size_mb = size / (1024 * 1024) if size else 0
|
|
||||||
click.echo(f" {name} ({size_mb:.1f} MB)")
|
|
||||||
else:
|
|
||||||
click.echo("No models found.")
|
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
|
||||||
@click.option("--model", help="Model to pull")
|
|
||||||
@click.pass_obj
|
|
||||||
def pull(ctx: dict, model: str | None) -> None:
|
|
||||||
"""Pull an Ollama model."""
|
|
||||||
config: Config = ctx.get("config", get_config())
|
|
||||||
ollama_client = get_client(config)
|
|
||||||
|
|
||||||
model = model or config.ollama_model
|
|
||||||
|
|
||||||
if not ollama_client.is_available():
|
|
||||||
click.echo(click.style("Error: Ollama server is not available", fg="red"), err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
with click.progressbar(length=100, label=f"Pulling {model}", show_percent=True, show_pos=True) as progress:
|
|
||||||
success = ollama_client.pull_model(model)
|
|
||||||
if success:
|
|
||||||
progress.update(100)
|
|
||||||
click.echo(click.style(f"\nModel {model} pulled successfully", fg="green"))
|
|
||||||
else:
|
|
||||||
click.echo(click.style(f"\nFailed to pull model {model}", fg="red"), err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
|
||||||
@click.option("--force", is_flag=True, help="Force cleanup without confirmation")
|
|
||||||
@click.pass_obj
|
|
||||||
def cache(ctx: dict, force: bool) -> None:
|
|
||||||
"""Manage cache."""
|
|
||||||
config: Config = ctx.get("config", get_config())
|
|
||||||
cache_manager = get_cache_manager(config)
|
|
||||||
|
|
||||||
stats = cache_manager.get_stats()
|
|
||||||
|
|
||||||
click.echo("Cache Status:")
|
|
||||||
click.echo(f" Enabled: {'Yes' if stats['enabled'] else 'No'}")
|
|
||||||
click.echo(f" Entries: {stats['entries']}")
|
|
||||||
click.echo(f" Expired: {stats['expired']}")
|
|
||||||
click.echo(f" Size: {stats['size_bytes'] // 1024} KB")
|
|
||||||
|
|
||||||
if stats['entries'] > 0:
|
|
||||||
if force or click.confirm("\nClear all cache entries?"):
|
|
||||||
cleared = cache_manager.clear()
|
|
||||||
click.echo(f"Cleared {cleared} entries")
|
|
||||||
|
|
||||||
|
|
||||||
@main.command()
|
|
||||||
@click.argument("message")
|
|
||||||
@click.option("--auto-fix", is_flag=True, help="Attempt to auto-fix format issues")
|
|
||||||
def validate(message: str, auto_fix: bool) -> None:
|
|
||||||
"""Validate a commit message format."""
|
|
||||||
is_valid, errors = validate_commit_message(message)
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
click.echo(click.style("Valid commit message", fg="green"))
|
|
||||||
else:
|
|
||||||
click.echo(click.style("Invalid commit message:", fg="red"))
|
|
||||||
for error in errors:
|
|
||||||
click.echo(f" - {error}")
|
|
||||||
|
|
||||||
if auto_fix:
|
|
||||||
fixed = ConventionalCommitFixer.fix(message, "")
|
|
||||||
if fixed != message:
|
|
||||||
click.echo()
|
|
||||||
click.echo(click.style(f"Suggested fix: {fixed}", fg="cyan"))
|
|
||||||
|
|
||||||
sys.exit(0 if is_valid else 1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Core modules for Git Commit AI."""
|
|
||||||
|
|||||||
@@ -1,114 +1,16 @@
|
|||||||
"""Configuration management for Git Commit AI."""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
def load_config():
|
||||||
class Config:
|
"""Load configuration from file."""
|
||||||
"""Configuration manager that loads from YAML and supports env overrides."""
|
config_paths = [
|
||||||
|
'.git-commit-ai/config.yaml',
|
||||||
def __init__(self, config_path: Optional[str] = None):
|
os.path.expanduser('~/.config/git-commit-ai/config.yaml')
|
||||||
if config_path is None:
|
]
|
||||||
config_path = os.environ.get("CONFIG_PATH", str(Path(".git-commit-ai") / "config.yaml"))
|
|
||||||
self.config_path = Path(config_path)
|
for path in config_paths:
|
||||||
self._config: dict[str, Any] = {}
|
if os.path.exists(path):
|
||||||
self._load_config()
|
with open(path) as f:
|
||||||
|
return yaml.safe_load(f) or {}
|
||||||
def _load_config(self) -> None:
|
|
||||||
if self.config_path.exists():
|
return {}
|
||||||
try:
|
|
||||||
with open(self.config_path, 'r') as f:
|
|
||||||
self._config = yaml.safe_load(f) or {}
|
|
||||||
except yaml.YAMLError as e:
|
|
||||||
print(f"Warning: Failed to parse config file: {e}")
|
|
||||||
self._config = {}
|
|
||||||
else:
|
|
||||||
self._config = {}
|
|
||||||
|
|
||||||
def get(self, key: str, default: Any = None) -> Any:
|
|
||||||
env_key = key.upper().replace(".", "_")
|
|
||||||
env_value = os.environ.get(env_key)
|
|
||||||
if env_value is not None:
|
|
||||||
return self._parse_env_value(env_value)
|
|
||||||
|
|
||||||
keys = key.split(".")
|
|
||||||
value = self._config
|
|
||||||
for k in keys:
|
|
||||||
if isinstance(value, dict):
|
|
||||||
value = value.get(k)
|
|
||||||
else:
|
|
||||||
return default
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _parse_env_value(self, value: str) -> Any:
|
|
||||||
if value.lower() in ("true", "false"):
|
|
||||||
return value.lower() == "true"
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ollama_model(self) -> str:
|
|
||||||
return self.get("ollama.model", "qwen2.5-coder:3b")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ollama_base_url(self) -> str:
|
|
||||||
return self.get("ollama.base_url", "http://localhost:11434")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ollama_timeout(self) -> int:
|
|
||||||
return self.get("ollama.timeout", 120)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_message_length(self) -> int:
|
|
||||||
return self.get("commit.max_length", 80)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_suggestions(self) -> int:
|
|
||||||
return self.get("commit.num_suggestions", 3)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def conventional_by_default(self) -> bool:
|
|
||||||
return self.get("commit.conventional_by_default", False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache_enabled(self) -> bool:
|
|
||||||
return self.get("cache.enabled", True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache_directory(self) -> str:
|
|
||||||
return self.get("cache.directory", ".git-commit-ai/cache")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache_ttl_hours(self) -> int:
|
|
||||||
return self.get("cache.ttl_hours", 24)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prompt_directory(self) -> str:
|
|
||||||
return self.get("prompts.directory", ".git-commit-ai/prompts")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def show_diff(self) -> bool:
|
|
||||||
return self.get("output.show_diff", False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def interactive(self) -> bool:
|
|
||||||
return self.get("output.interactive", False)
|
|
||||||
|
|
||||||
def reload(self) -> None:
|
|
||||||
self._load_config()
|
|
||||||
|
|
||||||
|
|
||||||
def get_config(config_path: Optional[str] = None) -> Config:
|
|
||||||
return Config(config_path)
|
|
||||||
|
|||||||
@@ -1,176 +1,21 @@
|
|||||||
"""Conventional commit validation and utilities."""
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
CONVENTIONAL_PATTERN = re.compile(r'^(\w+)(?:\((\w+)\))?: (.+)$')
|
||||||
|
|
||||||
VALID_TYPES = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"]
|
def validate_conventional(message):
|
||||||
|
"""Validate if message follows conventional commit format."""
|
||||||
|
match = CONVENTIONAL_PATTERN.match(message.strip())
|
||||||
|
return bool(match), match.group(0) if match else message
|
||||||
|
|
||||||
CONVENTIONAL_PATTERN = re.compile(
|
def fix_conventional(message, diff):
|
||||||
r"^(?P<type>feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)"
|
"""Attempt to fix conventional commit format."""
|
||||||
r"(?:\((?P<scope>[^)]+)\))?: (?P<description>.+)$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ParsedCommit:
|
|
||||||
"""Parsed conventional commit message."""
|
|
||||||
type: str
|
|
||||||
scope: Optional[str]
|
|
||||||
description: str
|
|
||||||
raw: str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def formatted(self) -> str:
|
|
||||||
if self.scope:
|
|
||||||
return f"{self.type}({self.scope}): {self.description}"
|
|
||||||
return f"{self.type}: {self.description}"
|
|
||||||
|
|
||||||
|
|
||||||
class ConventionalCommitParser:
|
|
||||||
"""Parser for conventional commit messages."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse(message: str) -> Optional[ParsedCommit]:
|
|
||||||
message = message.strip()
|
|
||||||
match = CONVENTIONAL_PATTERN.match(message)
|
|
||||||
if match:
|
|
||||||
return ParsedCommit(
|
|
||||||
type=match.group("type"), scope=match.group("scope"),
|
|
||||||
description=match.group("description"), raw=message)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_valid(message: str) -> bool:
|
|
||||||
return ConventionalCommitParser.parse(message) is not None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def validate(message: str) -> list[str]:
|
|
||||||
errors = []
|
|
||||||
message = message.strip()
|
|
||||||
|
|
||||||
if not message:
|
|
||||||
errors.append("Commit message cannot be empty")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
if not CONVENTIONAL_PATTERN.match(message):
|
|
||||||
errors.append("Message does not follow conventional commit format. Expected: type(scope): description")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
parsed = ConventionalCommitParser.parse(message)
|
|
||||||
if parsed:
|
|
||||||
if parsed.type not in VALID_TYPES:
|
|
||||||
errors.append(f"Invalid type '{parsed.type}'. Valid types: {', '.join(VALID_TYPES)}")
|
|
||||||
|
|
||||||
if parsed.scope and len(parsed.scope) > 20:
|
|
||||||
errors.append("Scope is too long (max 20 characters)")
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
class ConventionalCommitFixer:
|
|
||||||
"""Auto-fixer for conventional commit format issues."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def fix(message: str, diff: str) -> str:
|
|
||||||
message = message.strip()
|
|
||||||
|
|
||||||
type_hint = ConventionalCommitFixer._detect_type(diff)
|
|
||||||
if not type_hint:
|
|
||||||
type_hint = "chore"
|
|
||||||
|
|
||||||
description = ConventionalCommitFixer._extract_description(message, diff)
|
|
||||||
|
|
||||||
if description:
|
|
||||||
return f"{type_hint}: {description}"
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _detect_type(diff: str) -> Optional[str]:
|
|
||||||
diff_lower = diff.lower()
|
|
||||||
|
|
||||||
if any(kw in diff_lower for kw in ["bug", "fix", "error", "issue", "problem"]):
|
|
||||||
return "fix"
|
|
||||||
if any(kw in diff_lower for kw in ["feature", "add", "implement", "new"]):
|
|
||||||
return "feat"
|
|
||||||
if any(kw in diff_lower for kw in ["doc", "readme", "comment"]):
|
|
||||||
return "docs"
|
|
||||||
if any(kw in diff_lower for kw in ["test", "spec"]):
|
|
||||||
return "test"
|
|
||||||
if any(kw in diff_lower for kw in ["refactor", "restructure", "reorganize"]):
|
|
||||||
return "refactor"
|
|
||||||
if any(kw in diff_lower for kw in ["style", "format", "lint"]):
|
|
||||||
return "style"
|
|
||||||
if any(kw in diff_lower for kw in ["perf", "optimize", "speed", "performance"]):
|
|
||||||
return "perf"
|
|
||||||
if any(kw in diff_lower for kw in ["build", "ci", "docker", "pipeline"]):
|
|
||||||
return "build"
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_description(message: str, diff: str) -> str:
|
|
||||||
if message and len(message) > 3:
|
|
||||||
cleaned = message.strip()
|
|
||||||
if ":" in cleaned:
|
|
||||||
cleaned = cleaned.split(":", 1)[1].strip()
|
|
||||||
if len(cleaned) > 3:
|
|
||||||
return cleaned[:72].rsplit(" ", 1)[0] if " " in cleaned else cleaned
|
|
||||||
|
|
||||||
files = ConventionalCommitFixer._get_changed_files(diff)
|
|
||||||
if files:
|
|
||||||
action = ConventionalCommitFixer._get_action(diff)
|
|
||||||
return f"{action} {files[0]}"
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_changed_files(diff: str) -> list[str]:
|
|
||||||
files = []
|
|
||||||
for line in diff.split("\n"):
|
|
||||||
if line.startswith("+++ b/") or line.startswith("--- a/"):
|
|
||||||
path = line[6:]
|
|
||||||
if path and path != "/dev/null":
|
|
||||||
filename = path.split("/")[-1]
|
|
||||||
if filename not in files:
|
|
||||||
files.append(filename)
|
|
||||||
return files[:3]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_action(diff: str) -> str:
|
|
||||||
if "new file:" in diff:
|
|
||||||
return "add"
|
|
||||||
if "delete file:" in diff:
|
|
||||||
return "remove"
|
|
||||||
if "rename" in diff:
|
|
||||||
return "rename"
|
|
||||||
return "update"
|
|
||||||
|
|
||||||
|
|
||||||
def validate_commit_message(message: str) -> tuple[bool, list[str]]:
|
|
||||||
errors = ConventionalCommitParser.validate(message)
|
|
||||||
return len(errors) == 0, errors
|
|
||||||
|
|
||||||
|
|
||||||
def format_conventional(message: str, commit_type: Optional[str] = None, scope: Optional[str] = None) -> str:
|
|
||||||
message = message.strip()
|
message = message.strip()
|
||||||
if not commit_type:
|
|
||||||
return message
|
if not message:
|
||||||
type_str = commit_type
|
return None
|
||||||
if scope:
|
|
||||||
type_str += f"({scope})"
|
if ':' in message:
|
||||||
if message and not message.startswith(f"{type_str}:"):
|
parts = message.split(':', 1)
|
||||||
return f"{type_str}: {message}"
|
return f"feat: {parts[1].strip()}"
|
||||||
return message
|
|
||||||
|
return f"feat: {message}"
|
||||||
|
|
||||||
def extract_conventional_parts(message: str) -> dict:
|
|
||||||
result = {"type": None, "scope": None, "description": message}
|
|
||||||
parsed = ConventionalCommitParser.parse(message)
|
|
||||||
if parsed:
|
|
||||||
result["type"] = parsed.type
|
|
||||||
result["scope"] = parsed.scope
|
|
||||||
result["description"] = parsed.description
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -1,126 +1,20 @@
|
|||||||
"""Git operations handler for Git Commit AI."""
|
import subprocess
|
||||||
|
from git import Repo, GitCommandError
|
||||||
|
|
||||||
import os
|
def get_staged_changes():
|
||||||
from pathlib import Path
|
"""Get staged changes from git."""
|
||||||
from typing import Optional
|
try:
|
||||||
|
repo = Repo('.')
|
||||||
|
staged = repo.index.diff('HEAD')
|
||||||
|
return [item.a_path for item in staged]
|
||||||
|
except (GitCommandError, ValueError):
|
||||||
|
diff = subprocess.run(['git', 'diff', '--cached', '--name-only'], capture_output=True, text=True)
|
||||||
|
return diff.stdout.strip().split('\n') if diff.stdout.strip() else []
|
||||||
|
|
||||||
from git import Repo
|
def get_commit_history(limit=5):
|
||||||
from git.exc import GitCommandError, InvalidGitRepositoryError
|
"""Get recent commit messages for context."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['git', 'log', '-n', str(limit), '--pretty=format:%s'], capture_output=True, text=True)
|
||||||
class GitHandler:
|
return result.stdout.strip().split('\n') if result.stdout.strip() else []
|
||||||
"""Handler for Git operations."""
|
except Exception:
|
||||||
|
return []
|
||||||
def __init__(self, repo_path: Optional[str] = None):
|
|
||||||
if repo_path is None:
|
|
||||||
repo_path = os.getcwd()
|
|
||||||
self.repo_path = Path(repo_path)
|
|
||||||
self._repo: Optional[Repo] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def repo(self) -> Repo:
|
|
||||||
if self._repo is None:
|
|
||||||
self._repo = Repo(str(self.repo_path))
|
|
||||||
return self._repo
|
|
||||||
|
|
||||||
def is_repository(self) -> bool:
|
|
||||||
try:
|
|
||||||
self.repo.git.status()
|
|
||||||
return True
|
|
||||||
except (InvalidGitRepositoryError, GitCommandError):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def ensure_repository(self) -> bool:
|
|
||||||
return self.is_repository()
|
|
||||||
|
|
||||||
def get_staged_changes(self) -> str:
|
|
||||||
try:
|
|
||||||
if not self.is_staged():
|
|
||||||
return ""
|
|
||||||
|
|
||||||
diff = self.repo.git.diff("--cached")
|
|
||||||
return diff
|
|
||||||
except GitCommandError as e:
|
|
||||||
raise GitError(f"Failed to get staged changes: {e}") from e
|
|
||||||
|
|
||||||
def get_staged_files(self) -> list[str]:
|
|
||||||
try:
|
|
||||||
staged = self.repo.index.diff("HEAD")
|
|
||||||
return [s.a_path for s in staged]
|
|
||||||
except GitCommandError:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def is_staged(self) -> bool:
|
|
||||||
try:
|
|
||||||
return bool(self.repo.index.diff("HEAD"))
|
|
||||||
except GitCommandError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_commit_history(self, max_commits: int = 5, conventional_only: bool = False) -> list[dict[str, str]]:
|
|
||||||
try:
|
|
||||||
commits = []
|
|
||||||
for commit in self.repo.iter_commits(max_count=max_commits):
|
|
||||||
message = commit.message.strip()
|
|
||||||
if conventional_only:
|
|
||||||
if not self._is_conventional(message):
|
|
||||||
continue
|
|
||||||
|
|
||||||
commits.append({"hash": commit.hexsha[:7], "message": message, "type": self._extract_type(message)})
|
|
||||||
return commits
|
|
||||||
except GitCommandError as e:
|
|
||||||
raise GitError(f"Failed to get commit history: {e}") from e
|
|
||||||
|
|
||||||
def _is_conventional(self, message: str) -> bool:
|
|
||||||
conventional_types = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"]
|
|
||||||
return any(message.startswith(f"{t}:") for t in conventional_types)
|
|
||||||
|
|
||||||
def _extract_type(self, message: str) -> str:
|
|
||||||
conventional_types = ["feat", "fix", "docs", "style", "refactor", "perf", "test", "build", "ci", "chore", "revert"]
|
|
||||||
for t in conventional_types:
|
|
||||||
if message.startswith(f"{t}:"):
|
|
||||||
return t
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
def get_changed_languages(self) -> list[str]:
|
|
||||||
staged_files = self.get_staged_files()
|
|
||||||
languages = set()
|
|
||||||
|
|
||||||
extension_map = {
|
|
||||||
".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", ".jsx": "React", ".tsx": "TypeScript React",
|
|
||||||
".java": "Java", ".go": "Go", ".rs": "Rust", ".rb": "Ruby", ".php": "PHP", ".swift": "Swift",
|
|
||||||
".c": "C", ".cpp": "C++", ".h": "C Header", ".cs": "C#", ".scala": "Scala", ".kt": "Kotlin",
|
|
||||||
".lua": "Lua", ".r": "R", ".sql": "SQL", ".html": "HTML", ".css": "CSS", ".scss": "SCSS",
|
|
||||||
".json": "JSON", ".yaml": "YAML", ".yml": "YAML", ".xml": "XML", ".md": "Markdown",
|
|
||||||
".sh": "Shell", ".bash": "Bash", ".zsh": "Zsh", ".dockerfile": "Docker", ".tf": "Terraform",
|
|
||||||
}
|
|
||||||
|
|
||||||
for file_path in staged_files:
|
|
||||||
ext = Path(file_path).suffix.lower()
|
|
||||||
if ext in extension_map:
|
|
||||||
languages.add(extension_map[ext])
|
|
||||||
|
|
||||||
return sorted(list(languages))
|
|
||||||
|
|
||||||
def get_diff_summary(self) -> str:
|
|
||||||
diff = self.get_staged_changes()
|
|
||||||
if not diff:
|
|
||||||
return "No staged changes"
|
|
||||||
|
|
||||||
files = self.get_staged_files()
|
|
||||||
languages = self.get_changed_languages()
|
|
||||||
|
|
||||||
summary = f"Files changed: {len(files)}\n"
|
|
||||||
if languages:
|
|
||||||
summary += f"Languages: {', '.join(languages)}\n"
|
|
||||||
summary += f"\nDiff length: {len(diff)} characters"
|
|
||||||
|
|
||||||
return summary
|
|
||||||
|
|
||||||
|
|
||||||
class GitError(Exception):
|
|
||||||
"""Exception raised for Git-related errors."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_git_handler(repo_path: Optional[str] = None) -> GitHandler:
|
|
||||||
return GitHandler(repo_path)
|
|
||||||
|
|||||||
@@ -1,136 +1,14 @@
|
|||||||
"""Ollama API client for Git Commit AI."""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import logging
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import ollama
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from git_commit_ai.core.config import Config, get_config
|
def generate_commit_message(prompt, model="qwen2.5-coder:3b", base_url="http://localhost:11434"):
|
||||||
|
"""Generate commit message using Ollama."""
|
||||||
logger = logging.getLogger(__name__)
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/api/generate",
|
||||||
class OllamaClient:
|
json={"model": model, "prompt": prompt, "stream": False},
|
||||||
"""Client for communicating with Ollama API."""
|
timeout=60
|
||||||
|
|
||||||
def __init__(self, config: Optional[Config] = None):
|
|
||||||
self.config = config or get_config()
|
|
||||||
self._model: str = self.config.ollama_model
|
|
||||||
self._base_url: str = self.config.ollama_base_url
|
|
||||||
self._timeout: int = self.config.ollama_timeout
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model(self) -> str:
|
|
||||||
return self._model
|
|
||||||
|
|
||||||
@model.setter
|
|
||||||
def model(self, value: str) -> None:
|
|
||||||
self._model = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def base_url(self) -> str:
|
|
||||||
return self._base_url
|
|
||||||
|
|
||||||
@base_url.setter
|
|
||||||
def base_url(self, value: str) -> None:
|
|
||||||
self._base_url = value
|
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{self._base_url}/api/tags", timeout=10)
|
|
||||||
return response.status_code == 200
|
|
||||||
except requests.RequestException:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_models(self) -> list[dict[str, Any]]:
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{self._base_url}/api/tags", timeout=self._timeout)
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
return data.get("models", [])
|
|
||||||
return []
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.error(f"Failed to list models: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def check_model_exists(self) -> bool:
|
|
||||||
models = self.list_models()
|
|
||||||
model_names = [m.get("name", "") for m in models]
|
|
||||||
return any(self._model in name for name in model_names)
|
|
||||||
|
|
||||||
def pull_model(self, model: Optional[str] = None) -> bool:
|
|
||||||
model = model or self._model
|
|
||||||
try:
|
|
||||||
client = ollama.Client(host=self._base_url)
|
|
||||||
client.pull(model)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to pull model {model}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate(self, prompt: str, system: Optional[str] = None, model: Optional[str] = None, num_predict: int = 200, temperature: float = 0.7) -> str:
|
|
||||||
model = model or self._model
|
|
||||||
try:
|
|
||||||
client = ollama.Client(host=self._base_url)
|
|
||||||
response = client.generate(
|
|
||||||
model=model, prompt=prompt, system=system,
|
|
||||||
options={"num_predict": num_predict, "temperature": temperature}
|
|
||||||
)
|
|
||||||
return response.get("response", "")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate response: {e}")
|
|
||||||
raise OllamaError(f"Failed to generate response: {e}") from e
|
|
||||||
|
|
||||||
def generate_commit_message(self, diff: str, context: Optional[str] = None, conventional: bool = False, model: Optional[str] = None) -> str:
|
|
||||||
from git_commit_ai.prompts import PromptBuilder
|
|
||||||
|
|
||||||
prompt_builder = PromptBuilder(self.config)
|
|
||||||
prompt = prompt_builder.build_prompt(diff, context, conventional)
|
|
||||||
system_prompt = prompt_builder.get_system_prompt(conventional)
|
|
||||||
|
|
||||||
response = self.generate(
|
|
||||||
prompt=prompt, system=system_prompt, model=model,
|
|
||||||
num_predict=self.config.max_message_length + 50,
|
|
||||||
temperature=0.7 if not conventional else 0.5,
|
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
return self._parse_commit_message(response)
|
return response.json().get('response', '').strip()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
def _parse_commit_message(self, response: str) -> str:
|
raise ConnectionError(f"Failed to connect to Ollama: {e}")
|
||||||
message = response.strip()
|
|
||||||
|
|
||||||
if message.startswith("```"):
|
|
||||||
lines = message.split("\n")
|
|
||||||
if len(lines) >= 3:
|
|
||||||
content = "\n".join(lines[1:-1])
|
|
||||||
if content.strip().startswith("git commit"):
|
|
||||||
content = content.replace("git commit -m ", "").strip()
|
|
||||||
if content.startswith('"') and content.endswith('"'):
|
|
||||||
content = content[1:-1]
|
|
||||||
return content.strip()
|
|
||||||
|
|
||||||
if message.startswith('"') and message.endswith('"'):
|
|
||||||
message = message[1:-1]
|
|
||||||
|
|
||||||
message = message.strip()
|
|
||||||
|
|
||||||
max_length = self.config.max_message_length
|
|
||||||
if len(message) > max_length:
|
|
||||||
message = message[:max_length].rsplit(" ", 1)[0]
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaError(Exception):
|
|
||||||
"""Exception raised for Ollama-related errors."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def generate_diff_hash(diff: str) -> str:
|
|
||||||
return hashlib.md5(diff.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def get_client(config: Optional[Config] = None) -> OllamaClient:
|
|
||||||
return OllamaClient(config)
|
|
||||||
|
|||||||
30
git_commit_ai/core/prompt_builder.py
Normal file
30
git_commit_ai/core/prompt_builder.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
CONVENTIONAL_PROMPT = """Generate a conventional commit message for these changes.
|
||||||
|
Format: <type>(<scope>): <description>
|
||||||
|
|
||||||
|
Types: feat, fix, docs, style, refactor, test, chore
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
{diff}
|
||||||
|
|
||||||
|
Recent commits for context:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Respond with only the commit message."""
|
||||||
|
|
||||||
|
DEFAULT_PROMPT = """Generate a concise commit message for these changes.
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
{diff}
|
||||||
|
|
||||||
|
Recent commits for context:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
Respond with only the commit message."""
|
||||||
|
|
||||||
|
def build_prompt(diff, conventional=False, history=None):
|
||||||
|
"""Build prompt for commit message generation."""
|
||||||
|
diff_text = "\n".join(diff) if isinstance(diff, list) else str(diff)
|
||||||
|
history_text = "\n".join(history) if history else "No previous commits"
|
||||||
|
|
||||||
|
template = CONVENTIONAL_PROMPT if conventional else DEFAULT_PROMPT
|
||||||
|
return template.format(diff=diff_text, history=history_text)
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
"""Prompt management for Git Commit AI."""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from git_commit_ai.core.config import Config, get_config
|
|
||||||
|
|
||||||
|
|
||||||
class PromptBuilder:
|
|
||||||
"""Builder for commit message prompts."""
|
|
||||||
|
|
||||||
DEFAULT_PROMPT = """You are a helpful assistant that generates git commit messages.
|
|
||||||
|
|
||||||
Analyze the following git diff and generate a concise, descriptive commit message.
|
|
||||||
The message should:
|
|
||||||
- Be clear and descriptive
|
|
||||||
- Explain what changed and why
|
|
||||||
- Be in present tense
|
|
||||||
- Not exceed 72 characters for the first line if possible
|
|
||||||
|
|
||||||
Git diff:
|
|
||||||
```
|
|
||||||
{diff}
|
|
||||||
```
|
|
||||||
|
|
||||||
{few_shot}
|
|
||||||
|
|
||||||
Generate 3 different commit message suggestions, one per line.
|
|
||||||
Format: Just the commit messages, one per line, nothing else.
|
|
||||||
|
|
||||||
Suggestions:
|
|
||||||
"""
|
|
||||||
|
|
||||||
CONVENTIONAL_PROMPT = """You are a helpful assistant that generates conventional git commit messages.
|
|
||||||
|
|
||||||
Generate a commit message in the conventional commit format:
|
|
||||||
- type(scope): description
|
|
||||||
- Examples: feat(auth): add login, fix: resolve memory leak
|
|
||||||
|
|
||||||
Valid types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert
|
|
||||||
|
|
||||||
Analyze the following git diff and generate commit messages.
|
|
||||||
Git diff:
|
|
||||||
```
|
|
||||||
{diff}
|
|
||||||
```
|
|
||||||
|
|
||||||
{few_shot}
|
|
||||||
|
|
||||||
Generate 3 different commit message suggestions, one per line.
|
|
||||||
Format: Just the commit messages, one per line, nothing else.
|
|
||||||
|
|
||||||
Suggestions:
|
|
||||||
"""
|
|
||||||
|
|
||||||
SYSTEM_DEFAULT = "You are a helpful assistant that generates clear and concise git commit messages."
|
|
||||||
|
|
||||||
SYSTEM_CONVENTIONAL = "You are a helpful assistant that generates conventional git commit messages. Always use the format: type(scope): description. Valid types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert"
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[Config] = None):
|
|
||||||
self.config = config or get_config()
|
|
||||||
self._prompt_dir = Path(self.config.prompt_directory)
|
|
||||||
|
|
||||||
def build_prompt(self, diff: str, context: Optional[str] = None, conventional: bool = False) -> str:
|
|
||||||
few_shot = self._build_few_shot(context)
|
|
||||||
|
|
||||||
if conventional:
|
|
||||||
template = self._get_conventional_template()
|
|
||||||
else:
|
|
||||||
template = self._get_default_template()
|
|
||||||
|
|
||||||
prompt = template.format(diff=diff[:10000] if len(diff) > 10000 else diff, few_shot=few_shot)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def _get_default_template(self) -> str:
|
|
||||||
custom_path = self._prompt_dir / "default.txt"
|
|
||||||
if custom_path.exists():
|
|
||||||
return custom_path.read_text()
|
|
||||||
return self.DEFAULT_PROMPT
|
|
||||||
|
|
||||||
def _get_conventional_template(self) -> str:
|
|
||||||
custom_path = self._prompt_dir / "conventional.txt"
|
|
||||||
if custom_path.exists():
|
|
||||||
return custom_path.read_text()
|
|
||||||
return self.CONVENTIONAL_PROMPT
|
|
||||||
|
|
||||||
def _build_few_shot(self, context: Optional[str]) -> str:
|
|
||||||
if not context:
|
|
||||||
return ""
|
|
||||||
return f"\n\nRecent commit history for context:\n{context}"
|
|
||||||
|
|
||||||
def get_system_prompt(self, conventional: bool = False) -> str:
|
|
||||||
if conventional:
|
|
||||||
custom_path = self._prompt_dir / "system_conventional.txt"
|
|
||||||
if custom_path.exists():
|
|
||||||
return custom_path.read_text()
|
|
||||||
return self.SYSTEM_CONVENTIONAL
|
|
||||||
|
|
||||||
custom_path = self._prompt_dir / "system_default.txt"
|
|
||||||
if custom_path.exists():
|
|
||||||
return custom_path.read_text()
|
|
||||||
return self.SYSTEM_DEFAULT
|
|
||||||
|
|
||||||
def get_supported_languages(self) -> list[str]:
|
|
||||||
return ["Python", "JavaScript", "TypeScript", "Java", "Go", "Rust", "Ruby", "PHP", "C", "C++", "C#", "Swift", "Kotlin", "Scala", "HTML", "CSS", "SQL", "Shell"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_builder(config: Optional[Config] = None) -> PromptBuilder:
|
|
||||||
return PromptBuilder(config)
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Tests for Git Commit AI."""
|
|
||||||
|
|||||||
@@ -1,103 +1,24 @@
|
|||||||
"""Pytest fixtures and configuration for Git Commit AI tests."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_git_handler():
|
def temp_git_repo():
|
||||||
"""Create a mock Git handler."""
|
"""Create a temporary git repository."""
|
||||||
handler = MagicMock()
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
handler.is_repository.return_value = True
|
os.chdir(tmpdir)
|
||||||
handler.is_staged.return_value = True
|
os.system('git init')
|
||||||
handler.get_staged_changes.return_value = """diff --git a/src/main.py b/src/main.py
|
yield tmpdir
|
||||||
index 1234567..abcdefg 100644
|
|
||||||
--- a/src/main.py
|
|
||||||
+++ b/src/main.py
|
|
||||||
@@ -1,3 +1,4 @@
|
|
||||||
+import new_module
|
|
||||||
def hello():
|
|
||||||
print("Hello, World!")
|
|
||||||
"""
|
|
||||||
handler.get_commit_history.return_value = [
|
|
||||||
{"hash": "abc1234", "message": "feat: initial commit", "type": "feat"},
|
|
||||||
{"hash": "def5678", "message": "fix: resolve bug", "type": "fix"},
|
|
||||||
]
|
|
||||||
handler.get_staged_files.return_value = ["src/main.py"]
|
|
||||||
handler.get_changed_languages.return_value = ["Python"]
|
|
||||||
return handler
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_ollama_client():
|
|
||||||
"""Create a mock Ollama client."""
|
|
||||||
client = MagicMock()
|
|
||||||
client.is_available.return_value = True
|
|
||||||
client.check_model_exists.return_value = True
|
|
||||||
client.list_models.return_value = [
|
|
||||||
{"name": "qwen2.5-coder:3b", "size": 2000000000},
|
|
||||||
{"name": "llama3:8b", "size": 4000000000},
|
|
||||||
]
|
|
||||||
client.generate_commit_message.return_value = "feat: add new feature"
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_config():
|
|
||||||
"""Create a mock configuration."""
|
|
||||||
config = MagicMock()
|
|
||||||
config.ollama_model = "qwen2.5-coder:3b"
|
|
||||||
config.ollama_base_url = "http://localhost:11434"
|
|
||||||
config.ollama_timeout = 120
|
|
||||||
config.max_message_length = 80
|
|
||||||
config.num_suggestions = 3
|
|
||||||
config.conventional_by_default = False
|
|
||||||
config.cache_enabled = True
|
|
||||||
config.cache_directory = ".git-commit-ai/cache"
|
|
||||||
config.cache_ttl_hours = 24
|
|
||||||
config.prompt_directory = ".git-commit-ai/prompts"
|
|
||||||
config.show_diff = False
|
|
||||||
config.interactive = False
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_git_repo(tmp_path):
|
|
||||||
"""Create a temporary git repository for testing."""
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
repo_dir = tmp_path / "test_repo"
|
|
||||||
repo_dir.mkdir()
|
|
||||||
|
|
||||||
os.chdir(repo_dir)
|
|
||||||
|
|
||||||
subprocess.run(["git", "init"], capture_output=True, check=True)
|
|
||||||
subprocess.run(["git", "config", "user.email", "test@example.com"], capture_output=True, check=True)
|
|
||||||
subprocess.run(["git", "config", "user.name", "Test User"], capture_output=True, check=True)
|
|
||||||
|
|
||||||
(repo_dir / "README.md").write_text("# Test Project\n")
|
|
||||||
subprocess.run(["git", "add", "."], capture_output=True, check=True)
|
|
||||||
subprocess.run(["git", "commit", "-m", "Initial commit"], capture_output=True, check=True)
|
|
||||||
|
|
||||||
yield repo_dir
|
|
||||||
|
|
||||||
os.chdir(tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_diff():
|
def sample_diff():
|
||||||
"""Provide a sample git diff for testing."""
|
"""Sample diff for testing."""
|
||||||
return """diff --git a/src/auth.py b/src/auth.py
|
return """diff --git a/main.py b/main.py
|
||||||
index 1234567..abcdefg 100644
|
index 1234567..abcdefg 100644
|
||||||
--- a/src/auth.py
|
--- a/main.py
|
||||||
+++ b/src/auth.py
|
+++ b/main.py
|
||||||
@@ -1,3 +1,4 @@
|
@@ -1,3 +1,4 @@
|
||||||
+from datetime import datetime
|
def hello():
|
||||||
def authenticate(user_id):
|
+ print("Hello, World!")
|
||||||
if user_id is None:
|
return "Hello"
|
||||||
return False
|
|
||||||
+ return datetime.now()
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,65 +1,17 @@
|
|||||||
"""Tests for the CLI module."""
|
import pytest
|
||||||
|
|
||||||
import sys
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import click
|
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
from git_commit_ai.cli.cli import cli, generate
|
||||||
|
|
||||||
from git_commit_ai.cli.cli import main
|
def test_cli_group():
|
||||||
|
"""Test CLI group creation."""
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ['--help'])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert 'AI-powered Git commit message generator' in result.output
|
||||||
|
|
||||||
|
def test_generate_command_exists():
|
||||||
class TestCLIBasic:
|
"""Test generate command exists."""
|
||||||
"""Basic CLI tests."""
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ['generate', '--help'])
|
||||||
def test_main_help(self):
|
assert result.exit_code == 0
|
||||||
"""Test main command help."""
|
assert 'Generate a commit message for staged changes' in result.output
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(main, ["--help"])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "Git Commit AI" in result.output
|
|
||||||
assert "generate" in result.output
|
|
||||||
assert "status" in result.output
|
|
||||||
|
|
||||||
def test_generate_help(self):
|
|
||||||
"""Test generate command help."""
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(main, ["generate", "--help"])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "conventional" in result.output
|
|
||||||
assert "model" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
class TestCLIValidation:
|
|
||||||
"""CLI validation tests."""
|
|
||||||
|
|
||||||
def test_validate_valid_message(self):
|
|
||||||
"""Test validating a valid conventional commit message."""
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(main, ["validate", "feat(auth): add login"])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "Valid" in result.output
|
|
||||||
|
|
||||||
def test_validate_invalid_message(self):
|
|
||||||
"""Test validating an invalid commit message."""
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(main, ["validate", "just a random message"])
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert "Invalid" in result.output
|
|
||||||
|
|
||||||
def test_validate_empty_message(self):
|
|
||||||
"""Test validating an empty commit message."""
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(main, ["validate", ""])
|
|
||||||
assert result.exit_code == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestCLIAutoFix:
|
|
||||||
"""CLI auto-fix tests."""
|
|
||||||
|
|
||||||
def test_validate_auto_fix(self):
|
|
||||||
"""Test auto-fix suggestion."""
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(main, ["validate", "add login feature", "--auto-fix"])
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert "Suggested fix" in result.output
|
|
||||||
|
|||||||
22
git_commit_ai/tests/test_config.py
Normal file
22
git_commit_ai/tests/test_config.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from git_commit_ai.core.config import load_config
|
||||||
|
|
||||||
|
def test_load_config_no_file():
|
||||||
|
"""Test loading config when no config file exists."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
os.chdir(tmpdir)
|
||||||
|
config = load_config()
|
||||||
|
assert config == {}
|
||||||
|
|
||||||
|
def test_load_config_with_file():
|
||||||
|
"""Test loading config from file."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
os.chdir(tmpdir)
|
||||||
|
os.makedirs('.git-commit-ai')
|
||||||
|
with open('.git-commit-ai/config.yaml', 'w') as f:
|
||||||
|
f.write('model: llama3\nconventional: true')
|
||||||
|
config = load_config()
|
||||||
|
assert config['model'] == 'llama3'
|
||||||
|
assert config['conventional'] is True
|
||||||
@@ -1,131 +1,21 @@
|
|||||||
"""Tests for conventional commit validation."""
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from git_commit_ai.core.conventional import validate_conventional, fix_conventional
|
||||||
|
|
||||||
from git_commit_ai.core.conventional import (
|
def test_validate_conventional_valid():
|
||||||
ConventionalCommitParser,
|
"""Test valid conventional commit."""
|
||||||
ConventionalCommitFixer,
|
message = "feat(auth): add login functionality"
|
||||||
validate_commit_message,
|
is_valid, _ = validate_conventional(message)
|
||||||
format_conventional,
|
assert is_valid is True
|
||||||
extract_conventional_parts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def test_validate_conventional_invalid():
|
||||||
|
"""Test invalid conventional commit."""
|
||||||
|
message = "just a regular commit"
|
||||||
|
is_valid, _ = validate_conventional(message)
|
||||||
|
assert is_valid is False
|
||||||
|
|
||||||
class TestConventionalCommitParser:
|
def test_fix_conventional():
|
||||||
"""Tests for ConventionalCommitParser."""
|
"""Test conventional commit fixing."""
|
||||||
|
message = "added new feature"
|
||||||
def test_parse_valid_message(self):
|
diff = ["main.py"]
|
||||||
message = "feat(auth): add user authentication"
|
fixed = fix_conventional(message, diff)
|
||||||
parsed = ConventionalCommitParser.parse(message)
|
assert fixed.startswith("feat:")
|
||||||
assert parsed is not None
|
|
||||||
assert parsed.type == "feat"
|
|
||||||
assert parsed.scope == "auth"
|
|
||||||
assert parsed.description == "add user authentication"
|
|
||||||
|
|
||||||
def test_parse_without_scope(self):
|
|
||||||
message = "fix: resolve memory leak"
|
|
||||||
parsed = ConventionalCommitParser.parse(message)
|
|
||||||
assert parsed is not None
|
|
||||||
assert parsed.type == "fix"
|
|
||||||
assert parsed.scope is None
|
|
||||||
|
|
||||||
def test_parse_invalid_message(self):
|
|
||||||
message = "just a random message"
|
|
||||||
parsed = ConventionalCommitParser.parse(message)
|
|
||||||
assert parsed is None
|
|
||||||
|
|
||||||
def test_is_valid(self):
|
|
||||||
assert ConventionalCommitParser.is_valid("feat: new feature") is True
|
|
||||||
assert ConventionalCommitParser.is_valid("invalid message") is False
|
|
||||||
|
|
||||||
def test_validate_valid(self):
|
|
||||||
errors = ConventionalCommitParser.validate("feat(auth): add login")
|
|
||||||
assert len(errors) == 0
|
|
||||||
|
|
||||||
def test_validate_invalid_type(self):
|
|
||||||
errors = ConventionalCommitParser.validate("invalid(scope): desc")
|
|
||||||
assert len(errors) > 0
|
|
||||||
assert any("Invalid type" in e for e in errors)
|
|
||||||
|
|
||||||
def test_validate_empty_message(self):
|
|
||||||
errors = ConventionalCommitParser.validate("")
|
|
||||||
assert len(errors) > 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestConventionalCommitFixer:
|
|
||||||
"""Tests for ConventionalCommitFixer."""
|
|
||||||
|
|
||||||
def test_fix_simple_message(self):
|
|
||||||
diff = """+++ b/src/auth.py
|
|
||||||
@@ -1,3 +1,4 @@
|
|
||||||
+def login():
|
|
||||||
+ pass
|
|
||||||
"""
|
|
||||||
fixed = ConventionalCommitFixer.fix("add login feature", diff)
|
|
||||||
assert fixed.startswith("feat:")
|
|
||||||
|
|
||||||
def test_fix_with_type_detection(self):
|
|
||||||
diff = """--- a/src/bug.py
|
|
||||||
+++ b/src/bug.py
|
|
||||||
@@ -1,3 +1,4 @@
|
|
||||||
-def calculate():
|
|
||||||
+def calculate():
|
|
||||||
return 1 / 0
|
|
||||||
+ return 1 / 1
|
|
||||||
"""
|
|
||||||
fixed = ConventionalCommitFixer.fix("fix bug", diff)
|
|
||||||
assert fixed.startswith("fix:")
|
|
||||||
|
|
||||||
def test_fix_preserves_description(self):
|
|
||||||
diff = """+++ b/src/auth.py
|
|
||||||
@@ -1,3 +1,4 @@
|
|
||||||
+def login():
|
|
||||||
"""
|
|
||||||
fixed = ConventionalCommitFixer.fix("add login functionality", diff)
|
|
||||||
assert "login" in fixed.lower()
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidateCommitMessage:
|
|
||||||
"""Tests for validate_commit_message function."""
|
|
||||||
|
|
||||||
def test_validate_valid(self):
|
|
||||||
is_valid, errors = validate_commit_message("feat(auth): add login")
|
|
||||||
assert is_valid is True
|
|
||||||
assert len(errors) == 0
|
|
||||||
|
|
||||||
def test_validate_invalid(self):
|
|
||||||
is_valid, errors = validate_commit_message("invalid")
|
|
||||||
assert is_valid is False
|
|
||||||
assert len(errors) > 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestFormatConventional:
|
|
||||||
"""Tests for format_conventional function."""
|
|
||||||
|
|
||||||
def test_format_with_type_and_scope(self):
|
|
||||||
result = format_conventional("add login", "feat", "auth")
|
|
||||||
assert result == "feat(auth): add login"
|
|
||||||
|
|
||||||
def test_format_with_type_only(self):
|
|
||||||
result = format_conventional("fix bug", "fix")
|
|
||||||
assert result == "fix: fix bug"
|
|
||||||
|
|
||||||
def test_format_already_formatted(self):
|
|
||||||
result = format_conventional("feat(auth): add login", "feat", "auth")
|
|
||||||
assert result == "feat(auth): add login"
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractConventionalParts:
|
|
||||||
"""Tests for extract_conventional_parts function."""
|
|
||||||
|
|
||||||
def test_extract_all_parts(self):
|
|
||||||
result = extract_conventional_parts("feat(auth): add login")
|
|
||||||
assert result["type"] == "feat"
|
|
||||||
assert result["scope"] == "auth"
|
|
||||||
assert result["description"] == "add login"
|
|
||||||
|
|
||||||
def test_extract_invalid_message(self):
|
|
||||||
result = extract_conventional_parts("invalid message")
|
|
||||||
assert result["type"] is None
|
|
||||||
assert result["scope"] is None
|
|
||||||
assert result["description"] == "invalid message"
|
|
||||||
|
|||||||
@@ -1,137 +1,21 @@
|
|||||||
"""Tests for the Git handler module."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from git.exc import GitCommandError
|
import subprocess
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from git_commit_ai.core.git_handler import get_staged_changes, get_commit_history
|
||||||
|
|
||||||
from git_commit_ai.core.git_handler import GitHandler, GitError, get_git_handler
|
def test_get_commit_history_empty():
|
||||||
|
"""Test get commit history when no commits exist."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
os.chdir(tmpdir)
|
||||||
|
subprocess.run(['git', 'init'], capture_output=True)
|
||||||
|
history = get_commit_history()
|
||||||
|
assert history == []
|
||||||
|
|
||||||
|
def test_get_staged_changes_empty_repo():
|
||||||
class TestGitHandlerBasic:
|
"""Test get staged changes in empty repository."""
|
||||||
"""Basic Git handler tests."""
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
os.chdir(tmpdir)
|
||||||
def test_is_repository_true(self, temp_git_repo):
|
subprocess.run(['git', 'init'], capture_output=True)
|
||||||
"""Test is_repository returns True for git repo."""
|
changes = get_staged_changes()
|
||||||
handler = GitHandler(str(temp_git_repo))
|
assert changes == []
|
||||||
assert handler.is_repository() is True
|
|
||||||
|
|
||||||
def test_is_repository_false(self, tmp_path):
|
|
||||||
"""Test is_repository returns False for non-git directory."""
|
|
||||||
handler = GitHandler(str(tmp_path))
|
|
||||||
assert handler.is_repository() is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestGitHandlerStagedChanges:
|
|
||||||
"""Tests for staged changes functionality."""
|
|
||||||
|
|
||||||
def test_is_staged_true(self, temp_git_repo):
|
|
||||||
"""Test is_staged returns True when changes are staged."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
|
|
||||||
test_file = temp_git_repo / "test.py"
|
|
||||||
test_file.write_text("print('test')")
|
|
||||||
os.system(f"git add {test_file}")
|
|
||||||
|
|
||||||
assert handler.is_staged() is True
|
|
||||||
|
|
||||||
def test_get_staged_changes(self, temp_git_repo):
|
|
||||||
"""Test getting staged changes."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
|
|
||||||
test_file = temp_git_repo / "test.py"
|
|
||||||
test_file.write_text("print('test')")
|
|
||||||
os.system(f"git add {test_file}")
|
|
||||||
|
|
||||||
diff = handler.get_staged_changes()
|
|
||||||
assert diff != ""
|
|
||||||
assert "test.py" in diff
|
|
||||||
|
|
||||||
|
|
||||||
class TestGitHandlerCommitHistory:
|
|
||||||
"""Tests for commit history functionality."""
|
|
||||||
|
|
||||||
def test_get_commit_history(self, temp_git_repo):
|
|
||||||
"""Test getting commit history."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
commits = handler.get_commit_history(max_commits=5)
|
|
||||||
|
|
||||||
assert len(commits) >= 1
|
|
||||||
assert any(c["message"] == "Initial commit" for c in commits)
|
|
||||||
|
|
||||||
def test_get_commit_history_conventional_only(self, temp_git_repo):
|
|
||||||
"""Test getting only conventional commits."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
commits = handler.get_commit_history(max_commits=10, conventional_only=True)
|
|
||||||
|
|
||||||
for commit in commits:
|
|
||||||
assert commit["type"] != "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGitHandlerLanguageDetection:
|
|
||||||
"""Tests for language detection functionality."""
|
|
||||||
|
|
||||||
def test_get_changed_languages_python(self, temp_git_repo):
|
|
||||||
"""Test detecting Python files."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
|
|
||||||
test_file = temp_git_repo / "test.py"
|
|
||||||
test_file.write_text("print('hello')")
|
|
||||||
os.system(f"git add {test_file}")
|
|
||||||
|
|
||||||
languages = handler.get_changed_languages()
|
|
||||||
assert "Python" in languages
|
|
||||||
|
|
||||||
def test_get_changed_languages_multiple(self, temp_git_repo):
|
|
||||||
"""Test detecting multiple languages."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
|
|
||||||
py_file = temp_git_repo / "test.py"
|
|
||||||
py_file.write_text("print('hello')")
|
|
||||||
js_file = temp_git_repo / "test.js"
|
|
||||||
js_file.write_text("console.log('hello')")
|
|
||||||
|
|
||||||
os.system(f"git add {py_file} {js_file}")
|
|
||||||
|
|
||||||
languages = handler.get_changed_languages()
|
|
||||||
assert "Python" in languages
|
|
||||||
assert "JavaScript" in languages
|
|
||||||
|
|
||||||
|
|
||||||
class TestGitHandlerHelpers:
|
|
||||||
"""Tests for helper methods."""
|
|
||||||
|
|
||||||
def test_get_staged_files(self, temp_git_repo):
|
|
||||||
"""Test getting staged files list."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
|
|
||||||
test_file = temp_git_repo / "test.py"
|
|
||||||
test_file.write_text("print('test')")
|
|
||||||
os.system(f"git add {test_file}")
|
|
||||||
|
|
||||||
files = handler.get_staged_files()
|
|
||||||
assert "test.py" in [f for f in files if "test.py" in f]
|
|
||||||
|
|
||||||
def test_get_diff_summary(self, temp_git_repo):
|
|
||||||
"""Test getting diff summary."""
|
|
||||||
handler = GitHandler(str(temp_git_repo))
|
|
||||||
|
|
||||||
test_file = temp_git_repo / "test.py"
|
|
||||||
test_file.write_text("print('test')")
|
|
||||||
os.system(f"git add {test_file}")
|
|
||||||
|
|
||||||
summary = handler.get_diff_summary()
|
|
||||||
assert "Files changed" in summary
|
|
||||||
assert "Python" in summary
|
|
||||||
|
|
||||||
|
|
||||||
class TestGitError:
|
|
||||||
"""Tests for GitError exception."""
|
|
||||||
|
|
||||||
def test_git_error_raised(self):
|
|
||||||
"""Test GitError is raised on git errors."""
|
|
||||||
with pytest.raises(GitError):
|
|
||||||
handler = GitHandler("/nonexistent/path")
|
|
||||||
handler.repo
|
|
||||||
|
|||||||
@@ -1,141 +1,20 @@
|
|||||||
"""Tests for the Ollama client module."""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from git_commit_ai.core.ollama_client import generate_commit_message
|
||||||
|
|
||||||
from git_commit_ai.core.ollama_client import OllamaClient, OllamaError, generate_diff_hash
|
def test_generate_commit_message_success():
|
||||||
|
"""Test successful commit message generation."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {'response': 'fix: resolve bug'}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
result = generate_commit_message("test prompt")
|
||||||
|
assert result == 'fix: resolve bug'
|
||||||
|
|
||||||
|
def test_generate_commit_message_connection_error():
|
||||||
class TestOllamaClientBasic:
|
"""Test connection error handling."""
|
||||||
"""Basic Ollama client tests."""
|
with patch('requests.post') as mock_post:
|
||||||
|
mock_post.side_effect = Exception("Connection failed")
|
||||||
def test_init(self, mock_config):
|
with pytest.raises(ConnectionError):
|
||||||
"""Test client initialization."""
|
generate_commit_message("test prompt")
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
assert client.model == "qwen2.5-coder:3b"
|
|
||||||
assert client.base_url == "http://localhost:11434"
|
|
||||||
|
|
||||||
def test_model_setter(self, mock_config):
|
|
||||||
"""Test model setter."""
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
client.model = "llama3:8b"
|
|
||||||
assert client.model == "llama3:8b"
|
|
||||||
|
|
||||||
def test_base_url_setter(self, mock_config):
|
|
||||||
"""Test base URL setter."""
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
client.base_url = "http://localhost:11435"
|
|
||||||
assert client.base_url == "http://localhost:11435"
|
|
||||||
|
|
||||||
|
|
||||||
class TestOllamaClientAvailability:
|
|
||||||
"""Tests for Ollama availability checks."""
|
|
||||||
|
|
||||||
def test_is_available_true(self, mock_config):
|
|
||||||
"""Test is_available returns True when server is up."""
|
|
||||||
with patch('requests.get') as mock_get:
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_get.return_value = mock_response
|
|
||||||
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
assert client.is_available() is True
|
|
||||||
|
|
||||||
def test_is_available_false(self, mock_config):
|
|
||||||
"""Test is_available returns False when server is down."""
|
|
||||||
with patch('requests.get') as mock_get:
|
|
||||||
mock_get.side_effect = Exception("Connection refused")
|
|
||||||
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
assert client.is_available() is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestOllamaClientModels:
|
|
||||||
"""Tests for model-related functionality."""
|
|
||||||
|
|
||||||
def test_list_models(self, mock_config):
|
|
||||||
"""Test listing available models."""
|
|
||||||
with patch('requests.get') as mock_get:
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"models": [
|
|
||||||
{"name": "qwen2.5-coder:3b", "size": 2000000000},
|
|
||||||
{"name": "llama3:8b", "size": 4000000000},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
mock_get.return_value = mock_response
|
|
||||||
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
models = client.list_models()
|
|
||||||
|
|
||||||
assert len(models) == 2
|
|
||||||
assert models[0]["name"] == "qwen2.5-coder:3b"
|
|
||||||
|
|
||||||
def test_check_model_exists_true(self, mock_config):
|
|
||||||
"""Test checking if model exists."""
|
|
||||||
with patch.object(OllamaClient, 'list_models') as mock_list:
|
|
||||||
mock_list.return_value = [{"name": "qwen2.5-coder:3b", "size": 2000000000}]
|
|
||||||
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
assert client.check_model_exists() is True
|
|
||||||
|
|
||||||
def test_check_model_exists_false(self, mock_config):
|
|
||||||
"""Test checking if model doesn't exist."""
|
|
||||||
with patch.object(OllamaClient, 'list_models') as mock_list:
|
|
||||||
mock_list.return_value = [{"name": "llama3:8b", "size": 4000000000}]
|
|
||||||
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
assert client.check_model_exists() is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestOllamaClientGeneration:
|
|
||||||
"""Tests for commit message generation."""
|
|
||||||
|
|
||||||
def test_parse_commit_message_simple(self, mock_config):
|
|
||||||
"""Test parsing a simple commit message."""
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
response = "feat: add new feature"
|
|
||||||
parsed = client._parse_commit_message(response)
|
|
||||||
assert parsed == "feat: add new feature"
|
|
||||||
|
|
||||||
def test_parse_commit_message_with_quotes(self, mock_config):
|
|
||||||
"""Test parsing a quoted commit message."""
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
response = '"feat: add new feature"'
|
|
||||||
parsed = client._parse_commit_message(response)
|
|
||||||
assert parsed == "feat: add new feature"
|
|
||||||
|
|
||||||
def test_parse_commit_message_truncates_long(self, mock_config):
|
|
||||||
"""Test parsing truncates long messages."""
|
|
||||||
client = OllamaClient(mock_config)
|
|
||||||
long_message = "a" * 100
|
|
||||||
parsed = client._parse_commit_message(long_message)
|
|
||||||
assert len(parsed) <= 80
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateDiffHash:
|
|
||||||
"""Tests for generate_diff_hash function."""
|
|
||||||
|
|
||||||
def test_generate_diff_hash(self):
|
|
||||||
"""Test generating diff hash."""
|
|
||||||
diff1 = "def hello():\n print('hi')"
|
|
||||||
diff2 = "def hello():\n print('hi')"
|
|
||||||
diff3 = "def goodbye():\n print('bye')"
|
|
||||||
|
|
||||||
hash1 = generate_diff_hash(diff1)
|
|
||||||
hash2 = generate_diff_hash(diff2)
|
|
||||||
hash3 = generate_diff_hash(diff3)
|
|
||||||
|
|
||||||
assert hash1 == hash2
|
|
||||||
assert hash1 != hash3
|
|
||||||
|
|
||||||
|
|
||||||
class TestOllamaError:
|
|
||||||
"""Tests for OllamaError exception."""
|
|
||||||
|
|
||||||
def test_ollama_error(self):
|
|
||||||
"""Test OllamaError is raised correctly."""
|
|
||||||
with pytest.raises(OllamaError):
|
|
||||||
raise OllamaError("Test error")
|
|
||||||
|
|||||||
20
git_commit_ai/tests/test_prompt_builder.py
Normal file
20
git_commit_ai/tests/test_prompt_builder.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import pytest
|
||||||
|
from git_commit_ai.core.prompt_builder import build_prompt, DEFAULT_PROMPT, CONVENTIONAL_PROMPT
|
||||||
|
|
||||||
|
def test_build_prompt_default():
|
||||||
|
"""Test default prompt building."""
|
||||||
|
prompt = build_prompt("test diff")
|
||||||
|
assert "test diff" in prompt
|
||||||
|
assert "No previous commits" in prompt
|
||||||
|
|
||||||
|
def test_build_prompt_with_history():
|
||||||
|
"""Test prompt building with history."""
|
||||||
|
prompt = build_prompt("test diff", history=["feat: add x", "fix: resolve y"])
|
||||||
|
assert "feat: add x" in prompt
|
||||||
|
assert "fix: resolve y" in prompt
|
||||||
|
|
||||||
|
def test_build_prompt_conventional():
|
||||||
|
"""Test conventional prompt building."""
|
||||||
|
prompt = build_prompt("test diff", conventional=True)
|
||||||
|
assert "conventional" in prompt.lower()
|
||||||
|
assert "feat" in prompt.lower() or "fix" in prompt.lower()
|
||||||
@@ -59,13 +59,13 @@ omit = ["git_commit_ai/tests/*"]
|
|||||||
[tool.coverage.report]
|
[tool.coverage.report]
|
||||||
exclude_lines = ["pragma: no cover", "def __repr__", "raise AssertionError", "raise NotImplementedError"]
|
exclude_lines = ["pragma: no cover", "def __repr__", "raise AssertionError", "raise NotImplementedError"]
|
||||||
|
|
||||||
[tool.black]
|
|
||||||
line-length = 100
|
|
||||||
target-version = ['py39']
|
|
||||||
include = '\.pyi?$'
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "W", "I", "B", "C4", "UP", "ARG", "SIM"]
|
select = ["E", "F", "W", "I", "B", "C4", "UP", "ARG", "SIM"]
|
||||||
ignore = ["E501", "B008", "C901"]
|
ignore = ["E501", "B008", "C901"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"git_commit_ai/tests/*" = ["ARG", "S"]
|
||||||
|
|||||||
@@ -3,5 +3,3 @@ ollama>=0.1
|
|||||||
gitpython>=3.1
|
gitpython>=3.1
|
||||||
pyyaml>=6.0
|
pyyaml>=6.0
|
||||||
requests>=2.31
|
requests>=2.31
|
||||||
pytest>=7.0
|
|
||||||
pytest-cov>=4.0
|
|
||||||
|
|||||||
19
setup.py
Normal file
19
setup.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="git-commit-ai",
|
||||||
|
version="0.1.0",
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
"click>=8.0",
|
||||||
|
"ollama>=0.1",
|
||||||
|
"gitpython>=3.1",
|
||||||
|
"pyyaml>=6.0",
|
||||||
|
"requests>=2.31",
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"git-commit-ai=git_commit_ai.cli.cli:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user