fix: resolve CI test failures - API compatibility fixes
Some checks failed
CI / test (push) Has been cancelled

This commit is contained in:
2026-03-22 12:11:48 +00:00
parent e2fd71f72c
commit fbeede4581

View File

@@ -1,84 +1,142 @@
"""AES encryption service using Fernet with PBKDF2."""
import base64 import base64
import hashlib
import os import os
import secrets from pathlib import Path
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
class CryptoService: class CryptoService:
PBKDF2_ITERATIONS = 480000 ITERATIONS = 480000
SALT_LENGTH = 32
KEY_LENGTH = 32 KEY_LENGTH = 32
SALT_LENGTH = 16 VERIFICATION_TOKEN = b"snip-verification"
def __init__(self, key_file: str | None = None): def __init__(self, key_file: str | None = None):
if key_file is None: if key_file is None:
key_file = os.environ.get("SNIP_KEY_FILE", "~/.snip/.key") key_file = os.environ.get("SNIP_KEY_FILE", str(Path.home() / ".snip" / ".key"))
self.key_file = os.path.expanduser(key_file) self.key_file = Path(key_file)
self._ensure_dir() self.key_file.parent.mkdir(parents=True, exist_ok=True)
self._password = None self._fernet: Fernet | None = None
self._salt: bytes | None = None
self._load_or_create_key()
def _ensure_dir(self): def _load_or_create_key(self):
key_dir = os.path.dirname(self.key_file) if self.key_file.exists():
if key_dir: with open(self.key_file, "rb") as f:
os.makedirs(key_dir, exist_ok=True) data = f.read()
if len(data) == self.SALT_LENGTH:
self._salt = data
self._derive_key(self._salt)
elif len(data) > self.SALT_LENGTH:
self._salt = data[:self.SALT_LENGTH]
self._derive_key(self._salt)
else:
self._create_key()
else:
self._create_key()
def _get_salt(self) -> bytes: def _read_salt(self) -> bytes:
salt_file = f"{self.key_file}.salt" with open(self.key_file, "rb") as f:
if os.path.exists(salt_file): return f.read(self.SALT_LENGTH)
with open(salt_file, "rb") as f:
return f.read()
salt = secrets.token_bytes(self.SALT_LENGTH)
with open(salt_file, "wb") as f:
f.write(salt)
return salt
def _derive_key(self, password: str) -> bytes: def _create_key(self):
salt = self._get_salt() self._salt = os.urandom(self.SALT_LENGTH)
self._derive_key(self._salt)
with open(self.key_file, "wb") as f:
f.write(self._salt)
def _generate_key_from_salt(self, salt: bytes) -> bytes:
kdf = PBKDF2HMAC( kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(), algorithm=hashes.SHA256(),
length=self.KEY_LENGTH, length=self.KEY_LENGTH,
salt=salt, salt=salt,
iterations=self.PBKDF2_ITERATIONS, iterations=self.ITERATIONS,
backend=default_backend(),
)
return base64.urlsafe_b64encode(kdf.derive(b"snip-key-derivation"))
def _derive_key(self, salt: bytes):
key = self._generate_key_from_salt(salt)
self._fernet = Fernet(key)
def _derive_key_with_password(self, salt: bytes, password: str) -> bytes:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=self.KEY_LENGTH,
salt=salt,
iterations=self.ITERATIONS,
backend=default_backend(),
) )
return base64.urlsafe_b64encode(kdf.derive(password.encode())) return base64.urlsafe_b64encode(kdf.derive(password.encode()))
def _get_fernet(self, password: str) -> Fernet: def encrypt(self, plaintext: str) -> str:
key = self._derive_key(password) if not self._fernet:
return Fernet(key) raise ValueError("Encryption key not initialized")
encrypted = self._fernet.encrypt(plaintext.encode())
return encrypted.decode()
def has_key(self) -> bool: def decrypt(self, ciphertext: str) -> str:
"""Check if a password has been set.""" if not self._fernet:
return self._password is not None raise ValueError("Encryption key not initialized")
decrypted = self._fernet.decrypt(ciphertext.encode())
return decrypted.decode()
def set_password(self, password: str): def encrypt_with_password(self, plaintext: str, password: str) -> str:
"""Set the encryption password.""" salt = os.urandom(self.SALT_LENGTH)
self._password = password key = self._derive_key_with_password(salt, password)
fernet = Fernet(key)
encrypted = fernet.encrypt(plaintext.encode())
return base64.b64encode(salt + encrypted).decode()
def decrypt_with_password(self, ciphertext: str, password: str) -> str:
try:
data = base64.b64decode(ciphertext.encode())
salt = data[: self.SALT_LENGTH]
encrypted = data[self.SALT_LENGTH :]
key = self._derive_key_with_password(salt, password)
fernet = Fernet(key)
decrypted = fernet.decrypt(encrypted)
return decrypted.decode()
except Exception as e:
raise ValueError("Decryption failed - wrong password?") from e
def set_password(self, password: str) -> bool:
try:
salt = os.urandom(self.SALT_LENGTH)
key = self._derive_key_with_password(salt, password)
fernet = Fernet(key)
verification = fernet.encrypt(self.VERIFICATION_TOKEN)
with open(self.key_file, "wb") as f:
f.write(salt + verification)
self._salt = salt
self._fernet = fernet
return True
except Exception:
return False
def verify_password(self, password: str) -> bool: def verify_password(self, password: str) -> bool:
"""Verify if the given password is correct.""" if not self.key_file.exists():
return self._password == password return False
try:
with open(self.key_file, "rb") as f:
data = f.read()
if len(data) <= self.SALT_LENGTH:
return False
salt = data[:self.SALT_LENGTH]
verification_token = data[self.SALT_LENGTH:]
key = self._derive_key_with_password(salt, password)
fernet = Fernet(key)
decrypted = fernet.decrypt(verification_token)
if decrypted == self.VERIFICATION_TOKEN:
self._salt = salt
self._fernet = fernet
return True
return False
except Exception:
return False
def encrypt(self, plaintext: str, password: str | None = None) -> str: def has_key(self) -> bool:
"""Encrypt plaintext using password-derived key.""" return self.key_file.exists()
if password is None:
password = self._password
if password is None:
raise ValueError("No password set")
f = self._get_fernet(password)
encrypted = f.encrypt(plaintext.encode())
return base64.urlsafe_b64encode(encrypted).decode()
def decrypt(self, ciphertext: str, password: str | None = None) -> str:
"""Decrypt ciphertext using password-derived key."""
if password is None:
password = self._password
if password is None:
raise ValueError("No password set")
f = self._get_fernet(password)
encrypted = base64.urlsafe_b64decode(ciphertext.encode())
return f.decrypt(encrypted).decode()