1
use std::str::FromStr;
2

            
3
use serde::{Deserialize, Serialize};
4
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
5

            
6
#[derive(Default, PartialEq, Eq, Copy, Clone, Debug, zvariant::Type)]
7
#[zvariant(signature = "s")]
8
pub enum ContentType {
9
    Text,
10
    #[default]
11
    Blob,
12
}
13

            
14
impl Serialize for ContentType {
15
6
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
16
    where
17
        S: serde::Serializer,
18
    {
19
12
        self.as_str().serialize(serializer)
20
    }
21
}
22

            
23
impl<'de> Deserialize<'de> for ContentType {
24
7
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25
    where
26
        D: serde::Deserializer<'de>,
27
    {
28
7
        let s = String::deserialize(deserializer)?;
29
14
        Self::from_str(&s).map_err(serde::de::Error::custom)
30
    }
31
}
32

            
33
impl FromStr for ContentType {
34
    type Err = String;
35

            
36
5
    fn from_str(s: &str) -> Result<Self, Self::Err> {
37
        match s {
38
10
            "text/plain" => Ok(Self::Text),
39
8
            "application/octet-stream" => Ok(Self::Blob),
40
2
            e => Err(format!("Invalid content type: {e}")),
41
        }
42
    }
43
}
44

            
45
impl ContentType {
46
4
    pub const fn as_str(&self) -> &'static str {
47
4
        match self {
48
4
            Self::Text => "text/plain",
49
4
            Self::Blob => "application/octet-stream",
50
        }
51
    }
52
}
53

            
54
/// A wrapper around a combination of (secret, content-type).
55
#[derive(Clone, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
56
pub enum Secret {
57
    /// Corresponds to [`ContentType::Text`]
58
    Text(String),
59
    /// Corresponds to [`ContentType::Blob`]
60
    Blob(Vec<u8>),
61
}
62

            
63
impl std::fmt::Debug for Secret {
64
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65
2
        match self {
66
2
            Self::Text(_) => write!(f, "Secret::Text([REDACTED])"),
67
2
            Self::Blob(_) => write!(f, "Secret::Blob([REDACTED])"),
68
        }
69
    }
70
}
71

            
72
impl Secret {
73
    /// Generate a random secret, used when creating a session collection.
74
6
    pub fn random() -> Result<Self, getrandom::Error> {
75
6
        let mut secret = [0; 64];
76
        // Equivalent of `ring::rand::SecureRandom`
77
6
        getrandom::fill(&mut secret)?;
78

            
79
6
        Ok(Self::blob(secret))
80
    }
81

            
82
    /// Create a text secret, stored with `text/plain` content type.
83
11
    pub fn text(value: impl AsRef<str>) -> Self {
84
24
        Self::Text(value.as_ref().to_owned())
85
    }
86

            
87
    /// Create a blob secret, stored with `application/octet-stream` content
88
    /// type.
89
34
    pub fn blob(value: impl AsRef<[u8]>) -> Self {
90
68
        Self::Blob(value.as_ref().to_owned())
91
    }
92

            
93
4
    pub const fn content_type(&self) -> ContentType {
94
4
        match self {
95
4
            Self::Text(_) => ContentType::Text,
96
4
            Self::Blob(_) => ContentType::Blob,
97
        }
98
    }
99

            
100
5
    pub fn as_bytes(&self) -> &[u8] {
101
14
        match self {
102
5
            Self::Text(text) => text.as_bytes(),
103
4
            Self::Blob(bytes) => bytes.as_ref(),
104
        }
105
    }
106

            
107
6
    pub fn with_content_type(content_type: ContentType, secret: impl AsRef<[u8]>) -> Self {
108
6
        match content_type {
109
12
            ContentType::Text => match String::from_utf8(secret.as_ref().to_owned()) {
110
6
                Ok(text) => Secret::text(text),
111
2
                Err(_e) => {
112
5
                    #[cfg(feature = "tracing")]
113
                    tracing::warn!(
114
                        "Failed to decode secret as UTF-8: {}, falling back to blob",
115
                        _e
116
                    );
117

            
118
2
                    Secret::blob(secret)
119
                }
120
            },
121
8
            _ => Secret::blob(secret),
122
        }
123
    }
124
}
125

            
126
impl From<&[u8]> for Secret {
127
4
    fn from(value: &[u8]) -> Self {
128
4
        Self::blob(value)
129
    }
130
}
131

            
132
impl From<Zeroizing<Vec<u8>>> for Secret {
133
2
    fn from(value: Zeroizing<Vec<u8>>) -> Self {
134
2
        Self::blob(value)
135
    }
136
}
137

            
138
impl From<Vec<u8>> for Secret {
139
4
    fn from(value: Vec<u8>) -> Self {
140
4
        Self::blob(value)
141
    }
142
}
143

            
144
impl From<&Vec<u8>> for Secret {
145
    fn from(value: &Vec<u8>) -> Self {
146
        Self::blob(value)
147
    }
148
}
149

            
150
impl<const N: usize> From<&[u8; N]> for Secret {
151
    fn from(value: &[u8; N]) -> Self {
152
        Self::blob(value)
153
    }
154
}
155

            
156
impl From<String> for Secret {
157
2
    fn from(value: String) -> Self {
158
2
        Self::text(value)
159
    }
160
}
161

            
162
impl From<&str> for Secret {
163
6
    fn from(value: &str) -> Self {
164
7
        Self::text(value)
165
    }
166
}
167

            
168
impl std::ops::Deref for Secret {
169
    type Target = [u8];
170

            
171
5
    fn deref(&self) -> &Self::Target {
172
7
        self.as_bytes()
173
    }
174
}
175

            
176
impl AsRef<[u8]> for Secret {
177
2
    fn as_ref(&self) -> &[u8] {
178
2
        self.as_bytes()
179
    }
180
}
181

            
182
#[cfg(test)]
183
mod tests {
184
    use zvariant::{Endian, serialized::Context, to_bytes};
185

            
186
    use super::*;
187

            
188
    #[test]
189
    fn secret_debug_is_redacted() {
190
        let text_secret = Secret::text("password");
191
        let blob_secret = Secret::blob([1, 2, 3]);
192

            
193
        assert_eq!(format!("{:?}", text_secret), "Secret::Text([REDACTED])");
194
        assert_eq!(format!("{:?}", blob_secret), "Secret::Blob([REDACTED])");
195
    }
196

            
197
    #[test]
198
    fn content_type_serialization() {
199
        let ctxt = Context::new_dbus(Endian::Little, 0);
200

            
201
        // Test Text serialization
202
        let encoded = to_bytes(ctxt, &ContentType::Text).unwrap();
203
        let value: String = encoded.deserialize().unwrap().0;
204
        assert_eq!(value, "text/plain");
205

            
206
        // Test Blob serialization
207
        let encoded = to_bytes(ctxt, &ContentType::Blob).unwrap();
208
        let value: String = encoded.deserialize().unwrap().0;
209
        assert_eq!(value, "application/octet-stream");
210

            
211
        // Test Text deserialization
212
        let encoded = to_bytes(ctxt, &"text/plain").unwrap();
213
        let content_type: ContentType = encoded.deserialize().unwrap().0;
214
        assert_eq!(content_type, ContentType::Text);
215

            
216
        // Test Blob deserialization
217
        let encoded = to_bytes(ctxt, &"application/octet-stream").unwrap();
218
        let content_type: ContentType = encoded.deserialize().unwrap().0;
219
        assert_eq!(content_type, ContentType::Blob);
220

            
221
        // Test invalid content type deserialization
222
        let encoded = to_bytes(ctxt, &"invalid/type").unwrap();
223
        let result: Result<(ContentType, _), _> = encoded.deserialize();
224
        assert!(result.is_err());
225
        assert!(
226
            result
227
                .unwrap_err()
228
                .to_string()
229
                .contains("Invalid content type")
230
        );
231
    }
232

            
233
    #[test]
234
    fn content_type_from_str() {
235
        assert_eq!(
236
            ContentType::from_str("text/plain").unwrap(),
237
            ContentType::Text
238
        );
239
        assert_eq!(
240
            ContentType::from_str("application/octet-stream").unwrap(),
241
            ContentType::Blob
242
        );
243

            
244
        // Test error case
245
        let result = ContentType::from_str("invalid");
246
        assert!(result.is_err());
247
        assert!(result.unwrap_err().contains("Invalid content type"));
248
    }
249

            
250
    #[test]
251
    fn invalid_utf8() {
252
        // Test with invalid UTF-8 bytes
253
        let invalid_utf8 = vec![0xFF, 0xFE, 0xFD];
254

            
255
        // Should fall back to blob when UTF-8 decoding fails
256
        let secret = Secret::with_content_type(ContentType::Text, &invalid_utf8);
257
        assert_eq!(secret.content_type(), ContentType::Blob);
258
        assert_eq!(&*secret, &[0xFF, 0xFE, 0xFD]);
259

            
260
        // Test with valid UTF-8
261
        let valid_utf8 = "Hello, World!";
262
        let secret = Secret::with_content_type(ContentType::Text, valid_utf8.as_bytes());
263
        assert_eq!(secret.content_type(), ContentType::Text);
264
        assert_eq!(&*secret, valid_utf8.as_bytes());
265

            
266
        // Test with blob content type
267
        let data = vec![1, 2, 3, 4];
268
        let secret = Secret::with_content_type(ContentType::Blob, &data);
269
        assert_eq!(secret.content_type(), ContentType::Blob);
270
        assert_eq!(&*secret, &[1, 2, 3, 4]);
271
    }
272

            
273
    #[test]
274
    fn random() {
275
        let secret1 = Secret::random().unwrap();
276
        let secret2 = Secret::random().unwrap();
277

            
278
        // Random secrets should be blobs
279
        assert_eq!(secret1.content_type(), ContentType::Blob);
280
        assert_eq!(secret2.content_type(), ContentType::Blob);
281

            
282
        // Should be 64 bytes
283
        assert_eq!(secret1.as_bytes().len(), 64);
284
        assert_eq!(secret2.as_bytes().len(), 64);
285

            
286
        // Should be different
287
        assert_ne!(secret1.as_bytes(), secret2.as_bytes());
288
    }
289
}