mas_storage_pg/compat/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    BrowserSession, Clock, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState,
13    Device, User,
14};
15use mas_storage::{
16    Page, Pagination,
17    compat::{CompatSessionFilter, CompatSessionRepository},
18    pagination::Node,
19};
20use rand::RngCore;
21use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
22use sea_query_binder::SqlxBinder;
23use sqlx::PgConnection;
24use ulid::Ulid;
25use url::Url;
26use uuid::Uuid;
27
28use crate::{
29    DatabaseError, DatabaseInconsistencyError,
30    filter::{Filter, StatementExt, StatementWithJoinsExt},
31    iden::{CompatSessions, CompatSsoLogins, UserSessions},
32    pagination::QueryBuilderExt,
33    tracing::ExecuteExt,
34};
35
36/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
37pub struct PgCompatSessionRepository<'c> {
38    conn: &'c mut PgConnection,
39}
40
41impl<'c> PgCompatSessionRepository<'c> {
42    /// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL
43    /// connection
44    pub fn new(conn: &'c mut PgConnection) -> Self {
45        Self { conn }
46    }
47}
48
49struct CompatSessionLookup {
50    compat_session_id: Uuid,
51    device_id: Option<String>,
52    human_name: Option<String>,
53    user_id: Uuid,
54    user_session_id: Option<Uuid>,
55    created_at: DateTime<Utc>,
56    finished_at: Option<DateTime<Utc>>,
57    is_synapse_admin: bool,
58    user_agent: Option<String>,
59    last_active_at: Option<DateTime<Utc>>,
60    last_active_ip: Option<IpAddr>,
61}
62
63impl Node<Ulid> for CompatSessionLookup {
64    fn cursor(&self) -> Ulid {
65        self.compat_session_id.into()
66    }
67}
68
69impl From<CompatSessionLookup> for CompatSession {
70    fn from(value: CompatSessionLookup) -> Self {
71        let id = value.compat_session_id.into();
72
73        let state = match value.finished_at {
74            None => CompatSessionState::Valid,
75            Some(finished_at) => CompatSessionState::Finished { finished_at },
76        };
77
78        CompatSession {
79            id,
80            state,
81            user_id: value.user_id.into(),
82            user_session_id: value.user_session_id.map(Ulid::from),
83            device: value.device_id.map(Device::from),
84            human_name: value.human_name,
85            created_at: value.created_at,
86            is_synapse_admin: value.is_synapse_admin,
87            user_agent: value.user_agent,
88            last_active_at: value.last_active_at,
89            last_active_ip: value.last_active_ip,
90        }
91    }
92}
93
94#[derive(sqlx::FromRow)]
95#[enum_def]
96struct CompatSessionAndSsoLoginLookup {
97    compat_session_id: Uuid,
98    device_id: Option<String>,
99    human_name: Option<String>,
100    user_id: Uuid,
101    user_session_id: Option<Uuid>,
102    created_at: DateTime<Utc>,
103    finished_at: Option<DateTime<Utc>>,
104    is_synapse_admin: bool,
105    user_agent: Option<String>,
106    last_active_at: Option<DateTime<Utc>>,
107    last_active_ip: Option<IpAddr>,
108    compat_sso_login_id: Option<Uuid>,
109    compat_sso_login_token: Option<String>,
110    compat_sso_login_redirect_uri: Option<String>,
111    compat_sso_login_created_at: Option<DateTime<Utc>>,
112    compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
113    compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
114}
115
116impl Node<Ulid> for CompatSessionAndSsoLoginLookup {
117    fn cursor(&self) -> Ulid {
118        self.compat_session_id.into()
119    }
120}
121
122impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
123    type Error = DatabaseInconsistencyError;
124
125    fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
126        let id = value.compat_session_id.into();
127
128        let state = match value.finished_at {
129            None => CompatSessionState::Valid,
130            Some(finished_at) => CompatSessionState::Finished { finished_at },
131        };
132
133        let session = CompatSession {
134            id,
135            state,
136            user_id: value.user_id.into(),
137            device: value.device_id.map(Device::from),
138            human_name: value.human_name,
139            user_session_id: value.user_session_id.map(Ulid::from),
140            created_at: value.created_at,
141            is_synapse_admin: value.is_synapse_admin,
142            user_agent: value.user_agent,
143            last_active_at: value.last_active_at,
144            last_active_ip: value.last_active_ip,
145        };
146
147        match (
148            value.compat_sso_login_id,
149            value.compat_sso_login_token,
150            value.compat_sso_login_redirect_uri,
151            value.compat_sso_login_created_at,
152            value.compat_sso_login_fulfilled_at,
153            value.compat_sso_login_exchanged_at,
154        ) {
155            (None, None, None, None, None, None) => Ok((session, None)),
156            (
157                Some(id),
158                Some(login_token),
159                Some(redirect_uri),
160                Some(created_at),
161                fulfilled_at,
162                exchanged_at,
163            ) => {
164                let id = id.into();
165                let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
166                    DatabaseInconsistencyError::on("compat_sso_logins")
167                        .column("redirect_uri")
168                        .row(id)
169                        .source(e)
170                })?;
171
172                let state = match (fulfilled_at, exchanged_at) {
173                    (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
174                        fulfilled_at,
175                        exchanged_at,
176                        compat_session_id: session.id,
177                    },
178                    _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179                };
180
181                let login = CompatSsoLogin {
182                    id,
183                    redirect_uri,
184                    login_token,
185                    created_at,
186                    state,
187                };
188
189                Ok((session, Some(login)))
190            }
191            _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
192        }
193    }
194}
195
196impl Filter for CompatSessionFilter<'_> {
197    fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
198        sea_query::Condition::all()
199            .add_option(self.user().map(|user| {
200                Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
201            }))
202            .add_option(self.browser_session().map(|browser_session| {
203                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
204                    .eq(Uuid::from(browser_session.id))
205            }))
206            .add_option(self.browser_session_filter().map(|browser_session_filter| {
207                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)).in_subquery(
208                    Query::select()
209                        .expr(Expr::col((
210                            UserSessions::Table,
211                            UserSessions::UserSessionId,
212                        )))
213                        .apply_filter(browser_session_filter)
214                        .from(UserSessions::Table)
215                        .take(),
216                )
217            }))
218            .add_option(self.state().map(|state| {
219                if state.is_active() {
220                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
221                } else {
222                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
223                }
224            }))
225            .add_option(self.auth_type().map(|auth_type| {
226                // In in the SELECT to list sessions, we can rely on the JOINed table, whereas
227                // in other queries we need to do a subquery
228                if has_joins {
229                    if auth_type.is_sso_login() {
230                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
231                            .is_not_null()
232                    } else {
233                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
234                            .is_null()
235                    }
236                } else {
237                    // This builds either a:
238                    // `WHERE compat_session_id = ANY(...)`
239                    // or a `WHERE compat_session_id <> ALL(...)`
240                    let compat_sso_logins = Query::select()
241                        .expr(Expr::col((
242                            CompatSsoLogins::Table,
243                            CompatSsoLogins::CompatSessionId,
244                        )))
245                        .from(CompatSsoLogins::Table)
246                        .take();
247
248                    if auth_type.is_sso_login() {
249                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
250                            .eq(Expr::any(compat_sso_logins))
251                    } else {
252                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
253                            .ne(Expr::all(compat_sso_logins))
254                    }
255                }
256            }))
257            .add_option(self.last_active_after().map(|last_active_after| {
258                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
259                    .gt(last_active_after)
260            }))
261            .add_option(self.last_active_before().map(|last_active_before| {
262                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
263                    .lt(last_active_before)
264            }))
265            .add_option(self.device().map(|device| {
266                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
267            }))
268    }
269}
270
271#[async_trait]
272impl CompatSessionRepository for PgCompatSessionRepository<'_> {
273    type Error = DatabaseError;
274
275    #[tracing::instrument(
276        name = "db.compat_session.lookup",
277        skip_all,
278        fields(
279            db.query.text,
280            compat_session.id = %id,
281        ),
282        err,
283    )]
284    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
285        let res = sqlx::query_as!(
286            CompatSessionLookup,
287            r#"
288                SELECT compat_session_id
289                     , device_id
290                     , human_name
291                     , user_id
292                     , user_session_id
293                     , created_at
294                     , finished_at
295                     , is_synapse_admin
296                     , user_agent
297                     , last_active_at
298                     , last_active_ip as "last_active_ip: IpAddr"
299                FROM compat_sessions
300                WHERE compat_session_id = $1
301            "#,
302            Uuid::from(id),
303        )
304        .traced()
305        .fetch_optional(&mut *self.conn)
306        .await?;
307
308        let Some(res) = res else { return Ok(None) };
309
310        Ok(Some(res.into()))
311    }
312
313    #[tracing::instrument(
314        name = "db.compat_session.add",
315        skip_all,
316        fields(
317            db.query.text,
318            compat_session.id,
319            %user.id,
320            %user.username,
321            compat_session.device.id = device.as_str(),
322        ),
323        err,
324    )]
325    async fn add(
326        &mut self,
327        rng: &mut (dyn RngCore + Send),
328        clock: &dyn Clock,
329        user: &User,
330        device: Device,
331        browser_session: Option<&BrowserSession>,
332        is_synapse_admin: bool,
333        human_name: Option<String>,
334    ) -> Result<CompatSession, Self::Error> {
335        let created_at = clock.now();
336        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
337        tracing::Span::current().record("compat_session.id", tracing::field::display(id));
338
339        sqlx::query!(
340            r#"
341                INSERT INTO compat_sessions
342                    (compat_session_id, user_id, device_id,
343                     user_session_id, created_at, is_synapse_admin,
344                     human_name)
345                VALUES ($1, $2, $3, $4, $5, $6, $7)
346            "#,
347            Uuid::from(id),
348            Uuid::from(user.id),
349            device.as_str(),
350            browser_session.map(|s| Uuid::from(s.id)),
351            created_at,
352            is_synapse_admin,
353            human_name.as_deref(),
354        )
355        .traced()
356        .execute(&mut *self.conn)
357        .await?;
358
359        Ok(CompatSession {
360            id,
361            state: CompatSessionState::default(),
362            user_id: user.id,
363            device: Some(device),
364            human_name,
365            user_session_id: browser_session.map(|s| s.id),
366            created_at,
367            is_synapse_admin,
368            user_agent: None,
369            last_active_at: None,
370            last_active_ip: None,
371        })
372    }
373
374    #[tracing::instrument(
375        name = "db.compat_session.finish",
376        skip_all,
377        fields(
378            db.query.text,
379            %compat_session.id,
380            user.id = %compat_session.user_id,
381            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
382        ),
383        err,
384    )]
385    async fn finish(
386        &mut self,
387        clock: &dyn Clock,
388        compat_session: CompatSession,
389    ) -> Result<CompatSession, Self::Error> {
390        let finished_at = clock.now();
391
392        let res = sqlx::query!(
393            r#"
394                UPDATE compat_sessions cs
395                SET finished_at = $2
396                WHERE compat_session_id = $1
397            "#,
398            Uuid::from(compat_session.id),
399            finished_at,
400        )
401        .traced()
402        .execute(&mut *self.conn)
403        .await?;
404
405        DatabaseError::ensure_affected_rows(&res, 1)?;
406
407        let compat_session = compat_session
408            .finish(finished_at)
409            .map_err(DatabaseError::to_invalid_operation)?;
410
411        Ok(compat_session)
412    }
413
414    #[tracing::instrument(
415        name = "db.compat_session.finish_bulk",
416        skip_all,
417        fields(db.query.text),
418        err,
419    )]
420    async fn finish_bulk(
421        &mut self,
422        clock: &dyn Clock,
423        filter: CompatSessionFilter<'_>,
424    ) -> Result<usize, Self::Error> {
425        let finished_at = clock.now();
426        let (sql, arguments) = Query::update()
427            .table(CompatSessions::Table)
428            .value(CompatSessions::FinishedAt, finished_at)
429            .apply_filter(filter)
430            .build_sqlx(PostgresQueryBuilder);
431
432        let res = sqlx::query_with(&sql, arguments)
433            .traced()
434            .execute(&mut *self.conn)
435            .await?;
436
437        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
438    }
439
440    #[tracing::instrument(
441        name = "db.compat_session.list",
442        skip_all,
443        fields(
444            db.query.text,
445        ),
446        err,
447    )]
448    async fn list(
449        &mut self,
450        filter: CompatSessionFilter<'_>,
451        pagination: Pagination,
452    ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
453        let (sql, arguments) = Query::select()
454            .expr_as(
455                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
456                CompatSessionAndSsoLoginLookupIden::CompatSessionId,
457            )
458            .expr_as(
459                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
460                CompatSessionAndSsoLoginLookupIden::DeviceId,
461            )
462            .expr_as(
463                Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
464                CompatSessionAndSsoLoginLookupIden::HumanName,
465            )
466            .expr_as(
467                Expr::col((CompatSessions::Table, CompatSessions::UserId)),
468                CompatSessionAndSsoLoginLookupIden::UserId,
469            )
470            .expr_as(
471                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
472                CompatSessionAndSsoLoginLookupIden::UserSessionId,
473            )
474            .expr_as(
475                Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
476                CompatSessionAndSsoLoginLookupIden::CreatedAt,
477            )
478            .expr_as(
479                Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
480                CompatSessionAndSsoLoginLookupIden::FinishedAt,
481            )
482            .expr_as(
483                Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
484                CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
485            )
486            .expr_as(
487                Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
488                CompatSessionAndSsoLoginLookupIden::UserAgent,
489            )
490            .expr_as(
491                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
492                CompatSessionAndSsoLoginLookupIden::LastActiveAt,
493            )
494            .expr_as(
495                Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
496                CompatSessionAndSsoLoginLookupIden::LastActiveIp,
497            )
498            .expr_as(
499                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
500                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
501            )
502            .expr_as(
503                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
504                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
505            )
506            .expr_as(
507                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
508                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
509            )
510            .expr_as(
511                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
512                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
513            )
514            .expr_as(
515                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
516                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
517            )
518            .expr_as(
519                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
520                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
521            )
522            .from(CompatSessions::Table)
523            .left_join(
524                CompatSsoLogins::Table,
525                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
526                    .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
527            )
528            .apply_filter_with_joins(filter)
529            .generate_pagination(
530                (CompatSessions::Table, CompatSessions::CompatSessionId),
531                pagination,
532            )
533            .build_sqlx(PostgresQueryBuilder);
534
535        let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
536            .traced()
537            .fetch_all(&mut *self.conn)
538            .await?;
539
540        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
541
542        Ok(page)
543    }
544
545    #[tracing::instrument(
546        name = "db.compat_session.count",
547        skip_all,
548        fields(
549            db.query.text,
550        ),
551        err,
552    )]
553    async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
554        let (sql, arguments) = sea_query::Query::select()
555            .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
556            .from(CompatSessions::Table)
557            .apply_filter(filter)
558            .build_sqlx(PostgresQueryBuilder);
559
560        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
561            .traced()
562            .fetch_one(&mut *self.conn)
563            .await?;
564
565        count
566            .try_into()
567            .map_err(DatabaseError::to_invalid_operation)
568    }
569
570    #[tracing::instrument(
571        name = "db.compat_session.record_batch_activity",
572        skip_all,
573        fields(
574            db.query.text,
575        ),
576        err,
577    )]
578    async fn record_batch_activity(
579        &mut self,
580        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
581    ) -> Result<(), Self::Error> {
582        // Sort the activity by ID, so that when batching the updates, Postgres
583        // locks the rows in a stable order, preventing deadlocks
584        activities.sort_unstable();
585        let mut ids = Vec::with_capacity(activities.len());
586        let mut last_activities = Vec::with_capacity(activities.len());
587        let mut ips = Vec::with_capacity(activities.len());
588
589        for (id, last_activity, ip) in activities {
590            ids.push(Uuid::from(id));
591            last_activities.push(last_activity);
592            ips.push(ip);
593        }
594
595        let res = sqlx::query!(
596            r#"
597                UPDATE compat_sessions
598                SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
599                  , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
600                FROM (
601                    SELECT *
602                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
603                        AS t(compat_session_id, last_active_at, last_active_ip)
604                ) AS t
605                WHERE compat_sessions.compat_session_id = t.compat_session_id
606            "#,
607            &ids,
608            &last_activities,
609            &ips as &[Option<IpAddr>],
610        )
611        .traced()
612        .execute(&mut *self.conn)
613        .await?;
614
615        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
616
617        Ok(())
618    }
619
620    #[tracing::instrument(
621        name = "db.compat_session.record_user_agent",
622        skip_all,
623        fields(
624            db.query.text,
625            %compat_session.id,
626        ),
627        err,
628    )]
629    async fn record_user_agent(
630        &mut self,
631        mut compat_session: CompatSession,
632        user_agent: String,
633    ) -> Result<CompatSession, Self::Error> {
634        let res = sqlx::query!(
635            r#"
636            UPDATE compat_sessions
637            SET user_agent = $2
638            WHERE compat_session_id = $1
639        "#,
640            Uuid::from(compat_session.id),
641            &*user_agent,
642        )
643        .traced()
644        .execute(&mut *self.conn)
645        .await?;
646
647        compat_session.user_agent = Some(user_agent);
648
649        DatabaseError::ensure_affected_rows(&res, 1)?;
650
651        Ok(compat_session)
652    }
653
654    #[tracing::instrument(
655        name = "repository.compat_session.set_human_name",
656        skip(self),
657        fields(
658            compat_session.id = %compat_session.id,
659            compat_session.human_name = ?human_name,
660        ),
661        err,
662    )]
663    async fn set_human_name(
664        &mut self,
665        mut compat_session: CompatSession,
666        human_name: Option<String>,
667    ) -> Result<CompatSession, Self::Error> {
668        let res = sqlx::query!(
669            r#"
670            UPDATE compat_sessions
671            SET human_name = $2
672            WHERE compat_session_id = $1
673        "#,
674            Uuid::from(compat_session.id),
675            human_name.as_deref(),
676        )
677        .traced()
678        .execute(&mut *self.conn)
679        .await?;
680
681        compat_session.human_name = human_name;
682
683        DatabaseError::ensure_affected_rows(&res, 1)?;
684
685        Ok(compat_session)
686    }
687}