mas_handlers/admin/v1/user_registration_tokens/
add.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// SPDX-License-Identifier: AGPL-3.0-only
4// Please see LICENSE in the repository root for full details.
5
6use aide::{NoApi, OperationIo, transform::TransformOperation};
7use axum::{Json, response::IntoResponse};
8use chrono::{DateTime, Utc};
9use hyper::StatusCode;
10use mas_axum_utils::record_error;
11use mas_storage::BoxRng;
12use rand::distributions::{Alphanumeric, DistString};
13use schemars::JsonSchema;
14use serde::Deserialize;
15
16use crate::{
17    admin::{
18        call_context::CallContext,
19        model::UserRegistrationToken,
20        response::{ErrorResponse, SingleResponse},
21    },
22    impl_from_error_for_route,
23};
24
25#[derive(Debug, thiserror::Error, OperationIo)]
26#[aide(output_with = "Json<ErrorResponse>")]
27pub enum RouteError {
28    #[error("A registration token with the same token already exists")]
29    Conflict(mas_data_model::UserRegistrationToken),
30
31    #[error(transparent)]
32    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
33}
34
35impl_from_error_for_route!(mas_storage::RepositoryError);
36
37impl IntoResponse for RouteError {
38    fn into_response(self) -> axum::response::Response {
39        let error = ErrorResponse::from_error(&self);
40        let sentry_event_id = record_error!(self, Self::Internal(_));
41        let status = match self {
42            Self::Conflict(_) => StatusCode::CONFLICT,
43            Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
44        };
45        (status, sentry_event_id, Json(error)).into_response()
46    }
47}
48
49/// # JSON payload for the `POST /api/admin/v1/user-registration-tokens`
50#[derive(Deserialize, JsonSchema)]
51#[serde(rename = "AddUserRegistrationTokenRequest")]
52pub struct Request {
53    /// The token string. If not provided, a random token will be generated.
54    token: Option<String>,
55
56    /// Maximum number of times this token can be used. If not provided, the
57    /// token can be used an unlimited number of times.
58    usage_limit: Option<u32>,
59
60    /// When the token expires. If not provided, the token never expires.
61    expires_at: Option<DateTime<Utc>>,
62}
63
64pub fn doc(operation: TransformOperation) -> TransformOperation {
65    operation
66        .id("addUserRegistrationToken")
67        .summary("Create a new user registration token")
68        .tag("user-registration-token")
69        .response_with::<201, Json<SingleResponse<UserRegistrationToken>>, _>(|t| {
70            let [sample, ..] = UserRegistrationToken::samples();
71            let response = SingleResponse::new_canonical(sample);
72            t.description("A new user registration token was created")
73                .example(response)
74        })
75}
76
77#[tracing::instrument(name = "handler.admin.v1.user_registration_tokens.post", skip_all)]
78pub async fn handler(
79    CallContext {
80        mut repo, clock, ..
81    }: CallContext,
82    NoApi(mut rng): NoApi<BoxRng>,
83    Json(params): Json<Request>,
84) -> Result<(StatusCode, Json<SingleResponse<UserRegistrationToken>>), RouteError> {
85    // Generate a random token if none was provided
86    let token = params
87        .token
88        .unwrap_or_else(|| Alphanumeric.sample_string(&mut rng, 12));
89
90    // See if we have an existing token with the same token
91    let existing_token = repo.user_registration_token().find_by_token(&token).await?;
92    if let Some(existing_token) = existing_token {
93        return Err(RouteError::Conflict(existing_token));
94    }
95
96    let registration_token = repo
97        .user_registration_token()
98        .add(
99            &mut rng,
100            &clock,
101            token,
102            params.usage_limit,
103            params.expires_at,
104        )
105        .await?;
106
107    repo.save().await?;
108
109    Ok((
110        StatusCode::CREATED,
111        Json(SingleResponse::new_canonical(UserRegistrationToken::new(
112            registration_token,
113            clock.now(),
114        ))),
115    ))
116}
117
118#[cfg(test)]
119mod tests {
120    use hyper::{Request, StatusCode};
121    use insta::assert_json_snapshot;
122    use sqlx::PgPool;
123
124    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
125
126    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
127    async fn test_create(pool: PgPool) {
128        setup();
129        let mut state = TestState::from_pool(pool).await.unwrap();
130        let token = state.token_with_scope("urn:mas:admin").await;
131
132        let request = Request::post("/api/admin/v1/user-registration-tokens")
133            .bearer(&token)
134            .json(serde_json::json!({
135                "token": "test_token_123",
136                "usage_limit": 5,
137            }));
138        let response = state.request(request).await;
139        response.assert_status(StatusCode::CREATED);
140        let body: serde_json::Value = response.json();
141
142        assert_json_snapshot!(body, @r#"
143        {
144          "data": {
145            "type": "user-registration_token",
146            "id": "01FSHN9AG0MZAA6S4AF7CTV32E",
147            "attributes": {
148              "token": "test_token_123",
149              "valid": true,
150              "usage_limit": 5,
151              "times_used": 0,
152              "created_at": "2022-01-16T14:40:00Z",
153              "last_used_at": null,
154              "expires_at": null,
155              "revoked_at": null
156            },
157            "links": {
158              "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
159            }
160          },
161          "links": {
162            "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
163          }
164        }
165        "#);
166    }
167
168    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
169    async fn test_create_auto_token(pool: PgPool) {
170        setup();
171        let mut state = TestState::from_pool(pool).await.unwrap();
172        let token = state.token_with_scope("urn:mas:admin").await;
173
174        let request = Request::post("/api/admin/v1/user-registration-tokens")
175            .bearer(&token)
176            .json(serde_json::json!({
177                "usage_limit": 1
178            }));
179        let response = state.request(request).await;
180        response.assert_status(StatusCode::CREATED);
181
182        let body: serde_json::Value = response.json();
183
184        assert_json_snapshot!(body, @r#"
185        {
186          "data": {
187            "type": "user-registration_token",
188            "id": "01FSHN9AG0QMGC989M0XSFVF2X",
189            "attributes": {
190              "token": "42oTpLoieH5I",
191              "valid": true,
192              "usage_limit": 1,
193              "times_used": 0,
194              "created_at": "2022-01-16T14:40:00Z",
195              "last_used_at": null,
196              "expires_at": null,
197              "revoked_at": null
198            },
199            "links": {
200              "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0QMGC989M0XSFVF2X"
201            }
202          },
203          "links": {
204            "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0QMGC989M0XSFVF2X"
205          }
206        }
207        "#);
208    }
209
210    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
211    async fn test_create_conflict(pool: PgPool) {
212        setup();
213        let mut state = TestState::from_pool(pool).await.unwrap();
214        let token = state.token_with_scope("urn:mas:admin").await;
215
216        let request = Request::post("/api/admin/v1/user-registration-tokens")
217            .bearer(&token)
218            .json(serde_json::json!({
219                "token": "test_token_123",
220                "usage_limit": 5
221            }));
222        let response = state.request(request).await;
223        response.assert_status(StatusCode::CREATED);
224
225        let body: serde_json::Value = response.json();
226
227        assert_json_snapshot!(body, @r#"
228        {
229          "data": {
230            "type": "user-registration_token",
231            "id": "01FSHN9AG0MZAA6S4AF7CTV32E",
232            "attributes": {
233              "token": "test_token_123",
234              "valid": true,
235              "usage_limit": 5,
236              "times_used": 0,
237              "created_at": "2022-01-16T14:40:00Z",
238              "last_used_at": null,
239              "expires_at": null,
240              "revoked_at": null
241            },
242            "links": {
243              "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
244            }
245          },
246          "links": {
247            "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
248          }
249        }
250        "#);
251
252        let request = Request::post("/api/admin/v1/user-registration-tokens")
253            .bearer(&token)
254            .json(serde_json::json!({
255                "token": "test_token_123",
256                "usage_limit": 5
257            }));
258        let response = state.request(request).await;
259        response.assert_status(StatusCode::CONFLICT);
260    }
261}