← See all issues

Shuttle Launchpad #8: A Web Sockets Chat

We have a small favor to ask of our amazing readers and users. Shuttle Launchpad was founded with the simple idea of making Rust more accessible and approachable to individuals from diverse backgrounds. Our goal is to make Rust development and web development easier for everyone so, we want to reach everyone! We would sincerely appreciate it if you could help us by spreading the word on platforms such as Twitter, Reddit, or among your friends!

Now, without further ado, welcome to the next issue of Shuttle Launchpad! This time we create a web socket based chat server and apply some new learnings like using the Default trait, working with Arc and RwLock, and using serde to serialize and deserialize data. We also see how we can spawn new tasks and run them concurrently!

In the end, with just a few lines of code, we created some really good infrastructure for a chat server!

Creating a Chat Server with Web Sockets

First, we need a few dependencies.

$ cargo add axum --features ws
$ cargo add serde --features derive
$ cargo add serde_json
$ cargo add futures

And this are all our imports.

use std::sync::atomic::AtomicUsize;
use std::{collections::HashMap, sync::Arc};

use axum::extract::ws::{Message, WebSocket};
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use axum::{routing::get, Router};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{self, UnboundedReceiver};
use tokio::sync::{mpsc::UnboundedSender, RwLock};

We have a few globals. First, we need an AtomicUsize to generate a unique user ID for each new user. We also need a HashMap to store all connected users. The HashMap is wrapped in an Arc and an RwLock to make it thread-safe. The Arc is used to share the HashMap between all threads, and the RwLock is used to make sure that only one thread can write to the HashMap at a time.

Since a type called Arc<RwLock<HashMap<usize, UnboundedSender<Message>>>> is quite a mouthful, we create a type alias called Users.

static NEXT_USERID: AtomicUsize = AtomicUsize::new(1);

type Users = Arc<RwLock<HashMap<usize, UnboundedSender<Message>>>>;

The message we send between users contains a name, a user ID, and the message itself. The user ID is optional because we only add it when we send the message to all other users. We don't want to send the user ID to the user that sent the message.

Since we use serde, we can simply derive Serialize and Deserialize for our struct. With that, our struct becomes compatible with all serialization formats that serde supports, including JSON.

#[derive(Serialize, Deserialize)]
struct Msg {
    name: String,
    uid: Option<usize>,
    message: String,
}

Now we can start with the main function. The first thing we do is setting up our state. Since Arc, RwLock, and HashMap all implement the Default trait, we can simply call Users::default() to create a new Users struct. How amazing is that?

We create a new router and add a route to it. The route is /ws, and we use the get function to tell Axum that this route is for web socket connections. We also add our state to the router so we can access it later.

#[shuttle_runtime::main]
async fn axum() -> shuttle_axum::ShuttleAxum {
    let users = Users::default();

    let router = Router::new()
        .route("/ws", get(ws_handler))
        .with_state(users);

    Ok(router.into())
}

The ws_handler function is the entry point for all incoming web socket connections. It receives a WebSocketUpgrade struct, which is used to upgrade the incoming request to a web socket connection. We also receive the State struct, which contains the list of connected users.

All we do is calling a on_upgrade function on the WebSocketUpgrade struct. This way, our HTTP connection will become a functioning web socket connection. This function takes a closure that receives a WebSocket struct. The WebSocket struct is used to send and receive messages to and from the connected client.

We forward this information to the handle_socket function, which is responsible for handling the web socket connection.

async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Users>)
    -> impl IntoResponse {
    ws.on_upgrade(|socket| handle_socket(socket, state))
}

async fn handle_socket(ws: WebSocket, state: Users) {
    // tbd
}

Our new user needs a few things to get started. First, we need a unique user ID. We use an AtomicUsize to generate a new user ID for each new user. This way, we can safely increase the number even across threads.

The incoming web socket is split into two parts:

  1. A Sender, that sends Axum web socket messages to the connected client.
  2. A Receiver, that receives messages from the connected client.
let my_id = NEXT_USERID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let (mut sender, mut receiver) = ws.split();

This is for the communication between our app and the connected client, but we also want to make sure that all other connected clients can talk to each other. We do this using a channel. A channel is a way to send messages between tasks. We create a channel that can send and receive messages of type Message. The UnboundedSender is used to send messages, and the UnboundedReceiver is used to receive messages.

let (tx, mut rx): (UnboundedSender<Message>, UnboundedReceiver<Message>) =
    mpsc::unbounded_channel();

We need to set up our newly connected user so they can send and receive messages. Those are two tasks that need to run in parallel. To make sure the user can receive messages, we spawn a new task that receives messages from the channel and sends them to the user.

We use tokio::spawn to create a new task that runs concurrently to all the other tasks. The async move keyword says that what we execute here is async, thus a Future, and thus able to run concurrently. The move keyword is used to move the variables into the new task so they can take ownership. This is necessary because the task will run concurrently to the rest of the application, and we need to make sure that the variables are still available when the task is executed.

tokio::spawn(async move {
    while let Some(msg) = rx.recv().await {
        sender.send(msg).await.expect("Error!");
    }
    sender.close().await.unwrap();
});

Now that our user is created and is able to receive messages, we need to add it to the list of connected users. We do this by acquiring a write lock on the Users struct and inserting the user ID and the sender channel into the map.

state.write().await.insert(my_id, tx);

Now that the user and message receiving part is done, we work on broadcasting messages to all other connected clients. Every time a message is sent from the user, we want to send it to all other connected clients. We do this by iterating over all connected users and sending the message to each of them.

while let Some(Ok(result)) = receiver.next().await {
    if let Ok(result) = enrich_result(result, my_id) {
        broadcast_msg(result, &state).await;
    }
}

The enrich_result function is used to add the user id to the message. This way, we can display the user ID in the front-end later on. We parse the message into a Msg struct, add the user ID, and then serialize it back into a string. Thanks to the serde_json crate, this is very easy to do.

fn enrich_result(result: Message, id: usize) ->
    Result<Message, serde_json::Error> {
    match result {
        Message::Text(msg) => {
            let mut msg: Msg = serde_json::from_str(&msg)?;
            msg.uid = Some(id);
            let msg = serde_json::to_string(&msg)?;
            Ok(Message::Text(msg))
        }
        _ => Ok(result),
    }
}

The broadcast_msg function is used to send the message to all connected users. We iterate over all users and send the message to each of them. If the message is not a text message, we ignore it.

async fn broadcast_msg(msg: Message, users: &Users) {
    if let Message::Text(msg) = msg {
        for (&_uid, tx) in users.read().await.iter() {
            tx.send(Message::Text(msg.clone()))
                .expect("Failed to send Message")
        }
    }
}

So far, we managed how to send and receive messages inside our applications. The first task was to receive messages over a channel and send it out to the connected client. The second task was to receive messages from the client and send it out to all other connected clients using the channel.

The last thing we need to do is to remove the user from the list of connected users when they disconnect. The good thing is that when a user actually closes their browser, the web socket connection is closed automatically. This means that we can use the while let Some(Ok(result)) = receiver.next().await loop to detect when a user disconnects.

Then, we can call the disconnect function as a next step.

disconnect(my_id, &state).await;

Disconnecting is very simple. We remove the user from the list of connected users. That way, the sender channel is dropped, and the broadcast_msg function will not send any messages to the disconnected user.

async fn disconnect(my_id: usize, users: &Users) {
    users.write().await.remove(&my_id);
}

The full handle_socket function looks like this:

async fn handle_socket(ws: WebSocket, state: Users) {
    // Create the user
    let my_id = NEXT_USERID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
    let (mut sender, mut receiver) = ws.split();
    let (tx, mut rx): (UnboundedSender<Message>, UnboundedReceiver<Message>) =
        mpsc::unbounded_channel();

    // Receive messages from the channel and send them to the user
    tokio::spawn(async move {
        while let Some(msg) = rx.recv().await {
            sender.send(msg).await.expect("Error!");
        }
        sender.close().await.unwrap();
    });

    // Add the user to the list of connected users
    state.write().await.insert(my_id, tx);

    // Receive messages from the user and broadcast them
    while let Some(Ok(result)) = receiver.next().await {
        if let Ok(result) = enrich_result(result, my_id) {
            broadcast_msg(result, &state).await;
        }
    }

    // Remove the user from the list of connected users
    disconnect(my_id, &state).await;
}

Now that the main server part is done, you need a front-end that your users can interact with. I'll just put the HTML and JavaScript in here.

<!DOCTYPE html>
<html lang="en">

<head>
  <title>Shuttle Web Socket Chat</title>
</head>

<body>
  <div class="chat">
    <h1>!Discord</h1>
    <div id="log"></div>
    <div class="inp">
      <input id="input" type="text" /><button id="btn">Send</button>
    </div>
  </div>
  <script src="/main.js"></script>
</body>

</html>

And main.js

let log = console.log;

let name = prompt("Enter your name");

const wsUri = ((window.location.protocol == "https:" && "wss://") || "ws://") +
  window.location.host +
  "/ws";
conn = new WebSocket(wsUri);

log("Connecting...");

conn.onopen = function () {
  log("Connected.");
};

conn.onmessage = function (e) {
  log("Received: " + e.data);
  let msg = JSON.parse(e.data);
  //let str = `${msg.name} (${msg.uid}): ${msg.message}`;
  document.getElementById("log").appendChild(createMsg(msg));
};

conn.onclose = function () {
  log("Disconnected.");
  conn = null;
};

function createMsg(message) {
  const msg = document.createElement("div");
  msg.textContent = message.message;
  msg.classList.add("msg");
  const name = document.createElement("div");
  name.textContent = `${message.name} (${message.uid})`;
  name.classList.add("nom");
  const s = document.createElement("div");
  s.appendChild(name);
  s.appendChild(msg);
  s.classList.add("bubble");
  return s;
}

function send() {
  conn.send(
    JSON.stringify({
      name: name,
      message: document.getElementById("input").value,
    }),
  );
  document.getElementById("input").value = "";
}

document.getElementById("btn")?.addEventListener("click", send);

document.getElementById("input")?.addEventListener("keypress", (e) => {
  if (e.key === "Enter") {
    e.preventDefault();
    send();
  }
});

Try figuring out how to wire them up in your Axum applications. A few leads:

  1. Maybe create new named routes and return the content as String.
  2. Try figuring out how to serve static files using tower_http::services::ServeDir and shuttle_static_folder

And that's all for today!

Time for your feedback!

We want to tailor Shuttle Launchpad to your needs! Give us feedback on the most recent issue and your wishes here.

Join us!

Shuttle has a very active community. Join us on Discord, star us on GitHub, follow us on Twitter, and watch out for video content on YouTube.

If you have any questions regarding Launchpad, join the #launchpad channel on Shuttle's Discord.

Launchpad Examples: Check out all Launchpad Examples on GitHub.

Best Rust Web Frameworks to Use in 2023: A detailed analysis of Rust web frameworks by yours truly.

Semantic Search with Qdrant, OpenAI and Shuttle: A new blog article by yours truly on how to create a semantic search engine that actually works!

Bye!

That's it for today. Get in touch with us and let us know what you want to see!

-- Stefan and your friends from Shuttle