mas_storage_pg/upstream_oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    Clock, UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
11    UpstreamOAuthLink, UpstreamOAuthProvider,
12};
13use mas_storage::{
14    Page, Pagination,
15    pagination::Node,
16    upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
17};
18use rand::RngCore;
19use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
20use sea_query_binder::SqlxBinder;
21use sqlx::PgConnection;
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26    DatabaseError, DatabaseInconsistencyError,
27    filter::{Filter, StatementExt},
28    iden::UpstreamOAuthAuthorizationSessions,
29    pagination::QueryBuilderExt,
30    tracing::ExecuteExt,
31};
32
33impl Filter for UpstreamOAuthSessionFilter<'_> {
34    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
35        sea_query::Condition::all()
36            .add_option(self.provider().map(|provider| {
37                Expr::col((
38                    UpstreamOAuthAuthorizationSessions::Table,
39                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
40                ))
41                .eq(Uuid::from(provider.id))
42            }))
43            .add_option(self.sub_claim().map(|sub| {
44                Expr::col((
45                    UpstreamOAuthAuthorizationSessions::Table,
46                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
47                ))
48                .cast_json_field("sub")
49                .eq(sub)
50            }))
51            .add_option(self.sid_claim().map(|sid| {
52                Expr::col((
53                    UpstreamOAuthAuthorizationSessions::Table,
54                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
55                ))
56                .cast_json_field("sid")
57                .eq(sid)
58            }))
59    }
60}
61
62/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL
63/// connection
64pub struct PgUpstreamOAuthSessionRepository<'c> {
65    conn: &'c mut PgConnection,
66}
67
68impl<'c> PgUpstreamOAuthSessionRepository<'c> {
69    /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active
70    /// PostgreSQL connection
71    pub fn new(conn: &'c mut PgConnection) -> Self {
72        Self { conn }
73    }
74}
75
76#[derive(sqlx::FromRow)]
77#[enum_def]
78struct SessionLookup {
79    upstream_oauth_authorization_session_id: Uuid,
80    upstream_oauth_provider_id: Uuid,
81    upstream_oauth_link_id: Option<Uuid>,
82    state: String,
83    code_challenge_verifier: Option<String>,
84    nonce: Option<String>,
85    id_token: Option<String>,
86    id_token_claims: Option<serde_json::Value>,
87    userinfo: Option<serde_json::Value>,
88    created_at: DateTime<Utc>,
89    completed_at: Option<DateTime<Utc>>,
90    consumed_at: Option<DateTime<Utc>>,
91    extra_callback_parameters: Option<serde_json::Value>,
92    unlinked_at: Option<DateTime<Utc>>,
93}
94
95impl Node<Ulid> for SessionLookup {
96    fn cursor(&self) -> Ulid {
97        self.upstream_oauth_authorization_session_id.into()
98    }
99}
100
101impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
102    type Error = DatabaseInconsistencyError;
103
104    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
105        let id = value.upstream_oauth_authorization_session_id.into();
106        let state = match (
107            value.upstream_oauth_link_id,
108            value.id_token,
109            value.id_token_claims,
110            value.extra_callback_parameters,
111            value.userinfo,
112            value.completed_at,
113            value.consumed_at,
114            value.unlinked_at,
115        ) {
116            (None, None, None, None, None, None, None, None) => {
117                UpstreamOAuthAuthorizationSessionState::Pending
118            }
119            (
120                Some(link_id),
121                id_token,
122                id_token_claims,
123                extra_callback_parameters,
124                userinfo,
125                Some(completed_at),
126                None,
127                None,
128            ) => UpstreamOAuthAuthorizationSessionState::Completed {
129                completed_at,
130                link_id: link_id.into(),
131                id_token,
132                id_token_claims,
133                extra_callback_parameters,
134                userinfo,
135            },
136            (
137                Some(link_id),
138                id_token,
139                id_token_claims,
140                extra_callback_parameters,
141                userinfo,
142                Some(completed_at),
143                Some(consumed_at),
144                None,
145            ) => UpstreamOAuthAuthorizationSessionState::Consumed {
146                completed_at,
147                link_id: link_id.into(),
148                id_token,
149                id_token_claims,
150                extra_callback_parameters,
151                userinfo,
152                consumed_at,
153            },
154            (
155                _,
156                id_token,
157                id_token_claims,
158                _,
159                _,
160                Some(completed_at),
161                consumed_at,
162                Some(unlinked_at),
163            ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
164                completed_at,
165                id_token,
166                id_token_claims,
167                consumed_at,
168                unlinked_at,
169            },
170            _ => {
171                return Err(DatabaseInconsistencyError::on(
172                    "upstream_oauth_authorization_sessions",
173                )
174                .row(id));
175            }
176        };
177
178        Ok(Self {
179            id,
180            provider_id: value.upstream_oauth_provider_id.into(),
181            state_str: value.state,
182            nonce: value.nonce,
183            code_challenge_verifier: value.code_challenge_verifier,
184            created_at: value.created_at,
185            state,
186        })
187    }
188}
189
190#[async_trait]
191impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
192    type Error = DatabaseError;
193
194    #[tracing::instrument(
195        name = "db.upstream_oauth_authorization_session.lookup",
196        skip_all,
197        fields(
198            db.query.text,
199            upstream_oauth_provider.id = %id,
200        ),
201        err,
202    )]
203    async fn lookup(
204        &mut self,
205        id: Ulid,
206    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
207        let res = sqlx::query_as!(
208            SessionLookup,
209            r#"
210                SELECT
211                    upstream_oauth_authorization_session_id,
212                    upstream_oauth_provider_id,
213                    upstream_oauth_link_id,
214                    state,
215                    code_challenge_verifier,
216                    nonce,
217                    id_token,
218                    id_token_claims,
219                    extra_callback_parameters,
220                    userinfo,
221                    created_at,
222                    completed_at,
223                    consumed_at,
224                    unlinked_at
225                FROM upstream_oauth_authorization_sessions
226                WHERE upstream_oauth_authorization_session_id = $1
227            "#,
228            Uuid::from(id),
229        )
230        .traced()
231        .fetch_optional(&mut *self.conn)
232        .await?;
233
234        let Some(res) = res else { return Ok(None) };
235
236        Ok(Some(res.try_into()?))
237    }
238
239    #[tracing::instrument(
240        name = "db.upstream_oauth_authorization_session.add",
241        skip_all,
242        fields(
243            db.query.text,
244            %upstream_oauth_provider.id,
245            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
246            %upstream_oauth_provider.client_id,
247            upstream_oauth_authorization_session.id,
248        ),
249        err,
250    )]
251    async fn add(
252        &mut self,
253        rng: &mut (dyn RngCore + Send),
254        clock: &dyn Clock,
255        upstream_oauth_provider: &UpstreamOAuthProvider,
256        state_str: String,
257        code_challenge_verifier: Option<String>,
258        nonce: Option<String>,
259    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
260        let created_at = clock.now();
261        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
262        tracing::Span::current().record(
263            "upstream_oauth_authorization_session.id",
264            tracing::field::display(id),
265        );
266
267        sqlx::query!(
268            r#"
269                INSERT INTO upstream_oauth_authorization_sessions (
270                    upstream_oauth_authorization_session_id,
271                    upstream_oauth_provider_id,
272                    state,
273                    code_challenge_verifier,
274                    nonce,
275                    created_at,
276                    completed_at,
277                    consumed_at,
278                    id_token,
279                    userinfo
280                ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
281            "#,
282            Uuid::from(id),
283            Uuid::from(upstream_oauth_provider.id),
284            &state_str,
285            code_challenge_verifier.as_deref(),
286            nonce,
287            created_at,
288        )
289        .traced()
290        .execute(&mut *self.conn)
291        .await?;
292
293        Ok(UpstreamOAuthAuthorizationSession {
294            id,
295            state: UpstreamOAuthAuthorizationSessionState::default(),
296            provider_id: upstream_oauth_provider.id,
297            state_str,
298            code_challenge_verifier,
299            nonce,
300            created_at,
301        })
302    }
303
304    #[tracing::instrument(
305        name = "db.upstream_oauth_authorization_session.complete_with_link",
306        skip_all,
307        fields(
308            db.query.text,
309            %upstream_oauth_authorization_session.id,
310            %upstream_oauth_link.id,
311        ),
312        err,
313    )]
314    async fn complete_with_link(
315        &mut self,
316        clock: &dyn Clock,
317        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
318        upstream_oauth_link: &UpstreamOAuthLink,
319        id_token: Option<String>,
320        id_token_claims: Option<serde_json::Value>,
321        extra_callback_parameters: Option<serde_json::Value>,
322        userinfo: Option<serde_json::Value>,
323    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
324        let completed_at = clock.now();
325
326        sqlx::query!(
327            r#"
328                UPDATE upstream_oauth_authorization_sessions
329                SET upstream_oauth_link_id = $1
330                  , completed_at = $2
331                  , id_token = $3
332                  , id_token_claims = $4
333                  , extra_callback_parameters = $5
334                  , userinfo = $6
335                WHERE upstream_oauth_authorization_session_id = $7
336            "#,
337            Uuid::from(upstream_oauth_link.id),
338            completed_at,
339            id_token,
340            id_token_claims,
341            extra_callback_parameters,
342            userinfo,
343            Uuid::from(upstream_oauth_authorization_session.id),
344        )
345        .traced()
346        .execute(&mut *self.conn)
347        .await?;
348
349        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
350            .complete(
351                completed_at,
352                upstream_oauth_link,
353                id_token,
354                id_token_claims,
355                extra_callback_parameters,
356                userinfo,
357            )
358            .map_err(DatabaseError::to_invalid_operation)?;
359
360        Ok(upstream_oauth_authorization_session)
361    }
362
363    /// Mark a session as consumed
364    #[tracing::instrument(
365        name = "db.upstream_oauth_authorization_session.consume",
366        skip_all,
367        fields(
368            db.query.text,
369            %upstream_oauth_authorization_session.id,
370        ),
371        err,
372    )]
373    async fn consume(
374        &mut self,
375        clock: &dyn Clock,
376        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
377    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
378        let consumed_at = clock.now();
379        sqlx::query!(
380            r#"
381                UPDATE upstream_oauth_authorization_sessions
382                SET consumed_at = $1
383                WHERE upstream_oauth_authorization_session_id = $2
384            "#,
385            consumed_at,
386            Uuid::from(upstream_oauth_authorization_session.id),
387        )
388        .traced()
389        .execute(&mut *self.conn)
390        .await?;
391
392        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
393            .consume(consumed_at)
394            .map_err(DatabaseError::to_invalid_operation)?;
395
396        Ok(upstream_oauth_authorization_session)
397    }
398
399    #[tracing::instrument(
400        name = "db.upstream_oauth_authorization_session.list",
401        skip_all,
402        fields(
403            db.query.text,
404        ),
405        err,
406    )]
407    async fn list(
408        &mut self,
409        filter: UpstreamOAuthSessionFilter<'_>,
410        pagination: Pagination,
411    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
412        let (sql, arguments) = Query::select()
413            .expr_as(
414                Expr::col((
415                    UpstreamOAuthAuthorizationSessions::Table,
416                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
417                )),
418                SessionLookupIden::UpstreamOauthAuthorizationSessionId,
419            )
420            .expr_as(
421                Expr::col((
422                    UpstreamOAuthAuthorizationSessions::Table,
423                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
424                )),
425                SessionLookupIden::UpstreamOauthProviderId,
426            )
427            .expr_as(
428                Expr::col((
429                    UpstreamOAuthAuthorizationSessions::Table,
430                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
431                )),
432                SessionLookupIden::UpstreamOauthLinkId,
433            )
434            .expr_as(
435                Expr::col((
436                    UpstreamOAuthAuthorizationSessions::Table,
437                    UpstreamOAuthAuthorizationSessions::State,
438                )),
439                SessionLookupIden::State,
440            )
441            .expr_as(
442                Expr::col((
443                    UpstreamOAuthAuthorizationSessions::Table,
444                    UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
445                )),
446                SessionLookupIden::CodeChallengeVerifier,
447            )
448            .expr_as(
449                Expr::col((
450                    UpstreamOAuthAuthorizationSessions::Table,
451                    UpstreamOAuthAuthorizationSessions::Nonce,
452                )),
453                SessionLookupIden::Nonce,
454            )
455            .expr_as(
456                Expr::col((
457                    UpstreamOAuthAuthorizationSessions::Table,
458                    UpstreamOAuthAuthorizationSessions::IdToken,
459                )),
460                SessionLookupIden::IdToken,
461            )
462            .expr_as(
463                Expr::col((
464                    UpstreamOAuthAuthorizationSessions::Table,
465                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
466                )),
467                SessionLookupIden::IdTokenClaims,
468            )
469            .expr_as(
470                Expr::col((
471                    UpstreamOAuthAuthorizationSessions::Table,
472                    UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
473                )),
474                SessionLookupIden::ExtraCallbackParameters,
475            )
476            .expr_as(
477                Expr::col((
478                    UpstreamOAuthAuthorizationSessions::Table,
479                    UpstreamOAuthAuthorizationSessions::Userinfo,
480                )),
481                SessionLookupIden::Userinfo,
482            )
483            .expr_as(
484                Expr::col((
485                    UpstreamOAuthAuthorizationSessions::Table,
486                    UpstreamOAuthAuthorizationSessions::CreatedAt,
487                )),
488                SessionLookupIden::CreatedAt,
489            )
490            .expr_as(
491                Expr::col((
492                    UpstreamOAuthAuthorizationSessions::Table,
493                    UpstreamOAuthAuthorizationSessions::CompletedAt,
494                )),
495                SessionLookupIden::CompletedAt,
496            )
497            .expr_as(
498                Expr::col((
499                    UpstreamOAuthAuthorizationSessions::Table,
500                    UpstreamOAuthAuthorizationSessions::ConsumedAt,
501                )),
502                SessionLookupIden::ConsumedAt,
503            )
504            .expr_as(
505                Expr::col((
506                    UpstreamOAuthAuthorizationSessions::Table,
507                    UpstreamOAuthAuthorizationSessions::UnlinkedAt,
508                )),
509                SessionLookupIden::UnlinkedAt,
510            )
511            .from(UpstreamOAuthAuthorizationSessions::Table)
512            .apply_filter(filter)
513            .generate_pagination(
514                (
515                    UpstreamOAuthAuthorizationSessions::Table,
516                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
517                ),
518                pagination,
519            )
520            .build_sqlx(PostgresQueryBuilder);
521
522        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
523            .traced()
524            .fetch_all(&mut *self.conn)
525            .await?;
526
527        let page = pagination
528            .process(edges)
529            .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
530
531        Ok(page)
532    }
533
534    #[tracing::instrument(
535        name = "db.upstream_oauth_authorization_session.count",
536        skip_all,
537        fields(
538            db.query.text,
539        ),
540        err,
541    )]
542    async fn count(
543        &mut self,
544        filter: UpstreamOAuthSessionFilter<'_>,
545    ) -> Result<usize, Self::Error> {
546        let (sql, arguments) = Query::select()
547            .expr(
548                Expr::col((
549                    UpstreamOAuthAuthorizationSessions::Table,
550                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
551                ))
552                .count(),
553            )
554            .from(UpstreamOAuthAuthorizationSessions::Table)
555            .apply_filter(filter)
556            .build_sqlx(PostgresQueryBuilder);
557
558        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
559            .traced()
560            .fetch_one(&mut *self.conn)
561            .await?;
562
563        count
564            .try_into()
565            .map_err(DatabaseError::to_invalid_operation)
566    }
567}