1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 BrowserSession, Clock, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState,
13 Device, User,
14};
15use mas_storage::{
16 Page, Pagination,
17 compat::{CompatSessionFilter, CompatSessionRepository},
18 pagination::Node,
19};
20use rand::RngCore;
21use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
22use sea_query_binder::SqlxBinder;
23use sqlx::PgConnection;
24use ulid::Ulid;
25use url::Url;
26use uuid::Uuid;
27
28use crate::{
29 DatabaseError, DatabaseInconsistencyError,
30 filter::{Filter, StatementExt, StatementWithJoinsExt},
31 iden::{CompatSessions, CompatSsoLogins, UserSessions},
32 pagination::QueryBuilderExt,
33 tracing::ExecuteExt,
34};
35
36pub struct PgCompatSessionRepository<'c> {
38 conn: &'c mut PgConnection,
39}
40
41impl<'c> PgCompatSessionRepository<'c> {
42 pub fn new(conn: &'c mut PgConnection) -> Self {
45 Self { conn }
46 }
47}
48
49struct CompatSessionLookup {
50 compat_session_id: Uuid,
51 device_id: Option<String>,
52 human_name: Option<String>,
53 user_id: Uuid,
54 user_session_id: Option<Uuid>,
55 created_at: DateTime<Utc>,
56 finished_at: Option<DateTime<Utc>>,
57 is_synapse_admin: bool,
58 user_agent: Option<String>,
59 last_active_at: Option<DateTime<Utc>>,
60 last_active_ip: Option<IpAddr>,
61}
62
63impl Node<Ulid> for CompatSessionLookup {
64 fn cursor(&self) -> Ulid {
65 self.compat_session_id.into()
66 }
67}
68
69impl From<CompatSessionLookup> for CompatSession {
70 fn from(value: CompatSessionLookup) -> Self {
71 let id = value.compat_session_id.into();
72
73 let state = match value.finished_at {
74 None => CompatSessionState::Valid,
75 Some(finished_at) => CompatSessionState::Finished { finished_at },
76 };
77
78 CompatSession {
79 id,
80 state,
81 user_id: value.user_id.into(),
82 user_session_id: value.user_session_id.map(Ulid::from),
83 device: value.device_id.map(Device::from),
84 human_name: value.human_name,
85 created_at: value.created_at,
86 is_synapse_admin: value.is_synapse_admin,
87 user_agent: value.user_agent,
88 last_active_at: value.last_active_at,
89 last_active_ip: value.last_active_ip,
90 }
91 }
92}
93
94#[derive(sqlx::FromRow)]
95#[enum_def]
96struct CompatSessionAndSsoLoginLookup {
97 compat_session_id: Uuid,
98 device_id: Option<String>,
99 human_name: Option<String>,
100 user_id: Uuid,
101 user_session_id: Option<Uuid>,
102 created_at: DateTime<Utc>,
103 finished_at: Option<DateTime<Utc>>,
104 is_synapse_admin: bool,
105 user_agent: Option<String>,
106 last_active_at: Option<DateTime<Utc>>,
107 last_active_ip: Option<IpAddr>,
108 compat_sso_login_id: Option<Uuid>,
109 compat_sso_login_token: Option<String>,
110 compat_sso_login_redirect_uri: Option<String>,
111 compat_sso_login_created_at: Option<DateTime<Utc>>,
112 compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
113 compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
114}
115
116impl Node<Ulid> for CompatSessionAndSsoLoginLookup {
117 fn cursor(&self) -> Ulid {
118 self.compat_session_id.into()
119 }
120}
121
122impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
123 type Error = DatabaseInconsistencyError;
124
125 fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
126 let id = value.compat_session_id.into();
127
128 let state = match value.finished_at {
129 None => CompatSessionState::Valid,
130 Some(finished_at) => CompatSessionState::Finished { finished_at },
131 };
132
133 let session = CompatSession {
134 id,
135 state,
136 user_id: value.user_id.into(),
137 device: value.device_id.map(Device::from),
138 human_name: value.human_name,
139 user_session_id: value.user_session_id.map(Ulid::from),
140 created_at: value.created_at,
141 is_synapse_admin: value.is_synapse_admin,
142 user_agent: value.user_agent,
143 last_active_at: value.last_active_at,
144 last_active_ip: value.last_active_ip,
145 };
146
147 match (
148 value.compat_sso_login_id,
149 value.compat_sso_login_token,
150 value.compat_sso_login_redirect_uri,
151 value.compat_sso_login_created_at,
152 value.compat_sso_login_fulfilled_at,
153 value.compat_sso_login_exchanged_at,
154 ) {
155 (None, None, None, None, None, None) => Ok((session, None)),
156 (
157 Some(id),
158 Some(login_token),
159 Some(redirect_uri),
160 Some(created_at),
161 fulfilled_at,
162 exchanged_at,
163 ) => {
164 let id = id.into();
165 let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
166 DatabaseInconsistencyError::on("compat_sso_logins")
167 .column("redirect_uri")
168 .row(id)
169 .source(e)
170 })?;
171
172 let state = match (fulfilled_at, exchanged_at) {
173 (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
174 fulfilled_at,
175 exchanged_at,
176 compat_session_id: session.id,
177 },
178 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179 };
180
181 let login = CompatSsoLogin {
182 id,
183 redirect_uri,
184 login_token,
185 created_at,
186 state,
187 };
188
189 Ok((session, Some(login)))
190 }
191 _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
192 }
193 }
194}
195
196impl Filter for CompatSessionFilter<'_> {
197 fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
198 sea_query::Condition::all()
199 .add_option(self.user().map(|user| {
200 Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
201 }))
202 .add_option(self.browser_session().map(|browser_session| {
203 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
204 .eq(Uuid::from(browser_session.id))
205 }))
206 .add_option(self.browser_session_filter().map(|browser_session_filter| {
207 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)).in_subquery(
208 Query::select()
209 .expr(Expr::col((
210 UserSessions::Table,
211 UserSessions::UserSessionId,
212 )))
213 .apply_filter(browser_session_filter)
214 .from(UserSessions::Table)
215 .take(),
216 )
217 }))
218 .add_option(self.state().map(|state| {
219 if state.is_active() {
220 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
221 } else {
222 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
223 }
224 }))
225 .add_option(self.auth_type().map(|auth_type| {
226 if has_joins {
229 if auth_type.is_sso_login() {
230 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
231 .is_not_null()
232 } else {
233 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
234 .is_null()
235 }
236 } else {
237 let compat_sso_logins = Query::select()
241 .expr(Expr::col((
242 CompatSsoLogins::Table,
243 CompatSsoLogins::CompatSessionId,
244 )))
245 .from(CompatSsoLogins::Table)
246 .take();
247
248 if auth_type.is_sso_login() {
249 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
250 .eq(Expr::any(compat_sso_logins))
251 } else {
252 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
253 .ne(Expr::all(compat_sso_logins))
254 }
255 }
256 }))
257 .add_option(self.last_active_after().map(|last_active_after| {
258 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
259 .gt(last_active_after)
260 }))
261 .add_option(self.last_active_before().map(|last_active_before| {
262 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
263 .lt(last_active_before)
264 }))
265 .add_option(self.device().map(|device| {
266 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
267 }))
268 }
269}
270
271#[async_trait]
272impl CompatSessionRepository for PgCompatSessionRepository<'_> {
273 type Error = DatabaseError;
274
275 #[tracing::instrument(
276 name = "db.compat_session.lookup",
277 skip_all,
278 fields(
279 db.query.text,
280 compat_session.id = %id,
281 ),
282 err,
283 )]
284 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
285 let res = sqlx::query_as!(
286 CompatSessionLookup,
287 r#"
288 SELECT compat_session_id
289 , device_id
290 , human_name
291 , user_id
292 , user_session_id
293 , created_at
294 , finished_at
295 , is_synapse_admin
296 , user_agent
297 , last_active_at
298 , last_active_ip as "last_active_ip: IpAddr"
299 FROM compat_sessions
300 WHERE compat_session_id = $1
301 "#,
302 Uuid::from(id),
303 )
304 .traced()
305 .fetch_optional(&mut *self.conn)
306 .await?;
307
308 let Some(res) = res else { return Ok(None) };
309
310 Ok(Some(res.into()))
311 }
312
313 #[tracing::instrument(
314 name = "db.compat_session.add",
315 skip_all,
316 fields(
317 db.query.text,
318 compat_session.id,
319 %user.id,
320 %user.username,
321 compat_session.device.id = device.as_str(),
322 ),
323 err,
324 )]
325 async fn add(
326 &mut self,
327 rng: &mut (dyn RngCore + Send),
328 clock: &dyn Clock,
329 user: &User,
330 device: Device,
331 browser_session: Option<&BrowserSession>,
332 is_synapse_admin: bool,
333 human_name: Option<String>,
334 ) -> Result<CompatSession, Self::Error> {
335 let created_at = clock.now();
336 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
337 tracing::Span::current().record("compat_session.id", tracing::field::display(id));
338
339 sqlx::query!(
340 r#"
341 INSERT INTO compat_sessions
342 (compat_session_id, user_id, device_id,
343 user_session_id, created_at, is_synapse_admin,
344 human_name)
345 VALUES ($1, $2, $3, $4, $5, $6, $7)
346 "#,
347 Uuid::from(id),
348 Uuid::from(user.id),
349 device.as_str(),
350 browser_session.map(|s| Uuid::from(s.id)),
351 created_at,
352 is_synapse_admin,
353 human_name.as_deref(),
354 )
355 .traced()
356 .execute(&mut *self.conn)
357 .await?;
358
359 Ok(CompatSession {
360 id,
361 state: CompatSessionState::default(),
362 user_id: user.id,
363 device: Some(device),
364 human_name,
365 user_session_id: browser_session.map(|s| s.id),
366 created_at,
367 is_synapse_admin,
368 user_agent: None,
369 last_active_at: None,
370 last_active_ip: None,
371 })
372 }
373
374 #[tracing::instrument(
375 name = "db.compat_session.finish",
376 skip_all,
377 fields(
378 db.query.text,
379 %compat_session.id,
380 user.id = %compat_session.user_id,
381 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
382 ),
383 err,
384 )]
385 async fn finish(
386 &mut self,
387 clock: &dyn Clock,
388 compat_session: CompatSession,
389 ) -> Result<CompatSession, Self::Error> {
390 let finished_at = clock.now();
391
392 let res = sqlx::query!(
393 r#"
394 UPDATE compat_sessions cs
395 SET finished_at = $2
396 WHERE compat_session_id = $1
397 "#,
398 Uuid::from(compat_session.id),
399 finished_at,
400 )
401 .traced()
402 .execute(&mut *self.conn)
403 .await?;
404
405 DatabaseError::ensure_affected_rows(&res, 1)?;
406
407 let compat_session = compat_session
408 .finish(finished_at)
409 .map_err(DatabaseError::to_invalid_operation)?;
410
411 Ok(compat_session)
412 }
413
414 #[tracing::instrument(
415 name = "db.compat_session.finish_bulk",
416 skip_all,
417 fields(db.query.text),
418 err,
419 )]
420 async fn finish_bulk(
421 &mut self,
422 clock: &dyn Clock,
423 filter: CompatSessionFilter<'_>,
424 ) -> Result<usize, Self::Error> {
425 let finished_at = clock.now();
426 let (sql, arguments) = Query::update()
427 .table(CompatSessions::Table)
428 .value(CompatSessions::FinishedAt, finished_at)
429 .apply_filter(filter)
430 .build_sqlx(PostgresQueryBuilder);
431
432 let res = sqlx::query_with(&sql, arguments)
433 .traced()
434 .execute(&mut *self.conn)
435 .await?;
436
437 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
438 }
439
440 #[tracing::instrument(
441 name = "db.compat_session.list",
442 skip_all,
443 fields(
444 db.query.text,
445 ),
446 err,
447 )]
448 async fn list(
449 &mut self,
450 filter: CompatSessionFilter<'_>,
451 pagination: Pagination,
452 ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
453 let (sql, arguments) = Query::select()
454 .expr_as(
455 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
456 CompatSessionAndSsoLoginLookupIden::CompatSessionId,
457 )
458 .expr_as(
459 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
460 CompatSessionAndSsoLoginLookupIden::DeviceId,
461 )
462 .expr_as(
463 Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
464 CompatSessionAndSsoLoginLookupIden::HumanName,
465 )
466 .expr_as(
467 Expr::col((CompatSessions::Table, CompatSessions::UserId)),
468 CompatSessionAndSsoLoginLookupIden::UserId,
469 )
470 .expr_as(
471 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
472 CompatSessionAndSsoLoginLookupIden::UserSessionId,
473 )
474 .expr_as(
475 Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
476 CompatSessionAndSsoLoginLookupIden::CreatedAt,
477 )
478 .expr_as(
479 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
480 CompatSessionAndSsoLoginLookupIden::FinishedAt,
481 )
482 .expr_as(
483 Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
484 CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
485 )
486 .expr_as(
487 Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
488 CompatSessionAndSsoLoginLookupIden::UserAgent,
489 )
490 .expr_as(
491 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
492 CompatSessionAndSsoLoginLookupIden::LastActiveAt,
493 )
494 .expr_as(
495 Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
496 CompatSessionAndSsoLoginLookupIden::LastActiveIp,
497 )
498 .expr_as(
499 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
500 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
501 )
502 .expr_as(
503 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
504 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
505 )
506 .expr_as(
507 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
508 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
509 )
510 .expr_as(
511 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
512 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
513 )
514 .expr_as(
515 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
516 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
517 )
518 .expr_as(
519 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
520 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
521 )
522 .from(CompatSessions::Table)
523 .left_join(
524 CompatSsoLogins::Table,
525 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
526 .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
527 )
528 .apply_filter_with_joins(filter)
529 .generate_pagination(
530 (CompatSessions::Table, CompatSessions::CompatSessionId),
531 pagination,
532 )
533 .build_sqlx(PostgresQueryBuilder);
534
535 let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
536 .traced()
537 .fetch_all(&mut *self.conn)
538 .await?;
539
540 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
541
542 Ok(page)
543 }
544
545 #[tracing::instrument(
546 name = "db.compat_session.count",
547 skip_all,
548 fields(
549 db.query.text,
550 ),
551 err,
552 )]
553 async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
554 let (sql, arguments) = sea_query::Query::select()
555 .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
556 .from(CompatSessions::Table)
557 .apply_filter(filter)
558 .build_sqlx(PostgresQueryBuilder);
559
560 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
561 .traced()
562 .fetch_one(&mut *self.conn)
563 .await?;
564
565 count
566 .try_into()
567 .map_err(DatabaseError::to_invalid_operation)
568 }
569
570 #[tracing::instrument(
571 name = "db.compat_session.record_batch_activity",
572 skip_all,
573 fields(
574 db.query.text,
575 ),
576 err,
577 )]
578 async fn record_batch_activity(
579 &mut self,
580 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
581 ) -> Result<(), Self::Error> {
582 activities.sort_unstable();
585 let mut ids = Vec::with_capacity(activities.len());
586 let mut last_activities = Vec::with_capacity(activities.len());
587 let mut ips = Vec::with_capacity(activities.len());
588
589 for (id, last_activity, ip) in activities {
590 ids.push(Uuid::from(id));
591 last_activities.push(last_activity);
592 ips.push(ip);
593 }
594
595 let res = sqlx::query!(
596 r#"
597 UPDATE compat_sessions
598 SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
599 , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
600 FROM (
601 SELECT *
602 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
603 AS t(compat_session_id, last_active_at, last_active_ip)
604 ) AS t
605 WHERE compat_sessions.compat_session_id = t.compat_session_id
606 "#,
607 &ids,
608 &last_activities,
609 &ips as &[Option<IpAddr>],
610 )
611 .traced()
612 .execute(&mut *self.conn)
613 .await?;
614
615 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
616
617 Ok(())
618 }
619
620 #[tracing::instrument(
621 name = "db.compat_session.record_user_agent",
622 skip_all,
623 fields(
624 db.query.text,
625 %compat_session.id,
626 ),
627 err,
628 )]
629 async fn record_user_agent(
630 &mut self,
631 mut compat_session: CompatSession,
632 user_agent: String,
633 ) -> Result<CompatSession, Self::Error> {
634 let res = sqlx::query!(
635 r#"
636 UPDATE compat_sessions
637 SET user_agent = $2
638 WHERE compat_session_id = $1
639 "#,
640 Uuid::from(compat_session.id),
641 &*user_agent,
642 )
643 .traced()
644 .execute(&mut *self.conn)
645 .await?;
646
647 compat_session.user_agent = Some(user_agent);
648
649 DatabaseError::ensure_affected_rows(&res, 1)?;
650
651 Ok(compat_session)
652 }
653
654 #[tracing::instrument(
655 name = "repository.compat_session.set_human_name",
656 skip(self),
657 fields(
658 compat_session.id = %compat_session.id,
659 compat_session.human_name = ?human_name,
660 ),
661 err,
662 )]
663 async fn set_human_name(
664 &mut self,
665 mut compat_session: CompatSession,
666 human_name: Option<String>,
667 ) -> Result<CompatSession, Self::Error> {
668 let res = sqlx::query!(
669 r#"
670 UPDATE compat_sessions
671 SET human_name = $2
672 WHERE compat_session_id = $1
673 "#,
674 Uuid::from(compat_session.id),
675 human_name.as_deref(),
676 )
677 .traced()
678 .execute(&mut *self.conn)
679 .await?;
680
681 compat_session.human_name = human_name;
682
683 DatabaseError::ensure_affected_rows(&res, 1)?;
684
685 Ok(compat_session)
686 }
687}