DEV Community

Cover image for Create a REST API in Rust that you wouldn't be (too?) ashamed of deploying to production
Dex
Dex

Posted on

Create a REST API in Rust that you wouldn't be (too?) ashamed of deploying to production

Points that I will be covering in this blog post:

  • Application Architecture
  • Application Configuration
  • Application State (Diesel ORM & PostgreSQL)
  • REST API (Axum Server)
  • Calling external API (Reqwest HTTP Client)
  • Tracing (Open Telemetry & Jaeger)
  • Graceful Shutdown (Handling SIGTERM)

The source code is available in this repository on Github: todoservice

Why Rust?

Glad you asked! Rust is a modern static language that runs natively on bare metal. It does not need a garbage collector, nor a virtual environment to run on. It is very fast and has a small runtime memory footprint compared to garbage collected languages. The Rust compiler does static analysis to guarantee memory safety at runtime. Using regular Rust, it guarantees no null pointers at runtime. And if you have deployed code to production you would most certainly know the value of any programming language out there that can guarantee memory access safety by catching these error at compile time.
Now that this is out of the way, let's dive deep into our todoservice.

Application Architecture

The todoservice exposes 6 REST endpoints:

  • Create Todo item
  • Delete Todo item
  • Get Todo item by ID
  • Mark Todo as completed
  • Get all Todo items
  • Create a random Todo

The Todos are stored in PostgresSQL database in a todos table:

CREATE TABLE todos (
  id SERIAL PRIMARY KEY,
  title VARCHAR NOT NULL,
  body TEXT NOT NULL,
  completed BOOLEAN NOT NULL DEFAULT FALSE
)
Enter fullscreen mode Exit fullscreen mode

Application Configuration

The application default or base configuration is stored in a json file:

config/default.json

{
  "environment": "development",
  "server": {
    "host": "0.0.0.0",
    "port": 8080
  },
  "database": {
    "host": "localhost",
    "port": 5432,
    "name": "tododb",
    "user": "todouser",
    "password": "todopassword"
  },
  "logger": {
    "level": "DEBUG"
  },
  "tracing": {
    "host": "http://localhost:4317"
  },
  "service": {
    "name": "todoservice"
  }
}
Enter fullscreen mode Exit fullscreen mode

The environment specific overrides are stored in a separate file. Example:

config/production.json

{
  "environment": "production",
  "logger": {
    "level": "INFO"
  }
}
Enter fullscreen mode Exit fullscreen mode

First we load default.json then we use the environment variable ENV to load the environment specific config file and overlay it on top of the base config. The result of the merge is the configuration that the application will use. We also have the chance to override any configuration item within the file by using environment variables.

src/config.rs

impl Configurations {
    pub fn new() -> Result<Self, ConfigError> {
        let env = env::var("ENV").unwrap_or_else(|_| "development".into());

        let mut builder = Config::builder()
            .add_source(File::with_name("config/default"))
            .add_source(File::with_name(&format!("config/{env}")).required(false))
            .add_source(File::with_name("config/local").required(false))
            .add_source(Environment::default().separator("__"));

        // Allow to override settings from environment variables
        if let Ok(port) = env::var("PORT") {
            builder = builder.set_override("server.port", port)?;
        }
        if let Ok(log_level) = env::var("LOG_LEVEL") {
            builder = builder.set_override("logger.level", log_level)?;
        }

        builder
            .build()?
            // Deserialize (and thus freeze) the entire configuration.
            .try_deserialize()
    }
}
Enter fullscreen mode Exit fullscreen mode

In the code above I'm showing how the configuration files are loaded and merged. You can merge as many config files as you want with various rules. I'm also showing how the PORT and LOG_LEVEL are overriden if they are defined as environment variables.

Application State (Diesel ORM & PostgreSQL)

I'm using Diesel and PostgreSQL to store the application state in the database. Diesel makes it easy to handle the db migrations and it generates code for our CRUD (Create, Read, Update, Delete) use case. It also provides a connection pool to make access to the DB more efficient under heavy load in a multithreaded environment.

src/database.rs

pub struct AppState {
    pub pool: Pool<ConnectionManager<PgConnection>>,
}

pub fn get_connection_pool(config: &Configurations) -> AppState {
    let url = get_database_url(config);
    let manager = ConnectionManager::<PgConnection>::new(url);

    let pool = Pool::builder()
        .test_on_check_out(true)
        .build(manager)
        .expect("Could not build connection pool");

    AppState { pool }
}
Enter fullscreen mode Exit fullscreen mode

I encapsulate the connection pool in AppState struct and call get_connection_pool(&config) to create the pool. Then in the router I initialize it and pass the pool Router::new()...with_state(Arc::new(state)). Note that I initialize an ARC with the state and then pass it to the router. ARC in Rust is Atomically Reference Counted which is a thread-safe reference-counting pointer, which allows us to safely share the db pool in the multi-threaded HTTP server that we are going to create in a bit.

REST API (Axum Server)

Axum is a modular multithreaded and modern web server. In this application we will take advantage of its ergonomic and ease of use, support for asynchronous processing and multithreading, graceful shutdown, and tracing.

Creating the server and listening on the port:

src/main.rs

let app_state = database::get_connection_pool(&config);
    let app = app::create_app(app_state);

    let address: SocketAddr = format!("{}:{}", config.server.host, config.server.port)
        .parse()
        .expect("Unable to parse socket address");

    axum::Server::bind(&address)
        .serve(app.into_make_service())
        .with_graceful_shutdown(...)
        .await
        .expect("Failed to start server");
}
Enter fullscreen mode Exit fullscreen mode

Creating the routes and attaching the application state and tracing layer to the routes:

src/app.rs

pub fn create_app(state: AppState) -> Router {
    Router::new()
        .route("/todo/:todo_id", get(get_todo))
        .route("/todo/:todo_id", delete(delete_todo))
        .route("/todo/:todo_id", put(complete_todo))
        .route("/todo", post(create_todo))
        .route("/todo", get(get_all_todos))
        .route("/todo/random", post(create_random_todo))
        .with_state(Arc::new(state))
        .layer(TraceLayer::new_for_http())
}
Enter fullscreen mode Exit fullscreen mode

Lets dive into one of the routes to see how it works all together. In this case lets look at .route("/todo", post(create_todo)) which accepts a todo item as json request body and inserts it into the PostgreSQL database using Diesel.

First we define the NewTodo struct which we will use to deserialize the Json body into, and then use the struct to insert the record into the db.
src/models.rs

#[derive(serde::Deserialize, Insertable, Debug)]
#[diesel(table_name = crate::schema::todos)]
pub struct NewTodo {
    pub title: String,
    pub body: String,
}
Enter fullscreen mode Exit fullscreen mode

src/app.rs


#[instrument]
async fn create_todo(
    State(state): State<Arc<AppState>>,
    Json(new_todo): Json<NewTodo>,
) -> Result<Json<Todo>, (StatusCode, String)> {
    let mut conn = state.pool.get().map_err(internal_error)?;

    info!("Creating Todo {:?}", &new_todo);

    let res = diesel::insert_into(todos::table)
        .values(&new_todo)
        .returning(Todo::as_returning())
        .get_result(&mut conn)
        .map_err(internal_error)?;

    Ok(Json(res))
}
Enter fullscreen mode Exit fullscreen mode

In the create_todo function the async tells the compiler that this function is asynchronous and runs on a thread and returns a future that needs to be completed (using await in Rust). This function takes two arguments which Axum server will pass into the function, the application state, which is the db connection pool that we will use to get a connection to interact with the db, and the second argument is the deserialized JSON body into the NewTodo struct. After we get the connection, we use it to insert the record into the db.
The derive Insertable on the NewTodo struct makes it possible for us to use the struct with Diesel's api. And the derive serde::Deserialize makes it possible for Axum to be able to automatically deserialize the Json body into the struct and pass it as an argument. The Ok(Json(res)) is equivalent to returning HTTP 200 with the Todo struct serialized as Json, as hinted by the return type of the function Result<Json<Todo>, (StatusCode, String)>.

Calling external API (Reqwest HTTP Client)

The endpoint that generates a random Todo calls an external API to get a random activity, we use that to create a random todo item.
We're using reqwest to do an async external HTTP REST call to get a random activity. And then store it as a new todo item in the db.
Here I use request::get to do the external call, and I have to mention that this as it stands is not efficient because on each call it has to initialize the http client before doing the actual request. I intentionally left it this way to show the time it takes to initialize the client we talk about tracing in the next section. I'll leave it to the reader to create a Reqwest client when the application initializes, and pass the client for example in the state object to the router and use the already initialized client to make the http calls instead.

src/app.rs

#[instrument]
async fn create_random_todo(
    State(state): State<Arc<AppState>>,
) -> Result<Json<Todo>, (StatusCode, String)> {
    let random_activity: Activity = reqwest::get("https://www.boredapi.com/api/activity")
        .await
        .map_err(internal_error)?
        .json()
        .await
        .map_err(internal_error)?;

    info!("Got: {:?}", random_activity);

    let new_todo = NewTodo {
        title: random_activity.activity,
        body: random_activity.activity_type,
    };

    let mut conn = state.pool.get().map_err(internal_error)?;

    info!("Creating random Todo {:?}", &new_todo);

    let res = diesel::insert_into(todos::table)
        .values(&new_todo)
        .returning(Todo::as_returning())
        .get_result(&mut conn)
        .map_err(internal_error)?;

    Ok(Json(res))
}
Enter fullscreen mode Exit fullscreen mode

Tracing (Open Telemetry & Jaeger)

This is the fun part. I'm using OpenTelemetry for both logging and tracing, and I will show how we can trace functions in Jaeger UI.

src/main.rs

fn init_tracer(config: &Configurations) -> Result<opentelemetry_sdk::trace::Tracer, TraceError> {
    opentelemetry_otlp::new_pipeline()
        .tracing()
        .with_exporter(
            opentelemetry_otlp::new_exporter()
                .tonic()
                .with_endpoint(config.tracing.host.clone()),
        )
        .with_trace_config(
            sdktrace::config().with_resource(Resource::new(vec![KeyValue::new(
                "service.name",
                config.service.name.clone(),
            )])),
        )
        .install_batch(runtime::Tokio)
}

#[tokio::main]
async fn main() {
...
// initialize tracing
    let tracer = init_tracer(&config).expect("Failed to initialize tracer.");
    let fmt_layer = tracing_subscriber::fmt::layer();
    tracing_subscriber::registry()
        .with(tracing_subscriber::EnvFilter::from(&config.logger.level))
        .with(fmt_layer)
        .with(tracing_opentelemetry::layer().with_tracer(tracer))
        .init();
...
}
Enter fullscreen mode Exit fullscreen mode

Each function instrument we simple annotate it with #[instrument]. This instructs OpenTelemetry to create a new span whenever this function is entered. The to create events, or instrumentation points we can simple use regular logging, like info!, error!...etc

We create the tracer and we give it Jaeger's host and port to send the traces to, and we attach a logger to show the application logs in the console as well. Lastly we install the tracer as batch, this makes it more efficient to send out the traces.

Create Todo

Create Todo

Create Random Todo

Create Random Todo
See here in the create random todo it took more than 208ms to get the response from the outgoing request. This can be drastically improved by removing the initialization of the client of this function as described above.

Graceful Shutdown (Handling SIGTERM)

At last but not least we want to handle graceful shutdown.
We do this in two steps

First we setup the signal handling channel, and when we receive the OS signal we are interested in we send a signal to the receiver end of the channel.
src/shutdown.rs

pub fn register() -> Receiver<()> {
    let signals = Signals::new([SIGHUP, SIGTERM, SIGINT, SIGQUIT]).unwrap();
    signals.handle();
    let (tx, rx): (Sender<()>, Receiver<()>) = oneshot::channel();
    tokio::spawn(handle_signals(signals, tx));
    rx
}

async fn handle_signals(mut signals: Signals, tx: Sender<()>) {
    while let Some(signal) = signals.next().await {
        match signal {
            SIGHUP => {
                // Reload configuration, reopen the log file...etc
            }
            SIGTERM | SIGINT | SIGQUIT => {
                // Gracefully shut down
                let _ = tx.send(());
                return;
            }
            _ => unreachable!(),
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Second, we wait on the receiving channel. When the signal is received we shutdown everything.


let rx = shutdown::register();

...
.with_graceful_shutdown(async {
    rx.await.ok(); // This will block until a shutdown signal is received
    info!("Handling graceful shutdown");
    info!("Close resources, drain and shutdown event handler... etc");
    shutdown_tracer_provider();
})
...
Enter fullscreen mode Exit fullscreen mode

Thank you for reading, and I hope you enjoyed it!

Top comments (0)