import base64 import os from pathlib import Path from cryptography.fernet import Fernet from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC class CryptoService: ITERATIONS = 480000 SALT_LENGTH = 32 KEY_LENGTH = 32 VERIFICATION_TOKEN = b"snip-verification" def __init__(self, key_file: str | None = None): if key_file is None: key_file = os.environ.get("SNIP_KEY_FILE", str(Path.home() / ".snip" / ".key")) self.key_file = Path(key_file) self.key_file.parent.mkdir(parents=True, exist_ok=True) self._fernet: Fernet | None = None self._salt: bytes | None = None self._load_or_create_key() def _load_or_create_key(self): if self.key_file.exists(): with open(self.key_file, "rb") as f: data = f.read() if len(data) == self.SALT_LENGTH: self._salt = data self._derive_key(self._salt) elif len(data) > self.SALT_LENGTH: self._salt = data[:self.SALT_LENGTH] self._derive_key(self._salt) else: self._create_key() else: self._create_key() def _read_salt(self) -> bytes: with open(self.key_file, "rb") as f: return f.read(self.SALT_LENGTH) def _create_key(self): self._salt = os.urandom(self.SALT_LENGTH) self._derive_key(self._salt) with open(self.key_file, "wb") as f: f.write(self._salt) def _generate_key_from_salt(self, salt: bytes) -> bytes: kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=self.KEY_LENGTH, salt=salt, iterations=self.ITERATIONS, backend=default_backend(), ) return base64.urlsafe_b64encode(kdf.derive(b"snip-key-derivation")) def _derive_key(self, salt: bytes): key = self._generate_key_from_salt(salt) self._fernet = Fernet(key) def _derive_key_with_password(self, salt: bytes, password: str) -> bytes: kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=self.KEY_LENGTH, salt=salt, iterations=self.ITERATIONS, backend=default_backend(), ) return base64.urlsafe_b64encode(kdf.derive(password.encode())) def encrypt(self, plaintext: str) -> str: if not self._fernet: raise ValueError("Encryption key not initialized") encrypted = self._fernet.encrypt(plaintext.encode()) return encrypted.decode() def decrypt(self, ciphertext: str) -> str: if not self._fernet: raise ValueError("Encryption key not initialized") decrypted = self._fernet.decrypt(ciphertext.encode()) return decrypted.decode() def encrypt_with_password(self, plaintext: str, password: str) -> str: salt = os.urandom(self.SALT_LENGTH) key = self._derive_key_with_password(salt, password) fernet = Fernet(key) encrypted = fernet.encrypt(plaintext.encode()) return base64.b64encode(salt + encrypted).decode() def decrypt_with_password(self, ciphertext: str, password: str) -> str: try: data = base64.b64decode(ciphertext.encode()) salt = data[: self.SALT_LENGTH] encrypted = data[self.SALT_LENGTH :] key = self._derive_key_with_password(salt, password) fernet = Fernet(key) decrypted = fernet.decrypt(encrypted) return decrypted.decode() except Exception as e: raise ValueError("Decryption failed - wrong password?") from e def set_password(self, password: str) -> bool: try: salt = os.urandom(self.SALT_LENGTH) key = self._derive_key_with_password(salt, password) fernet = Fernet(key) verification = fernet.encrypt(self.VERIFICATION_TOKEN) with open(self.key_file, "wb") as f: f.write(salt + verification) self._salt = salt self._fernet = fernet return True except Exception: return False def verify_password(self, password: str) -> bool: if not self.key_file.exists(): return False try: with open(self.key_file, "rb") as f: data = f.read() if len(data) <= self.SALT_LENGTH: return False salt = data[:self.SALT_LENGTH] verification_token = data[self.SALT_LENGTH:] key = self._derive_key_with_password(salt, password) fernet = Fernet(key) decrypted = fernet.decrypt(verification_token) if decrypted == self.VERIFICATION_TOKEN: self._salt = salt self._fernet = fernet return True return False except Exception: return False def has_key(self) -> bool: return self.key_file.exists()