How to Implement OAuth in Rust

Cover image

In this post, we'll be learning how to implement OAuth 2.0 in Rust by writing a backend service that will interact with Google OAuth and will interact with OpenID Connect ("OIDC") service from Google to retrieve a user's email. We'll first learn to use the oauth2 library to authorise our users using database-backed sessions to keep them authenticated with a private cookie jar, then we'll use a middleware for Rust authentication to authenticate users and insert an extension to the request from the middleware.

The final code for the repository can be found here.

Set Up

Before we get started, you'll want the following:

  • A project in Google Cloud Console (you can get started here - it's free!)

We'll want to also install sqlx-cli, which we can do by running the following command:

cargo install sqlx-cli

Once we've created our project, we'll want to use sqlx migrate add schema to create our initial schema file which you will be able to find in the migrations folder. Once you open the file, it'll be empty with a simple comment to add your migrations - in which we'll add the following:

CREATE TABLE IF NOT EXISTS users (
    id SERIAL PRIMARY KEY,
    email VARCHAR(255) NOT NULL UNIQUE,
    created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
    last_updated TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE IF NOT EXISTS sessions (
    id SERIAL PRIMARY KEY,
    user_id INT NOT NULL UNIQUE,
    session_id VARCHAR NOT NULL,
    expires_at TIMESTAMP WITH TIME ZONE NOT NULL,
    FOREIGN KEY (user_id) REFERENCES users(id)
);

When we run our app we'll run the migrate macro, which will automatically attempt to run our migrations and add a new migration entry to the table so it won't automatically try to run the migration again.

Getting Started

To get started, you'll want to create a new project by running the following:

cargo shuttle init

We'll want to pick "axum" as the framework. For the purposes of the project we will refer to the project name as "oauth-rust".

Looking to deploy? Make sure you enable initialising your project on the Shuttle servers!

Next we'll want to install our dependencies - copy the script below to install everything in one go:

cargo add axum --features multipart,macros
cargo add axum-extra --features cookie-private
cargo add chrono --features clock
cargo add shuttle-shared-db --features postgres,sqlx
cargo add sqlx --features runtime-tokio-rustls,macros,chrono
cargo add tower-http --features cors,fs
cargo add anyhow tracing oauth2 reqwest shuttle-secrets thiserror

Next you'll want to create a Secrets.toml file in the root of your backend that holds all of our secret variables - you'll want to make sure you have at least the following, in the following format:

GOOGLE_OAUTH_CLIENT_ID = "Your key here"
GOOGLE_OAUTH_CLIENT_SECRET = "Your key here"

Then we'll want to get started on setting up our main entrypoint function! We can get it set up like so:

// main.rs
use reqwest::Client as ReqwestClient;
use sqlx::PgPool;
use axum::extract::{cookie::Key, FromRef};
use axum::{Router, routing::get};

#[derive(Clone)]
pub struct AppState {
    db: PgPool,
    ctx: ReqwestClient,
    key: Key
}

// implementing FromRef is required here so we can extract substate in Axum
// read more here: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html
impl FromRef<AppState> for Key {
    fn from_ref(state: &AppState) -> Self {
        state.key.clone()
    }
}

async fn hello_world() -> &'static str {
    "Hello world!"
}

#[shuttle_runtime::main]
async fn axum(
    #[shuttle_shared_db::Postgres] db: PgPool,
    #[shuttle_secrets::Secrets] secrets: SecretStore,
) -> shuttle_axum::ShuttleAxum {
    sqlx::migrate!().run(&db).await.expect("Failed migrations :(");

    // Getting secrets from our SecretsStore - safe to unwrap as they're required for the app to work
    let oauth_id = secrets.get("GOOGLE_OAUTH_CLIENT_ID").unwrap();
    let oauth_secret = secrets.get("GOOGLE_OAUTH_CLIENT_SECRET").unwrap();

    let ctx = ReqwestClient::new();

    let state = AppState {
        db,
        ctx,
        key: Key::generate()
    };

    let router = Router::new().route("/", get(hello_world));

    // More info about this below - we will build an oauth client that can interface with any OAuth service
    // Depending on the URLs we pass into it - read more here: https://docs.rs/oauth2/latest/oauth2/struct.Client.html?search=bassiclient#method.new
    let client = build_oauth_client(oauth_id, oauth_secret);
        
    Ok(router.into())
}

Before we go any further, we should set up our error handling type so that we can propagate errors up the call stack instead of trying to either unwrap everything or manually handle every single error.

// src/errors.rs
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ApiError {
    #[error("SQL error: {0}")]
    SQL(#[from] sqlx::Error),
    #[error("HTTP request error: {0}")]
    Request(#[from] reqwest::Error),
    #[error("OAuth token error: {0}")]
    TokenError(
        #[from]
        oauth2::RequestTokenError<
            oauth2::reqwest::Error<reqwest::Error>,
            oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>,
        >,
    ),
    #[error("You're not authorized!")]
    Unauthorized,
    #[error("Attempted to get a non-none value but found none")]
    OptionError,
    #[error("Attempted to parse a number to an integer but errored out: {0}")]
    ParseIntError(#[from] std::num::TryFromIntError),
    #[error("Encountered an error trying to convert an infallible value: {0}")]
    FromRequestPartsError(#[from] std::convert::Infallible),
}

Here, note that the #[from] attribute allows us to directly implement From<T> for our enum. The #[error("...")] attribute allows us to write an error message while still including the original error.

To make our error type compatible with Axum, we need to implement the IntoResponse trait. We can do this like so:

// src/routes/errors.rs
use axum::{response::IntoResponse, Response, http::StatusCode};

impl IntoResponse for ApiError {
    fn into_response(self) -> Response {
        let response = match self {
            Self::SQL(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
            Self::Request(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
            Self::TokenError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
            Self::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized!".to_string()),
            Self::OptionError => (
                StatusCode::INTERNAL_SERVER_ERROR,
                "Attempted to get a non-none value but found none".to_string(),
            ),
            Self::ParseIntError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
            Self::FromRequestPartsError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
        };

        response.into_response()
    }
}

But how do I use OAuth?

First, we will need to write a function to create an oauth2::BasicClient. This client can take any OAuth authorization endpoint URL and token endpoint URL (as long as they're both from the same OAuth service). We can pass in our Google OAuth secrets that we created earlier, as you'll be able to see below. Our redirect URL should be an endpoint that we create on our side so when the user gets successfully authorised, they get sent back to our application with a code we can exchange for a token that allows a user to stay authenticated.

// src/main.rs
use oauth::{TokenUrl, AuthUrl,
    basic::BasicClient,
    ClientId, ClientSecret,
    RedirectUrl
};


fn build_oauth_client(client_id: String, client_secret: String) -> BasicClient {
// In prod, http://localhost:8000 would get replaced by whatever your production URL is
    let redirect_url = "http://localhost:8000/api/auth/google_callback".to_string();

// If you're not using Google OAuth, you can use whatever the relevant auth/token URL is for your given OAuth service
    let auth_url = AuthUrl::new("https://accounts.google.com/o/oauth2/v2/auth".to_string())
        .expect("Invalid authorization endpoint URL");
    let token_url = TokenUrl::new("https://www.googleapis.com/oauth2/v3/token".to_string())
        .expect("Invalid token endpoint URL");

    BasicClient::new(
        ClientId::new(client_id),
        Some(ClientSecret::new(client_secret)),
        auth_url,
        Some(token_url),
    )
    .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap())
}

Now that we've created our BasicClient, we can use it anywhere we wish! Before we set up our OAuth callback route though, let's first examine how OAuth works. First we need to make up our link to the Google OAuth for our backend. Here we have a premade route that has the oauth ID inserted in for you already (click here to find out more about customising your OAuth URL):

// src/main.rs
use axum::{response::Html, Extension};

async fn homepage(
    Extension(oauth_id): Extension<String>
) -> Html<String> {
    Html(format!("<p>Welcome!</p>
    
    <a href=\"https://accounts.google.com/o/oauth2/v2/auth?scope=openid%20profile%20email&client_id={oauth_id}&response_type=code&redirect_uri=http://localhost:8000/api/auth/google_callback\">
    Click here to sign into Google!
     </a>"))
    
}

To get this route to work, you'll want to create a Router that layers an axum::Extension, then nest it onto your main router:

// Use the oauth_id from earlier in your main function

let homepage_router = Router::new()
   .route("/", get(homepage))
   .layer(Extension(oauth_id));

Once we allow the application to use our user's credentials, Google will fire a GET request to our chosen OAuth redirect URI as seen in the homepage router, with some URI query parameters. Although there's multiple parameters returned, for us we only need the code response given back by Google so we can exchange it for an access token. We can make a struct to extract the query parameters:

use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct AuthRequest {
    code: String
}

Then we need to exchange the code for a token by using our BasicClient we made earlier:

// src/routes/oauth.rs
use axum_extra::extract::cookie::PrivateCookieJar;
use axum::extract::{State, Query};
use axum::Extension;
use oauth2::AuthorizationCode;
use crate::routes::errors::ApiError;

// "async_http_client" is from oauth2::reqwest::async_http_client
pub async fn google_callback(
    State(state): State<AppState>,
    jar: PrivateCookieJar,
    Query(query): Query<AuthRequest>,
    Extension(oauth_client): Extension<BasicClient>,
) -> Result<impl IntoResponse, ApiError> {
    let token = oauth_client
        .exchange_code(AuthorizationCode::new(query.code))
        .request_async(async_http_client)
        .await?;

    // .. rest of the function
}

The token returned by exchanging the token holds all of the information from the response and has methods to get all of the required fields that we need (life duration of the access token, the code, etc...).

Because we requested OpenID privileges earlier, we can now access any of Google's OpenID using the access token that was given to us (that we required permissions for, as per the redirect URI). Thankfully, this is pretty simple to do without the oauth2 crate and we only need to use a simple Reqwest client with bearer auth to get the user profile data, like so:

#[derive(Deserialize, sqlx::FromRow, Clone)]
pub struct UserProfile {
    email: String
}

// Note that in the full code, the reqwest client is already created in the main function
// and passed to the AppState. Rather than initializing a Reqwest client with a connection pool
// for every request, we share it in the router state.
let profile = state.ctx.get("https://openidconnect.googleapis.com/v1/userinfo")
    .bearer_auth(token.access_token().secret().to_owned())
    .send().await?;

let profile = profile.json::<UserProfile>().await.unwrap();

As you can see, we needed to create a struct for the type. This particular OIDC endpoint returns much more than just an email and you'll be able to see that if you use .text().await.unwrap() instead of trying to convert the response to JSON - however, for our purposes currently we only need the email for verification - serde ignores unknown fields (unless deny_unknown_fields is enabled), so this is safe to do.

Using OAuth with Axum Extensions

Now that we've got our access token, all we need to do is store the token somewhere our service can access it. We can do this with SQLx and usage of the PrivateCookieJar type from axum_extra, which uses cryptographically secure cookies. Let's have a look at what the code would look like:

let Some(secs) = token.expires_in() else {
    return Err(ApiError::OptionError);
}

let secs: i64 = secs.as_secs().try_into().unwrap();

let max_age = Local::now().naive_local() + Duration::seconds(secs);

let cookie = Cookie::build("sid", token.access_token().secret().to_owned())
    .domain(".app.localhost")
    .path("/")
    .secure(true)
    .http_only(true)
    .max_age(TimeDuration::seconds(secs));

    sqlx::query("INSERT INTO users (email) VALUES ($1) ON CONFLICT (email) DO NOTHING")
    .bind(profile.email.clone())
    .execute(&state.db)
    .await?;
    
    sqlx::query("INSERT INTO sessions (user_id, session_id, expires_at) VALUES (
    (SELECT ID FROM USERS WHERE email = $1 LIMIT 1),
     $2, $3)
    ON CONFLICT (user_id) DO UPDATE SET 
    session_id = excluded.session_id, 
    expires_at = excluded.expires_at")
    .bind(profile.email)
    .bind(token.access_token().secret().to_owned())
    .bind(max_age)
    .execute(&state.db)
    .await?;

Now that we've done everything, we want to make sure to include our token addition in the response and a redirect:

    Ok((
        jar.add(cookie),
        Redirect::to("/protected")
    ))

Of course, our "protected" route doesn't actually exist yet - we'll create it in a moment. But first of all, let's see what the final OAuth callback handler looks like:

pub async fn google_callback(
    State(state): State<AppState>,
    jar: PrivateCookieJar,
    Query(query): Query<AuthRequest>,
    Extension(oauth_client): Extension<BasicClient>,
) -> Result<impl IntoResponse, ApiError> {
    let token = oauth_client
        .exchange_code(AuthorizationCode::new(query.code))
        .request_async(async_http_client)
        .await?;

    let profile = state
        .ctx
        .get("https://openidconnect.googleapis.com/v1/userinfo")
        .bearer_auth(token.access_token().secret().to_owned())
        .send()
        .await?;

    let profile = profile.json::<UserProfile>().await?;

    let Some(secs) = token.expires_in() else {
        return Err(ApiError::OptionError);
    };

    let secs: i64 = secs.as_secs().try_into()?;

    let max_age = Local::now().naive_local() + Duration::try_seconds(secs).unwrap();

    let cookie = Cookie::build(("sid", token.access_token().secret().to_owned()))
        .domain(".app.localhost")
        .path("/")
        .secure(true)
        .http_only(true)
        .max_age(TimeDuration::seconds(secs));

    sqlx::query("INSERT INTO users (email) VALUES ($1) ON CONFLICT (email) DO NOTHING")
        .bind(profile.email.clone())
        .execute(&state.db)
        .await?;

    sqlx::query(
        "INSERT INTO sessions (user_id, session_id, expires_at) VALUES (
        (SELECT ID FROM USERS WHERE email = $1 LIMIT 1),
         $2, $3)
        ON CONFLICT (user_id) DO UPDATE SET
        session_id = excluded.session_id,
        expires_at = excluded.expires_at",
    )
    .bind(profile.email)
    .bind(token.access_token().secret().to_owned())
    .bind(max_age)
    .execute(&state.db)
    .await?;

    Ok((jar.add(cookie), Redirect::to("/protected")))
}

To be able to authenticate users more easily, we will implement FromRequest for UserProfile. This will allow us to directly call the database while extracting the body. We then return the user profile of the person who just authenticated.

#[axum::async_trait]
impl FromRequest<AppState> for UserProfile {
    type Rejection = ApiError;
    async fn from_request(req: Request, state: &AppState) -> Result<Self, Self::Rejection> {
        let state = state.to_owned();
        let (mut parts, _body) = req.into_parts();
        let cookiejar: PrivateCookieJar =
            PrivateCookieJar::from_request_parts(&mut parts, &state).await?;

        let Some(cookie) = cookiejar.get("sid").map(|cookie| cookie.value().to_owned()) else {
            return Err(ApiError::Unauthorized);
        };

        let res = sqlx::query_as::<_, UserProfile>(
            "SELECT
        users.email
        FROM sessions
        LEFT JOIN USERS ON sessions.user_id = users.id
        WHERE sessions.session_id = $1
        LIMIT 1",
        )
        .bind(cookie)
        .fetch_one(&state.db)
        .await?;

        Ok(Self { email: res.email })
    }
}

Now we just need to add the protected route!

pub async fn protected(profile: UserProfile) -> impl IntoResponse {
    (StatusCode::OK, profile.email)
}

Now that we've filled out everything we need, we can come back to the main entrypoint function and fill back in all of our routes so that we can use them:

#[shuttle_runtime::main]
async fn axum(
    #[shuttle_shared_db::Postgres] db: PgPool,
    #[shuttle_secrets::Secrets] secrets: SecretStore,
) -> shuttle_axum::ShuttleAxum {
    sqlx::migrate!().run(&db).await.expect("Failed migrations :(");

    let oauth_id = secrets.get("GOOGLE_OAUTH_CLIENT_ID").unwrap();
    let oauth_secret = secrets.get("GOOGLE_OAUTH_CLIENT_SECRET").unwrap();

    let ctx = Client::new();

    let state = AppState {
        db,
        ctx,
        key: Key::generate()
    };

    let oauth_client = build_oauth_client(oauth_id.clone(), oauth_secret);

    let router = init_router(state, oauth_client, oauth_id);

    Ok(router.into())
}

fn init_router(state: AppState, oauth_client: BasicClient, oauth_id: String) -> Router {
    let auth_router = Router::new()
        .route("/auth/google_callback", get(oauth::google_callback));

    let protected_router = Router::new()
        .route("/", get(oauth::protected))
        .route_layer(middleware::from_fn_with_state(state.clone(), oauth::check_authenticated));

    let homepage_router = Router::new()
        .route("/", get(homepage))
        .layer(Extension(oauth_id));

    Router::new()
        .nest("/api", auth_router)
        .nest("/protected", protected_router)
        .nest("/", homepage_router)
        .layer(Extension(oauth_client))
        .with_state(state)
}

Deploying to Production

Once we're done implementing OAuth, all you need to do is use cargo shuttle deploy (with --allow-dirty if you're working on a dirty Git branch) and it'll work!

If you didn't initialise your project on Shuttle servers before, make sure your project container is started. You can do so by running cargo shuttle project start.

Finishing Up

Thanks for reading! I hope you enjoyed this guide to implementing OAuth in Rust and leveraging the oauth2 library for Rust auth.

Some extra ideas if you'd like to extend this article:

  • Silent token rotation
  • Add more functionality so users don't have to go through the whole OAuth process every single time
  • Try implementing refresh tokens (make sure they're implemented securely!)
This blog post is powered by shuttle - The Rust-native, open source, cloud development platform. If you have any questions, or want to provide feedback, join our Discord server!
Share article
rocket

Build the Future of Backend Development with us

Join the movement and help revolutionize the world of backend development. Together, we can create the future!