diff --git a/snip/crypto/service.py b/snip/crypto/service.py index 16dae47..605f85e 100644 --- a/snip/crypto/service.py +++ b/snip/crypto/service.py @@ -20,9 +20,12 @@ class CryptoService: key_file = os.environ.get("SNIP_KEY_FILE", "~/.snip/.key") self.key_file = os.path.expanduser(key_file) self._ensure_dir() + self._password = None def _ensure_dir(self): - os.makedirs(os.path.dirname(self.key_file), exist_ok=True) + key_dir = os.path.dirname(self.key_file) + if key_dir: + os.makedirs(key_dir, exist_ok=True) def _get_salt(self) -> bytes: salt_file = f"{self.key_file}.salt" @@ -48,14 +51,34 @@ class CryptoService: key = self._derive_key(password) return Fernet(key) - def encrypt(self, plaintext: str, password: str) -> str: + def has_key(self) -> bool: + """Check if a password has been set.""" + return self._password is not None + + def set_password(self, password: str): + """Set the encryption password.""" + self._password = password + + def verify_password(self, password: str) -> bool: + """Verify if the given password is correct.""" + return self._password == password + + def encrypt(self, plaintext: str, password: str | None = None) -> str: """Encrypt plaintext using password-derived key.""" + if password is None: + password = self._password + if password is None: + raise ValueError("No password set") f = self._get_fernet(password) encrypted = f.encrypt(plaintext.encode()) return base64.urlsafe_b64encode(encrypted).decode() - def decrypt(self, ciphertext: str, password: str) -> str: + def decrypt(self, ciphertext: str, password: str | None = None) -> str: """Decrypt ciphertext using password-derived key.""" + if password is None: + password = self._password + if password is None: + raise ValueError("No password set") f = self._get_fernet(password) encrypted = base64.urlsafe_b64decode(ciphertext.encode()) - return f.decrypt(encrypted).decode() + return f.decrypt(encrypted).decode() \ No newline at end of file