diff --git a/src/crypto.rs b/src/crypto.rs new file mode 100644 index 0000000..6ecbbb7 --- /dev/null +++ b/src/crypto.rs @@ -0,0 +1,121 @@ +use sodiumoxide::crypto::secretbox; +use sodiumoxide::crypto::pwhash; +use base64::{Engine as _, engine::general_purpose::STANDARD}; +use std::fmt; + +#[derive(Debug)] +pub enum CryptoError { + KeyDerivationFailed, + EncryptionFailed, + DecryptionFailed, + InvalidKey, + InvalidNonce, +} + +impl fmt::Display for CryptoError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CryptoError::KeyDerivationFailed => write!(f, "Failed to derive encryption key from password"), + CryptoError::EncryptionFailed => write!(f, "Failed to encrypt data"), + CryptoError::DecryptionFailed => write!(f, "Failed to decrypt data"), + CryptoError::InvalidKey => write!(f, "Invalid encryption key"), + CryptoError::InvalidNonce => write!(f, "Invalid nonce"), + } + } +} + +impl std::error::Error for CryptoError {} + +pub struct CryptoManager { + key: secretbox::Key, +} + +impl CryptoManager { + pub fn from_password(password: &str, salt: &[u8; pwhash::SALTBYTES]) -> Result { + let mut key = secretbox::Key::from_slice(&[0u8; secretbox::KEYBYTES]) + .ok_or(CryptoError::InvalidKey)?; + + let password_bytes = password.as_bytes(); + + let derived_key = pwhash::pwhash( + password_bytes, + salt, + pwhash::Oscillating::interactive().unwrap(), + secretbox::KEYBYTES, + ).map_err(|_| CryptoError::KeyDerivationFailed)?; + + key.as_mut_slice().copy_from_slice(&derived_key); + + Ok(CryptoManager { key }) + } + + pub fn encrypt(&self, plaintext: &[u8]) -> Result<(Vec, Vec), CryptoError> { + let nonce = secretbox::gen_nonce(); + let ciphertext = secretbox::seal(plaintext, &nonce, &self.key) + .map_err(|_| CryptoError::EncryptionFailed)?; + + Ok((nonce.as_ref().to_vec(), ciphertext)) + } + + pub fn decrypt(&self, nonce: &[u8], ciphertext: &[u8]) -> Result, CryptoError> { + let nonce = secretbox::Nonce::from_slice(nonce) + .ok_or(CryptoError::InvalidNonce)?; + + let plaintext = secretbox::open(ciphertext, &nonce, &self.key) + .map_err(|_| CryptoError::DecryptionFailed)?; + + Ok(plaintext) + } + + pub fn encrypt_base64(&self, plaintext: &str) -> Result { + let (nonce, ciphertext) = self.encrypt(plaintext.as_bytes())?; + + let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len()); + combined.extend_from_slice(&nonce); + combined.extend_from_slice(&ciphertext); + + Ok(STANDARD.encode(combined)) + } + + pub fn decrypt_base64(&self, encrypted_data: &str) -> Result { + let combined = STANDARD.decode(encrypted_data) + .map_err(|_| CryptoError::DecryptionFailed)?; + + if combined.len() < secretbox::NONCEBYTES + secretbox::MACBYTES { + return Err(CryptoError::DecryptionFailed); + } + + let nonce_len = secretbox::NONCEBYTES; + let nonce = &combined[..nonce_len]; + let ciphertext = &combined[nonce_len..]; + + let plaintext = self.decrypt(nonce, ciphertext)?; + + String::from_utf8(plaintext) + .map_err(|_| CryptoError::DecryptionFailed) + } +} + +pub fn generate_salt() -> [u8; pwhash::SALTBYTES] { + let mut salt = [0u8; pwhash::SALTBYTES]; + sodiumoxide::init().unwrap(); + rand::Rng::fill(&mut rand::thread_rng(), &mut salt); + salt +} + +pub fn salt_to_base64(salt: &[u8; pwhash::SALTBYTES]) -> String { + STANDARD.encode(salt) +} + +pub fn salt_from_base64(salt_str: &str) -> Result<[u8; pwhash::SALTBYTES], CryptoError> { + let decoded = STANDARD.decode(salt_str) + .map_err(|_| CryptoError::DecryptionFailed)?; + + if decoded.len() != pwhash::SALTBYTES { + return Err(CryptoError::DecryptionFailed); + } + + let mut salt = [0u8; pwhash::SALTBYTES]; + salt.copy_from_slice(&decoded); + Ok(salt) +}