Files
schema2mock/snip/crypto/service.py

84 lines
2.7 KiB
Python

"""AES encryption service using Fernet with PBKDF2."""
import base64
import hashlib
import os
import secrets
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
class CryptoService:
PBKDF2_ITERATIONS = 480000
KEY_LENGTH = 32
SALT_LENGTH = 16
def __init__(self, key_file: str | None = None):
if key_file is None:
key_file = os.environ.get("SNIP_KEY_FILE", "~/.snip/.key")
self.key_file = os.path.expanduser(key_file)
self._ensure_dir()
self._password = None
def _ensure_dir(self):
key_dir = os.path.dirname(self.key_file)
if key_dir:
os.makedirs(key_dir, exist_ok=True)
def _get_salt(self) -> bytes:
salt_file = f"{self.key_file}.salt"
if os.path.exists(salt_file):
with open(salt_file, "rb") as f:
return f.read()
salt = secrets.token_bytes(self.SALT_LENGTH)
with open(salt_file, "wb") as f:
f.write(salt)
return salt
def _derive_key(self, password: str) -> bytes:
salt = self._get_salt()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=self.KEY_LENGTH,
salt=salt,
iterations=self.PBKDF2_ITERATIONS,
)
return base64.urlsafe_b64encode(kdf.derive(password.encode()))
def _get_fernet(self, password: str) -> Fernet:
key = self._derive_key(password)
return Fernet(key)
def has_key(self) -> bool:
"""Check if a password has been set."""
return self._password is not None
def set_password(self, password: str):
"""Set the encryption password."""
self._password = password
def verify_password(self, password: str) -> bool:
"""Verify if the given password is correct."""
return self._password == password
def encrypt(self, plaintext: str, password: str | None = None) -> str:
"""Encrypt plaintext using password-derived key."""
if password is None:
password = self._password
if password is None:
raise ValueError("No password set")
f = self._get_fernet(password)
encrypted = f.encrypt(plaintext.encode())
return base64.urlsafe_b64encode(encrypted).decode()
def decrypt(self, ciphertext: str, password: str | None = None) -> str:
"""Decrypt ciphertext using password-derived key."""
if password is None:
password = self._password
if password is None:
raise ValueError("No password set")
f = self._get_fernet(password)
encrypted = base64.urlsafe_b64decode(ciphertext.encode())
return f.decrypt(encrypted).decode()