"""AES encryption service using Fernet with PBKDF2.""" import base64 import hashlib import os import secrets from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC class CryptoService: PBKDF2_ITERATIONS = 480000 KEY_LENGTH = 32 SALT_LENGTH = 16 def __init__(self, key_file: str | None = None): if key_file is None: key_file = os.environ.get("SNIP_KEY_FILE", "~/.snip/.key") self.key_file = os.path.expanduser(key_file) self._ensure_dir() def _ensure_dir(self): os.makedirs(os.path.dirname(self.key_file), exist_ok=True) def _get_salt(self) -> bytes: salt_file = f"{self.key_file}.salt" if os.path.exists(salt_file): with open(salt_file, "rb") as f: return f.read() salt = secrets.token_bytes(self.SALT_LENGTH) with open(salt_file, "wb") as f: f.write(salt) return salt def _derive_key(self, password: str) -> bytes: salt = self._get_salt() kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=self.KEY_LENGTH, salt=salt, iterations=self.PBKDF2_ITERATIONS, ) return base64.urlsafe_b64encode(kdf.derive(password.encode())) def _get_fernet(self, password: str) -> Fernet: key = self._derive_key(password) return Fernet(key) def encrypt(self, plaintext: str, password: str) -> str: """Encrypt plaintext using password-derived key.""" f = self._get_fernet(password) encrypted = f.encrypt(plaintext.encode()) return base64.urlsafe_b64encode(encrypted).decode() def decrypt(self, ciphertext: str, password: str) -> str: """Decrypt ciphertext using password-derived key.""" f = self._get_fernet(password) encrypted = base64.urlsafe_b64decode(ciphertext.encode()) return f.decrypt(encrypted).decode()