1
use std::sync::Arc;
2

            
3
use formatx::formatx;
4
use gettextrs::gettext;
5
use oo7::{Key, ashpd::WindowIdentifierType, dbus::ServiceError};
6
use serde::{Deserialize, Serialize};
7
use tokio::sync::OnceCell;
8
use zbus::zvariant::{self, ObjectPath, Optional, OwnedObjectPath, Type, Value, as_value};
9

            
10
use super::secret_exchange;
11
use crate::{
12
    error::custom_service_error,
13
    prompt::{Prompt, PromptRole},
14
    service::Service,
15
};
16

            
17
/// Custom serde module to handle GCR's double-Value wrapping bug
18
///
19
/// See: https://gitlab.gnome.org/GNOME/gcr/-/merge_requests/169
20
mod double_value_optional {
21
    use super::*;
22

            
23
6
    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
24
    where
25
        D: serde::Deserializer<'de>,
26
        T: TryFrom<Value<'de>> + zvariant::Type,
27
        T::Error: std::fmt::Display,
28
    {
29
6
        let outer_value = Value::deserialize(deserializer)?;
30

            
31
        // Try to downcast to check if it's double-wrapped
32
12
        let value_to_deserialize = match outer_value.downcast_ref::<Value>() {
33
12
            Ok(_) => outer_value.downcast::<Value>().map_err(|e| {
34
                serde::de::Error::custom(format!("Failed to unwrap double-wrapped Value: {e}"))
35
            })?,
36
            Err(_) => outer_value,
37
        };
38

            
39
6
        match T::try_from(value_to_deserialize) {
40
6
            Ok(val) => Ok(Some(val)),
41
            Err(_) => Ok(None),
42
        }
43
    }
44
}
45

            
46
#[derive(Debug, Serialize, Deserialize, Type, Default)]
47
#[zvariant(signature = "dict")]
48
#[serde(rename_all = "kebab-case")]
49
// GcrPrompt properties <https://gitlab.gnome.org/GNOME/gcr/-/blob/main/gcr/gcr-prompt.c#L95>
50
pub struct Properties {
51
    #[serde(
52
        serialize_with = "as_value::optional::serialize",
53
        deserialize_with = "double_value_optional::deserialize",
54
        skip_serializing_if = "Option::is_none",
55
        default
56
    )]
57
    title: Option<String>,
58
    #[serde(
59
        serialize_with = "as_value::optional::serialize",
60
        deserialize_with = "double_value_optional::deserialize",
61
        skip_serializing_if = "Option::is_none",
62
        default
63
    )]
64
    message: Option<String>,
65
    #[serde(
66
        serialize_with = "as_value::optional::serialize",
67
        deserialize_with = "double_value_optional::deserialize",
68
        skip_serializing_if = "Option::is_none",
69
        default
70
    )]
71
    description: Option<String>,
72
    #[serde(
73
        serialize_with = "as_value::optional::serialize",
74
        deserialize_with = "double_value_optional::deserialize",
75
        skip_serializing_if = "Option::is_none",
76
        default
77
    )]
78
    warning: Option<String>,
79
    #[serde(
80
        serialize_with = "as_value::optional::serialize",
81
        deserialize_with = "double_value_optional::deserialize",
82
        skip_serializing_if = "Option::is_none",
83
        default
84
    )]
85
    password_new: Option<bool>,
86
    #[serde(
87
        serialize_with = "as_value::optional::serialize",
88
        deserialize_with = "double_value_optional::deserialize",
89
        skip_serializing_if = "Option::is_none",
90
        default
91
    )]
92
    password_strength: Option<i32>,
93
    #[serde(
94
        serialize_with = "as_value::optional::serialize",
95
        deserialize_with = "double_value_optional::deserialize",
96
        skip_serializing_if = "Option::is_none",
97
        default
98
    )]
99
    choice_label: Option<String>,
100
    #[serde(
101
        serialize_with = "as_value::optional::serialize",
102
        deserialize_with = "double_value_optional::deserialize",
103
        skip_serializing_if = "Option::is_none",
104
        default
105
    )]
106
    choice_chosen: Option<bool>,
107
    #[serde(
108
        with = "as_value::optional",
109
        skip_serializing_if = "Option::is_none",
110
        default
111
    )]
112
    caller_window: Option<WindowIdentifierType>,
113
    #[serde(
114
        serialize_with = "as_value::optional::serialize",
115
        deserialize_with = "double_value_optional::deserialize",
116
        skip_serializing_if = "Option::is_none",
117
        default
118
    )]
119
    continue_label: Option<String>,
120
    #[serde(
121
        serialize_with = "as_value::optional::serialize",
122
        deserialize_with = "double_value_optional::deserialize",
123
        skip_serializing_if = "Option::is_none",
124
        default
125
    )]
126
    cancel_label: Option<String>,
127
}
128

            
129
impl Properties {
130
2
    fn for_unlock(
131
        keyring: &str,
132
        warning: Option<&str>,
133
        window_id: Option<&WindowIdentifierType>,
134
    ) -> Self {
135
        Self {
136
2
            title: Some(gettext("Unlock Keyring")),
137
4
            message: Some(gettext("Authentication required")),
138
2
            description: Some(
139
                formatx!(
140
                    gettext("An application wants access to the keyring '{}', but it is locked",),
141
                    keyring,
142
                )
143
                .expect("Wrong format in translatable string"),
144
            ),
145
2
            warning: warning.map(ToOwned::to_owned),
146
            password_new: None,
147
            password_strength: None,
148
            choice_label: None,
149
            choice_chosen: None,
150
2
            caller_window: window_id.map(ToOwned::to_owned),
151
4
            continue_label: Some(gettext("Unlock")),
152
4
            cancel_label: Some(gettext("Cancel")),
153
        }
154
    }
155

            
156
2
    fn for_create_collection(label: &str, window_id: Option<&WindowIdentifierType>) -> Self {
157
        Self {
158
2
            title: Some(gettext("New Keyring Password")),
159
4
            message: Some(gettext("Choose password for new keyring")),
160
2
            description: Some(
161
                formatx!(
162
                    gettext("An application wants to create a new keyring called '{}'. Choose the password you want to use for it."),
163
                    &label
164
                )
165
                .expect("Wrong format in translatable string")
166
            ),
167
            warning: None,
168
            password_new: Some(true),
169
            password_strength: None,
170
            choice_label: None,
171
            choice_chosen: None,
172
2
            caller_window: window_id.map(ToOwned::to_owned),
173
4
            continue_label: Some(gettext("Create")),
174
4
            cancel_label: Some(gettext("Cancel")),
175
        }
176
    }
177
}
178

            
179
#[derive(Deserialize, Serialize, Debug, Type)]
180
#[serde(rename_all = "lowercase")]
181
#[zvariant(signature = "s")]
182
pub enum Reply {
183
    No,
184
    Yes,
185
}
186

            
187
impl zvariant::NoneValue for Reply {
188
    type NoneType = String;
189

            
190
2
    fn null_value() -> Self::NoneType {
191
2
        String::new()
192
    }
193
}
194

            
195
impl TryFrom<String> for Reply {
196
    type Error = String;
197

            
198
2
    fn try_from(value: String) -> Result<Self, Self::Error> {
199
4
        match value.as_str() {
200
4
            "no" => Ok(Reply::No),
201
6
            "yes" => Ok(Reply::Yes),
202
            _ => Err("Invalid value".to_string()),
203
        }
204
    }
205
}
206

            
207
#[derive(Deserialize, Serialize, Debug, Type, PartialEq, Eq, PartialOrd, Ord)]
208
#[serde(rename_all = "lowercase")]
209
#[zvariant(signature = "s")]
210
pub enum PromptType {
211
    Confirm,
212
    Password,
213
}
214

            
215
#[zbus::proxy(
216
    default_service = "org.gnome.keyring.SystemPrompter",
217
    interface = "org.gnome.keyring.internal.Prompter",
218
    default_path = "/org/gnome/keyring/Prompter",
219
    gen_blocking = false
220
)]
221
pub trait Prompter {
222
    fn begin_prompting(&self, callback: &ObjectPath<'_>) -> Result<(), ServiceError>;
223

            
224
    fn perform_prompt(
225
        &self,
226
        callback: &ObjectPath<'_>,
227
        type_: PromptType,
228
        properties: Properties,
229
        exchange: &str,
230
    ) -> Result<(), ServiceError>;
231

            
232
    fn stop_prompting(&self, callback: &ObjectPath<'_>) -> Result<(), ServiceError>;
233
}
234

            
235
#[derive(Debug, Clone)]
236
pub struct PrompterCallback {
237
    window_id: Option<WindowIdentifierType>,
238
    private_key: Arc<Key>,
239
    public_key: Arc<Key>,
240
    exchange: OnceCell<String>,
241
    service: Service,
242
    prompt_path: OwnedObjectPath,
243
    path: OwnedObjectPath,
244
}
245

            
246
#[zbus::interface(name = "org.gnome.keyring.internal.Prompter.Callback")]
247
impl PrompterCallback {
248
2
    pub async fn prompt_ready(
249
        &self,
250
        reply: Optional<Reply>,
251
        _properties: Properties,
252
        exchange: &str,
253
    ) -> Result<(), ServiceError> {
254
2
        let prompt_path = &self.prompt_path;
255
4
        let Some(prompt) = self.service.prompt(prompt_path).await else {
256
4
            return Err(ServiceError::NoSuchObject(format!(
257
                "Prompt '{prompt_path}' does not exist."
258
            )));
259
        };
260

            
261
4
        match *reply {
262
            // First PromptReady call
263
2
            None => {
264
2
                self.prompter_init(&prompt).await?;
265
            }
266
            // Second PromptReady call with final exchange
267
2
            Some(Reply::Yes) => {
268
6
                self.prompter_done(&prompt, exchange).await?;
269
            }
270
            // Dismissed prompt
271
2
            Some(Reply::No) => {
272
8
                self.prompter_dismissed(prompt.path().clone().into())
273
6
                    .await?;
274
            }
275
        };
276
2
        Ok(())
277
    }
278

            
279
8
    async fn prompt_done(&self) -> Result<(), ServiceError> {
280
        // This is only does check if the prompt is tracked on Service
281
2
        let path = &self.prompt_path;
282
4
        if self.service.prompt(path).await.is_some() {
283
8
            self.service
284
                .object_server()
285
2
                .remove::<Prompt, _>(path)
286
6
                .await?;
287
2
            self.service.remove_prompt(path).await;
288
        }
289
8
        self.service
290
            .object_server()
291
2
            .remove::<Self, _>(&self.path)
292
6
            .await?;
293

            
294
2
        Ok(())
295
    }
296
}
297

            
298
impl PrompterCallback {
299
2
    pub async fn new(
300
        window_id: Option<WindowIdentifierType>,
301
        service: Service,
302
        prompt_path: OwnedObjectPath,
303
    ) -> Result<Self, oo7::crypto::Error> {
304
4
        let index = service.prompt_index().await;
305
4
        let private_key = Arc::new(Key::generate_private_key()?);
306
4
        let public_key = Arc::new(crate::gnome::crypto::generate_public_key(&private_key)?);
307
2
        Ok(Self {
308
2
            window_id,
309
2
            public_key,
310
2
            private_key,
311
2
            exchange: Default::default(),
312
4
            path: OwnedObjectPath::try_from(format!("/org/gnome/keyring/Prompt/p{index}")).unwrap(),
313
2
            service,
314
2
            prompt_path,
315
        })
316
    }
317

            
318
2
    pub fn path(&self) -> &ObjectPath<'_> {
319
2
        &self.path
320
    }
321

            
322
8
    async fn prompter_init(&self, prompt: &Prompt) -> Result<(), ServiceError> {
323
4
        let connection = self.service.connection();
324
2
        let exchange = secret_exchange::begin(&self.public_key);
325
2
        self.exchange.set(exchange).unwrap();
326

            
327
2
        let label = prompt.label();
328
4
        let (properties, prompt_type) = match prompt.role() {
329
2
            PromptRole::Unlock => (
330
2
                Properties::for_unlock(label, None, self.window_id.as_ref()),
331
                PromptType::Password,
332
            ),
333
2
            PromptRole::CreateCollection => (
334
4
                Properties::for_create_collection(label, self.window_id.as_ref()),
335
                PromptType::Password,
336
            ),
337
        };
338

            
339
4
        let prompter = PrompterProxy::new(connection).await?;
340
4
        let path = self.path.clone();
341
4
        let exchange = self.exchange.get().unwrap().clone();
342
6
        tokio::spawn(async move {
343
8
            prompter
344
6
                .perform_prompt(&path, prompt_type, properties, &exchange)
345
10
                .await
346
        });
347
2
        Ok(())
348
    }
349

            
350
10
    async fn prompter_done(&self, prompt: &Prompt, exchange: &str) -> Result<(), ServiceError> {
351
4
        let prompter = PrompterProxy::new(self.service.connection()).await?;
352
4
        let aes_key = secret_exchange::handshake(&self.private_key, exchange).map_err(|err| {
353
            custom_service_error(&format!(
354
                "Failed to generate AES key for SecretExchange {err}."
355
            ))
356
        })?;
357

            
358
4
        let Some(secret) = secret_exchange::retrieve(exchange, &aes_key) else {
359
            return Err(custom_service_error(
360
                "Failed to retrieve keyring secret from SecretExchange.",
361
            ));
362
        };
363

            
364
        // Handle each role differently based on what validation/preparation is needed
365
4
        match prompt.role() {
366
            PromptRole::Unlock => {
367
2
                if prompt.on_unlock_collection(secret).await? {
368
2
                    let path = self.path.clone();
369
6
                    tokio::spawn(async move { prompter.stop_prompting(&path).await });
370
                } else {
371
                    let properties = Properties::for_unlock(
372
2
                        prompt.label(),
373
                        Some("The unlock password was incorrect"),
374
2
                        self.window_id.as_ref(),
375
                    );
376
2
                    let server_exchange = self
377
                        .exchange
378
                        .get()
379
                        .expect("Exchange cannot be empty at this stage")
380
                        .clone();
381
2
                    let path = self.path.clone();
382

            
383
6
                    tokio::spawn(async move {
384
8
                        prompter
385
2
                            .perform_prompt(
386
2
                                &path,
387
                                PromptType::Password,
388
2
                                properties,
389
2
                                &server_exchange,
390
                            )
391
10
                            .await
392
                    });
393
                }
394
            }
395
            PromptRole::CreateCollection => {
396
6
                prompt.on_create_collection(secret).await?;
397

            
398
2
                let path = self.path.clone();
399
6
                tokio::spawn(async move { prompter.stop_prompting(&path).await });
400
            }
401
        }
402
2
        Ok(())
403
    }
404

            
405
8
    async fn prompter_dismissed(&self, prompt_path: OwnedObjectPath) -> Result<(), ServiceError> {
406
4
        let path = self.path.clone();
407
4
        let prompter = PrompterProxy::new(self.service.connection()).await?;
408

            
409
8
        tokio::spawn(async move { prompter.stop_prompting(&path).await });
410
4
        let signal_emitter = self.service.signal_emitter(prompt_path)?;
411
4
        let result = zvariant::Value::new::<Vec<OwnedObjectPath>>(vec![])
412
            .try_into_owned()
413
            .unwrap();
414

            
415
6
        tokio::spawn(async move { Prompt::completed(&signal_emitter, true, result).await });
416
2
        Ok(())
417
    }
418
}
419

            
420
#[cfg(test)]
421
mod tests {
422
    use std::collections::HashMap;
423

            
424
    use zvariant::{serialized::Context, to_bytes};
425

            
426
    use super::*;
427

            
428
    #[test]
429
    fn properties_serialization_roundtrip() {
430
        let props = Properties {
431
            title: Some("Test Title".to_string()),
432
            message: Some("Test Message".to_string()),
433
            ..Default::default()
434
        };
435

            
436
        // Serialize to bytes
437
        let ctxt = Context::new_dbus(zvariant::LE, 0);
438
        let encoded = to_bytes(ctxt, &props).expect("Failed to serialize");
439

            
440
        // Deserialize back to verify roundtrip works
441
        let decoded: Properties = encoded.deserialize().unwrap().0;
442

            
443
        assert_eq!(decoded.title, Some("Test Title".to_string()));
444
        assert_eq!(decoded.message, Some("Test Message".to_string()));
445
    }
446

            
447
    #[test]
448
    fn deserialize_properties() {
449
        let mut map: HashMap<String, Value> = HashMap::new();
450

            
451
        // Double-wrap: Value<Value<String>>
452
        map.insert(
453
            "title".to_string(),
454
            Value::new(Value::new("Unlock Keyring")),
455
        );
456

            
457
        map.insert(
458
            "message".to_string(),
459
            Value::new(Value::new("Authentication required")),
460
        );
461

            
462
        // Serialize the HashMap
463
        let ctxt = Context::new_dbus(zvariant::LE, 0);
464
        let encoded = to_bytes(ctxt, &map).expect("Failed to serialize test data");
465

            
466
        // Deserialize as Properties
467
        let props: Properties = encoded.deserialize().unwrap().0;
468

            
469
        assert_eq!(props.title, Some("Unlock Keyring".to_string()));
470
        assert_eq!(props.message, Some("Authentication required".to_string()));
471

            
472
        let mut map: HashMap<String, Value> = HashMap::new();
473

            
474
        // Single-wrap: Value<String> (the correct format)
475
        map.insert("title".to_string(), Value::new("Unlock Keyring"));
476
        map.insert("message".to_string(), Value::new("Authentication required"));
477

            
478
        // Serialize the HashMap
479
        let ctxt = Context::new_dbus(zvariant::LE, 0);
480
        let encoded = to_bytes(ctxt, &map).expect("Failed to serialize test data");
481

            
482
        // Deserialize as Properties - should also work
483
        let props: Properties = encoded.deserialize().unwrap().0;
484

            
485
        assert_eq!(props.title, Some("Unlock Keyring".to_string()));
486
        assert_eq!(props.message, Some("Authentication required".to_string()));
487

            
488
        let props = Properties {
489
            title: None,
490
            message: Some("Test".to_string()),
491
            ..Default::default()
492
        };
493

            
494
        let ctxt = Context::new_dbus(zvariant::LE, 0);
495
        let encoded = to_bytes(ctxt, &props).expect("Failed to serialize");
496
        let decoded: Properties = encoded.deserialize().unwrap().0;
497

            
498
        assert_eq!(decoded.title, None);
499
        assert_eq!(decoded.message, Some("Test".to_string()));
500

            
501
        let props = Properties {
502
            password_new: Some(true),
503
            password_strength: Some(42),
504
            choice_chosen: Some(false),
505
            ..Default::default()
506
        };
507

            
508
        let ctxt = Context::new_dbus(zvariant::LE, 0);
509
        let encoded = to_bytes(ctxt, &props).expect("Failed to serialize");
510
        let decoded: Properties = encoded.deserialize().unwrap().0;
511

            
512
        assert_eq!(decoded.password_new, Some(true));
513
        assert_eq!(decoded.password_strength, Some(42));
514
        assert_eq!(decoded.choice_chosen, Some(false));
515
    }
516
}