1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 BrowserSession, Clock, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11 UserRegistration,
12};
13use mas_storage::{
14 Page, Pagination,
15 pagination::Node,
16 user::{UserEmailFilter, UserEmailRepository},
17};
18use rand::RngCore;
19use sea_query::{Expr, Func, PostgresQueryBuilder, Query, SimpleExpr, enum_def};
20use sea_query_binder::SqlxBinder;
21use sqlx::PgConnection;
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26 DatabaseError,
27 filter::{Filter, StatementExt},
28 iden::UserEmails,
29 pagination::QueryBuilderExt,
30 tracing::ExecuteExt,
31};
32
33pub struct PgUserEmailRepository<'c> {
35 conn: &'c mut PgConnection,
36}
37
38impl<'c> PgUserEmailRepository<'c> {
39 pub fn new(conn: &'c mut PgConnection) -> Self {
42 Self { conn }
43 }
44}
45
46#[derive(Debug, Clone, sqlx::FromRow)]
47#[enum_def]
48struct UserEmailLookup {
49 user_email_id: Uuid,
50 user_id: Uuid,
51 email: String,
52 created_at: DateTime<Utc>,
53}
54
55impl Node<Ulid> for UserEmailLookup {
56 fn cursor(&self) -> Ulid {
57 self.user_email_id.into()
58 }
59}
60
61impl From<UserEmailLookup> for UserEmail {
62 fn from(e: UserEmailLookup) -> UserEmail {
63 UserEmail {
64 id: e.user_email_id.into(),
65 user_id: e.user_id.into(),
66 email: e.email,
67 created_at: e.created_at,
68 }
69 }
70}
71
72struct UserEmailAuthenticationLookup {
73 user_email_authentication_id: Uuid,
74 user_session_id: Option<Uuid>,
75 user_registration_id: Option<Uuid>,
76 email: String,
77 created_at: DateTime<Utc>,
78 completed_at: Option<DateTime<Utc>>,
79}
80
81impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
82 fn from(value: UserEmailAuthenticationLookup) -> Self {
83 UserEmailAuthentication {
84 id: value.user_email_authentication_id.into(),
85 user_session_id: value.user_session_id.map(Ulid::from),
86 user_registration_id: value.user_registration_id.map(Ulid::from),
87 email: value.email,
88 created_at: value.created_at,
89 completed_at: value.completed_at,
90 }
91 }
92}
93
94struct UserEmailAuthenticationCodeLookup {
95 user_email_authentication_code_id: Uuid,
96 user_email_authentication_id: Uuid,
97 code: String,
98 created_at: DateTime<Utc>,
99 expires_at: DateTime<Utc>,
100}
101
102impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
103 fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
104 UserEmailAuthenticationCode {
105 id: value.user_email_authentication_code_id.into(),
106 user_email_authentication_id: value.user_email_authentication_id.into(),
107 code: value.code,
108 created_at: value.created_at,
109 expires_at: value.expires_at,
110 }
111 }
112}
113
114impl Filter for UserEmailFilter<'_> {
115 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
116 sea_query::Condition::all()
117 .add_option(self.user().map(|user| {
118 Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
119 }))
120 .add_option(self.email().map(|email| {
121 SimpleExpr::from(Func::lower(Expr::col((
122 UserEmails::Table,
123 UserEmails::Email,
124 ))))
125 .eq(Func::lower(email))
126 }))
127 }
128}
129
130#[async_trait]
131impl UserEmailRepository for PgUserEmailRepository<'_> {
132 type Error = DatabaseError;
133
134 #[tracing::instrument(
135 name = "db.user_email.lookup",
136 skip_all,
137 fields(
138 db.query.text,
139 user_email.id = %id,
140 ),
141 err,
142 )]
143 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
144 let res = sqlx::query_as!(
145 UserEmailLookup,
146 r#"
147 SELECT user_email_id
148 , user_id
149 , email
150 , created_at
151 FROM user_emails
152
153 WHERE user_email_id = $1
154 "#,
155 Uuid::from(id),
156 )
157 .traced()
158 .fetch_optional(&mut *self.conn)
159 .await?;
160
161 let Some(user_email) = res else {
162 return Ok(None);
163 };
164
165 Ok(Some(user_email.into()))
166 }
167
168 #[tracing::instrument(
169 name = "db.user_email.find",
170 skip_all,
171 fields(
172 db.query.text,
173 %user.id,
174 user_email.email = email,
175 ),
176 err,
177 )]
178 async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
179 let res = sqlx::query_as!(
180 UserEmailLookup,
181 r#"
182 SELECT user_email_id
183 , user_id
184 , email
185 , created_at
186 FROM user_emails
187
188 WHERE user_id = $1 AND LOWER(email) = LOWER($2)
189 "#,
190 Uuid::from(user.id),
191 email,
192 )
193 .traced()
194 .fetch_optional(&mut *self.conn)
195 .await?;
196
197 let Some(user_email) = res else {
198 return Ok(None);
199 };
200
201 Ok(Some(user_email.into()))
202 }
203
204 #[tracing::instrument(
205 name = "db.user_email.find_by_email",
206 skip_all,
207 fields(
208 db.query.text,
209 user_email.email = email,
210 ),
211 err,
212 )]
213 async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
214 let res = sqlx::query_as!(
215 UserEmailLookup,
216 r#"
217 SELECT user_email_id
218 , user_id
219 , email
220 , created_at
221 FROM user_emails
222 WHERE LOWER(email) = LOWER($1)
223 "#,
224 email,
225 )
226 .traced()
227 .fetch_all(&mut *self.conn)
228 .await?;
229
230 if res.len() != 1 {
231 return Ok(None);
232 }
233
234 let Some(user_email) = res.into_iter().next() else {
235 return Ok(None);
236 };
237
238 Ok(Some(user_email.into()))
239 }
240
241 #[tracing::instrument(
242 name = "db.user_email.all",
243 skip_all,
244 fields(
245 db.query.text,
246 %user.id,
247 ),
248 err,
249 )]
250 async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
251 let res = sqlx::query_as!(
252 UserEmailLookup,
253 r#"
254 SELECT user_email_id
255 , user_id
256 , email
257 , created_at
258 FROM user_emails
259
260 WHERE user_id = $1
261
262 ORDER BY email ASC
263 "#,
264 Uuid::from(user.id),
265 )
266 .traced()
267 .fetch_all(&mut *self.conn)
268 .await?;
269
270 Ok(res.into_iter().map(Into::into).collect())
271 }
272
273 #[tracing::instrument(
274 name = "db.user_email.list",
275 skip_all,
276 fields(
277 db.query.text,
278 ),
279 err,
280 )]
281 async fn list(
282 &mut self,
283 filter: UserEmailFilter<'_>,
284 pagination: Pagination,
285 ) -> Result<Page<UserEmail>, DatabaseError> {
286 let (sql, arguments) = Query::select()
287 .expr_as(
288 Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
289 UserEmailLookupIden::UserEmailId,
290 )
291 .expr_as(
292 Expr::col((UserEmails::Table, UserEmails::UserId)),
293 UserEmailLookupIden::UserId,
294 )
295 .expr_as(
296 Expr::col((UserEmails::Table, UserEmails::Email)),
297 UserEmailLookupIden::Email,
298 )
299 .expr_as(
300 Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
301 UserEmailLookupIden::CreatedAt,
302 )
303 .from(UserEmails::Table)
304 .apply_filter(filter)
305 .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
306 .build_sqlx(PostgresQueryBuilder);
307
308 let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
309 .traced()
310 .fetch_all(&mut *self.conn)
311 .await?;
312
313 let page = pagination.process(edges).map(UserEmail::from);
314
315 Ok(page)
316 }
317
318 #[tracing::instrument(
319 name = "db.user_email.count",
320 skip_all,
321 fields(
322 db.query.text,
323 ),
324 err,
325 )]
326 async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
327 let (sql, arguments) = Query::select()
328 .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
329 .from(UserEmails::Table)
330 .apply_filter(filter)
331 .build_sqlx(PostgresQueryBuilder);
332
333 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
334 .traced()
335 .fetch_one(&mut *self.conn)
336 .await?;
337
338 count
339 .try_into()
340 .map_err(DatabaseError::to_invalid_operation)
341 }
342
343 #[tracing::instrument(
344 name = "db.user_email.add",
345 skip_all,
346 fields(
347 db.query.text,
348 %user.id,
349 user_email.id,
350 user_email.email = email,
351 ),
352 err,
353 )]
354 async fn add(
355 &mut self,
356 rng: &mut (dyn RngCore + Send),
357 clock: &dyn Clock,
358 user: &User,
359 email: String,
360 ) -> Result<UserEmail, Self::Error> {
361 let created_at = clock.now();
362 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
363 tracing::Span::current().record("user_email.id", tracing::field::display(id));
364
365 sqlx::query!(
366 r#"
367 INSERT INTO user_emails (user_email_id, user_id, email, created_at)
368 VALUES ($1, $2, $3, $4)
369 "#,
370 Uuid::from(id),
371 Uuid::from(user.id),
372 &email,
373 created_at,
374 )
375 .traced()
376 .execute(&mut *self.conn)
377 .await?;
378
379 Ok(UserEmail {
380 id,
381 user_id: user.id,
382 email,
383 created_at,
384 })
385 }
386
387 #[tracing::instrument(
388 name = "db.user_email.remove",
389 skip_all,
390 fields(
391 db.query.text,
392 user.id = %user_email.user_id,
393 %user_email.id,
394 %user_email.email,
395 ),
396 err,
397 )]
398 async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
399 let res = sqlx::query!(
400 r#"
401 DELETE FROM user_emails
402 WHERE user_email_id = $1
403 "#,
404 Uuid::from(user_email.id),
405 )
406 .traced()
407 .execute(&mut *self.conn)
408 .await?;
409
410 DatabaseError::ensure_affected_rows(&res, 1)?;
411
412 Ok(())
413 }
414
415 #[tracing::instrument(
416 name = "db.user_email.remove_bulk",
417 skip_all,
418 fields(
419 db.query.text,
420 ),
421 err,
422 )]
423 async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
424 let (sql, arguments) = Query::delete()
425 .from_table(UserEmails::Table)
426 .apply_filter(filter)
427 .build_sqlx(PostgresQueryBuilder);
428
429 let res = sqlx::query_with(&sql, arguments)
430 .traced()
431 .execute(&mut *self.conn)
432 .await?;
433
434 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
435 }
436
437 #[tracing::instrument(
438 name = "db.user_email.add_authentication_for_session",
439 skip_all,
440 fields(
441 db.query.text,
442 %session.id,
443 user_email_authentication.id,
444 user_email_authentication.email = email,
445 ),
446 err,
447 )]
448 async fn add_authentication_for_session(
449 &mut self,
450 rng: &mut (dyn RngCore + Send),
451 clock: &dyn Clock,
452 email: String,
453 session: &BrowserSession,
454 ) -> Result<UserEmailAuthentication, Self::Error> {
455 let created_at = clock.now();
456 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
457 tracing::Span::current()
458 .record("user_email_authentication.id", tracing::field::display(id));
459
460 sqlx::query!(
461 r#"
462 INSERT INTO user_email_authentications
463 ( user_email_authentication_id
464 , user_session_id
465 , email
466 , created_at
467 )
468 VALUES ($1, $2, $3, $4)
469 "#,
470 Uuid::from(id),
471 Uuid::from(session.id),
472 &email,
473 created_at,
474 )
475 .traced()
476 .execute(&mut *self.conn)
477 .await?;
478
479 Ok(UserEmailAuthentication {
480 id,
481 user_session_id: Some(session.id),
482 user_registration_id: None,
483 email,
484 created_at,
485 completed_at: None,
486 })
487 }
488
489 #[tracing::instrument(
490 name = "db.user_email.add_authentication_for_registration",
491 skip_all,
492 fields(
493 db.query.text,
494 %user_registration.id,
495 user_email_authentication.id,
496 user_email_authentication.email = email,
497 ),
498 err,
499 )]
500 async fn add_authentication_for_registration(
501 &mut self,
502 rng: &mut (dyn RngCore + Send),
503 clock: &dyn Clock,
504 email: String,
505 user_registration: &UserRegistration,
506 ) -> Result<UserEmailAuthentication, Self::Error> {
507 let created_at = clock.now();
508 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
509 tracing::Span::current()
510 .record("user_email_authentication.id", tracing::field::display(id));
511
512 sqlx::query!(
513 r#"
514 INSERT INTO user_email_authentications
515 ( user_email_authentication_id
516 , user_registration_id
517 , email
518 , created_at
519 )
520 VALUES ($1, $2, $3, $4)
521 "#,
522 Uuid::from(id),
523 Uuid::from(user_registration.id),
524 &email,
525 created_at,
526 )
527 .traced()
528 .execute(&mut *self.conn)
529 .await?;
530
531 Ok(UserEmailAuthentication {
532 id,
533 user_session_id: None,
534 user_registration_id: Some(user_registration.id),
535 email,
536 created_at,
537 completed_at: None,
538 })
539 }
540
541 #[tracing::instrument(
542 name = "db.user_email.add_authentication_code",
543 skip_all,
544 fields(
545 db.query.text,
546 %user_email_authentication.id,
547 %user_email_authentication.email,
548 user_email_authentication_code.id,
549 user_email_authentication_code.code = code,
550 ),
551 err,
552 )]
553 async fn add_authentication_code(
554 &mut self,
555 rng: &mut (dyn RngCore + Send),
556 clock: &dyn Clock,
557 duration: chrono::Duration,
558 user_email_authentication: &UserEmailAuthentication,
559 code: String,
560 ) -> Result<UserEmailAuthenticationCode, Self::Error> {
561 let created_at = clock.now();
562 let expires_at = created_at + duration;
563 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
564 tracing::Span::current().record(
565 "user_email_authentication_code.id",
566 tracing::field::display(id),
567 );
568
569 sqlx::query!(
570 r#"
571 INSERT INTO user_email_authentication_codes
572 ( user_email_authentication_code_id
573 , user_email_authentication_id
574 , code
575 , created_at
576 , expires_at
577 )
578 VALUES ($1, $2, $3, $4, $5)
579 "#,
580 Uuid::from(id),
581 Uuid::from(user_email_authentication.id),
582 &code,
583 created_at,
584 expires_at,
585 )
586 .traced()
587 .execute(&mut *self.conn)
588 .await?;
589
590 Ok(UserEmailAuthenticationCode {
591 id,
592 user_email_authentication_id: user_email_authentication.id,
593 code,
594 created_at,
595 expires_at,
596 })
597 }
598
599 #[tracing::instrument(
600 name = "db.user_email.lookup_authentication",
601 skip_all,
602 fields(
603 db.query.text,
604 user_email_authentication.id = %id,
605 ),
606 err,
607 )]
608 async fn lookup_authentication(
609 &mut self,
610 id: Ulid,
611 ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
612 let res = sqlx::query_as!(
613 UserEmailAuthenticationLookup,
614 r#"
615 SELECT user_email_authentication_id
616 , user_session_id
617 , user_registration_id
618 , email
619 , created_at
620 , completed_at
621 FROM user_email_authentications
622 WHERE user_email_authentication_id = $1
623 "#,
624 Uuid::from(id),
625 )
626 .traced()
627 .fetch_optional(&mut *self.conn)
628 .await?;
629
630 Ok(res.map(UserEmailAuthentication::from))
631 }
632
633 #[tracing::instrument(
634 name = "db.user_email.find_authentication_by_code",
635 skip_all,
636 fields(
637 db.query.text,
638 %authentication.id,
639 user_email_authentication_code.code = code,
640 ),
641 err,
642 )]
643 async fn find_authentication_code(
644 &mut self,
645 authentication: &UserEmailAuthentication,
646 code: &str,
647 ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
648 let res = sqlx::query_as!(
649 UserEmailAuthenticationCodeLookup,
650 r#"
651 SELECT user_email_authentication_code_id
652 , user_email_authentication_id
653 , code
654 , created_at
655 , expires_at
656 FROM user_email_authentication_codes
657 WHERE user_email_authentication_id = $1
658 AND code = $2
659 "#,
660 Uuid::from(authentication.id),
661 code,
662 )
663 .traced()
664 .fetch_optional(&mut *self.conn)
665 .await?;
666
667 Ok(res.map(UserEmailAuthenticationCode::from))
668 }
669
670 #[tracing::instrument(
671 name = "db.user_email.complete_email_authentication",
672 skip_all,
673 fields(
674 db.query.text,
675 %user_email_authentication.id,
676 %user_email_authentication.email,
677 %user_email_authentication_code.id,
678 %user_email_authentication_code.code,
679 ),
680 err,
681 )]
682 async fn complete_authentication(
683 &mut self,
684 clock: &dyn Clock,
685 mut user_email_authentication: UserEmailAuthentication,
686 user_email_authentication_code: &UserEmailAuthenticationCode,
687 ) -> Result<UserEmailAuthentication, Self::Error> {
688 let completed_at = clock.now();
692
693 let res = sqlx::query!(
697 r#"
698 UPDATE user_email_authentications
699 SET completed_at = $2
700 WHERE user_email_authentication_id = $1
701 AND completed_at IS NULL
702 "#,
703 Uuid::from(user_email_authentication.id),
704 completed_at,
705 )
706 .traced()
707 .execute(&mut *self.conn)
708 .await?;
709
710 DatabaseError::ensure_affected_rows(&res, 1)?;
711
712 user_email_authentication.completed_at = Some(completed_at);
713 Ok(user_email_authentication)
714 }
715}