mas_storage_pg/user/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    Authentication, AuthenticationMethod, BrowserSession, Clock, Password,
13    UpstreamOAuthAuthorizationSession, User,
14};
15use mas_storage::{
16    Page, Pagination,
17    pagination::Node,
18    user::{BrowserSessionFilter, BrowserSessionRepository},
19};
20use rand::RngCore;
21use sea_query::{Expr, PostgresQueryBuilder, Query};
22use sea_query_binder::SqlxBinder;
23use sqlx::PgConnection;
24use ulid::Ulid;
25use uuid::Uuid;
26
27use crate::{
28    DatabaseError, DatabaseInconsistencyError,
29    filter::StatementExt,
30    iden::{UpstreamOAuthAuthorizationSessions, UserSessionAuthentications, UserSessions, Users},
31    pagination::QueryBuilderExt,
32    tracing::ExecuteExt,
33};
34
35/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
36/// connection
37pub struct PgBrowserSessionRepository<'c> {
38    conn: &'c mut PgConnection,
39}
40
41impl<'c> PgBrowserSessionRepository<'c> {
42    /// Create a new [`PgBrowserSessionRepository`] from an active PostgreSQL
43    /// connection
44    pub fn new(conn: &'c mut PgConnection) -> Self {
45        Self { conn }
46    }
47}
48
49#[allow(clippy::struct_field_names)]
50#[derive(sqlx::FromRow)]
51#[sea_query::enum_def]
52struct SessionLookup {
53    user_session_id: Uuid,
54    user_session_created_at: DateTime<Utc>,
55    user_session_finished_at: Option<DateTime<Utc>>,
56    user_session_user_agent: Option<String>,
57    user_session_last_active_at: Option<DateTime<Utc>>,
58    user_session_last_active_ip: Option<IpAddr>,
59    user_id: Uuid,
60    user_username: String,
61    user_created_at: DateTime<Utc>,
62    user_locked_at: Option<DateTime<Utc>>,
63    user_deactivated_at: Option<DateTime<Utc>>,
64    user_can_request_admin: bool,
65    user_is_guest: bool,
66}
67
68impl Node<Ulid> for SessionLookup {
69    fn cursor(&self) -> Ulid {
70        self.user_id.into()
71    }
72}
73
74impl TryFrom<SessionLookup> for BrowserSession {
75    type Error = DatabaseInconsistencyError;
76
77    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
78        let id = Ulid::from(value.user_id);
79        let user = User {
80            id,
81            username: value.user_username,
82            sub: id.to_string(),
83            created_at: value.user_created_at,
84            locked_at: value.user_locked_at,
85            deactivated_at: value.user_deactivated_at,
86            can_request_admin: value.user_can_request_admin,
87            is_guest: value.user_is_guest,
88        };
89
90        Ok(BrowserSession {
91            id: value.user_session_id.into(),
92            user,
93            created_at: value.user_session_created_at,
94            finished_at: value.user_session_finished_at,
95            user_agent: value.user_session_user_agent,
96            last_active_at: value.user_session_last_active_at,
97            last_active_ip: value.user_session_last_active_ip,
98        })
99    }
100}
101
102struct AuthenticationLookup {
103    user_session_authentication_id: Uuid,
104    created_at: DateTime<Utc>,
105    user_password_id: Option<Uuid>,
106    upstream_oauth_authorization_session_id: Option<Uuid>,
107}
108
109impl TryFrom<AuthenticationLookup> for Authentication {
110    type Error = DatabaseInconsistencyError;
111
112    fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
113        let id = Ulid::from(value.user_session_authentication_id);
114        let authentication_method = match (
115            value.user_password_id.map(Into::into),
116            value
117                .upstream_oauth_authorization_session_id
118                .map(Into::into),
119        ) {
120            (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
121            (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
122                upstream_oauth2_session_id,
123            },
124            (None, None) => AuthenticationMethod::Unknown,
125            _ => {
126                return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
127            }
128        };
129
130        Ok(Authentication {
131            id,
132            created_at: value.created_at,
133            authentication_method,
134        })
135    }
136}
137
138impl crate::filter::Filter for BrowserSessionFilter<'_> {
139    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
140        sea_query::Condition::all()
141            .add_option(self.user().map(|user| {
142                Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
143            }))
144            .add_option(self.state().map(|state| {
145                if state.is_active() {
146                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
147                } else {
148                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
149                }
150            }))
151            .add_option(self.last_active_after().map(|last_active_after| {
152                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
153            }))
154            .add_option(self.last_active_before().map(|last_active_before| {
155                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
156            }))
157            .add_option(self.authenticated_by_upstream_sessions().map(|filter| {
158                // For filtering by upstream sessions, we need to hop over the
159                // `user_session_authentications` table
160                let join_expr = Expr::col((
161                    UserSessionAuthentications::Table,
162                    UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
163                ))
164                .eq(Expr::col((
165                    UpstreamOAuthAuthorizationSessions::Table,
166                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
167                )));
168
169                Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
170                    Query::select()
171                        .expr(Expr::col((
172                            UserSessionAuthentications::Table,
173                            UserSessionAuthentications::UserSessionId,
174                        )))
175                        .from(UserSessionAuthentications::Table)
176                        .inner_join(UpstreamOAuthAuthorizationSessions::Table, join_expr)
177                        .apply_filter(filter)
178                        .take(),
179                )
180            }))
181    }
182}
183
184#[async_trait]
185impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
186    type Error = DatabaseError;
187
188    #[tracing::instrument(
189        name = "db.browser_session.lookup",
190        skip_all,
191        fields(
192            db.query.text,
193            user_session.id = %id,
194        ),
195        err,
196    )]
197    async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
198        let res = sqlx::query_as!(
199            SessionLookup,
200            r#"
201                SELECT s.user_session_id
202                     , s.created_at            AS "user_session_created_at"
203                     , s.finished_at           AS "user_session_finished_at"
204                     , s.user_agent            AS "user_session_user_agent"
205                     , s.last_active_at        AS "user_session_last_active_at"
206                     , s.last_active_ip        AS "user_session_last_active_ip: IpAddr"
207                     , u.user_id
208                     , u.username              AS "user_username"
209                     , u.created_at            AS "user_created_at"
210                     , u.locked_at             AS "user_locked_at"
211                     , u.deactivated_at        AS "user_deactivated_at"
212                     , u.can_request_admin     AS "user_can_request_admin"
213                     , u.is_guest              AS "user_is_guest"
214                FROM user_sessions s
215                INNER JOIN users u
216                    USING (user_id)
217                WHERE s.user_session_id = $1
218            "#,
219            Uuid::from(id),
220        )
221        .traced()
222        .fetch_optional(&mut *self.conn)
223        .await?;
224
225        let Some(res) = res else { return Ok(None) };
226
227        Ok(Some(res.try_into()?))
228    }
229
230    #[tracing::instrument(
231        name = "db.browser_session.add",
232        skip_all,
233        fields(
234            db.query.text,
235            %user.id,
236            user_session.id,
237        ),
238        err,
239    )]
240    async fn add(
241        &mut self,
242        rng: &mut (dyn RngCore + Send),
243        clock: &dyn Clock,
244        user: &User,
245        user_agent: Option<String>,
246    ) -> Result<BrowserSession, Self::Error> {
247        let created_at = clock.now();
248        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
249        tracing::Span::current().record("user_session.id", tracing::field::display(id));
250
251        sqlx::query!(
252            r#"
253                INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
254                VALUES ($1, $2, $3, $4)
255            "#,
256            Uuid::from(id),
257            Uuid::from(user.id),
258            created_at,
259            user_agent.as_deref(),
260        )
261        .traced()
262        .execute(&mut *self.conn)
263        .await?;
264
265        let session = BrowserSession {
266            id,
267            // XXX
268            user: user.clone(),
269            created_at,
270            finished_at: None,
271            user_agent,
272            last_active_at: None,
273            last_active_ip: None,
274        };
275
276        Ok(session)
277    }
278
279    #[tracing::instrument(
280        name = "db.browser_session.finish",
281        skip_all,
282        fields(
283            db.query.text,
284            %user_session.id,
285        ),
286        err,
287    )]
288    async fn finish(
289        &mut self,
290        clock: &dyn Clock,
291        mut user_session: BrowserSession,
292    ) -> Result<BrowserSession, Self::Error> {
293        let finished_at = clock.now();
294        let res = sqlx::query!(
295            r#"
296                UPDATE user_sessions
297                SET finished_at = $1
298                WHERE user_session_id = $2
299            "#,
300            finished_at,
301            Uuid::from(user_session.id),
302        )
303        .traced()
304        .execute(&mut *self.conn)
305        .await?;
306
307        user_session.finished_at = Some(finished_at);
308
309        DatabaseError::ensure_affected_rows(&res, 1)?;
310
311        Ok(user_session)
312    }
313
314    #[tracing::instrument(
315        name = "db.browser_session.finish_bulk",
316        skip_all,
317        fields(
318            db.query.text,
319        ),
320        err,
321    )]
322    async fn finish_bulk(
323        &mut self,
324        clock: &dyn Clock,
325        filter: BrowserSessionFilter<'_>,
326    ) -> Result<usize, Self::Error> {
327        let finished_at = clock.now();
328        let (sql, arguments) = sea_query::Query::update()
329            .table(UserSessions::Table)
330            .value(UserSessions::FinishedAt, finished_at)
331            .apply_filter(filter)
332            .build_sqlx(PostgresQueryBuilder);
333
334        let res = sqlx::query_with(&sql, arguments)
335            .traced()
336            .execute(&mut *self.conn)
337            .await?;
338
339        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
340    }
341
342    #[tracing::instrument(
343        name = "db.browser_session.list",
344        skip_all,
345        fields(
346            db.query.text,
347        ),
348        err,
349    )]
350    async fn list(
351        &mut self,
352        filter: BrowserSessionFilter<'_>,
353        pagination: Pagination,
354    ) -> Result<Page<BrowserSession>, Self::Error> {
355        let (sql, arguments) = sea_query::Query::select()
356            .expr_as(
357                Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
358                SessionLookupIden::UserSessionId,
359            )
360            .expr_as(
361                Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
362                SessionLookupIden::UserSessionCreatedAt,
363            )
364            .expr_as(
365                Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
366                SessionLookupIden::UserSessionFinishedAt,
367            )
368            .expr_as(
369                Expr::col((UserSessions::Table, UserSessions::UserAgent)),
370                SessionLookupIden::UserSessionUserAgent,
371            )
372            .expr_as(
373                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
374                SessionLookupIden::UserSessionLastActiveAt,
375            )
376            .expr_as(
377                Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
378                SessionLookupIden::UserSessionLastActiveIp,
379            )
380            .expr_as(
381                Expr::col((Users::Table, Users::UserId)),
382                SessionLookupIden::UserId,
383            )
384            .expr_as(
385                Expr::col((Users::Table, Users::Username)),
386                SessionLookupIden::UserUsername,
387            )
388            .expr_as(
389                Expr::col((Users::Table, Users::CreatedAt)),
390                SessionLookupIden::UserCreatedAt,
391            )
392            .expr_as(
393                Expr::col((Users::Table, Users::LockedAt)),
394                SessionLookupIden::UserLockedAt,
395            )
396            .expr_as(
397                Expr::col((Users::Table, Users::DeactivatedAt)),
398                SessionLookupIden::UserDeactivatedAt,
399            )
400            .expr_as(
401                Expr::col((Users::Table, Users::CanRequestAdmin)),
402                SessionLookupIden::UserCanRequestAdmin,
403            )
404            .expr_as(
405                Expr::col((Users::Table, Users::IsGuest)),
406                SessionLookupIden::UserIsGuest,
407            )
408            .from(UserSessions::Table)
409            .inner_join(
410                Users::Table,
411                Expr::col((UserSessions::Table, UserSessions::UserId))
412                    .equals((Users::Table, Users::UserId)),
413            )
414            .apply_filter(filter)
415            .generate_pagination(
416                (UserSessions::Table, UserSessions::UserSessionId),
417                pagination,
418            )
419            .build_sqlx(PostgresQueryBuilder);
420
421        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
422            .traced()
423            .fetch_all(&mut *self.conn)
424            .await?;
425
426        let page = pagination
427            .process(edges)
428            .try_map(BrowserSession::try_from)?;
429
430        Ok(page)
431    }
432
433    #[tracing::instrument(
434        name = "db.browser_session.count",
435        skip_all,
436        fields(
437            db.query.text,
438        ),
439        err,
440    )]
441    async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
442        let (sql, arguments) = sea_query::Query::select()
443            .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
444            .from(UserSessions::Table)
445            .apply_filter(filter)
446            .build_sqlx(PostgresQueryBuilder);
447
448        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
449            .traced()
450            .fetch_one(&mut *self.conn)
451            .await?;
452
453        count
454            .try_into()
455            .map_err(DatabaseError::to_invalid_operation)
456    }
457
458    #[tracing::instrument(
459        name = "db.browser_session.authenticate_with_password",
460        skip_all,
461        fields(
462            db.query.text,
463            %user_session.id,
464            %user_password.id,
465            user_session_authentication.id,
466        ),
467        err,
468    )]
469    async fn authenticate_with_password(
470        &mut self,
471        rng: &mut (dyn RngCore + Send),
472        clock: &dyn Clock,
473        user_session: &BrowserSession,
474        user_password: &Password,
475    ) -> Result<Authentication, Self::Error> {
476        let created_at = clock.now();
477        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
478        tracing::Span::current().record(
479            "user_session_authentication.id",
480            tracing::field::display(id),
481        );
482
483        sqlx::query!(
484            r#"
485                INSERT INTO user_session_authentications
486                    (user_session_authentication_id, user_session_id, created_at, user_password_id)
487                VALUES ($1, $2, $3, $4)
488            "#,
489            Uuid::from(id),
490            Uuid::from(user_session.id),
491            created_at,
492            Uuid::from(user_password.id),
493        )
494        .traced()
495        .execute(&mut *self.conn)
496        .await?;
497
498        Ok(Authentication {
499            id,
500            created_at,
501            authentication_method: AuthenticationMethod::Password {
502                user_password_id: user_password.id,
503            },
504        })
505    }
506
507    #[tracing::instrument(
508        name = "db.browser_session.authenticate_with_upstream",
509        skip_all,
510        fields(
511            db.query.text,
512            %user_session.id,
513            %upstream_oauth_session.id,
514            user_session_authentication.id,
515        ),
516        err,
517    )]
518    async fn authenticate_with_upstream(
519        &mut self,
520        rng: &mut (dyn RngCore + Send),
521        clock: &dyn Clock,
522        user_session: &BrowserSession,
523        upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
524    ) -> Result<Authentication, Self::Error> {
525        let created_at = clock.now();
526        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
527        tracing::Span::current().record(
528            "user_session_authentication.id",
529            tracing::field::display(id),
530        );
531
532        sqlx::query!(
533            r#"
534                INSERT INTO user_session_authentications
535                    (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
536                VALUES ($1, $2, $3, $4)
537            "#,
538            Uuid::from(id),
539            Uuid::from(user_session.id),
540            created_at,
541            Uuid::from(upstream_oauth_session.id),
542        )
543        .traced()
544        .execute(&mut *self.conn)
545        .await?;
546
547        Ok(Authentication {
548            id,
549            created_at,
550            authentication_method: AuthenticationMethod::UpstreamOAuth2 {
551                upstream_oauth2_session_id: upstream_oauth_session.id,
552            },
553        })
554    }
555
556    #[tracing::instrument(
557        name = "db.browser_session.get_last_authentication",
558        skip_all,
559        fields(
560            db.query.text,
561            %user_session.id,
562        ),
563        err,
564    )]
565    async fn get_last_authentication(
566        &mut self,
567        user_session: &BrowserSession,
568    ) -> Result<Option<Authentication>, Self::Error> {
569        let authentication = sqlx::query_as!(
570            AuthenticationLookup,
571            r#"
572                SELECT user_session_authentication_id
573                     , created_at
574                     , user_password_id
575                     , upstream_oauth_authorization_session_id
576                FROM user_session_authentications
577                WHERE user_session_id = $1
578                ORDER BY created_at DESC
579                LIMIT 1
580            "#,
581            Uuid::from(user_session.id),
582        )
583        .traced()
584        .fetch_optional(&mut *self.conn)
585        .await?;
586
587        let Some(authentication) = authentication else {
588            return Ok(None);
589        };
590
591        let authentication = Authentication::try_from(authentication)?;
592        Ok(Some(authentication))
593    }
594
595    #[tracing::instrument(
596        name = "db.browser_session.record_batch_activity",
597        skip_all,
598        fields(
599            db.query.text,
600        ),
601        err,
602    )]
603    async fn record_batch_activity(
604        &mut self,
605        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
606    ) -> Result<(), Self::Error> {
607        // Sort the activity by ID, so that when batching the updates, Postgres
608        // locks the rows in a stable order, preventing deadlocks
609        activities.sort_unstable();
610        let mut ids = Vec::with_capacity(activities.len());
611        let mut last_activities = Vec::with_capacity(activities.len());
612        let mut ips = Vec::with_capacity(activities.len());
613
614        for (id, last_activity, ip) in activities {
615            ids.push(Uuid::from(id));
616            last_activities.push(last_activity);
617            ips.push(ip);
618        }
619
620        let res = sqlx::query!(
621            r#"
622                UPDATE user_sessions
623                SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
624                  , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
625                FROM (
626                    SELECT *
627                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
628                        AS t(user_session_id, last_active_at, last_active_ip)
629                ) AS t
630                WHERE user_sessions.user_session_id = t.user_session_id
631            "#,
632            &ids,
633            &last_activities,
634            &ips as &[Option<IpAddr>],
635        )
636        .traced()
637        .execute(&mut *self.conn)
638        .await?;
639
640        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
641
642        Ok(())
643    }
644}