Skip to content

Commit 010a236

Browse files
committed
fix: allow aud claim to be a string or a sequence of strings
1 parent d698e7e commit 010a236

File tree

3 files changed

+162
-2
lines changed

3 files changed

+162
-2
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ uuid = { version = "0.8.2", features = ["serde"] }
2525
[dev-dependencies]
2626
actix-rt = "1.1.1"
2727
env_logger = "0.8.2"
28+
serde_json = "1.0.61"
2829
uuid = { version = "0.8.2", features = ["serde", "v4"] }

src/lib.rs

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ pub struct Claims {
194194
/// Issuer
195195
pub iss: String,
196196
/// Audience
197-
pub aud: Option<String>,
197+
///
198+
/// _This can be extracted from either a JSON string or a JSON sequence of strings._
199+
#[serde(default, deserialize_with = "deserialize_optional_string_or_strings")]
200+
pub aud: Option<Vec<String>>,
198201
/// Issuance date
199202
#[serde(with = "ts_seconds")]
200203
pub iat: DateTime<Utc>,
@@ -204,6 +207,24 @@ pub struct Claims {
204207
pub azp: String,
205208
}
206209

210+
fn deserialize_optional_string_or_strings<'de, D>(de: D) -> Result<Option<Vec<String>>, D::Error>
211+
where
212+
D: ::serde::Deserializer<'de>,
213+
{
214+
#[derive(Deserialize)]
215+
#[serde(untagged)]
216+
enum StringOrVec {
217+
String(String),
218+
Vec(Vec<String>),
219+
}
220+
221+
Option::<StringOrVec>::deserialize(de).map(|string_or_vec| match string_or_vec {
222+
Some(StringOrVec::String(string)) => Some(vec![string]),
223+
Some(StringOrVec::Vec(vec)) => Some(vec),
224+
None => None,
225+
})
226+
}
227+
207228
impl Default for Claims {
208229
fn default() -> Self {
209230
use chrono::Duration;
@@ -215,7 +236,7 @@ impl Default for Claims {
215236
realm_access: None,
216237
resource_access: None,
217238
iss: env!("CARGO_PKG_NAME").to_owned(),
218-
aud: Some("account".to_owned()),
239+
aud: Some(vec!["account".to_owned()]),
219240
iat: Utc::now(),
220241
jti: Uuid::from_u128_le(22685491128062564230891640495451214097),
221242
azp: "".to_owned(),
@@ -392,3 +413,81 @@ where
392413
}
393414
}
394415
}
416+
417+
#[cfg(test)]
418+
mod tests {
419+
use super::*;
420+
use serde_json::{from_value, json};
421+
422+
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
423+
struct StringOrVec {
424+
field: u8,
425+
#[serde(default, deserialize_with = "deserialize_optional_string_or_strings")]
426+
string_or_vec: Option<Vec<String>>,
427+
}
428+
429+
#[test]
430+
fn deserialize_string_or_vec_when_vec() {
431+
let input = json!({
432+
"field": 1,
433+
"string_or_vec": ["1", "2"],
434+
});
435+
let output = from_value::<StringOrVec>(input);
436+
assert_eq!(
437+
output.ok(),
438+
Some(StringOrVec {
439+
field: 1,
440+
string_or_vec: Some(vec!["1".to_owned(), "2".to_owned()]),
441+
})
442+
)
443+
}
444+
445+
#[test]
446+
fn deserialize_string_or_vec_when_string() {
447+
let input = json!({
448+
"field": 1,
449+
"string_or_vec": "1",
450+
});
451+
let output = from_value::<StringOrVec>(input);
452+
assert_eq!(
453+
output.ok(),
454+
Some(StringOrVec {
455+
field: 1,
456+
string_or_vec: Some(vec!["1".to_owned()]),
457+
})
458+
)
459+
}
460+
461+
#[test]
462+
fn deserialize_string_or_vec_when_none() {
463+
let input = json!({
464+
"field": 1,
465+
});
466+
let output = from_value::<StringOrVec>(input);
467+
dbg!(&output);
468+
assert_eq!(
469+
output.ok(),
470+
Some(StringOrVec {
471+
field: 1,
472+
string_or_vec: None,
473+
})
474+
)
475+
}
476+
477+
#[test]
478+
fn deserialize_string_or_vec_when_null() {
479+
let input = json!({
480+
"field": 1,
481+
"string_or_vec": null,
482+
});
483+
let output = from_value::<StringOrVec>(input);
484+
dbg!(&output);
485+
assert_eq!(
486+
output.ok(),
487+
Some(StringOrVec {
488+
field: 1,
489+
string_or_vec: None,
490+
})
491+
)
492+
}
493+
}

tests/middleware.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use actix_web::web::{Bytes, ReqData};
88
use actix_web::{test, web, App, HttpResponse, Responder};
99
use actix_web_middleware_keycloak_auth::{Access, Claims, KeycloakAuth, Role};
1010
use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header};
11+
use serde_json::json;
1112
use std::collections::HashMap;
1213
use std::iter::FromIterator;
1314
use uuid::Uuid;
@@ -424,3 +425,62 @@ async fn valid_jwt_roles() {
424425
let body = test::read_body(resp).await;
425426
assert_eq!(body, Bytes::from(user_id.to_string()));
426427
}
428+
429+
#[actix_rt::test]
430+
async fn from_raw_claims_single_aud_as_string() {
431+
init_logger();
432+
433+
let keycloak_auth = KeycloakAuth {
434+
detailed_responses: true,
435+
keycloak_oid_public_key: DecodingKey::from_rsa_pem(KEYCLOAK_PK.as_bytes()).unwrap(),
436+
required_roles: vec![Role::Client {
437+
client: "client1".to_owned(),
438+
role: "test1".to_owned(),
439+
}],
440+
};
441+
let mut app = test::init_service(
442+
App::new()
443+
.service(
444+
web::scope("/private")
445+
.wrap(keycloak_auth)
446+
.route("", web::get().to(private)),
447+
)
448+
.service(web::resource("/").to(hello_world)),
449+
)
450+
.await;
451+
452+
let user_id = Uuid::new_v4();
453+
let default = Claims::default();
454+
let claims = json!({
455+
"sub": user_id,
456+
"resource_access": {
457+
"client1": {
458+
"roles": ["test1"],
459+
},
460+
"client2": {
461+
"roles": ["test2"],
462+
},
463+
},
464+
// Defaults
465+
"exp": default.exp.timestamp(),
466+
"iss": default.iss,
467+
"aud": "some-aud",
468+
"iat": default.iat.timestamp(),
469+
"jti": default.jti,
470+
"azp": default.azp,
471+
});
472+
let jwt = encode(
473+
&Header::new(Algorithm::RS256),
474+
&claims,
475+
&EncodingKey::from_rsa_pem(KEYCLOAK_KEY.as_bytes()).unwrap(),
476+
)
477+
.unwrap();
478+
let req = test::TestRequest::with_uri("/private")
479+
.header("Authorization", format!("Bearer {}", &jwt))
480+
.to_request();
481+
let resp = test::call_service(&mut app, req).await;
482+
483+
assert!(resp.status().is_success());
484+
let body = test::read_body(resp).await;
485+
assert_eq!(body, Bytes::from(user_id.to_string()));
486+
}

0 commit comments

Comments
 (0)