Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions auth/server/src/api/named/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ pub async fn google_login<I: AuthImpl>(
.context("Google provider not available")
.status_code(StatusCode::UNAUTHORIZED)?;

let (state, uri) =
let (state, nonce, uri) =
provider.get_state_and_login_redirect_url(redirect).await;

session.insert_google_login(&state).await?;
session.insert_google_login(&state, &nonce).await?;

Ok(Redirect::to(&uri))
}
Expand Down Expand Up @@ -91,10 +91,10 @@ pub async fn google_link<I: AuthImpl>(
.context("Google provider not available")
.status_code(StatusCode::UNAUTHORIZED)?;

let (state, uri) =
let (state, nonce, uri) =
provider.get_state_and_login_redirect_url(None).await;

session.insert_google_link(&user_id, &state).await?;
session.insert_google_link(&user_id, &state, &nonce).await?;

info!(
user_id = user.id(),
Expand Down Expand Up @@ -146,7 +146,7 @@ pub async fn google_callback<I: AuthImpl>(
.await;
}

let state = session.retrieve_google_login().await?;
let (state, nonce) = session.retrieve_google_login().await?;

if client_state != state {
return Err(
Expand All @@ -155,8 +155,7 @@ pub async fn google_callback<I: AuthImpl>(
);
}

let token = provider.get_access_token(&code).await?;
let google_user = provider.get_google_user(&token.id_token)?;
let google_user = provider.get_google_user(&code, &nonce).await?;
let google_id = google_user.id;
let avatar_url = google_user.picture;

Expand Down Expand Up @@ -229,7 +228,7 @@ pub async fn google_callback<I: AuthImpl>(
async fn link_google_callback<I: AuthImpl>(
auth: &I,
provider: &GoogleProvider,
(user_id, state): (String, String),
(user_id, state, nonce): (String, String, String),
client_state: String,
code: String,
) -> mogh_error::Result<Redirect> {
Expand All @@ -239,9 +238,7 @@ async fn link_google_callback<I: AuthImpl>(
);
}

let token = provider.get_access_token(&code).await?;

let google_user = provider.get_google_user(&token.id_token)?;
let google_user = provider.get_google_user(&code, &nonce).await?;
let google_id = google_user.id;
let avatar_url = google_user.picture;

Expand Down
129 changes: 67 additions & 62 deletions auth/server/src/provider/named/google.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use std::sync::OnceLock;

use anyhow::Context;
use jsonwebtoken::dangerous::insecure_decode;
use mogh_auth_client::config::NamedOauthConfig;
use serde::{Deserialize, de::DeserializeOwned};
use openidconnect::{
ClientId, ClientSecret, IssuerUrl, Nonce, RedirectUrl,
core::CoreProviderMetadata,
reqwest as oidc_reqwest,
};
use serde::Deserialize;
use tracing::warn;

use crate::{
provider::named::{STATE_PREFIX_LENGTH, handle_response},
provider::named::STATE_PREFIX_LENGTH,
rand::random_string,
};

Expand All @@ -24,12 +28,10 @@ pub fn google_provider(
}

pub struct GoogleProvider {
http: reqwest::Client,
client_id: String,
client_secret: String,
redirect_uri: String,
scopes: String,
user_agent: String,
}

impl GoogleProvider {
Expand Down Expand Up @@ -70,11 +72,9 @@ impl GoogleProvider {
)
.to_string();
GoogleProvider {
http: Default::default(),
client_id: client_id.clone(),
client_secret: client_secret.clone(),
redirect_uri: format!("{host}{path}/google/callback"),
user_agent: String::from("komodo"),
scopes,
}
.into()
Expand All @@ -83,78 +83,83 @@ impl GoogleProvider {
pub async fn get_state_and_login_redirect_url(
&self,
redirect: Option<String>,
) -> (String, String) {
) -> (String, String, String) {
let state_prefix = random_string(STATE_PREFIX_LENGTH);
let state = match redirect {
Some(redirect) => state_prefix + &redirect,
None => state_prefix,
};
let nonce = Nonce::new(random_string(32));
let redirect_url = format!(
"https://accounts.google.com/o/oauth2/v2/auth?response_type=code&state={state}&client_id={}&redirect_uri={}&scope={}",
self.client_id, self.redirect_uri, self.scopes
"https://accounts.google.com/o/oauth2/v2/auth?response_type=code&state={state}&nonce={}&client_id={}&redirect_uri={}&scope={}",
urlencoding::encode(nonce.secret()),
self.client_id,
self.redirect_uri,
self.scopes
);
(state, redirect_url)
(state, nonce.secret().clone(), redirect_url)
}

pub async fn get_access_token(
pub async fn get_google_user(
&self,
code: &str,
) -> anyhow::Result<AccessTokenResponse> {
self
.post::<_>(
"https://oauth2.googleapis.com/token",
&[
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("redirect_uri", self.redirect_uri.as_str()),
("code", code),
("grant_type", "authorization_code"),
],
None,
)
.await
.context("failed to get google access token using code")
}

pub fn get_google_user(
&self,
id_token: &str,
nonce: &str,
) -> anyhow::Result<GoogleUser> {
let res = insecure_decode::<GoogleUser>(id_token)
.context("failed to decode google id token")?;
Ok(res.claims)
}
let http_client = oidc_reqwest::ClientBuilder::new()
.redirect(oidc_reqwest::redirect::Policy::none())
.build()
.context("Failed to build HTTP client")?;

async fn post<R: DeserializeOwned>(
&self,
endpoint: &str,
body: &[(&str, &str)],
bearer_token: Option<&str>,
) -> anyhow::Result<R> {
let mut req = self
.http
.post(endpoint)
.form(body)
.header("Accept", "application/json")
.header("User-Agent", &self.user_agent);
let issuer_url =
IssuerUrl::new("https://accounts.google.com".to_string())
.context("Invalid Google issuer URL")?;

if let Some(bearer_token) = bearer_token {
req =
req.header("Authorization", format!("Bearer {bearer_token}"));
}
let provider_metadata =
CoreProviderMetadata::discover_async(issuer_url, &http_client)
.await
.context("Failed to discover Google OpenID configuration")?;

let res = req.send().await.context("Failed to reach Google")?;
let client = openidconnect::core::CoreClient::from_provider_metadata(
provider_metadata,
ClientId::new(self.client_id.clone()),
Some(ClientSecret::new(self.client_secret.clone())),
)
.set_redirect_uri(
RedirectUrl::new(self.redirect_uri.clone())
.context("Invalid Google redirect URI")?,
);

handle_response(res).await
}
}
let token_response = client
.exchange_code(openidconnect::AuthorizationCode::new(
code.to_string(),
))?
.request_async(&http_client)
.await
.context("Failed to exchange Google authorization code")?;

let id_token = token_response
.extra_fields()
.id_token()
.context("Google did not return an ID token")?;

#[derive(Deserialize)]
pub struct AccessTokenResponse {
// pub access_token: String,
pub id_token: String,
// pub scope: String,
// pub token_type: String,
let verifier = client.id_token_verifier();
let claims = id_token
.claims(&verifier, &Nonce::new(nonce.to_string()))
.context("Failed to verify Google ID token")?;

Ok(GoogleUser {
id: claims.subject().as_str().to_string(),
email: claims
.email()
.map(|e| e.as_str().to_string())
.unwrap_or_default(),
picture: claims
.picture()
.and_then(|p| p.get(None))
.map(|p| p.as_str().to_string())
.unwrap_or_default(),
})
}
}

#[derive(Deserialize, Clone)]
Expand Down
10 changes: 6 additions & 4 deletions auth/server/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,11 @@ impl Session {
pub async fn insert_google_login(
&self,
state: &str,
nonce: &str,
) -> mogh_error::Result<()> {
self
.0
.insert(Self::GOOGLE_LOGIN, state)
.insert(Self::GOOGLE_LOGIN, (state, nonce))
.await
.context("Failed to serialize session data")
.map_err(Into::into)
Expand All @@ -134,7 +135,7 @@ impl Session {
/// Returns the CSRF state for validation
pub async fn retrieve_google_login(
&self,
) -> mogh_error::Result<String> {
) -> mogh_error::Result<(String, String)> {
self
.0
.remove(Self::GOOGLE_LOGIN)
Expand Down Expand Up @@ -359,10 +360,11 @@ impl Session {
&self,
user_id: &str,
state: &str,
nonce: &str,
) -> mogh_error::Result<()> {
self
.0
.insert(Self::GOOGLE_LINK, (user_id, state))
.insert(Self::GOOGLE_LINK, (user_id, state, nonce))
.await
.context("Failed to serialize session data")
.map_err(Into::into)
Expand All @@ -371,7 +373,7 @@ impl Session {
/// Returns (user_id, state)
pub async fn retrieve_google_link(
&self,
) -> mogh_error::Result<Option<(String, String)>> {
) -> mogh_error::Result<Option<(String, String, String)>> {
self
.0
.remove(Self::GOOGLE_LINK)
Expand Down