1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{Clock, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports};
10use mas_storage::{
11 Page, Pagination,
12 pagination::Node,
13 upstream_oauth2::{
14 UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
15 },
16};
17use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
18use rand::RngCore;
19use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
20use sea_query_binder::SqlxBinder;
21use sqlx::{PgConnection, types::Json};
22use tracing::{Instrument, info_span};
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27 DatabaseError, DatabaseInconsistencyError,
28 filter::{Filter, StatementExt},
29 iden::UpstreamOAuthProviders,
30 pagination::QueryBuilderExt,
31 tracing::ExecuteExt,
32};
33
34pub struct PgUpstreamOAuthProviderRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgUpstreamOAuthProviderRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48#[derive(sqlx::FromRow)]
49#[enum_def]
50struct ProviderLookup {
51 upstream_oauth_provider_id: Uuid,
52 issuer: Option<String>,
53 human_name: Option<String>,
54 brand_name: Option<String>,
55 scope: String,
56 client_id: String,
57 encrypted_client_secret: Option<String>,
58 token_endpoint_signing_alg: Option<String>,
59 token_endpoint_auth_method: String,
60 id_token_signed_response_alg: String,
61 fetch_userinfo: bool,
62 userinfo_signed_response_alg: Option<String>,
63 created_at: DateTime<Utc>,
64 disabled_at: Option<DateTime<Utc>>,
65 claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
66 jwks_uri_override: Option<String>,
67 authorization_endpoint_override: Option<String>,
68 token_endpoint_override: Option<String>,
69 userinfo_endpoint_override: Option<String>,
70 discovery_mode: String,
71 pkce_mode: String,
72 response_mode: Option<String>,
73 additional_parameters: Option<Json<Vec<(String, String)>>>,
74 forward_login_hint: bool,
75 on_backchannel_logout: String,
76}
77
78impl Node<Ulid> for ProviderLookup {
79 fn cursor(&self) -> Ulid {
80 self.upstream_oauth_provider_id.into()
81 }
82}
83
84impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
85 type Error = DatabaseInconsistencyError;
86
87 fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
88 let id = value.upstream_oauth_provider_id.into();
89 let scope = value.scope.parse().map_err(|e| {
90 DatabaseInconsistencyError::on("upstream_oauth_providers")
91 .column("scope")
92 .row(id)
93 .source(e)
94 })?;
95 let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
96 DatabaseInconsistencyError::on("upstream_oauth_providers")
97 .column("token_endpoint_auth_method")
98 .row(id)
99 .source(e)
100 })?;
101 let token_endpoint_signing_alg = value
102 .token_endpoint_signing_alg
103 .map(|x| x.parse())
104 .transpose()
105 .map_err(|e| {
106 DatabaseInconsistencyError::on("upstream_oauth_providers")
107 .column("token_endpoint_signing_alg")
108 .row(id)
109 .source(e)
110 })?;
111 let id_token_signed_response_alg =
112 value.id_token_signed_response_alg.parse().map_err(|e| {
113 DatabaseInconsistencyError::on("upstream_oauth_providers")
114 .column("id_token_signed_response_alg")
115 .row(id)
116 .source(e)
117 })?;
118
119 let userinfo_signed_response_alg = value
120 .userinfo_signed_response_alg
121 .map(|x| x.parse())
122 .transpose()
123 .map_err(|e| {
124 DatabaseInconsistencyError::on("upstream_oauth_providers")
125 .column("userinfo_signed_response_alg")
126 .row(id)
127 .source(e)
128 })?;
129
130 let authorization_endpoint_override = value
131 .authorization_endpoint_override
132 .map(|x| x.parse())
133 .transpose()
134 .map_err(|e| {
135 DatabaseInconsistencyError::on("upstream_oauth_providers")
136 .column("authorization_endpoint_override")
137 .row(id)
138 .source(e)
139 })?;
140
141 let token_endpoint_override = value
142 .token_endpoint_override
143 .map(|x| x.parse())
144 .transpose()
145 .map_err(|e| {
146 DatabaseInconsistencyError::on("upstream_oauth_providers")
147 .column("token_endpoint_override")
148 .row(id)
149 .source(e)
150 })?;
151
152 let userinfo_endpoint_override = value
153 .userinfo_endpoint_override
154 .map(|x| x.parse())
155 .transpose()
156 .map_err(|e| {
157 DatabaseInconsistencyError::on("upstream_oauth_providers")
158 .column("userinfo_endpoint_override")
159 .row(id)
160 .source(e)
161 })?;
162
163 let jwks_uri_override = value
164 .jwks_uri_override
165 .map(|x| x.parse())
166 .transpose()
167 .map_err(|e| {
168 DatabaseInconsistencyError::on("upstream_oauth_providers")
169 .column("jwks_uri_override")
170 .row(id)
171 .source(e)
172 })?;
173
174 let discovery_mode = value.discovery_mode.parse().map_err(|e| {
175 DatabaseInconsistencyError::on("upstream_oauth_providers")
176 .column("discovery_mode")
177 .row(id)
178 .source(e)
179 })?;
180
181 let pkce_mode = value.pkce_mode.parse().map_err(|e| {
182 DatabaseInconsistencyError::on("upstream_oauth_providers")
183 .column("pkce_mode")
184 .row(id)
185 .source(e)
186 })?;
187
188 let response_mode = value
189 .response_mode
190 .map(|x| x.parse())
191 .transpose()
192 .map_err(|e| {
193 DatabaseInconsistencyError::on("upstream_oauth_providers")
194 .column("response_mode")
195 .row(id)
196 .source(e)
197 })?;
198
199 let additional_authorization_parameters = value
200 .additional_parameters
201 .map(|Json(x)| x)
202 .unwrap_or_default();
203
204 let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
205 DatabaseInconsistencyError::on("upstream_oauth_providers")
206 .column("on_backchannel_logout")
207 .row(id)
208 .source(e)
209 })?;
210
211 Ok(UpstreamOAuthProvider {
212 id,
213 issuer: value.issuer,
214 human_name: value.human_name,
215 brand_name: value.brand_name,
216 scope,
217 client_id: value.client_id,
218 encrypted_client_secret: value.encrypted_client_secret,
219 token_endpoint_auth_method,
220 token_endpoint_signing_alg,
221 id_token_signed_response_alg,
222 fetch_userinfo: value.fetch_userinfo,
223 userinfo_signed_response_alg,
224 created_at: value.created_at,
225 disabled_at: value.disabled_at,
226 claims_imports: value.claims_imports.0,
227 authorization_endpoint_override,
228 token_endpoint_override,
229 userinfo_endpoint_override,
230 jwks_uri_override,
231 discovery_mode,
232 pkce_mode,
233 response_mode,
234 additional_authorization_parameters,
235 forward_login_hint: value.forward_login_hint,
236 on_backchannel_logout,
237 })
238 }
239}
240
241impl Filter for UpstreamOAuthProviderFilter<'_> {
242 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
243 sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
244 Expr::col((
245 UpstreamOAuthProviders::Table,
246 UpstreamOAuthProviders::DisabledAt,
247 ))
248 .is_null()
249 .eq(enabled)
250 }))
251 }
252}
253
254#[async_trait]
255impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
256 type Error = DatabaseError;
257
258 #[tracing::instrument(
259 name = "db.upstream_oauth_provider.lookup",
260 skip_all,
261 fields(
262 db.query.text,
263 upstream_oauth_provider.id = %id,
264 ),
265 err,
266 )]
267 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
268 let res = sqlx::query_as!(
269 ProviderLookup,
270 r#"
271 SELECT
272 upstream_oauth_provider_id,
273 issuer,
274 human_name,
275 brand_name,
276 scope,
277 client_id,
278 encrypted_client_secret,
279 token_endpoint_signing_alg,
280 token_endpoint_auth_method,
281 id_token_signed_response_alg,
282 fetch_userinfo,
283 userinfo_signed_response_alg,
284 created_at,
285 disabled_at,
286 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
287 jwks_uri_override,
288 authorization_endpoint_override,
289 token_endpoint_override,
290 userinfo_endpoint_override,
291 discovery_mode,
292 pkce_mode,
293 response_mode,
294 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
295 forward_login_hint,
296 on_backchannel_logout
297 FROM upstream_oauth_providers
298 WHERE upstream_oauth_provider_id = $1
299 "#,
300 Uuid::from(id),
301 )
302 .traced()
303 .fetch_optional(&mut *self.conn)
304 .await?;
305
306 let res = res
307 .map(UpstreamOAuthProvider::try_from)
308 .transpose()
309 .map_err(DatabaseError::from)?;
310
311 Ok(res)
312 }
313
314 #[tracing::instrument(
315 name = "db.upstream_oauth_provider.add",
316 skip_all,
317 fields(
318 db.query.text,
319 upstream_oauth_provider.id,
320 upstream_oauth_provider.issuer = params.issuer,
321 upstream_oauth_provider.client_id = %params.client_id,
322 ),
323 err,
324 )]
325 async fn add(
326 &mut self,
327 rng: &mut (dyn RngCore + Send),
328 clock: &dyn Clock,
329 params: UpstreamOAuthProviderParams,
330 ) -> Result<UpstreamOAuthProvider, Self::Error> {
331 let created_at = clock.now();
332 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
333 tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
334
335 sqlx::query!(
336 r#"
337 INSERT INTO upstream_oauth_providers (
338 upstream_oauth_provider_id,
339 issuer,
340 human_name,
341 brand_name,
342 scope,
343 token_endpoint_auth_method,
344 token_endpoint_signing_alg,
345 id_token_signed_response_alg,
346 fetch_userinfo,
347 userinfo_signed_response_alg,
348 client_id,
349 encrypted_client_secret,
350 claims_imports,
351 authorization_endpoint_override,
352 token_endpoint_override,
353 userinfo_endpoint_override,
354 jwks_uri_override,
355 discovery_mode,
356 pkce_mode,
357 response_mode,
358 forward_login_hint,
359 on_backchannel_logout,
360 created_at
361 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
362 $12, $13, $14, $15, $16, $17, $18, $19, $20,
363 $21, $22, $23)
364 "#,
365 Uuid::from(id),
366 params.issuer.as_deref(),
367 params.human_name.as_deref(),
368 params.brand_name.as_deref(),
369 params.scope.to_string(),
370 params.token_endpoint_auth_method.to_string(),
371 params
372 .token_endpoint_signing_alg
373 .as_ref()
374 .map(ToString::to_string),
375 params.id_token_signed_response_alg.to_string(),
376 params.fetch_userinfo,
377 params
378 .userinfo_signed_response_alg
379 .as_ref()
380 .map(ToString::to_string),
381 ¶ms.client_id,
382 params.encrypted_client_secret.as_deref(),
383 Json(¶ms.claims_imports) as _,
384 params
385 .authorization_endpoint_override
386 .as_ref()
387 .map(ToString::to_string),
388 params
389 .token_endpoint_override
390 .as_ref()
391 .map(ToString::to_string),
392 params
393 .userinfo_endpoint_override
394 .as_ref()
395 .map(ToString::to_string),
396 params.jwks_uri_override.as_ref().map(ToString::to_string),
397 params.discovery_mode.as_str(),
398 params.pkce_mode.as_str(),
399 params.response_mode.as_ref().map(ToString::to_string),
400 params.forward_login_hint,
401 params.on_backchannel_logout.as_str(),
402 created_at,
403 )
404 .traced()
405 .execute(&mut *self.conn)
406 .await?;
407
408 Ok(UpstreamOAuthProvider {
409 id,
410 issuer: params.issuer,
411 human_name: params.human_name,
412 brand_name: params.brand_name,
413 scope: params.scope,
414 client_id: params.client_id,
415 encrypted_client_secret: params.encrypted_client_secret,
416 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
417 token_endpoint_auth_method: params.token_endpoint_auth_method,
418 id_token_signed_response_alg: params.id_token_signed_response_alg,
419 fetch_userinfo: params.fetch_userinfo,
420 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
421 created_at,
422 disabled_at: None,
423 claims_imports: params.claims_imports,
424 authorization_endpoint_override: params.authorization_endpoint_override,
425 token_endpoint_override: params.token_endpoint_override,
426 userinfo_endpoint_override: params.userinfo_endpoint_override,
427 jwks_uri_override: params.jwks_uri_override,
428 discovery_mode: params.discovery_mode,
429 pkce_mode: params.pkce_mode,
430 response_mode: params.response_mode,
431 additional_authorization_parameters: params.additional_authorization_parameters,
432 on_backchannel_logout: params.on_backchannel_logout,
433 forward_login_hint: params.forward_login_hint,
434 })
435 }
436
437 #[tracing::instrument(
438 name = "db.upstream_oauth_provider.delete_by_id",
439 skip_all,
440 fields(
441 db.query.text,
442 upstream_oauth_provider.id = %id,
443 ),
444 err,
445 )]
446 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
447 {
450 let span = info_span!(
451 "db.oauth2_client.delete_by_id.authorization_sessions",
452 upstream_oauth_provider.id = %id,
453 { DB_QUERY_TEXT } = tracing::field::Empty,
454 );
455 sqlx::query!(
456 r#"
457 DELETE FROM upstream_oauth_authorization_sessions
458 WHERE upstream_oauth_provider_id = $1
459 "#,
460 Uuid::from(id),
461 )
462 .record(&span)
463 .execute(&mut *self.conn)
464 .instrument(span)
465 .await?;
466 }
467
468 {
471 let span = info_span!(
472 "db.oauth2_client.delete_by_id.links",
473 upstream_oauth_provider.id = %id,
474 { DB_QUERY_TEXT } = tracing::field::Empty,
475 );
476 sqlx::query!(
477 r#"
478 DELETE FROM upstream_oauth_links
479 WHERE upstream_oauth_provider_id = $1
480 "#,
481 Uuid::from(id),
482 )
483 .record(&span)
484 .execute(&mut *self.conn)
485 .instrument(span)
486 .await?;
487 }
488
489 let res = sqlx::query!(
490 r#"
491 DELETE FROM upstream_oauth_providers
492 WHERE upstream_oauth_provider_id = $1
493 "#,
494 Uuid::from(id),
495 )
496 .traced()
497 .execute(&mut *self.conn)
498 .await?;
499
500 DatabaseError::ensure_affected_rows(&res, 1)
501 }
502
503 #[tracing::instrument(
504 name = "db.upstream_oauth_provider.add",
505 skip_all,
506 fields(
507 db.query.text,
508 upstream_oauth_provider.id = %id,
509 upstream_oauth_provider.issuer = params.issuer,
510 upstream_oauth_provider.client_id = %params.client_id,
511 ),
512 err,
513 )]
514 async fn upsert(
515 &mut self,
516 clock: &dyn Clock,
517 id: Ulid,
518 params: UpstreamOAuthProviderParams,
519 ) -> Result<UpstreamOAuthProvider, Self::Error> {
520 let created_at = clock.now();
521
522 let created_at = sqlx::query_scalar!(
523 r#"
524 INSERT INTO upstream_oauth_providers (
525 upstream_oauth_provider_id,
526 issuer,
527 human_name,
528 brand_name,
529 scope,
530 token_endpoint_auth_method,
531 token_endpoint_signing_alg,
532 id_token_signed_response_alg,
533 fetch_userinfo,
534 userinfo_signed_response_alg,
535 client_id,
536 encrypted_client_secret,
537 claims_imports,
538 authorization_endpoint_override,
539 token_endpoint_override,
540 userinfo_endpoint_override,
541 jwks_uri_override,
542 discovery_mode,
543 pkce_mode,
544 response_mode,
545 additional_parameters,
546 forward_login_hint,
547 ui_order,
548 on_backchannel_logout,
549 created_at
550 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
551 $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
552 $21, $22, $23, $24, $25)
553 ON CONFLICT (upstream_oauth_provider_id)
554 DO UPDATE
555 SET
556 issuer = EXCLUDED.issuer,
557 human_name = EXCLUDED.human_name,
558 brand_name = EXCLUDED.brand_name,
559 scope = EXCLUDED.scope,
560 token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
561 token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
562 id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
563 fetch_userinfo = EXCLUDED.fetch_userinfo,
564 userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
565 disabled_at = NULL,
566 client_id = EXCLUDED.client_id,
567 encrypted_client_secret = EXCLUDED.encrypted_client_secret,
568 claims_imports = EXCLUDED.claims_imports,
569 authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
570 token_endpoint_override = EXCLUDED.token_endpoint_override,
571 userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
572 jwks_uri_override = EXCLUDED.jwks_uri_override,
573 discovery_mode = EXCLUDED.discovery_mode,
574 pkce_mode = EXCLUDED.pkce_mode,
575 response_mode = EXCLUDED.response_mode,
576 additional_parameters = EXCLUDED.additional_parameters,
577 forward_login_hint = EXCLUDED.forward_login_hint,
578 ui_order = EXCLUDED.ui_order,
579 on_backchannel_logout = EXCLUDED.on_backchannel_logout
580 RETURNING created_at
581 "#,
582 Uuid::from(id),
583 params.issuer.as_deref(),
584 params.human_name.as_deref(),
585 params.brand_name.as_deref(),
586 params.scope.to_string(),
587 params.token_endpoint_auth_method.to_string(),
588 params
589 .token_endpoint_signing_alg
590 .as_ref()
591 .map(ToString::to_string),
592 params.id_token_signed_response_alg.to_string(),
593 params.fetch_userinfo,
594 params
595 .userinfo_signed_response_alg
596 .as_ref()
597 .map(ToString::to_string),
598 ¶ms.client_id,
599 params.encrypted_client_secret.as_deref(),
600 Json(¶ms.claims_imports) as _,
601 params
602 .authorization_endpoint_override
603 .as_ref()
604 .map(ToString::to_string),
605 params
606 .token_endpoint_override
607 .as_ref()
608 .map(ToString::to_string),
609 params
610 .userinfo_endpoint_override
611 .as_ref()
612 .map(ToString::to_string),
613 params.jwks_uri_override.as_ref().map(ToString::to_string),
614 params.discovery_mode.as_str(),
615 params.pkce_mode.as_str(),
616 params.response_mode.as_ref().map(ToString::to_string),
617 Json(¶ms.additional_authorization_parameters) as _,
618 params.forward_login_hint,
619 params.ui_order,
620 params.on_backchannel_logout.as_str(),
621 created_at,
622 )
623 .traced()
624 .fetch_one(&mut *self.conn)
625 .await?;
626
627 Ok(UpstreamOAuthProvider {
628 id,
629 issuer: params.issuer,
630 human_name: params.human_name,
631 brand_name: params.brand_name,
632 scope: params.scope,
633 client_id: params.client_id,
634 encrypted_client_secret: params.encrypted_client_secret,
635 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
636 token_endpoint_auth_method: params.token_endpoint_auth_method,
637 id_token_signed_response_alg: params.id_token_signed_response_alg,
638 fetch_userinfo: params.fetch_userinfo,
639 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
640 created_at,
641 disabled_at: None,
642 claims_imports: params.claims_imports,
643 authorization_endpoint_override: params.authorization_endpoint_override,
644 token_endpoint_override: params.token_endpoint_override,
645 userinfo_endpoint_override: params.userinfo_endpoint_override,
646 jwks_uri_override: params.jwks_uri_override,
647 discovery_mode: params.discovery_mode,
648 pkce_mode: params.pkce_mode,
649 response_mode: params.response_mode,
650 additional_authorization_parameters: params.additional_authorization_parameters,
651 forward_login_hint: params.forward_login_hint,
652 on_backchannel_logout: params.on_backchannel_logout,
653 })
654 }
655
656 #[tracing::instrument(
657 name = "db.upstream_oauth_provider.disable",
658 skip_all,
659 fields(
660 db.query.text,
661 %upstream_oauth_provider.id,
662 ),
663 err,
664 )]
665 async fn disable(
666 &mut self,
667 clock: &dyn Clock,
668 mut upstream_oauth_provider: UpstreamOAuthProvider,
669 ) -> Result<UpstreamOAuthProvider, Self::Error> {
670 let disabled_at = clock.now();
671 let res = sqlx::query!(
672 r#"
673 UPDATE upstream_oauth_providers
674 SET disabled_at = $2
675 WHERE upstream_oauth_provider_id = $1
676 "#,
677 Uuid::from(upstream_oauth_provider.id),
678 disabled_at,
679 )
680 .traced()
681 .execute(&mut *self.conn)
682 .await?;
683
684 DatabaseError::ensure_affected_rows(&res, 1)?;
685
686 upstream_oauth_provider.disabled_at = Some(disabled_at);
687
688 Ok(upstream_oauth_provider)
689 }
690
691 #[tracing::instrument(
692 name = "db.upstream_oauth_provider.list",
693 skip_all,
694 fields(
695 db.query.text,
696 ),
697 err,
698 )]
699 async fn list(
700 &mut self,
701 filter: UpstreamOAuthProviderFilter<'_>,
702 pagination: Pagination,
703 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
704 let (sql, arguments) = Query::select()
705 .expr_as(
706 Expr::col((
707 UpstreamOAuthProviders::Table,
708 UpstreamOAuthProviders::UpstreamOAuthProviderId,
709 )),
710 ProviderLookupIden::UpstreamOauthProviderId,
711 )
712 .expr_as(
713 Expr::col((
714 UpstreamOAuthProviders::Table,
715 UpstreamOAuthProviders::Issuer,
716 )),
717 ProviderLookupIden::Issuer,
718 )
719 .expr_as(
720 Expr::col((
721 UpstreamOAuthProviders::Table,
722 UpstreamOAuthProviders::HumanName,
723 )),
724 ProviderLookupIden::HumanName,
725 )
726 .expr_as(
727 Expr::col((
728 UpstreamOAuthProviders::Table,
729 UpstreamOAuthProviders::BrandName,
730 )),
731 ProviderLookupIden::BrandName,
732 )
733 .expr_as(
734 Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
735 ProviderLookupIden::Scope,
736 )
737 .expr_as(
738 Expr::col((
739 UpstreamOAuthProviders::Table,
740 UpstreamOAuthProviders::ClientId,
741 )),
742 ProviderLookupIden::ClientId,
743 )
744 .expr_as(
745 Expr::col((
746 UpstreamOAuthProviders::Table,
747 UpstreamOAuthProviders::EncryptedClientSecret,
748 )),
749 ProviderLookupIden::EncryptedClientSecret,
750 )
751 .expr_as(
752 Expr::col((
753 UpstreamOAuthProviders::Table,
754 UpstreamOAuthProviders::TokenEndpointSigningAlg,
755 )),
756 ProviderLookupIden::TokenEndpointSigningAlg,
757 )
758 .expr_as(
759 Expr::col((
760 UpstreamOAuthProviders::Table,
761 UpstreamOAuthProviders::TokenEndpointAuthMethod,
762 )),
763 ProviderLookupIden::TokenEndpointAuthMethod,
764 )
765 .expr_as(
766 Expr::col((
767 UpstreamOAuthProviders::Table,
768 UpstreamOAuthProviders::IdTokenSignedResponseAlg,
769 )),
770 ProviderLookupIden::IdTokenSignedResponseAlg,
771 )
772 .expr_as(
773 Expr::col((
774 UpstreamOAuthProviders::Table,
775 UpstreamOAuthProviders::FetchUserinfo,
776 )),
777 ProviderLookupIden::FetchUserinfo,
778 )
779 .expr_as(
780 Expr::col((
781 UpstreamOAuthProviders::Table,
782 UpstreamOAuthProviders::UserinfoSignedResponseAlg,
783 )),
784 ProviderLookupIden::UserinfoSignedResponseAlg,
785 )
786 .expr_as(
787 Expr::col((
788 UpstreamOAuthProviders::Table,
789 UpstreamOAuthProviders::CreatedAt,
790 )),
791 ProviderLookupIden::CreatedAt,
792 )
793 .expr_as(
794 Expr::col((
795 UpstreamOAuthProviders::Table,
796 UpstreamOAuthProviders::DisabledAt,
797 )),
798 ProviderLookupIden::DisabledAt,
799 )
800 .expr_as(
801 Expr::col((
802 UpstreamOAuthProviders::Table,
803 UpstreamOAuthProviders::ClaimsImports,
804 )),
805 ProviderLookupIden::ClaimsImports,
806 )
807 .expr_as(
808 Expr::col((
809 UpstreamOAuthProviders::Table,
810 UpstreamOAuthProviders::JwksUriOverride,
811 )),
812 ProviderLookupIden::JwksUriOverride,
813 )
814 .expr_as(
815 Expr::col((
816 UpstreamOAuthProviders::Table,
817 UpstreamOAuthProviders::TokenEndpointOverride,
818 )),
819 ProviderLookupIden::TokenEndpointOverride,
820 )
821 .expr_as(
822 Expr::col((
823 UpstreamOAuthProviders::Table,
824 UpstreamOAuthProviders::AuthorizationEndpointOverride,
825 )),
826 ProviderLookupIden::AuthorizationEndpointOverride,
827 )
828 .expr_as(
829 Expr::col((
830 UpstreamOAuthProviders::Table,
831 UpstreamOAuthProviders::UserinfoEndpointOverride,
832 )),
833 ProviderLookupIden::UserinfoEndpointOverride,
834 )
835 .expr_as(
836 Expr::col((
837 UpstreamOAuthProviders::Table,
838 UpstreamOAuthProviders::DiscoveryMode,
839 )),
840 ProviderLookupIden::DiscoveryMode,
841 )
842 .expr_as(
843 Expr::col((
844 UpstreamOAuthProviders::Table,
845 UpstreamOAuthProviders::PkceMode,
846 )),
847 ProviderLookupIden::PkceMode,
848 )
849 .expr_as(
850 Expr::col((
851 UpstreamOAuthProviders::Table,
852 UpstreamOAuthProviders::ResponseMode,
853 )),
854 ProviderLookupIden::ResponseMode,
855 )
856 .expr_as(
857 Expr::col((
858 UpstreamOAuthProviders::Table,
859 UpstreamOAuthProviders::AdditionalParameters,
860 )),
861 ProviderLookupIden::AdditionalParameters,
862 )
863 .expr_as(
864 Expr::col((
865 UpstreamOAuthProviders::Table,
866 UpstreamOAuthProviders::ForwardLoginHint,
867 )),
868 ProviderLookupIden::ForwardLoginHint,
869 )
870 .expr_as(
871 Expr::col((
872 UpstreamOAuthProviders::Table,
873 UpstreamOAuthProviders::OnBackchannelLogout,
874 )),
875 ProviderLookupIden::OnBackchannelLogout,
876 )
877 .from(UpstreamOAuthProviders::Table)
878 .apply_filter(filter)
879 .generate_pagination(
880 (
881 UpstreamOAuthProviders::Table,
882 UpstreamOAuthProviders::UpstreamOAuthProviderId,
883 ),
884 pagination,
885 )
886 .build_sqlx(PostgresQueryBuilder);
887
888 let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
889 .traced()
890 .fetch_all(&mut *self.conn)
891 .await?;
892
893 let page = pagination
894 .process(edges)
895 .try_map(UpstreamOAuthProvider::try_from)?;
896
897 return Ok(page);
898 }
899
900 #[tracing::instrument(
901 name = "db.upstream_oauth_provider.count",
902 skip_all,
903 fields(
904 db.query.text,
905 ),
906 err,
907 )]
908 async fn count(
909 &mut self,
910 filter: UpstreamOAuthProviderFilter<'_>,
911 ) -> Result<usize, Self::Error> {
912 let (sql, arguments) = Query::select()
913 .expr(
914 Expr::col((
915 UpstreamOAuthProviders::Table,
916 UpstreamOAuthProviders::UpstreamOAuthProviderId,
917 ))
918 .count(),
919 )
920 .from(UpstreamOAuthProviders::Table)
921 .apply_filter(filter)
922 .build_sqlx(PostgresQueryBuilder);
923
924 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
925 .traced()
926 .fetch_one(&mut *self.conn)
927 .await?;
928
929 count
930 .try_into()
931 .map_err(DatabaseError::to_invalid_operation)
932 }
933
934 #[tracing::instrument(
935 name = "db.upstream_oauth_provider.all_enabled",
936 skip_all,
937 fields(
938 db.query.text,
939 ),
940 err,
941 )]
942 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
943 let res = sqlx::query_as!(
944 ProviderLookup,
945 r#"
946 SELECT
947 upstream_oauth_provider_id,
948 issuer,
949 human_name,
950 brand_name,
951 scope,
952 client_id,
953 encrypted_client_secret,
954 token_endpoint_signing_alg,
955 token_endpoint_auth_method,
956 id_token_signed_response_alg,
957 fetch_userinfo,
958 userinfo_signed_response_alg,
959 created_at,
960 disabled_at,
961 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
962 jwks_uri_override,
963 authorization_endpoint_override,
964 token_endpoint_override,
965 userinfo_endpoint_override,
966 discovery_mode,
967 pkce_mode,
968 response_mode,
969 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
970 forward_login_hint,
971 on_backchannel_logout
972 FROM upstream_oauth_providers
973 WHERE disabled_at IS NULL
974 ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
975 "#,
976 )
977 .traced()
978 .fetch_all(&mut *self.conn)
979 .await?;
980
981 let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
982 Ok(res?)
983 }
984}