Skip to content

Commit ba5ceef

Browse files
authored
Make psk callback async-capable (#751)
* Make psk callback async-capable * Amend examples
1 parent d3b1b31 commit ba5ceef

File tree

7 files changed

+38
-21
lines changed

7 files changed

+38
-21
lines changed

dtls/src/config.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::future::Future;
2+
use std::pin::Pin;
13
use std::sync::Arc;
24

35
use tokio::time::Duration;
@@ -131,7 +133,8 @@ pub(crate) const DEFAULT_MTU: usize = 1200; // bytes
131133

132134
// PSKCallback is called once we have the remote's psk_identity_hint.
133135
// If the remote provided none it will be nil
134-
pub(crate) type PskCallback = Arc<dyn (Fn(&[u8]) -> Result<Vec<u8>>) + Send + Sync>;
136+
pub(crate) type PskCallback =
137+
Arc<dyn (Fn(&[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>) + Send + Sync>;
135138

136139
// ClientAuthType declares the policy the server will follow for
137140
// TLS Client Authentication.

dtls/src/conn/conn_test.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
use std::future::Future;
2+
use std::pin::Pin;
13
use std::time::SystemTime;
24

35
use rand::Rng;
46
use rustls::pki_types::CertificateDer;
7+
use tokio::time::sleep;
58
use util::conn::conn_pipe::*;
69
use util::KeyingMaterialExporter;
710

@@ -79,24 +82,27 @@ async fn pipe_conn(
7982
Ok((client, sever))
8083
}
8184

82-
fn psk_callback_client(hint: &[u8]) -> Result<Vec<u8>> {
85+
fn psk_callback_client(hint: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
8386
trace!(
8487
"Server's hint: {}",
8588
String::from_utf8(hint.to_vec()).unwrap()
8689
);
87-
Ok(vec![0xAB, 0xC1, 0x23])
90+
Box::pin(async move { Ok(vec![0xAB, 0xC1, 0x23]) })
8891
}
8992

90-
fn psk_callback_server(hint: &[u8]) -> Result<Vec<u8>> {
93+
fn psk_callback_server(hint: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
9194
trace!(
9295
"Client's hint: {}",
9396
String::from_utf8(hint.to_vec()).unwrap()
9497
);
95-
Ok(vec![0xAB, 0xC1, 0x23])
98+
Box::pin(async move {
99+
sleep(Duration::from_millis(1)).await; // Now it's possible to await in the psk callback
100+
Ok(vec![0xAB, 0xC1, 0x23])
101+
})
96102
}
97103

98-
fn psk_callback_hint_fail(_hint: &[u8]) -> Result<Vec<u8>> {
99-
Err(Error::Other(ERR_PSK_REJECTED.to_owned()))
104+
fn psk_callback_hint_fail(_hint: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
105+
Box::pin(async move { Err(Error::Other(ERR_PSK_REJECTED.to_owned())) })
100106
}
101107

102108
async fn create_test_client(
@@ -1617,7 +1623,7 @@ async fn test_cipher_suite_configuration() -> Result<()> {
16171623
assert!(cipher_suite.is_some(), "{name} expected some, but got none");
16181624
if let Some(cs) = &*cipher_suite {
16191625
assert_eq!(cs.id(), want_cs,
1620-
"test_cipher_suite_configuration: Server Selected Bad Cipher Suite '{}': expected({}) actual({})",
1626+
"test_cipher_suite_configuration: Server Selected Bad Cipher Suite '{}': expected({}) actual({})",
16211627
name, want_cs, cs.id());
16221628
}
16231629
}
@@ -1630,8 +1636,8 @@ async fn test_cipher_suite_configuration() -> Result<()> {
16301636
Ok(())
16311637
}
16321638

1633-
fn psk_callback(_b: &[u8]) -> Result<Vec<u8>> {
1634-
Ok(vec![0x00, 0x01, 0x02])
1639+
fn psk_callback(_b: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
1640+
Box::pin(async move { Ok(vec![0x00, 0x01, 0x02]) })
16351641
}
16361642

16371643
#[tokio::test]

dtls/src/flight/flight3.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ impl Flight for Flight3 {
319319
}
320320
};
321321

322-
if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h) {
322+
if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h).await {
323323
return Err((alert, err));
324324
}
325325
}
@@ -411,13 +411,13 @@ impl Flight for Flight3 {
411411
}
412412
}
413413

414-
pub(crate) fn handle_server_key_exchange(
414+
pub(crate) async fn handle_server_key_exchange(
415415
state: &mut State,
416416
cfg: &HandshakeConfig,
417417
h: &HandshakeMessageServerKeyExchange,
418418
) -> Result<(), (Option<Alert>, Option<Error>)> {
419419
if let Some(local_psk_callback) = &cfg.local_psk_callback {
420-
let psk = match local_psk_callback(&h.identity_hint) {
420+
let psk = match local_psk_callback(&h.identity_hint).await {
421421
Ok(psk) => psk,
422422
Err(err) => {
423423
return Err((

dtls/src/flight/flight4.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ impl Flight for Flight4 {
290290

291291
let mut pre_master_secret = vec![];
292292
if let Some(local_psk_callback) = &cfg.local_psk_callback {
293-
let psk = match local_psk_callback(&client_key_exchange.identity_hint) {
293+
let psk = match local_psk_callback(&client_key_exchange.identity_hint).await
294+
{
294295
Ok(psk) => psk,
295296
Err(err) => {
296297
return Err((

dtls/src/flight/flight5.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,8 @@ impl Flight for Flight5 {
273273

274274
// handshakeMessageServerKeyExchange is optional for PSK
275275
if server_key_exchange_data.is_empty() {
276-
if let Err((alert, err)) = handle_server_key_exchange(state, cfg, &server_key_exchange)
276+
if let Err((alert, err)) =
277+
handle_server_key_exchange(state, cfg, &server_key_exchange).await
277278
{
278279
return Err((alert, err));
279280
}

examples/examples/dtls/dial/psk/dial_psk.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ async fn main() -> Result<(), Error> {
6262
println!("connecting {server}..");
6363

6464
let config = Config {
65-
psk: Some(Arc::new(|hint: &[u8]| -> Result<Vec<u8>, Error> {
66-
println!("Server's hint: {}", String::from_utf8(hint.to_vec())?);
67-
Ok(vec![0xAB, 0xC1, 0x23])
65+
psk: Some(Arc::new(|hint: &[u8]| {
66+
let hint = hint.to_owned();
67+
Box::pin(async move {
68+
println!("Server's hint: {}", String::from_utf8(hint.to_vec())?);
69+
Ok(vec![0xAB, 0xC1, 0x23])
70+
})
6871
})),
6972
psk_identity_hint: Some("webrtc-rs DTLS Server".as_bytes().to_vec()),
7073
cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8],

examples/examples/dtls/listen/psk/listen_psk.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ async fn main() -> Result<(), Error> {
5757
let host = matches.value_of("host").unwrap().to_owned();
5858

5959
let cfg = Config {
60-
psk: Some(Arc::new(|hint: &[u8]| -> Result<Vec<u8>, Error> {
61-
println!("Client's hint: {}", String::from_utf8(hint.to_vec())?);
62-
Ok(vec![0xAB, 0xC1, 0x23])
60+
psk: Some(Arc::new(|hint: &[u8]| {
61+
let hint = hint.to_owned();
62+
Box::pin(async move {
63+
println!("Client's hint: {}", String::from_utf8(hint.to_vec())?);
64+
Ok(vec![0xAB, 0xC1, 0x23])
65+
})
6366
})),
6467
psk_identity_hint: Some("webrtc-rs DTLS Client".as_bytes().to_vec()),
6568
cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8],

0 commit comments

Comments
 (0)