mas_storage_pg/oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Clock, Session, SessionState, User};
12use mas_storage::{
13    Page, Pagination,
14    oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15    pagination::Node,
16};
17use oauth2_types::scope::{Scope, ScopeToken};
18use rand::RngCore;
19use sea_query::{
20    Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
21    extension::postgres::PgExpr,
22};
23use sea_query_binder::SqlxBinder;
24use sqlx::PgConnection;
25use ulid::Ulid;
26use uuid::Uuid;
27
28use crate::{
29    DatabaseError, DatabaseInconsistencyError,
30    filter::{Filter, StatementExt},
31    iden::{OAuth2Clients, OAuth2Sessions, UserSessions},
32    pagination::QueryBuilderExt,
33    tracing::ExecuteExt,
34};
35
36/// An implementation of [`OAuth2SessionRepository`] for a PostgreSQL connection
37pub struct PgOAuth2SessionRepository<'c> {
38    conn: &'c mut PgConnection,
39}
40
41impl<'c> PgOAuth2SessionRepository<'c> {
42    /// Create a new [`PgOAuth2SessionRepository`] from an active PostgreSQL
43    /// connection
44    pub fn new(conn: &'c mut PgConnection) -> Self {
45        Self { conn }
46    }
47}
48
49#[derive(sqlx::FromRow)]
50#[enum_def]
51struct OAuthSessionLookup {
52    oauth2_session_id: Uuid,
53    user_id: Option<Uuid>,
54    user_session_id: Option<Uuid>,
55    oauth2_client_id: Uuid,
56    scope_list: Vec<String>,
57    created_at: DateTime<Utc>,
58    finished_at: Option<DateTime<Utc>>,
59    user_agent: Option<String>,
60    last_active_at: Option<DateTime<Utc>>,
61    last_active_ip: Option<IpAddr>,
62    human_name: Option<String>,
63}
64
65impl Node<Ulid> for OAuthSessionLookup {
66    fn cursor(&self) -> Ulid {
67        self.oauth2_session_id.into()
68    }
69}
70
71impl TryFrom<OAuthSessionLookup> for Session {
72    type Error = DatabaseInconsistencyError;
73
74    fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
75        let id = Ulid::from(value.oauth2_session_id);
76        let scope: Result<Scope, _> = value
77            .scope_list
78            .iter()
79            .map(|s| s.parse::<ScopeToken>())
80            .collect();
81        let scope = scope.map_err(|e| {
82            DatabaseInconsistencyError::on("oauth2_sessions")
83                .column("scope")
84                .row(id)
85                .source(e)
86        })?;
87
88        let state = match value.finished_at {
89            None => SessionState::Valid,
90            Some(finished_at) => SessionState::Finished { finished_at },
91        };
92
93        Ok(Session {
94            id,
95            state,
96            created_at: value.created_at,
97            client_id: value.oauth2_client_id.into(),
98            user_id: value.user_id.map(Ulid::from),
99            user_session_id: value.user_session_id.map(Ulid::from),
100            scope,
101            user_agent: value.user_agent,
102            last_active_at: value.last_active_at,
103            last_active_ip: value.last_active_ip,
104            human_name: value.human_name,
105        })
106    }
107}
108
109impl Filter for OAuth2SessionFilter<'_> {
110    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
111        sea_query::Condition::all()
112            .add_option(self.user().map(|user| {
113                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
114            }))
115            .add_option(self.client().map(|client| {
116                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
117                    .eq(Uuid::from(client.id))
118            }))
119            .add_option(self.client_kind().map(|client_kind| {
120                // This builds either a:
121                // `WHERE oauth2_client_id = ANY(...)`
122                // or a `WHERE oauth2_client_id <> ALL(...)`
123                let static_clients = Query::select()
124                    .expr(Expr::col((
125                        OAuth2Clients::Table,
126                        OAuth2Clients::OAuth2ClientId,
127                    )))
128                    .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
129                    .from(OAuth2Clients::Table)
130                    .take();
131                if client_kind.is_static() {
132                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
133                        .eq(Expr::any(static_clients))
134                } else {
135                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
136                        .ne(Expr::all(static_clients))
137                }
138            }))
139            .add_option(self.device().map(|device| -> SimpleExpr {
140                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
141                    Condition::any()
142                        .add(
143                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
144                                OAuth2Sessions::Table,
145                                OAuth2Sessions::ScopeList,
146                            )))),
147                        )
148                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
149                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
150                        )))
151                        .into()
152                } else {
153                    // If the device ID can't be encoded as a scope token, match no rows
154                    Expr::val(false).into()
155                }
156            }))
157            .add_option(self.browser_session().map(|browser_session| {
158                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
159                    .eq(Uuid::from(browser_session.id))
160            }))
161            .add_option(self.browser_session_filter().map(|browser_session_filter| {
162                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)).in_subquery(
163                    Query::select()
164                        .expr(Expr::col((
165                            UserSessions::Table,
166                            UserSessions::UserSessionId,
167                        )))
168                        .apply_filter(browser_session_filter)
169                        .from(UserSessions::Table)
170                        .take(),
171                )
172            }))
173            .add_option(self.state().map(|state| {
174                if state.is_active() {
175                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
176                } else {
177                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
178                }
179            }))
180            .add_option(self.scope().map(|scope| {
181                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
182                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
183            }))
184            .add_option(self.any_user().map(|any_user| {
185                if any_user {
186                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
187                } else {
188                    Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
189                }
190            }))
191            .add_option(self.last_active_after().map(|last_active_after| {
192                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
193                    .gt(last_active_after)
194            }))
195            .add_option(self.last_active_before().map(|last_active_before| {
196                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
197                    .lt(last_active_before)
198            }))
199    }
200}
201
202#[async_trait]
203impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
204    type Error = DatabaseError;
205
206    #[tracing::instrument(
207        name = "db.oauth2_session.lookup",
208        skip_all,
209        fields(
210            db.query.text,
211            session.id = %id,
212        ),
213        err,
214    )]
215    async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
216        let res = sqlx::query_as!(
217            OAuthSessionLookup,
218            r#"
219                SELECT oauth2_session_id
220                     , user_id
221                     , user_session_id
222                     , oauth2_client_id
223                     , scope_list
224                     , created_at
225                     , finished_at
226                     , user_agent
227                     , last_active_at
228                     , last_active_ip as "last_active_ip: IpAddr"
229                     , human_name
230                FROM oauth2_sessions
231
232                WHERE oauth2_session_id = $1
233            "#,
234            Uuid::from(id),
235        )
236        .traced()
237        .fetch_optional(&mut *self.conn)
238        .await?;
239
240        let Some(session) = res else { return Ok(None) };
241
242        Ok(Some(session.try_into()?))
243    }
244
245    #[tracing::instrument(
246        name = "db.oauth2_session.add",
247        skip_all,
248        fields(
249            db.query.text,
250            %client.id,
251            session.id,
252            session.scope = %scope,
253        ),
254        err,
255    )]
256    async fn add(
257        &mut self,
258        rng: &mut (dyn RngCore + Send),
259        clock: &dyn Clock,
260        client: &Client,
261        user: Option<&User>,
262        user_session: Option<&BrowserSession>,
263        scope: Scope,
264    ) -> Result<Session, Self::Error> {
265        let created_at = clock.now();
266        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
267        tracing::Span::current().record("session.id", tracing::field::display(id));
268
269        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
270
271        sqlx::query!(
272            r#"
273                INSERT INTO oauth2_sessions
274                    ( oauth2_session_id
275                    , user_id
276                    , user_session_id
277                    , oauth2_client_id
278                    , scope_list
279                    , created_at
280                    )
281                VALUES ($1, $2, $3, $4, $5, $6)
282            "#,
283            Uuid::from(id),
284            user.map(|u| Uuid::from(u.id)),
285            user_session.map(|s| Uuid::from(s.id)),
286            Uuid::from(client.id),
287            &scope_list,
288            created_at,
289        )
290        .traced()
291        .execute(&mut *self.conn)
292        .await?;
293
294        Ok(Session {
295            id,
296            state: SessionState::Valid,
297            created_at,
298            user_id: user.map(|u| u.id),
299            user_session_id: user_session.map(|s| s.id),
300            client_id: client.id,
301            scope,
302            user_agent: None,
303            last_active_at: None,
304            last_active_ip: None,
305            human_name: None,
306        })
307    }
308
309    #[tracing::instrument(
310        name = "db.oauth2_session.finish_bulk",
311        skip_all,
312        fields(
313            db.query.text,
314        ),
315        err,
316    )]
317    async fn finish_bulk(
318        &mut self,
319        clock: &dyn Clock,
320        filter: OAuth2SessionFilter<'_>,
321    ) -> Result<usize, Self::Error> {
322        let finished_at = clock.now();
323        let (sql, arguments) = Query::update()
324            .table(OAuth2Sessions::Table)
325            .value(OAuth2Sessions::FinishedAt, finished_at)
326            .apply_filter(filter)
327            .build_sqlx(PostgresQueryBuilder);
328
329        let res = sqlx::query_with(&sql, arguments)
330            .traced()
331            .execute(&mut *self.conn)
332            .await?;
333
334        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
335    }
336
337    #[tracing::instrument(
338        name = "db.oauth2_session.finish",
339        skip_all,
340        fields(
341            db.query.text,
342            %session.id,
343            %session.scope,
344            client.id = %session.client_id,
345        ),
346        err,
347    )]
348    async fn finish(
349        &mut self,
350        clock: &dyn Clock,
351        session: Session,
352    ) -> Result<Session, Self::Error> {
353        let finished_at = clock.now();
354        let res = sqlx::query!(
355            r#"
356                UPDATE oauth2_sessions
357                SET finished_at = $2
358                WHERE oauth2_session_id = $1
359            "#,
360            Uuid::from(session.id),
361            finished_at,
362        )
363        .traced()
364        .execute(&mut *self.conn)
365        .await?;
366
367        DatabaseError::ensure_affected_rows(&res, 1)?;
368
369        session
370            .finish(finished_at)
371            .map_err(DatabaseError::to_invalid_operation)
372    }
373
374    #[tracing::instrument(
375        name = "db.oauth2_session.list",
376        skip_all,
377        fields(
378            db.query.text,
379        ),
380        err,
381    )]
382    async fn list(
383        &mut self,
384        filter: OAuth2SessionFilter<'_>,
385        pagination: Pagination,
386    ) -> Result<Page<Session>, Self::Error> {
387        let (sql, arguments) = Query::select()
388            .expr_as(
389                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
390                OAuthSessionLookupIden::Oauth2SessionId,
391            )
392            .expr_as(
393                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
394                OAuthSessionLookupIden::UserId,
395            )
396            .expr_as(
397                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
398                OAuthSessionLookupIden::UserSessionId,
399            )
400            .expr_as(
401                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
402                OAuthSessionLookupIden::Oauth2ClientId,
403            )
404            .expr_as(
405                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
406                OAuthSessionLookupIden::ScopeList,
407            )
408            .expr_as(
409                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
410                OAuthSessionLookupIden::CreatedAt,
411            )
412            .expr_as(
413                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
414                OAuthSessionLookupIden::FinishedAt,
415            )
416            .expr_as(
417                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
418                OAuthSessionLookupIden::UserAgent,
419            )
420            .expr_as(
421                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
422                OAuthSessionLookupIden::LastActiveAt,
423            )
424            .expr_as(
425                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
426                OAuthSessionLookupIden::LastActiveIp,
427            )
428            .expr_as(
429                Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
430                OAuthSessionLookupIden::HumanName,
431            )
432            .from(OAuth2Sessions::Table)
433            .apply_filter(filter)
434            .generate_pagination(
435                (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
436                pagination,
437            )
438            .build_sqlx(PostgresQueryBuilder);
439
440        let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
441            .traced()
442            .fetch_all(&mut *self.conn)
443            .await?;
444
445        let page = pagination.process(edges).try_map(Session::try_from)?;
446
447        Ok(page)
448    }
449
450    #[tracing::instrument(
451        name = "db.oauth2_session.count",
452        skip_all,
453        fields(
454            db.query.text,
455        ),
456        err,
457    )]
458    async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
459        let (sql, arguments) = Query::select()
460            .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
461            .from(OAuth2Sessions::Table)
462            .apply_filter(filter)
463            .build_sqlx(PostgresQueryBuilder);
464
465        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
466            .traced()
467            .fetch_one(&mut *self.conn)
468            .await?;
469
470        count
471            .try_into()
472            .map_err(DatabaseError::to_invalid_operation)
473    }
474
475    #[tracing::instrument(
476        name = "db.oauth2_session.record_batch_activity",
477        skip_all,
478        fields(
479            db.query.text,
480        ),
481        err,
482    )]
483    async fn record_batch_activity(
484        &mut self,
485        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
486    ) -> Result<(), Self::Error> {
487        // Sort the activity by ID, so that when batching the updates, Postgres
488        // locks the rows in a stable order, preventing deadlocks
489        activities.sort_unstable();
490        let mut ids = Vec::with_capacity(activities.len());
491        let mut last_activities = Vec::with_capacity(activities.len());
492        let mut ips = Vec::with_capacity(activities.len());
493
494        for (id, last_activity, ip) in activities {
495            ids.push(Uuid::from(id));
496            last_activities.push(last_activity);
497            ips.push(ip);
498        }
499
500        let res = sqlx::query!(
501            r#"
502                UPDATE oauth2_sessions
503                SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
504                  , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
505                FROM (
506                    SELECT *
507                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
508                        AS t(oauth2_session_id, last_active_at, last_active_ip)
509                ) AS t
510                WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
511            "#,
512            &ids,
513            &last_activities,
514            &ips as &[Option<IpAddr>],
515        )
516        .traced()
517        .execute(&mut *self.conn)
518        .await?;
519
520        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
521
522        Ok(())
523    }
524
525    #[tracing::instrument(
526        name = "db.oauth2_session.record_user_agent",
527        skip_all,
528        fields(
529            db.query.text,
530            %session.id,
531            %session.scope,
532            client.id = %session.client_id,
533            session.user_agent = user_agent,
534        ),
535        err,
536    )]
537    async fn record_user_agent(
538        &mut self,
539        mut session: Session,
540        user_agent: String,
541    ) -> Result<Session, Self::Error> {
542        let res = sqlx::query!(
543            r#"
544                UPDATE oauth2_sessions
545                SET user_agent = $2
546                WHERE oauth2_session_id = $1
547            "#,
548            Uuid::from(session.id),
549            &*user_agent,
550        )
551        .traced()
552        .execute(&mut *self.conn)
553        .await?;
554
555        session.user_agent = Some(user_agent);
556
557        DatabaseError::ensure_affected_rows(&res, 1)?;
558
559        Ok(session)
560    }
561
562    #[tracing::instrument(
563        name = "repository.oauth2_session.set_human_name",
564        skip(self),
565        fields(
566            client.id = %session.client_id,
567            session.human_name = ?human_name,
568        ),
569        err,
570    )]
571    async fn set_human_name(
572        &mut self,
573        mut session: Session,
574        human_name: Option<String>,
575    ) -> Result<Session, Self::Error> {
576        let res = sqlx::query!(
577            r#"
578                UPDATE oauth2_sessions
579                SET human_name = $2
580                WHERE oauth2_session_id = $1
581            "#,
582            Uuid::from(session.id),
583            human_name.as_deref(),
584        )
585        .traced()
586        .execute(&mut *self.conn)
587        .await?;
588
589        session.human_name = human_name;
590
591        DatabaseError::ensure_affected_rows(&res, 1)?;
592
593        Ok(session)
594    }
595}