1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 Clock, UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState,
11 UpstreamOAuthLink, UpstreamOAuthProvider,
12};
13use mas_storage::{
14 Page, Pagination,
15 pagination::Node,
16 upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
17};
18use rand::RngCore;
19use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
20use sea_query_binder::SqlxBinder;
21use sqlx::PgConnection;
22use ulid::Ulid;
23use uuid::Uuid;
24
25use crate::{
26 DatabaseError, DatabaseInconsistencyError,
27 filter::{Filter, StatementExt},
28 iden::UpstreamOAuthAuthorizationSessions,
29 pagination::QueryBuilderExt,
30 tracing::ExecuteExt,
31};
32
33impl Filter for UpstreamOAuthSessionFilter<'_> {
34 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
35 sea_query::Condition::all()
36 .add_option(self.provider().map(|provider| {
37 Expr::col((
38 UpstreamOAuthAuthorizationSessions::Table,
39 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
40 ))
41 .eq(Uuid::from(provider.id))
42 }))
43 .add_option(self.sub_claim().map(|sub| {
44 Expr::col((
45 UpstreamOAuthAuthorizationSessions::Table,
46 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
47 ))
48 .cast_json_field("sub")
49 .eq(sub)
50 }))
51 .add_option(self.sid_claim().map(|sid| {
52 Expr::col((
53 UpstreamOAuthAuthorizationSessions::Table,
54 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
55 ))
56 .cast_json_field("sid")
57 .eq(sid)
58 }))
59 }
60}
61
62pub struct PgUpstreamOAuthSessionRepository<'c> {
65 conn: &'c mut PgConnection,
66}
67
68impl<'c> PgUpstreamOAuthSessionRepository<'c> {
69 pub fn new(conn: &'c mut PgConnection) -> Self {
72 Self { conn }
73 }
74}
75
76#[derive(sqlx::FromRow)]
77#[enum_def]
78struct SessionLookup {
79 upstream_oauth_authorization_session_id: Uuid,
80 upstream_oauth_provider_id: Uuid,
81 upstream_oauth_link_id: Option<Uuid>,
82 state: String,
83 code_challenge_verifier: Option<String>,
84 nonce: Option<String>,
85 id_token: Option<String>,
86 id_token_claims: Option<serde_json::Value>,
87 userinfo: Option<serde_json::Value>,
88 created_at: DateTime<Utc>,
89 completed_at: Option<DateTime<Utc>>,
90 consumed_at: Option<DateTime<Utc>>,
91 extra_callback_parameters: Option<serde_json::Value>,
92 unlinked_at: Option<DateTime<Utc>>,
93}
94
95impl Node<Ulid> for SessionLookup {
96 fn cursor(&self) -> Ulid {
97 self.upstream_oauth_authorization_session_id.into()
98 }
99}
100
101impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
102 type Error = DatabaseInconsistencyError;
103
104 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
105 let id = value.upstream_oauth_authorization_session_id.into();
106 let state = match (
107 value.upstream_oauth_link_id,
108 value.id_token,
109 value.id_token_claims,
110 value.extra_callback_parameters,
111 value.userinfo,
112 value.completed_at,
113 value.consumed_at,
114 value.unlinked_at,
115 ) {
116 (None, None, None, None, None, None, None, None) => {
117 UpstreamOAuthAuthorizationSessionState::Pending
118 }
119 (
120 Some(link_id),
121 id_token,
122 id_token_claims,
123 extra_callback_parameters,
124 userinfo,
125 Some(completed_at),
126 None,
127 None,
128 ) => UpstreamOAuthAuthorizationSessionState::Completed {
129 completed_at,
130 link_id: link_id.into(),
131 id_token,
132 id_token_claims,
133 extra_callback_parameters,
134 userinfo,
135 },
136 (
137 Some(link_id),
138 id_token,
139 id_token_claims,
140 extra_callback_parameters,
141 userinfo,
142 Some(completed_at),
143 Some(consumed_at),
144 None,
145 ) => UpstreamOAuthAuthorizationSessionState::Consumed {
146 completed_at,
147 link_id: link_id.into(),
148 id_token,
149 id_token_claims,
150 extra_callback_parameters,
151 userinfo,
152 consumed_at,
153 },
154 (
155 _,
156 id_token,
157 id_token_claims,
158 _,
159 _,
160 Some(completed_at),
161 consumed_at,
162 Some(unlinked_at),
163 ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
164 completed_at,
165 id_token,
166 id_token_claims,
167 consumed_at,
168 unlinked_at,
169 },
170 _ => {
171 return Err(DatabaseInconsistencyError::on(
172 "upstream_oauth_authorization_sessions",
173 )
174 .row(id));
175 }
176 };
177
178 Ok(Self {
179 id,
180 provider_id: value.upstream_oauth_provider_id.into(),
181 state_str: value.state,
182 nonce: value.nonce,
183 code_challenge_verifier: value.code_challenge_verifier,
184 created_at: value.created_at,
185 state,
186 })
187 }
188}
189
190#[async_trait]
191impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
192 type Error = DatabaseError;
193
194 #[tracing::instrument(
195 name = "db.upstream_oauth_authorization_session.lookup",
196 skip_all,
197 fields(
198 db.query.text,
199 upstream_oauth_provider.id = %id,
200 ),
201 err,
202 )]
203 async fn lookup(
204 &mut self,
205 id: Ulid,
206 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
207 let res = sqlx::query_as!(
208 SessionLookup,
209 r#"
210 SELECT
211 upstream_oauth_authorization_session_id,
212 upstream_oauth_provider_id,
213 upstream_oauth_link_id,
214 state,
215 code_challenge_verifier,
216 nonce,
217 id_token,
218 id_token_claims,
219 extra_callback_parameters,
220 userinfo,
221 created_at,
222 completed_at,
223 consumed_at,
224 unlinked_at
225 FROM upstream_oauth_authorization_sessions
226 WHERE upstream_oauth_authorization_session_id = $1
227 "#,
228 Uuid::from(id),
229 )
230 .traced()
231 .fetch_optional(&mut *self.conn)
232 .await?;
233
234 let Some(res) = res else { return Ok(None) };
235
236 Ok(Some(res.try_into()?))
237 }
238
239 #[tracing::instrument(
240 name = "db.upstream_oauth_authorization_session.add",
241 skip_all,
242 fields(
243 db.query.text,
244 %upstream_oauth_provider.id,
245 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
246 %upstream_oauth_provider.client_id,
247 upstream_oauth_authorization_session.id,
248 ),
249 err,
250 )]
251 async fn add(
252 &mut self,
253 rng: &mut (dyn RngCore + Send),
254 clock: &dyn Clock,
255 upstream_oauth_provider: &UpstreamOAuthProvider,
256 state_str: String,
257 code_challenge_verifier: Option<String>,
258 nonce: Option<String>,
259 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
260 let created_at = clock.now();
261 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
262 tracing::Span::current().record(
263 "upstream_oauth_authorization_session.id",
264 tracing::field::display(id),
265 );
266
267 sqlx::query!(
268 r#"
269 INSERT INTO upstream_oauth_authorization_sessions (
270 upstream_oauth_authorization_session_id,
271 upstream_oauth_provider_id,
272 state,
273 code_challenge_verifier,
274 nonce,
275 created_at,
276 completed_at,
277 consumed_at,
278 id_token,
279 userinfo
280 ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
281 "#,
282 Uuid::from(id),
283 Uuid::from(upstream_oauth_provider.id),
284 &state_str,
285 code_challenge_verifier.as_deref(),
286 nonce,
287 created_at,
288 )
289 .traced()
290 .execute(&mut *self.conn)
291 .await?;
292
293 Ok(UpstreamOAuthAuthorizationSession {
294 id,
295 state: UpstreamOAuthAuthorizationSessionState::default(),
296 provider_id: upstream_oauth_provider.id,
297 state_str,
298 code_challenge_verifier,
299 nonce,
300 created_at,
301 })
302 }
303
304 #[tracing::instrument(
305 name = "db.upstream_oauth_authorization_session.complete_with_link",
306 skip_all,
307 fields(
308 db.query.text,
309 %upstream_oauth_authorization_session.id,
310 %upstream_oauth_link.id,
311 ),
312 err,
313 )]
314 async fn complete_with_link(
315 &mut self,
316 clock: &dyn Clock,
317 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
318 upstream_oauth_link: &UpstreamOAuthLink,
319 id_token: Option<String>,
320 id_token_claims: Option<serde_json::Value>,
321 extra_callback_parameters: Option<serde_json::Value>,
322 userinfo: Option<serde_json::Value>,
323 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
324 let completed_at = clock.now();
325
326 sqlx::query!(
327 r#"
328 UPDATE upstream_oauth_authorization_sessions
329 SET upstream_oauth_link_id = $1
330 , completed_at = $2
331 , id_token = $3
332 , id_token_claims = $4
333 , extra_callback_parameters = $5
334 , userinfo = $6
335 WHERE upstream_oauth_authorization_session_id = $7
336 "#,
337 Uuid::from(upstream_oauth_link.id),
338 completed_at,
339 id_token,
340 id_token_claims,
341 extra_callback_parameters,
342 userinfo,
343 Uuid::from(upstream_oauth_authorization_session.id),
344 )
345 .traced()
346 .execute(&mut *self.conn)
347 .await?;
348
349 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
350 .complete(
351 completed_at,
352 upstream_oauth_link,
353 id_token,
354 id_token_claims,
355 extra_callback_parameters,
356 userinfo,
357 )
358 .map_err(DatabaseError::to_invalid_operation)?;
359
360 Ok(upstream_oauth_authorization_session)
361 }
362
363 #[tracing::instrument(
365 name = "db.upstream_oauth_authorization_session.consume",
366 skip_all,
367 fields(
368 db.query.text,
369 %upstream_oauth_authorization_session.id,
370 ),
371 err,
372 )]
373 async fn consume(
374 &mut self,
375 clock: &dyn Clock,
376 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
377 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
378 let consumed_at = clock.now();
379 sqlx::query!(
380 r#"
381 UPDATE upstream_oauth_authorization_sessions
382 SET consumed_at = $1
383 WHERE upstream_oauth_authorization_session_id = $2
384 "#,
385 consumed_at,
386 Uuid::from(upstream_oauth_authorization_session.id),
387 )
388 .traced()
389 .execute(&mut *self.conn)
390 .await?;
391
392 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
393 .consume(consumed_at)
394 .map_err(DatabaseError::to_invalid_operation)?;
395
396 Ok(upstream_oauth_authorization_session)
397 }
398
399 #[tracing::instrument(
400 name = "db.upstream_oauth_authorization_session.list",
401 skip_all,
402 fields(
403 db.query.text,
404 ),
405 err,
406 )]
407 async fn list(
408 &mut self,
409 filter: UpstreamOAuthSessionFilter<'_>,
410 pagination: Pagination,
411 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
412 let (sql, arguments) = Query::select()
413 .expr_as(
414 Expr::col((
415 UpstreamOAuthAuthorizationSessions::Table,
416 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
417 )),
418 SessionLookupIden::UpstreamOauthAuthorizationSessionId,
419 )
420 .expr_as(
421 Expr::col((
422 UpstreamOAuthAuthorizationSessions::Table,
423 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
424 )),
425 SessionLookupIden::UpstreamOauthProviderId,
426 )
427 .expr_as(
428 Expr::col((
429 UpstreamOAuthAuthorizationSessions::Table,
430 UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
431 )),
432 SessionLookupIden::UpstreamOauthLinkId,
433 )
434 .expr_as(
435 Expr::col((
436 UpstreamOAuthAuthorizationSessions::Table,
437 UpstreamOAuthAuthorizationSessions::State,
438 )),
439 SessionLookupIden::State,
440 )
441 .expr_as(
442 Expr::col((
443 UpstreamOAuthAuthorizationSessions::Table,
444 UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
445 )),
446 SessionLookupIden::CodeChallengeVerifier,
447 )
448 .expr_as(
449 Expr::col((
450 UpstreamOAuthAuthorizationSessions::Table,
451 UpstreamOAuthAuthorizationSessions::Nonce,
452 )),
453 SessionLookupIden::Nonce,
454 )
455 .expr_as(
456 Expr::col((
457 UpstreamOAuthAuthorizationSessions::Table,
458 UpstreamOAuthAuthorizationSessions::IdToken,
459 )),
460 SessionLookupIden::IdToken,
461 )
462 .expr_as(
463 Expr::col((
464 UpstreamOAuthAuthorizationSessions::Table,
465 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
466 )),
467 SessionLookupIden::IdTokenClaims,
468 )
469 .expr_as(
470 Expr::col((
471 UpstreamOAuthAuthorizationSessions::Table,
472 UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
473 )),
474 SessionLookupIden::ExtraCallbackParameters,
475 )
476 .expr_as(
477 Expr::col((
478 UpstreamOAuthAuthorizationSessions::Table,
479 UpstreamOAuthAuthorizationSessions::Userinfo,
480 )),
481 SessionLookupIden::Userinfo,
482 )
483 .expr_as(
484 Expr::col((
485 UpstreamOAuthAuthorizationSessions::Table,
486 UpstreamOAuthAuthorizationSessions::CreatedAt,
487 )),
488 SessionLookupIden::CreatedAt,
489 )
490 .expr_as(
491 Expr::col((
492 UpstreamOAuthAuthorizationSessions::Table,
493 UpstreamOAuthAuthorizationSessions::CompletedAt,
494 )),
495 SessionLookupIden::CompletedAt,
496 )
497 .expr_as(
498 Expr::col((
499 UpstreamOAuthAuthorizationSessions::Table,
500 UpstreamOAuthAuthorizationSessions::ConsumedAt,
501 )),
502 SessionLookupIden::ConsumedAt,
503 )
504 .expr_as(
505 Expr::col((
506 UpstreamOAuthAuthorizationSessions::Table,
507 UpstreamOAuthAuthorizationSessions::UnlinkedAt,
508 )),
509 SessionLookupIden::UnlinkedAt,
510 )
511 .from(UpstreamOAuthAuthorizationSessions::Table)
512 .apply_filter(filter)
513 .generate_pagination(
514 (
515 UpstreamOAuthAuthorizationSessions::Table,
516 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
517 ),
518 pagination,
519 )
520 .build_sqlx(PostgresQueryBuilder);
521
522 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
523 .traced()
524 .fetch_all(&mut *self.conn)
525 .await?;
526
527 let page = pagination
528 .process(edges)
529 .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
530
531 Ok(page)
532 }
533
534 #[tracing::instrument(
535 name = "db.upstream_oauth_authorization_session.count",
536 skip_all,
537 fields(
538 db.query.text,
539 ),
540 err,
541 )]
542 async fn count(
543 &mut self,
544 filter: UpstreamOAuthSessionFilter<'_>,
545 ) -> Result<usize, Self::Error> {
546 let (sql, arguments) = Query::select()
547 .expr(
548 Expr::col((
549 UpstreamOAuthAuthorizationSessions::Table,
550 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
551 ))
552 .count(),
553 )
554 .from(UpstreamOAuthAuthorizationSessions::Table)
555 .apply_filter(filter)
556 .build_sqlx(PostgresQueryBuilder);
557
558 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
559 .traced()
560 .fetch_one(&mut *self.conn)
561 .await?;
562
563 count
564 .try_into()
565 .map_err(DatabaseError::to_invalid_operation)
566 }
567}