1
use openssl::{
2
    bn::BigNum,
3
    dh::Dh,
4
    hash::{Hasher, MessageDigest, hash},
5
    md::Md,
6
    memcmp,
7
    nid::Nid,
8
    pkcs5::pbkdf2_hmac,
9
    pkey::{Id, PKey},
10
    pkey_ctx::PkeyCtx,
11
    rand::rand_bytes,
12
    sign::Signer,
13
    symm::{Cipher, Crypter, Mode},
14
};
15
use zeroize::Zeroizing;
16

            
17
use crate::{Key, Mac, file};
18

            
19
const ENC_ALG: Nid = Nid::AES_128_CBC;
20
const MAC_ALG: Nid = Nid::SHA256;
21

            
22
7
pub fn encrypt(
23
    data: impl AsRef<[u8]>,
24
    key: &Key,
25
    iv: impl AsRef<[u8]>,
26
) -> Result<Vec<u8>, super::Error> {
27
14
    let cipher = Cipher::from_nid(ENC_ALG).unwrap();
28
14
    let mut encryptor = Crypter::new(cipher, Mode::Encrypt, key.as_ref(), Some(iv.as_ref()))
29
        .expect("Invalid key or IV length");
30
7
    encryptor.pad(true);
31

            
32
7
    let mut blob = vec![0; data.as_ref().len() + cipher.block_size()];
33
    // Unwrapping since adding `CIPHER_BLOCK_SIZE` to array is enough space for
34
    // PKCS7
35
14
    let mut encrypted_len = encryptor.update(data.as_ref(), &mut blob)?;
36
7
    encrypted_len += encryptor.finalize(&mut blob[encrypted_len..])?;
37

            
38
7
    blob.truncate(encrypted_len);
39

            
40
7
    Ok(blob)
41
}
42

            
43
8
fn decrypt_with_padding(
44
    blob: impl AsRef<[u8]>,
45
    key: &Key,
46
    iv: impl AsRef<[u8]>,
47
    pad: bool,
48
) -> Result<Zeroizing<Vec<u8>>, super::Error> {
49
16
    let cipher = Cipher::from_nid(ENC_ALG).unwrap();
50
16
    let mut decrypter = Crypter::new(cipher, Mode::Decrypt, key.as_ref(), Some(iv.as_ref()))
51
        .expect("Invalid key or IV length");
52
8
    decrypter.pad(pad);
53

            
54
8
    let mut data = Zeroizing::new(vec![0; blob.as_ref().len() + cipher.block_size()]);
55
16
    let mut decrypted_len = decrypter.update(blob.as_ref(), &mut data)?;
56
8
    decrypted_len += decrypter.finalize(&mut data[decrypted_len..])?;
57

            
58
16
    data.truncate(decrypted_len);
59

            
60
8
    Ok(data)
61
}
62

            
63
6
pub fn decrypt(
64
    blob: impl AsRef<[u8]>,
65
    key: &Key,
66
    iv: impl AsRef<[u8]>,
67
) -> Result<Zeroizing<Vec<u8>>, super::Error> {
68
6
    decrypt_with_padding(blob, key, iv, true)
69
}
70

            
71
2
pub(crate) fn decrypt_no_padding(
72
    blob: impl AsRef<[u8]>,
73
    key: &Key,
74
    iv: impl AsRef<[u8]>,
75
) -> Result<Zeroizing<Vec<u8>>, super::Error> {
76
2
    decrypt_with_padding(blob, key, iv, false)
77
}
78

            
79
2
pub(crate) fn iv_len() -> usize {
80
2
    let cipher = Cipher::from_nid(ENC_ALG).unwrap();
81
2
    cipher.iv_len().unwrap()
82
}
83

            
84
2
pub(crate) fn generate_private_key() -> Result<Zeroizing<Vec<u8>>, super::Error> {
85
2
    let cipher = Cipher::from_nid(ENC_ALG).unwrap();
86
2
    let mut buf = Zeroizing::new(vec![0; cipher.key_len()]);
87
4
    rand_bytes(&mut buf)?;
88
2
    Ok(buf)
89
}
90

            
91
2
pub(crate) fn generate_public_key(private_key: impl AsRef<[u8]>) -> Result<Vec<u8>, super::Error> {
92
4
    let private_key_bn = BigNum::from_slice(private_key.as_ref()).unwrap();
93
    let dh = Dh::from_pqg(
94
4
        BigNum::get_rfc2409_prime_1024().unwrap(),
95
2
        None,
96
4
        BigNum::from_u32(2).unwrap(),
97
    )?;
98
4
    Ok(dh.set_private_key(private_key_bn)?.public_key().to_vec())
99
}
100

            
101
2
pub(crate) fn generate_aes_key(
102
    private_key: impl AsRef<[u8]>,
103
    server_public_key: impl AsRef<[u8]>,
104
) -> Result<Zeroizing<Vec<u8>>, super::Error> {
105
4
    let private_key_bn = BigNum::from_slice(private_key.as_ref()).unwrap();
106
4
    let server_public_key_bn = BigNum::from_slice(server_public_key.as_ref()).unwrap();
107
    let dh = Dh::from_pqg(
108
4
        BigNum::get_rfc2409_prime_1024().unwrap(),
109
2
        None,
110
4
        BigNum::from_u32(2).unwrap(),
111
    )?;
112
6
    let mut common_secret_bytes = dh
113
2
        .set_private_key(private_key_bn)?
114
2
        .compute_key(&server_public_key_bn)?;
115

            
116
2
    let mut common_secret_padded = vec![0; 128 - common_secret_bytes.len()];
117
    // inefficient, but ok for now
118
2
    common_secret_padded.append(&mut common_secret_bytes);
119

            
120
    // hkdf
121
    // input_keying_material
122
2
    let ikm = common_secret_padded;
123

            
124
4
    let mut okm = Zeroizing::new(vec![0; 16]);
125
4
    let mut ctx = PkeyCtx::new_id(Id::HKDF)?;
126
4
    ctx.derive_init()?;
127
2
    ctx.set_hkdf_md(Md::sha256())?;
128
2
    ctx.set_hkdf_key(&ikm)?;
129
2
    ctx.derive(Some(okm.as_mut()))
130
        .expect("hkdf expand should never fail");
131
2
    Ok(okm)
132
}
133

            
134
2
pub fn generate_iv() -> Result<Vec<u8>, super::Error> {
135
2
    let mut buf = vec![0; iv_len()];
136
4
    rand_bytes(&mut buf)?;
137
2
    Ok(buf)
138
}
139

            
140
2
pub(crate) fn mac_len() -> usize {
141
2
    let md = MessageDigest::from_nid(MAC_ALG).unwrap();
142
2
    md.size()
143
}
144

            
145
8
pub(crate) fn compute_mac(data: impl AsRef<[u8]>, key: &Key) -> Result<Mac, super::Error> {
146
16
    let md = MessageDigest::from_nid(MAC_ALG).unwrap();
147
8
    let mac_key = PKey::hmac(key.as_ref())?;
148
16
    let mut signer = Signer::new(md, &mac_key)?;
149
16
    signer.update(data.as_ref())?;
150
8
    signer.sign_to_vec().map_err(From::from).map(Mac::new)
151
}
152

            
153
4
pub(crate) fn verify_mac(
154
    data: impl AsRef<[u8]>,
155
    key: &Key,
156
    expected_mac: impl AsRef<[u8]>,
157
) -> Result<bool, super::Error> {
158
4
    Ok(memcmp::eq(
159
8
        compute_mac(&data, key)?.as_slice(),
160
4
        expected_mac.as_ref(),
161
    ))
162
}
163

            
164
2
pub(crate) fn verify_checksum_md5(digest: impl AsRef<[u8]>, content: impl AsRef<[u8]>) -> bool {
165
    memcmp::eq(
166
5
        &hash(MessageDigest::md5(), content.as_ref()).unwrap(),
167
2
        digest.as_ref(),
168
    )
169
}
170

            
171
2
pub(crate) fn derive_key(
172
    secret: impl AsRef<[u8]>,
173
    key_strength: Result<(), file::WeakKeyError>,
174
    salt: impl AsRef<[u8]>,
175
    iteration_count: usize,
176
) -> Result<Key, super::Error> {
177
6
    let cipher = Cipher::from_nid(ENC_ALG).unwrap();
178
4
    let mut key = Key::new_with_strength(vec![0; cipher.block_size()], key_strength);
179

            
180
6
    let md = MessageDigest::from_nid(MAC_ALG).unwrap();
181
    pbkdf2_hmac(
182
4
        secret.as_ref(),
183
2
        salt.as_ref(),
184
        iteration_count,
185
        md,
186
4
        key.as_mut(),
187
    )?;
188

            
189
2
    Ok(key)
190
}
191

            
192
3
pub(crate) fn legacy_derive_key_and_iv(
193
    secret: impl AsRef<[u8]>,
194
    key_strength: Result<(), file::WeakKeyError>,
195
    salt: impl AsRef<[u8]>,
196
    iteration_count: usize,
197
) -> Result<(Key, Vec<u8>), super::Error> {
198
6
    let cipher = Cipher::from_nid(ENC_ALG).unwrap();
199
3
    let mut buffer = vec![0; cipher.key_len() + cipher.iv_len().unwrap()];
200
6
    let mut hasher = Hasher::new(MessageDigest::sha256())?;
201
3
    let mut pos = 0usize;
202

            
203
    loop {
204
6
        hasher.update(secret.as_ref())?;
205
3
        hasher.update(salt.as_ref())?;
206
3
        let mut digest = hasher.finish()?;
207

            
208
6
        for _ in 1..iteration_count {
209
            // We can't pass an instance, the borrow checker
210
            // would complain about digest being dropped at the end of
211
            // for block
212
            #[allow(clippy::needless_borrows_for_generic_args)]
213
6
            hasher.update(&digest)?;
214
3
            digest = hasher.finish()?;
215
        }
216

            
217
3
        let to_read = usize::min(digest.len(), buffer.len() - pos);
218
3
        buffer[pos..].copy_from_slice(&(&*digest)[..to_read]);
219
3
        pos += to_read;
220

            
221
6
        if pos == buffer.len() {
222
            break;
223
        }
224

            
225
        // We can't pass an instance, the borrow checker
226
        // would complain about digest being dropped at the end of
227
        // for block
228
        #[allow(clippy::needless_borrows_for_generic_args)]
229
        hasher.update(&digest)?;
230
    }
231

            
232
6
    let iv = buffer.split_off(cipher.key_len());
233
6
    Ok((Key::new_with_strength(buffer, key_strength), iv))
234
}