mas_axum_utils/
client_authorization.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::HashMap;
8
9use axum::{
10    BoxError, Json,
11    extract::{
12        Form, FromRequest,
13        rejection::{FailedToDeserializeForm, FormRejection},
14    },
15    response::IntoResponse,
16};
17use headers::authorization::{Basic, Bearer, Credentials as _};
18use http::{Request, StatusCode};
19use mas_data_model::{Client, JwksOrJwksUri};
20use mas_http::RequestBuilderExt;
21use mas_iana::oauth::OAuthClientAuthenticationMethod;
22use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
23use mas_keystore::Encrypter;
24use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
25use oauth2_types::errors::{ClientError, ClientErrorCode};
26use serde::{Deserialize, de::DeserializeOwned};
27use serde_json::Value;
28use thiserror::Error;
29
30use crate::record_error;
31
32static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
33
34#[derive(Deserialize)]
35struct AuthorizedForm<F = ()> {
36    client_id: Option<String>,
37    client_secret: Option<String>,
38    client_assertion_type: Option<String>,
39    client_assertion: Option<String>,
40
41    #[serde(flatten)]
42    inner: F,
43}
44
45#[derive(Debug, PartialEq, Eq)]
46pub enum Credentials {
47    None {
48        client_id: String,
49    },
50    ClientSecretBasic {
51        client_id: String,
52        client_secret: String,
53    },
54    ClientSecretPost {
55        client_id: String,
56        client_secret: String,
57    },
58    ClientAssertionJwtBearer {
59        client_id: String,
60        jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
61    },
62    BearerToken {
63        token: String,
64    },
65}
66
67impl Credentials {
68    /// Get the `client_id` of the credentials
69    #[must_use]
70    pub fn client_id(&self) -> Option<&str> {
71        match self {
72            Credentials::None { client_id }
73            | Credentials::ClientSecretBasic { client_id, .. }
74            | Credentials::ClientSecretPost { client_id, .. }
75            | Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
76            Credentials::BearerToken { .. } => None,
77        }
78    }
79
80    /// Get the bearer token from the credentials.
81    #[must_use]
82    pub fn bearer_token(&self) -> Option<&str> {
83        match self {
84            Credentials::BearerToken { token } => Some(token),
85            _ => None,
86        }
87    }
88
89    /// Fetch the client from the database
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if the client could not be found or if the underlying
94    /// repository errored.
95    pub async fn fetch<E>(
96        &self,
97        repo: &mut impl RepositoryAccess<Error = E>,
98    ) -> Result<Option<Client>, E> {
99        let client_id = match self {
100            Credentials::None { client_id }
101            | Credentials::ClientSecretBasic { client_id, .. }
102            | Credentials::ClientSecretPost { client_id, .. }
103            | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
104            Credentials::BearerToken { .. } => return Ok(None),
105        };
106
107        repo.oauth2_client().find_by_client_id(client_id).await
108    }
109
110    /// Verify credentials presented by the client for authentication
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the credentials are invalid.
115    #[tracing::instrument(skip_all)]
116    pub async fn verify(
117        &self,
118        http_client: &reqwest::Client,
119        encrypter: &Encrypter,
120        method: &OAuthClientAuthenticationMethod,
121        client: &Client,
122    ) -> Result<(), CredentialsVerificationError> {
123        match (self, method) {
124            (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
125
126            (
127                Credentials::ClientSecretPost { client_secret, .. },
128                OAuthClientAuthenticationMethod::ClientSecretPost,
129            )
130            | (
131                Credentials::ClientSecretBasic { client_secret, .. },
132                OAuthClientAuthenticationMethod::ClientSecretBasic,
133            ) => {
134                // Decrypt the client_secret
135                let encrypted_client_secret = client
136                    .encrypted_client_secret
137                    .as_ref()
138                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
139
140                let decrypted_client_secret = encrypter
141                    .decrypt_string(encrypted_client_secret)
142                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
143
144                // Check if the client_secret matches
145                if client_secret.as_bytes() != decrypted_client_secret {
146                    return Err(CredentialsVerificationError::ClientSecretMismatch);
147                }
148            }
149
150            (
151                Credentials::ClientAssertionJwtBearer { jwt, .. },
152                OAuthClientAuthenticationMethod::PrivateKeyJwt,
153            ) => {
154                // Get the client JWKS
155                let jwks = client
156                    .jwks
157                    .as_ref()
158                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
159
160                let jwks = fetch_jwks(http_client, jwks)
161                    .await
162                    .map_err(CredentialsVerificationError::JwksFetchFailed)?;
163
164                jwt.verify_with_jwks(&jwks)
165                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
166            }
167
168            (
169                Credentials::ClientAssertionJwtBearer { jwt, .. },
170                OAuthClientAuthenticationMethod::ClientSecretJwt,
171            ) => {
172                // Decrypt the client_secret
173                let encrypted_client_secret = client
174                    .encrypted_client_secret
175                    .as_ref()
176                    .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
177
178                let decrypted_client_secret = encrypter
179                    .decrypt_string(encrypted_client_secret)
180                    .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
181
182                jwt.verify_with_shared_secret(decrypted_client_secret)
183                    .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
184            }
185
186            (_, _) => {
187                return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
188            }
189        }
190        Ok(())
191    }
192}
193
194async fn fetch_jwks(
195    http_client: &reqwest::Client,
196    jwks: &JwksOrJwksUri,
197) -> Result<PublicJsonWebKeySet, BoxError> {
198    let uri = match jwks {
199        JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
200        JwksOrJwksUri::JwksUri(u) => u,
201    };
202
203    let response = http_client
204        .get(uri.as_str())
205        .send_traced()
206        .await?
207        .error_for_status()?
208        .json()
209        .await?;
210
211    Ok(response)
212}
213
214#[derive(Debug, Error)]
215pub enum CredentialsVerificationError {
216    #[error("failed to decrypt client credentials")]
217    DecryptionError,
218
219    #[error("invalid client configuration")]
220    InvalidClientConfig,
221
222    #[error("client secret did not match")]
223    ClientSecretMismatch,
224
225    #[error("authentication method mismatch")]
226    AuthenticationMethodMismatch,
227
228    #[error("invalid assertion signature")]
229    InvalidAssertionSignature,
230
231    #[error("failed to fetch jwks")]
232    JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
233}
234
235impl CredentialsVerificationError {
236    /// Returns true if the error is an internal error, not caused by the client
237    #[must_use]
238    pub fn is_internal(&self) -> bool {
239        matches!(
240            self,
241            Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
242        )
243    }
244}
245
246#[derive(Debug, PartialEq, Eq)]
247pub struct ClientAuthorization<F = ()> {
248    pub credentials: Credentials,
249    pub form: Option<F>,
250}
251
252impl<F> ClientAuthorization<F> {
253    /// Get the `client_id` from the credentials.
254    #[must_use]
255    pub fn client_id(&self) -> Option<&str> {
256        self.credentials.client_id()
257    }
258}
259
260#[derive(Debug, Error)]
261pub enum ClientAuthorizationError {
262    #[error("Invalid Authorization header")]
263    InvalidHeader,
264
265    #[error("Could not deserialize request body")]
266    BadForm(#[source] FailedToDeserializeForm),
267
268    #[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
269    ClientIdMismatch { credential: String, form: String },
270
271    #[error("Unsupported client_assertion_type: {client_assertion_type}")]
272    UnsupportedClientAssertion { client_assertion_type: String },
273
274    #[error("No credentials were presented")]
275    MissingCredentials,
276
277    #[error("Invalid request")]
278    InvalidRequest,
279
280    #[error("Invalid client_assertion")]
281    InvalidAssertion,
282
283    #[error(transparent)]
284    Internal(Box<dyn std::error::Error>),
285}
286
287impl IntoResponse for ClientAuthorizationError {
288    fn into_response(self) -> axum::response::Response {
289        let sentry_event_id = record_error!(self, Self::Internal(_));
290        match &self {
291            ClientAuthorizationError::InvalidHeader => (
292                StatusCode::BAD_REQUEST,
293                sentry_event_id,
294                Json(ClientError::new(
295                    ClientErrorCode::InvalidRequest,
296                    "Invalid Authorization header",
297                )),
298            ),
299
300            ClientAuthorizationError::BadForm(err) => (
301                StatusCode::BAD_REQUEST,
302                sentry_event_id,
303                Json(
304                    ClientError::from(ClientErrorCode::InvalidRequest)
305                        .with_description(format!("{err}")),
306                ),
307            ),
308
309            ClientAuthorizationError::ClientIdMismatch { .. } => (
310                StatusCode::BAD_REQUEST,
311                sentry_event_id,
312                Json(
313                    ClientError::from(ClientErrorCode::InvalidGrant)
314                        .with_description(format!("{self}")),
315                ),
316            ),
317
318            ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
319                StatusCode::BAD_REQUEST,
320                sentry_event_id,
321                Json(
322                    ClientError::from(ClientErrorCode::InvalidRequest)
323                        .with_description(format!("{self}")),
324                ),
325            ),
326
327            ClientAuthorizationError::MissingCredentials => (
328                StatusCode::BAD_REQUEST,
329                sentry_event_id,
330                Json(ClientError::new(
331                    ClientErrorCode::InvalidRequest,
332                    "No credentials were presented",
333                )),
334            ),
335
336            ClientAuthorizationError::InvalidRequest => (
337                StatusCode::BAD_REQUEST,
338                sentry_event_id,
339                Json(ClientError::from(ClientErrorCode::InvalidRequest)),
340            ),
341
342            ClientAuthorizationError::InvalidAssertion => (
343                StatusCode::BAD_REQUEST,
344                sentry_event_id,
345                Json(ClientError::new(
346                    ClientErrorCode::InvalidRequest,
347                    "Invalid client_assertion",
348                )),
349            ),
350
351            ClientAuthorizationError::Internal(e) => (
352                StatusCode::INTERNAL_SERVER_ERROR,
353                sentry_event_id,
354                Json(
355                    ClientError::from(ClientErrorCode::ServerError)
356                        .with_description(format!("{e}")),
357                ),
358            ),
359        }
360        .into_response()
361    }
362}
363
364impl<S, F> FromRequest<S> for ClientAuthorization<F>
365where
366    F: DeserializeOwned,
367    S: Send + Sync,
368{
369    type Rejection = ClientAuthorizationError;
370
371    #[allow(clippy::too_many_lines)]
372    async fn from_request(
373        req: Request<axum::body::Body>,
374        state: &S,
375    ) -> Result<Self, Self::Rejection> {
376        enum Authorization {
377            Basic(String, String),
378            Bearer(String),
379        }
380
381        // Sadly, the typed-header 'Authorization' doesn't let us check for both
382        // Basic and Bearer at the same time, so we need to parse them manually
383        let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
384            let bytes = header.as_bytes();
385            if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
386                let Some(decoded) = Basic::decode(header) else {
387                    return Err(ClientAuthorizationError::InvalidHeader);
388                };
389
390                Some(Authorization::Basic(
391                    decoded.username().to_owned(),
392                    decoded.password().to_owned(),
393                ))
394            } else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
395                let Some(decoded) = Bearer::decode(header) else {
396                    return Err(ClientAuthorizationError::InvalidHeader);
397                };
398
399                Some(Authorization::Bearer(decoded.token().to_owned()))
400            } else {
401                return Err(ClientAuthorizationError::InvalidHeader);
402            }
403        } else {
404            None
405        };
406
407        // Take the form value
408        let (
409            client_id_from_form,
410            client_secret_from_form,
411            client_assertion_type,
412            client_assertion,
413            form,
414        ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
415            Ok(Form(form)) => (
416                form.client_id,
417                form.client_secret,
418                form.client_assertion_type,
419                form.client_assertion,
420                Some(form.inner),
421            ),
422            // If it is not a form, continue
423            Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
424            // If the form could not be read, return a Bad Request error
425            Err(FormRejection::FailedToDeserializeForm(err)) => {
426                return Err(ClientAuthorizationError::BadForm(err));
427            }
428            // Other errors (body read twice, byte stream broke) return an internal error
429            Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
430        };
431
432        // And now, figure out the actual auth method
433        let credentials = match (
434            authorization,
435            client_id_from_form,
436            client_secret_from_form,
437            client_assertion_type,
438            client_assertion,
439        ) {
440            (
441                Some(Authorization::Basic(client_id, client_secret)),
442                client_id_from_form,
443                None,
444                None,
445                None,
446            ) => {
447                if let Some(client_id_from_form) = client_id_from_form {
448                    // If the client_id was in the body, verify it matches with the header
449                    if client_id != client_id_from_form {
450                        return Err(ClientAuthorizationError::ClientIdMismatch {
451                            credential: client_id,
452                            form: client_id_from_form,
453                        });
454                    }
455                }
456
457                Credentials::ClientSecretBasic {
458                    client_id,
459                    client_secret,
460                }
461            }
462
463            (None, Some(client_id), Some(client_secret), None, None) => {
464                // Got both client_id and client_secret from the form
465                Credentials::ClientSecretPost {
466                    client_id,
467                    client_secret,
468                }
469            }
470
471            (None, Some(client_id), None, None, None) => {
472                // Only got a client_id in the form
473                Credentials::None { client_id }
474            }
475
476            (
477                None,
478                client_id_from_form,
479                None,
480                Some(client_assertion_type),
481                Some(client_assertion),
482            ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
483                // Got a JWT bearer client_assertion
484                let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
485                    .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
486
487                let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
488                    client_id.clone()
489                } else {
490                    return Err(ClientAuthorizationError::InvalidAssertion);
491                };
492
493                if let Some(client_id_from_form) = client_id_from_form {
494                    // If the client_id was in the body, verify it matches the one in the JWT
495                    if client_id != client_id_from_form {
496                        return Err(ClientAuthorizationError::ClientIdMismatch {
497                            credential: client_id,
498                            form: client_id_from_form,
499                        });
500                    }
501                }
502
503                Credentials::ClientAssertionJwtBearer {
504                    client_id,
505                    jwt: Box::new(jwt),
506                }
507            }
508
509            (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
510                // Got another unsupported client_assertion
511                return Err(ClientAuthorizationError::UnsupportedClientAssertion {
512                    client_assertion_type,
513                });
514            }
515
516            (Some(Authorization::Bearer(token)), None, None, None, None) => {
517                // Got a bearer token
518                Credentials::BearerToken { token }
519            }
520
521            (None, None, None, None, None) => {
522                // Special case when there are no credentials anywhere
523                return Err(ClientAuthorizationError::MissingCredentials);
524            }
525
526            _ => {
527                // Every other combination is an invalid request
528                return Err(ClientAuthorizationError::InvalidRequest);
529            }
530        };
531
532        Ok(ClientAuthorization { credentials, form })
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use axum::body::Body;
539    use http::{Method, Request};
540
541    use super::*;
542
543    #[tokio::test]
544    async fn none_test() {
545        let req = Request::builder()
546            .method(Method::POST)
547            .header(
548                http::header::CONTENT_TYPE,
549                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
550            )
551            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
552            .unwrap();
553
554        assert_eq!(
555            ClientAuthorization::<serde_json::Value>::from_request(req, &())
556                .await
557                .unwrap(),
558            ClientAuthorization {
559                credentials: Credentials::None {
560                    client_id: "client-id".to_owned(),
561                },
562                form: Some(serde_json::json!({"foo": "bar"})),
563            }
564        );
565    }
566
567    #[tokio::test]
568    async fn client_secret_basic_test() {
569        let req = Request::builder()
570            .method(Method::POST)
571            .header(
572                http::header::CONTENT_TYPE,
573                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
574            )
575            .header(
576                http::header::AUTHORIZATION,
577                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
578            )
579            .body(Body::new("foo=bar".to_owned()))
580            .unwrap();
581
582        assert_eq!(
583            ClientAuthorization::<serde_json::Value>::from_request(req, &())
584                .await
585                .unwrap(),
586            ClientAuthorization {
587                credentials: Credentials::ClientSecretBasic {
588                    client_id: "client-id".to_owned(),
589                    client_secret: "client-secret".to_owned(),
590                },
591                form: Some(serde_json::json!({"foo": "bar"})),
592            }
593        );
594
595        // client_id in both header and body
596        let req = Request::builder()
597            .method(Method::POST)
598            .header(
599                http::header::CONTENT_TYPE,
600                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
601            )
602            .header(
603                http::header::AUTHORIZATION,
604                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
605            )
606            .body(Body::new("client_id=client-id&foo=bar".to_owned()))
607            .unwrap();
608
609        assert_eq!(
610            ClientAuthorization::<serde_json::Value>::from_request(req, &())
611                .await
612                .unwrap(),
613            ClientAuthorization {
614                credentials: Credentials::ClientSecretBasic {
615                    client_id: "client-id".to_owned(),
616                    client_secret: "client-secret".to_owned(),
617                },
618                form: Some(serde_json::json!({"foo": "bar"})),
619            }
620        );
621
622        // client_id in both header and body mismatch
623        let req = Request::builder()
624            .method(Method::POST)
625            .header(
626                http::header::CONTENT_TYPE,
627                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
628            )
629            .header(
630                http::header::AUTHORIZATION,
631                "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
632            )
633            .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
634            .unwrap();
635
636        assert!(matches!(
637            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
638            Err(ClientAuthorizationError::ClientIdMismatch { .. }),
639        ));
640
641        // Invalid header
642        let req = Request::builder()
643            .method(Method::POST)
644            .header(
645                http::header::CONTENT_TYPE,
646                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
647            )
648            .header(http::header::AUTHORIZATION, "Basic invalid")
649            .body(Body::new("foo=bar".to_owned()))
650            .unwrap();
651
652        assert!(matches!(
653            ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
654            Err(ClientAuthorizationError::InvalidHeader),
655        ));
656    }
657
658    #[tokio::test]
659    async fn client_secret_post_test() {
660        let req = Request::builder()
661            .method(Method::POST)
662            .header(
663                http::header::CONTENT_TYPE,
664                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
665            )
666            .body(Body::new(
667                "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
668            ))
669            .unwrap();
670
671        assert_eq!(
672            ClientAuthorization::<serde_json::Value>::from_request(req, &())
673                .await
674                .unwrap(),
675            ClientAuthorization {
676                credentials: Credentials::ClientSecretPost {
677                    client_id: "client-id".to_owned(),
678                    client_secret: "client-secret".to_owned(),
679                },
680                form: Some(serde_json::json!({"foo": "bar"})),
681            }
682        );
683    }
684
685    #[tokio::test]
686    async fn client_assertion_test() {
687        // Signed with client_secret = "client-secret"
688        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
689        let body = Body::new(format!(
690            "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
691        ));
692
693        let req = Request::builder()
694            .method(Method::POST)
695            .header(
696                http::header::CONTENT_TYPE,
697                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
698            )
699            .body(body)
700            .unwrap();
701
702        let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
703            .await
704            .unwrap();
705        assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
706
707        let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
708            panic!("expected a JWT client_assertion");
709        };
710
711        assert_eq!(client_id, "client-id");
712        jwt.verify_with_shared_secret(b"client-secret".to_vec())
713            .unwrap();
714    }
715
716    #[tokio::test]
717    async fn bearer_token_test() {
718        let req = Request::builder()
719            .method(Method::POST)
720            .header(
721                http::header::CONTENT_TYPE,
722                mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
723            )
724            .header(http::header::AUTHORIZATION, "Bearer token")
725            .body(Body::new("foo=bar".to_owned()))
726            .unwrap();
727
728        assert_eq!(
729            ClientAuthorization::<serde_json::Value>::from_request(req, &())
730                .await
731                .unwrap(),
732            ClientAuthorization {
733                credentials: Credentials::BearerToken {
734                    token: "token".to_owned(),
735                },
736                form: Some(serde_json::json!({"foo": "bar"})),
737            }
738        );
739    }
740}