1
use zeroize::{Zeroize, ZeroizeOnDrop};
2

            
3
use crate::{crypto, file};
4

            
5
/// A key.
6
#[derive(Zeroize, ZeroizeOnDrop)]
7
pub struct Key {
8
    key: Vec<u8>,
9
    #[zeroize(skip)]
10
    strength: Result<(), file::WeakKeyError>,
11
}
12

            
13
impl std::fmt::Debug for Key {
14
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15
2
        write!(
16
            f,
17
            "Key {{ key: [REDACTED], strength: {:?} }}",
18
            self.strength
19
        )
20
    }
21
}
22

            
23
impl AsRef<[u8]> for Key {
24
4
    fn as_ref(&self) -> &[u8] {
25
4
        self.key.as_slice()
26
    }
27
}
28

            
29
impl AsMut<[u8]> for Key {
30
5
    fn as_mut(&mut self) -> &mut [u8] {
31
7
        &mut self.key
32
    }
33
}
34

            
35
impl Key {
36
4
    pub const fn new(key: Vec<u8>) -> Self {
37
4
        Self::new_with_strength(key, Err(file::WeakKeyError::StrengthUnknown))
38
    }
39

            
40
4
    pub(crate) const fn check_strength(&self) -> Result<(), file::WeakKeyError> {
41
4
        self.strength
42
    }
43

            
44
5
    pub(crate) const fn new_with_strength(
45
        key: Vec<u8>,
46
        strength: Result<(), file::WeakKeyError>,
47
    ) -> Self {
48
        Self { key, strength }
49
    }
50

            
51
4
    pub fn generate_private_key() -> Result<Self, crypto::Error> {
52
4
        Ok(Self::new(crypto::generate_private_key()?.to_vec()))
53
    }
54

            
55
4
    pub fn generate_public_key(private_key: &Self) -> Result<Self, crypto::Error> {
56
4
        Ok(Self::new(crypto::generate_public_key(private_key)?))
57
    }
58

            
59
4
    pub fn generate_aes_key(
60
        private_key: &Self,
61
        server_public_key: &Self,
62
    ) -> Result<Self, crypto::Error> {
63
4
        Ok(Self::new(
64
8
            crypto::generate_aes_key(private_key, server_public_key)?.to_vec(),
65
        ))
66
    }
67
}
68

            
69
impl From<Key> for zvariant::Value<'static> {
70
4
    fn from(key: Key) -> Self {
71
4
        let mut key = key;
72
4
        let inner: Vec<u8> = std::mem::take(&mut key.key);
73
4
        zvariant::Array::from(inner).into()
74
    }
75
}
76

            
77
impl From<Key> for zvariant::OwnedValue {
78
2
    fn from(key: Key) -> Self {
79
2
        zvariant::Value::from(key).try_into_owned().unwrap()
80
    }
81
}
82

            
83
impl TryFrom<zvariant::Value<'_>> for Key {
84
    type Error = zvariant::Error;
85

            
86
4
    fn try_from(value: zvariant::Value<'_>) -> Result<Self, Self::Error> {
87
4
        Ok(Key::new(value.try_into()?))
88
    }
89
}
90

            
91
impl TryFrom<zvariant::OwnedValue> for Key {
92
    type Error = zvariant::Error;
93

            
94
4
    fn try_from(value: zvariant::OwnedValue) -> Result<Self, Self::Error> {
95
4
        Self::try_from(zvariant::Value::from(value))
96
    }
97
}
98

            
99
#[cfg(test)]
100
mod tests {
101
    use super::*;
102

            
103
    #[test]
104
    fn private_public_pair() {
105
        let private_key = Key::new(vec![
106
            41, 20, 63, 236, 246, 132, 109, 70, 172, 121, 45, 66, 129, 21, 247, 91, 96, 217, 56,
107
            201, 205, 56, 17, 178, 202, 81, 71, 104, 233, 89, 87, 32, 88, 146, 107, 224, 56, 103,
108
            111, 74, 143, 80, 170, 40, 5, 52, 48, 90, 75, 71, 193, 224, 222, 57, 91, 81, 66, 1, 6,
109
            88, 137, 66, 102, 207, 55, 95, 67, 92, 140, 227, 242, 153, 185, 195, 89, 236, 146, 242,
110
            88, 215, 1, 7, 135, 254, 85, 165, 236, 110, 22, 79, 107, 254, 149, 164, 243, 94, 129,
111
            198, 45, 208, 132, 166, 0, 153, 243, 160, 255, 188, 59, 216, 99, 221, 85, 162, 116,
112
            210, 160, 117, 201, 39, 179, 123, 107, 8, 242, 139, 207, 250,
113
        ]);
114
        let server_public_key = Key::new(vec![
115
            50, 233, 76, 88, 47, 206, 235, 107, 9, 232, 98, 14, 188, 214, 209, 77, 35, 66, 109,
116
            119, 24, 191, 120, 90, 242, 198, 240, 115, 200, 66, 51, 180, 8, 164, 89, 9, 229, 31,
117
            160, 31, 156, 101, 169, 60, 63, 247, 37, 255, 75, 198, 62, 235, 50, 29, 221, 245, 29,
118
            248, 140, 209, 62, 215, 2, 137, 82, 77, 248, 242, 56, 176, 118, 183, 124, 74, 26, 133,
119
            188, 47, 31, 141, 232, 194, 92, 18, 69, 3, 56, 153, 42, 9, 143, 81, 197, 159, 200, 197,
120
            221, 74, 186, 157, 158, 36, 74, 125, 11, 234, 33, 2, 5, 36, 206, 248, 155, 157, 145,
121
            159, 238, 19, 185, 194, 134, 3, 195, 198, 60, 100, 159, 31,
122
        ]);
123

            
124
        let expected_public_key = &[
125
            9, 192, 210, 81, 212, 191, 74, 119, 22, 172, 81, 142, 124, 89, 17, 71, 118, 190, 81,
126
            71, 49, 149, 200, 204, 14, 47, 111, 165, 119, 103, 216, 102, 111, 93, 242, 64, 73, 224,
127
            165, 11, 127, 219, 197, 188, 168, 222, 254, 10, 104, 81, 8, 206, 237, 119, 225, 100,
128
            78, 196, 89, 163, 63, 169, 77, 236, 80, 241, 189, 49, 27, 40, 243, 229, 66, 53, 80, 86,
129
            44, 213, 87, 186, 68, 55, 216, 56, 236, 51, 229, 44, 174, 18, 87, 141, 85, 71, 185,
130
            203, 208, 144, 190, 117, 141, 255, 153, 106, 123, 28, 152, 200, 237, 189, 176, 20, 80,
131
            211, 33, 158, 232, 194, 145, 45, 194, 35, 108, 106, 214, 221, 159, 137,
132
        ];
133
        let expected_aes_key = &[
134
            132, 3, 113, 222, 81, 209, 49, 43, 81, 232, 243, 46, 1, 103, 184, 42,
135
        ];
136

            
137
        let public_key = Key::generate_public_key(&private_key);
138
        let aes_key = Key::generate_aes_key(&private_key, &server_public_key);
139

            
140
        assert_eq!(public_key.unwrap().as_ref(), expected_public_key);
141
        assert_eq!(aes_key.unwrap().as_ref(), expected_aes_key);
142
    }
143

            
144
    #[test]
145
    fn key_debug_is_redacted() {
146
        let key = Key::new(vec![1, 2, 3, 4]);
147
        let debug_output = format!("{:?}", key);
148

            
149
        assert!(debug_output.contains("key: [REDACTED]"));
150
        assert!(debug_output.contains("strength:"));
151
    }
152
}