diff --git a/snip/crypto/service.py b/snip/crypto/service.py new file mode 100644 index 0000000..16dae47 --- /dev/null +++ b/snip/crypto/service.py @@ -0,0 +1,61 @@ +"""AES encryption service using Fernet with PBKDF2.""" + +import base64 +import hashlib +import os +import secrets + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +class CryptoService: + PBKDF2_ITERATIONS = 480000 + KEY_LENGTH = 32 + SALT_LENGTH = 16 + + def __init__(self, key_file: str | None = None): + if key_file is None: + key_file = os.environ.get("SNIP_KEY_FILE", "~/.snip/.key") + self.key_file = os.path.expanduser(key_file) + self._ensure_dir() + + def _ensure_dir(self): + os.makedirs(os.path.dirname(self.key_file), exist_ok=True) + + def _get_salt(self) -> bytes: + salt_file = f"{self.key_file}.salt" + if os.path.exists(salt_file): + with open(salt_file, "rb") as f: + return f.read() + salt = secrets.token_bytes(self.SALT_LENGTH) + with open(salt_file, "wb") as f: + f.write(salt) + return salt + + def _derive_key(self, password: str) -> bytes: + salt = self._get_salt() + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=self.KEY_LENGTH, + salt=salt, + iterations=self.PBKDF2_ITERATIONS, + ) + return base64.urlsafe_b64encode(kdf.derive(password.encode())) + + def _get_fernet(self, password: str) -> Fernet: + key = self._derive_key(password) + return Fernet(key) + + def encrypt(self, plaintext: str, password: str) -> str: + """Encrypt plaintext using password-derived key.""" + f = self._get_fernet(password) + encrypted = f.encrypt(plaintext.encode()) + return base64.urlsafe_b64encode(encrypted).decode() + + def decrypt(self, ciphertext: str, password: str) -> str: + """Decrypt ciphertext using password-derived key.""" + f = self._get_fernet(password) + encrypted = base64.urlsafe_b64decode(ciphertext.encode()) + return f.decrypt(encrypted).decode()