mas_storage_pg/user/
email.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11    UserRegistration,
12};
13use mas_storage::{
14    Page, Pagination,
15    pagination::Node,
16    user::{UserEmailFilter, UserEmailRepository},
17};
18use rand::RngCore;
19use sea_query::{Expr, Func, PostgresQueryBuilder, Query, SimpleExpr, enum_def};
20use sea_query_binder::SqlxBinder;
21use sqlx::PgConnection;
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26    DatabaseError,
27    filter::{Filter, StatementExt},
28    iden::UserEmails,
29    pagination::QueryBuilderExt,
30    tracing::ExecuteExt,
31};
32
33/// An implementation of [`UserEmailRepository`] for a PostgreSQL connection
34pub struct PgUserEmailRepository<'c> {
35    conn: &'c mut PgConnection,
36}
37
38impl<'c> PgUserEmailRepository<'c> {
39    /// Create a new [`PgUserEmailRepository`] from an active PostgreSQL
40    /// connection
41    pub fn new(conn: &'c mut PgConnection) -> Self {
42        Self { conn }
43    }
44}
45
46#[derive(Debug, Clone, sqlx::FromRow)]
47#[enum_def]
48struct UserEmailLookup {
49    user_email_id: Uuid,
50    user_id: Uuid,
51    email: String,
52    created_at: DateTime<Utc>,
53}
54
55impl Node<Ulid> for UserEmailLookup {
56    fn cursor(&self) -> Ulid {
57        self.user_email_id.into()
58    }
59}
60
61impl From<UserEmailLookup> for UserEmail {
62    fn from(e: UserEmailLookup) -> UserEmail {
63        UserEmail {
64            id: e.user_email_id.into(),
65            user_id: e.user_id.into(),
66            email: e.email,
67            created_at: e.created_at,
68        }
69    }
70}
71
72struct UserEmailAuthenticationLookup {
73    user_email_authentication_id: Uuid,
74    user_session_id: Option<Uuid>,
75    user_registration_id: Option<Uuid>,
76    email: String,
77    created_at: DateTime<Utc>,
78    completed_at: Option<DateTime<Utc>>,
79}
80
81impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
82    fn from(value: UserEmailAuthenticationLookup) -> Self {
83        UserEmailAuthentication {
84            id: value.user_email_authentication_id.into(),
85            user_session_id: value.user_session_id.map(Ulid::from),
86            user_registration_id: value.user_registration_id.map(Ulid::from),
87            email: value.email,
88            created_at: value.created_at,
89            completed_at: value.completed_at,
90        }
91    }
92}
93
94struct UserEmailAuthenticationCodeLookup {
95    user_email_authentication_code_id: Uuid,
96    user_email_authentication_id: Uuid,
97    code: String,
98    created_at: DateTime<Utc>,
99    expires_at: DateTime<Utc>,
100}
101
102impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
103    fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
104        UserEmailAuthenticationCode {
105            id: value.user_email_authentication_code_id.into(),
106            user_email_authentication_id: value.user_email_authentication_id.into(),
107            code: value.code,
108            created_at: value.created_at,
109            expires_at: value.expires_at,
110        }
111    }
112}
113
114impl Filter for UserEmailFilter<'_> {
115    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
116        sea_query::Condition::all()
117            .add_option(self.user().map(|user| {
118                Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
119            }))
120            .add_option(self.email().map(|email| {
121                SimpleExpr::from(Func::lower(Expr::col((
122                    UserEmails::Table,
123                    UserEmails::Email,
124                ))))
125                .eq(Func::lower(email))
126            }))
127    }
128}
129
130#[async_trait]
131impl UserEmailRepository for PgUserEmailRepository<'_> {
132    type Error = DatabaseError;
133
134    #[tracing::instrument(
135        name = "db.user_email.lookup",
136        skip_all,
137        fields(
138            db.query.text,
139            user_email.id = %id,
140        ),
141        err,
142    )]
143    async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
144        let res = sqlx::query_as!(
145            UserEmailLookup,
146            r#"
147                SELECT user_email_id
148                     , user_id
149                     , email
150                     , created_at
151                FROM user_emails
152
153                WHERE user_email_id = $1
154            "#,
155            Uuid::from(id),
156        )
157        .traced()
158        .fetch_optional(&mut *self.conn)
159        .await?;
160
161        let Some(user_email) = res else {
162            return Ok(None);
163        };
164
165        Ok(Some(user_email.into()))
166    }
167
168    #[tracing::instrument(
169        name = "db.user_email.find",
170        skip_all,
171        fields(
172            db.query.text,
173            %user.id,
174            user_email.email = email,
175        ),
176        err,
177    )]
178    async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
179        let res = sqlx::query_as!(
180            UserEmailLookup,
181            r#"
182                SELECT user_email_id
183                     , user_id
184                     , email
185                     , created_at
186                FROM user_emails
187
188                WHERE user_id = $1 AND LOWER(email) = LOWER($2)
189            "#,
190            Uuid::from(user.id),
191            email,
192        )
193        .traced()
194        .fetch_optional(&mut *self.conn)
195        .await?;
196
197        let Some(user_email) = res else {
198            return Ok(None);
199        };
200
201        Ok(Some(user_email.into()))
202    }
203
204    #[tracing::instrument(
205        name = "db.user_email.find_by_email",
206        skip_all,
207        fields(
208            db.query.text,
209            user_email.email = email,
210        ),
211        err,
212    )]
213    async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
214        let res = sqlx::query_as!(
215            UserEmailLookup,
216            r#"
217                SELECT user_email_id
218                     , user_id
219                     , email
220                     , created_at
221                FROM user_emails
222                WHERE LOWER(email) = LOWER($1)
223            "#,
224            email,
225        )
226        .traced()
227        .fetch_all(&mut *self.conn)
228        .await?;
229
230        if res.len() != 1 {
231            return Ok(None);
232        }
233
234        let Some(user_email) = res.into_iter().next() else {
235            return Ok(None);
236        };
237
238        Ok(Some(user_email.into()))
239    }
240
241    #[tracing::instrument(
242        name = "db.user_email.all",
243        skip_all,
244        fields(
245            db.query.text,
246            %user.id,
247        ),
248        err,
249    )]
250    async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
251        let res = sqlx::query_as!(
252            UserEmailLookup,
253            r#"
254                SELECT user_email_id
255                     , user_id
256                     , email
257                     , created_at
258                FROM user_emails
259
260                WHERE user_id = $1
261
262                ORDER BY email ASC
263            "#,
264            Uuid::from(user.id),
265        )
266        .traced()
267        .fetch_all(&mut *self.conn)
268        .await?;
269
270        Ok(res.into_iter().map(Into::into).collect())
271    }
272
273    #[tracing::instrument(
274        name = "db.user_email.list",
275        skip_all,
276        fields(
277            db.query.text,
278        ),
279        err,
280    )]
281    async fn list(
282        &mut self,
283        filter: UserEmailFilter<'_>,
284        pagination: Pagination,
285    ) -> Result<Page<UserEmail>, DatabaseError> {
286        let (sql, arguments) = Query::select()
287            .expr_as(
288                Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
289                UserEmailLookupIden::UserEmailId,
290            )
291            .expr_as(
292                Expr::col((UserEmails::Table, UserEmails::UserId)),
293                UserEmailLookupIden::UserId,
294            )
295            .expr_as(
296                Expr::col((UserEmails::Table, UserEmails::Email)),
297                UserEmailLookupIden::Email,
298            )
299            .expr_as(
300                Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
301                UserEmailLookupIden::CreatedAt,
302            )
303            .from(UserEmails::Table)
304            .apply_filter(filter)
305            .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
306            .build_sqlx(PostgresQueryBuilder);
307
308        let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
309            .traced()
310            .fetch_all(&mut *self.conn)
311            .await?;
312
313        let page = pagination.process(edges).map(UserEmail::from);
314
315        Ok(page)
316    }
317
318    #[tracing::instrument(
319        name = "db.user_email.count",
320        skip_all,
321        fields(
322            db.query.text,
323        ),
324        err,
325    )]
326    async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
327        let (sql, arguments) = Query::select()
328            .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
329            .from(UserEmails::Table)
330            .apply_filter(filter)
331            .build_sqlx(PostgresQueryBuilder);
332
333        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
334            .traced()
335            .fetch_one(&mut *self.conn)
336            .await?;
337
338        count
339            .try_into()
340            .map_err(DatabaseError::to_invalid_operation)
341    }
342
343    #[tracing::instrument(
344        name = "db.user_email.add",
345        skip_all,
346        fields(
347            db.query.text,
348            %user.id,
349            user_email.id,
350            user_email.email = email,
351        ),
352        err,
353    )]
354    async fn add(
355        &mut self,
356        rng: &mut (dyn RngCore + Send),
357        clock: &dyn Clock,
358        user: &User,
359        email: String,
360    ) -> Result<UserEmail, Self::Error> {
361        let created_at = clock.now();
362        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
363        tracing::Span::current().record("user_email.id", tracing::field::display(id));
364
365        sqlx::query!(
366            r#"
367                INSERT INTO user_emails (user_email_id, user_id, email, created_at)
368                VALUES ($1, $2, $3, $4)
369            "#,
370            Uuid::from(id),
371            Uuid::from(user.id),
372            &email,
373            created_at,
374        )
375        .traced()
376        .execute(&mut *self.conn)
377        .await?;
378
379        Ok(UserEmail {
380            id,
381            user_id: user.id,
382            email,
383            created_at,
384        })
385    }
386
387    #[tracing::instrument(
388        name = "db.user_email.remove",
389        skip_all,
390        fields(
391            db.query.text,
392            user.id = %user_email.user_id,
393            %user_email.id,
394            %user_email.email,
395        ),
396        err,
397    )]
398    async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
399        let res = sqlx::query!(
400            r#"
401                DELETE FROM user_emails
402                WHERE user_email_id = $1
403            "#,
404            Uuid::from(user_email.id),
405        )
406        .traced()
407        .execute(&mut *self.conn)
408        .await?;
409
410        DatabaseError::ensure_affected_rows(&res, 1)?;
411
412        Ok(())
413    }
414
415    #[tracing::instrument(
416        name = "db.user_email.remove_bulk",
417        skip_all,
418        fields(
419            db.query.text,
420        ),
421        err,
422    )]
423    async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
424        let (sql, arguments) = Query::delete()
425            .from_table(UserEmails::Table)
426            .apply_filter(filter)
427            .build_sqlx(PostgresQueryBuilder);
428
429        let res = sqlx::query_with(&sql, arguments)
430            .traced()
431            .execute(&mut *self.conn)
432            .await?;
433
434        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
435    }
436
437    #[tracing::instrument(
438        name = "db.user_email.add_authentication_for_session",
439        skip_all,
440        fields(
441            db.query.text,
442            %session.id,
443            user_email_authentication.id,
444            user_email_authentication.email = email,
445        ),
446        err,
447    )]
448    async fn add_authentication_for_session(
449        &mut self,
450        rng: &mut (dyn RngCore + Send),
451        clock: &dyn Clock,
452        email: String,
453        session: &BrowserSession,
454    ) -> Result<UserEmailAuthentication, Self::Error> {
455        let created_at = clock.now();
456        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
457        tracing::Span::current()
458            .record("user_email_authentication.id", tracing::field::display(id));
459
460        sqlx::query!(
461            r#"
462                INSERT INTO user_email_authentications
463                  ( user_email_authentication_id
464                  , user_session_id
465                  , email
466                  , created_at
467                  )
468                VALUES ($1, $2, $3, $4)
469            "#,
470            Uuid::from(id),
471            Uuid::from(session.id),
472            &email,
473            created_at,
474        )
475        .traced()
476        .execute(&mut *self.conn)
477        .await?;
478
479        Ok(UserEmailAuthentication {
480            id,
481            user_session_id: Some(session.id),
482            user_registration_id: None,
483            email,
484            created_at,
485            completed_at: None,
486        })
487    }
488
489    #[tracing::instrument(
490        name = "db.user_email.add_authentication_for_registration",
491        skip_all,
492        fields(
493            db.query.text,
494            %user_registration.id,
495            user_email_authentication.id,
496            user_email_authentication.email = email,
497        ),
498        err,
499    )]
500    async fn add_authentication_for_registration(
501        &mut self,
502        rng: &mut (dyn RngCore + Send),
503        clock: &dyn Clock,
504        email: String,
505        user_registration: &UserRegistration,
506    ) -> Result<UserEmailAuthentication, Self::Error> {
507        let created_at = clock.now();
508        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
509        tracing::Span::current()
510            .record("user_email_authentication.id", tracing::field::display(id));
511
512        sqlx::query!(
513            r#"
514                INSERT INTO user_email_authentications
515                  ( user_email_authentication_id
516                  , user_registration_id
517                  , email
518                  , created_at
519                  )
520                VALUES ($1, $2, $3, $4)
521            "#,
522            Uuid::from(id),
523            Uuid::from(user_registration.id),
524            &email,
525            created_at,
526        )
527        .traced()
528        .execute(&mut *self.conn)
529        .await?;
530
531        Ok(UserEmailAuthentication {
532            id,
533            user_session_id: None,
534            user_registration_id: Some(user_registration.id),
535            email,
536            created_at,
537            completed_at: None,
538        })
539    }
540
541    #[tracing::instrument(
542        name = "db.user_email.add_authentication_code",
543        skip_all,
544        fields(
545            db.query.text,
546            %user_email_authentication.id,
547            %user_email_authentication.email,
548            user_email_authentication_code.id,
549            user_email_authentication_code.code = code,
550        ),
551        err,
552    )]
553    async fn add_authentication_code(
554        &mut self,
555        rng: &mut (dyn RngCore + Send),
556        clock: &dyn Clock,
557        duration: chrono::Duration,
558        user_email_authentication: &UserEmailAuthentication,
559        code: String,
560    ) -> Result<UserEmailAuthenticationCode, Self::Error> {
561        let created_at = clock.now();
562        let expires_at = created_at + duration;
563        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
564        tracing::Span::current().record(
565            "user_email_authentication_code.id",
566            tracing::field::display(id),
567        );
568
569        sqlx::query!(
570            r#"
571                INSERT INTO user_email_authentication_codes
572                  ( user_email_authentication_code_id
573                  , user_email_authentication_id
574                  , code
575                  , created_at
576                  , expires_at
577                  )
578                VALUES ($1, $2, $3, $4, $5)
579            "#,
580            Uuid::from(id),
581            Uuid::from(user_email_authentication.id),
582            &code,
583            created_at,
584            expires_at,
585        )
586        .traced()
587        .execute(&mut *self.conn)
588        .await?;
589
590        Ok(UserEmailAuthenticationCode {
591            id,
592            user_email_authentication_id: user_email_authentication.id,
593            code,
594            created_at,
595            expires_at,
596        })
597    }
598
599    #[tracing::instrument(
600        name = "db.user_email.lookup_authentication",
601        skip_all,
602        fields(
603            db.query.text,
604            user_email_authentication.id = %id,
605        ),
606        err,
607    )]
608    async fn lookup_authentication(
609        &mut self,
610        id: Ulid,
611    ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
612        let res = sqlx::query_as!(
613            UserEmailAuthenticationLookup,
614            r#"
615                SELECT user_email_authentication_id
616                     , user_session_id
617                     , user_registration_id
618                     , email
619                     , created_at
620                     , completed_at
621                FROM user_email_authentications
622                WHERE user_email_authentication_id = $1
623            "#,
624            Uuid::from(id),
625        )
626        .traced()
627        .fetch_optional(&mut *self.conn)
628        .await?;
629
630        Ok(res.map(UserEmailAuthentication::from))
631    }
632
633    #[tracing::instrument(
634        name = "db.user_email.find_authentication_by_code",
635        skip_all,
636        fields(
637            db.query.text,
638            %authentication.id,
639            user_email_authentication_code.code = code,
640        ),
641        err,
642    )]
643    async fn find_authentication_code(
644        &mut self,
645        authentication: &UserEmailAuthentication,
646        code: &str,
647    ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
648        let res = sqlx::query_as!(
649            UserEmailAuthenticationCodeLookup,
650            r#"
651                SELECT user_email_authentication_code_id
652                     , user_email_authentication_id
653                     , code
654                     , created_at
655                     , expires_at
656                FROM user_email_authentication_codes
657                WHERE user_email_authentication_id = $1
658                  AND code = $2
659            "#,
660            Uuid::from(authentication.id),
661            code,
662        )
663        .traced()
664        .fetch_optional(&mut *self.conn)
665        .await?;
666
667        Ok(res.map(UserEmailAuthenticationCode::from))
668    }
669
670    #[tracing::instrument(
671        name = "db.user_email.complete_email_authentication",
672        skip_all,
673        fields(
674            db.query.text,
675            %user_email_authentication.id,
676            %user_email_authentication.email,
677            %user_email_authentication_code.id,
678            %user_email_authentication_code.code,
679        ),
680        err,
681    )]
682    async fn complete_authentication(
683        &mut self,
684        clock: &dyn Clock,
685        mut user_email_authentication: UserEmailAuthentication,
686        user_email_authentication_code: &UserEmailAuthenticationCode,
687    ) -> Result<UserEmailAuthentication, Self::Error> {
688        // We technically don't use the authentication code here (other than
689        // recording it in the span), but this is to make sure the caller has
690        // fetched one before calling this
691        let completed_at = clock.now();
692
693        // We'll assume the caller has checked that completed_at is None, so in case
694        // they haven't, the update will not affect any rows, which will raise
695        // an error
696        let res = sqlx::query!(
697            r#"
698                UPDATE user_email_authentications
699                SET completed_at = $2
700                WHERE user_email_authentication_id = $1
701                  AND completed_at IS NULL
702            "#,
703            Uuid::from(user_email_authentication.id),
704            completed_at,
705        )
706        .traced()
707        .execute(&mut *self.conn)
708        .await?;
709
710        DatabaseError::ensure_affected_rows(&res, 1)?;
711
712        user_email_authentication.completed_at = Some(completed_at);
713        Ok(user_email_authentication)
714    }
715}