Introduction:
In the rapidly evolving landscape of artificial intelligence and natural language processing, the ability to deploy and serve private language models efficiently is becoming increasingly crucial.
Challenge:
I wanted a small server with several properties:
Configurable, lite, local and private that would allow me to easily iterate on top of it other tools and ideas.
It also had to work in an offline environment with as little resources as possible for the task at hand.
This article will provide an in-depth exploration of building a high-performance server for language model inference using Rust, Tokyo and the Actix web framework.
Note: For simplicity, I've striped away all code related to database request history, telemetry, logging, tests and docker to keep this as frictionless as possible.
Some structs are not shown on purpose as the point of the article is to show how to start an inference session over a private network with a dynamic model. Also, if you have ideas about optimising this example further or improving on top please do share your own experiences.
Although we focus mainly on the binary side and its server implementation, we can also imagine it being used as a lib.
We'll delve into the intricacies of each component, discussing not just the how, but also the why behind our design choices.
Prerequisites:
Before embarking on this project, ensure you have Rust installed on your system. Familiarity with Rust's syntax and concepts like ownership, borrowing, and async programming will be beneficial. Additionally, a basic understanding of web development principles and RESTful APIs will help you grasp the concepts more easily.
- Project setup and dependencies:
Expanding on our project setup, let's discuss why we've chosen some of these specific dependencies:
[dependencies]
actix-web = { version = "4", optional = true, features = ["openssl"] }
actix-web-httpauth = { version = "0.8.1" }
actix-cors = { version = "0.7.0" }
tokio = { version = "1", features = ["full"] }
dotenvs = "0.1.0"
serde_json = "1.0.113"
serde = { version = "1.0", features = ["derive"] }
rand = { version = "0.8.5" }
openssl = "0.10.64"
llm = {version = "0.1.1"}
- actix-web: This is our primary web framework, it's the library I have most experience with and we'll see why shortly. Actix is known for its high performance and low overhead, making it an excellent choice for building efficient web services in Rust.
- actix-cors: Cross-Origin Resource Sharing (CORS) is crucial for web applications that might be accessed from different domains. This crate provides CORS middleware for Actix.
- serde: Serialization and deserialization are fundamental operations in web services. Serde is the de facto standard for handling these tasks in Rust.
- llm: This crate provides a unified interface for loading and using Large Language Model. The backend at the time of writing is ggml only https://github.com/ggerganov/ggml.
- openssl: For implementing HTTPS, which is essential for securing our API.
- rand: Used for random number generation, which can be necessary for certain language model operations.
- Server configuration:
Create a file named .env
in the root directory of your project with the following content:
SERVER_ADDRESS=localhost
SERVER_PORT=8080
SERVER_REQUEST_TIMEOUT_IN_SECONDS=10
MACHINE_COMMAND_TIMEOUT_IN_SECONDS=10
DATABASE_URL=""
MAX_CONNECTIONS=10
ALLOWED_ORIGIN="localhost"
MAX_AGE=4600
LLM_MODEL="open_llama_7b-f16.bin"
LLM_MODEL_ARCHITECTURE="llama"
LLM_INFERENCE_MAX_TOKEN_COUNT=400
Now let's expand on our configuration system:
#[derive(Debug, Clone)]
pub struct Config {
pub server_address: String,
pub server_port: u16,
pub server_request_timeout: u64,
pub machine_command_timeout: u64,
pub max_connections: u32,
pub database_url: String,
pub allowed_origin: String,
pub max_age: u64,
pub llm_model: String,
pub llm_model_architecture: String,
pub llm_inference_max_token_count: usize,
}
impl Config {
pub fn init() -> Config {
let _ = dotenv::load();
let server_address = env::var("SERVER_ADDRESS")
.expect("SERVER_ADDRESS must be specified")
.parse::<String>()
.unwrap();
let server_port = env::var("SERVER_PORT")
.expect("SERVER_PORT must be specified")
.parse::<u16>()
.unwrap();
let max_connections = env::var("MAX_CONNECTIONS")
.expect("MAX_CONNECTIONS must be specified")
.parse::<u32>()
.unwrap();
let database_url = env::var("DATABASE_URL")
.expect("DATABASE_URL must be specified")
.parse::<String>()
.unwrap();
let allowed_origin = env::var("ALLOWED_ORIGIN").expect("ALLOWED_ORIGIN must be specified");
let max_age = env::var("MAX_AGE")
.expect("MAX_AGE must be specified")
.parse::<u64>()
.unwrap();
let server_request_timeout = env::var("SERVER_REQUEST_TIMEOUT_IN_SECONDS")
.expect("SERVER_REQUEST_TIMEOUT_IN_SECONDS must be specified")
.parse::<u64>()
.unwrap();
let machine_command_timeout = env::var("MACHINE_COMMAND_TIMEOUT_IN_SECONDS")
.expect("MACHINE_COMMAND_TIMEOUT_IN_SECONDS must be specified")
.parse::<u64>()
.unwrap();
let llm_model = env::var("LLM_MODEL")
.expect("LLM_MODEL must be specified")
.parse::<String>()
.unwrap();
let llm_model_architecture = env::var("LLM_MODEL_ARCHITECTURE")
.expect("LLM_MODEL_ARCHITECTURE must be specified")
.parse::<String>()
.unwrap();
let llm_inference_max_token_count = env::var("LLM_INFERENCE_MAX_TOKEN_COUNT")
.expect("LLM_INFERENCE_MAX_TOKEN_COUNT must be specified")
.parse::<usize>()
.unwrap();
Config {
server_address,
server_port,
server_request_timeout,
machine_command_timeout,
max_connections,
database_url, // used for request history
allowed_origin,
max_age,
// llm specific configuration
llm_model,
llm_model_architecture,
llm_inference_max_token_count,
}
}
}
This simple configuration system although not complete and exhaustive is enough and it allows for flexible deployment across different environments. By using environment variables with sensible defaults, we can easily adjust our server's behaviour without recompiling.
This is particularly useful in containerised deployments or when moving between development and production environments.
- Main server structure:
Let's break down the main server structure in more detail:
#[actix_web::main]
#[cfg(feature = "server")]
async fn main() -> std::io::Result<()> {
let config: Config = Config::init();
let model_path = PathBuf::from(&config.llm_model);
let now = std::time::Instant::now();
let model_architecture = match_model_architecture(&config.llm_model_architecture)
.unwrap_or_else(|| {
panic!(
"Failed to find model architecture {} for model: {}.\n",
config.llm_model_architecture, &config.llm_model
);
});
let model = llm::load_dynamic(
model_architecture,
&model_path,
Default::default(),
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!(
"Failed to load {} model from {:?}: {}",
config.llm_model, model_path, err
);
});
println!(
"{} model ({}) has been started!\nElapsed: {}ms",
config.llm_model,
config.llm_model_architecture,
now.elapsed().as_millis()
);
println!(
"Starting server at https://{}:{}.\n",
config.server_address, config.server_port
);
let config: Config = Config::init();
let app_state = web::Data::new(AppState {
model,
config: config.clone(),
});
let complete_address = format!("{}:{}", config.server_address, config.server_port);
let mut ssl_builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
ssl_builder
.set_private_key_file("certs/key.pem", SslFiletype::PEM)
.unwrap();
ssl_builder
.set_certificate_chain_file("certs/cert.pem")
.unwrap();
HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
.wrap(middleware::Logger::default())
.wrap(middleware::Logger::new("%a %{User-Agent}i"))
.wrap(middleware::Compress::default())
.wrap(
Cors::default()
.allowed_origin(&config.allowed_origin)
.allowed_methods(vec!["GET", "POST"])
.allowed_headers(vec![
http::header::AUTHORIZATION,
http::header::ACCEPT,
http::header::CONTENT_TYPE,
])
.max_age(config.max_age as usize),
)
.route("/", web::get().to(server_info_handler))
.service(
web::scope("/api")
.route("/generate", web::post().to(generate_handler))
.route("/health", web::get().to(health_handler)),
)
})
.bind_openssl(complete_address, ssl_builder)?
.run()
.await
}
Key points to note:
a. Logging: We initialise a logger to help with debugging and monitoring.
This can be further enhanced with OpenTelemetry (OTEP) traces and Bunyan-formatted logging, but for the sake of simplicity we'll not add it for now.
b. Model Loading: The match_model_architecture
function (not shown) would handle loading the correct language model into memory. This is a critical step that might involve reading a large file and initializing complex data structures.
c. Application State: By wrapping our state in web::Data
, we're leveraging Actix's built-in mechanism for sharing state across different handlers efficiently.
For instance Data<T>
in Actix uses an Arc<T>
internally, where Arc
stands for Atomic Reference Counting. The Arc
type ensures that the data is safely accessible from multiple threads concurrently.
This means that the underlying data is not recreated for each thread but rather shared across all threads that require access to it, in our case this seems like a good place to keep our model.
For instance, if you have some application state wrapped in Data<MyAppState>
, all HTTP request handlers can access the same instance of MyAppState
without duplicating it for each request or thread.
d. SSL Configuration: HTTPS is crucial for securing communication between clients and our server. We're using OpenSSL to set this up.
e. Server Configuration: We use various middleware:
- Logger: For request logging
- Compress: For response compression
- CORS: To control which origins can access our API
f. Route Definition: We define three main routes:
-
/
: A simple info endpoint -
/api/generate
: Our main inference endpoint -
/api/health
: A health check endpoint, useful for container orchestration systems
- Request Handling and Inference:
Let's dive deeper into how we handle requests and perform inference:
fn run_inference_session(
config: &Config,
model: &Box<dyn Model>,
prompt: String,
) -> Result<String, InferenceError> {
let mut inference_session = model.start_session(Default::default());
let inference_session_result = inference_session.infer::<Infallible>(
model.as_ref(),
// Input:
&mut rand::thread_rng(),
&llm::InferenceRequest {
prompt: (&*prompt).into(),
parameters: Option::from(&llm::InferenceParameters::default()),
play_back_previous_tokens: false,
maximum_token_count: Some(config.llm_inference_max_token_count),
},
// Output:
&mut Default::default(),
|response| {
print!("{response}");
std::io::stdout().flush().unwrap();
Ok(())
},
);
match inference_session_result {
Ok(_) => Ok(String::new()),
Err(err) => Err(err),
}
}
- Security Considerations:
Expanding on our security measures:
a. HTTPS: By using SSL/TLS, we encrypt all traffic between the client and server, preventing eavesdropping and man-in-the-middle attacks.
b. CORS: Our CORS configuration allows us to control which domains can access our API, preventing unauthorized access from malicious websites.
c. Input Validation: Although not shown in the code, it's crucial to implement proper input validation. For example:
fn validate_prompt(prompt: &str) -> Result<(), ValidationError> {
if prompt.is_empty() {
return Err(ValidationError::new("Prompt cannot be empty"));
}
if prompt.len() > 1000 {
return Err(ValidationError::new("Prompt exceeds maximum length"));
}
Ok(())
}
d. Rate Limiting: To prevent abuse, consider implementing rate limiting. Actix has middleware available for this purpose.
- Performance Optimizations:
Let's delve deeper into performance considerations:
a. Model Caching: Loading large language models can be time-consuming. We considered implementing a caching mechanism to keep the model in memory between requests.
b. Batching: If your use case allows, consider implementing request batching to process multiple prompts in a single inference pass.
c. Streaming Responses: For long generations, consider implementing streaming responses to start sending data to the client as soon as it's available.
pub async fn generate_stream(
data: web::Data<AppState>,
body: Json<GenerateRequest>,
) -> impl Responder {
let (tx, rx) = mpsc::channel(100);
actix_web::rt::spawn(async move {
run_inference_session_streaming(&data.config, &data.model, &body.prompt, tx).await;
});
HttpResponse::Ok()
.content_type("text/event-stream")
.streaming(rx.map(|token| Ok(Bytes::from(token)) as Result<Bytes, actix_web::Error>))
}
d. Asynchronous Processing: Leverage Rust's async capabilities to handle multiple requests concurrently without blocking.
- Model sourcing and quantisation:
When building a language model server, one crucial aspect is selecting the right model for your use case. Hugging Face's Model Hub (https://huggingface.co/models) is an excellent resource for finding pre-trained models, including quantised versions that can significantly reduce the computational and memory requirements of your server.
Quantisation is a technique that reduces the precision of the model's weights, typically from 32-bit floating-point numbers to 8-bit integers. This can dramatically decrease the model size and inference time, with often minimal impact on performance.
For our server, we experimented with quantised versions of popular models available on Hugging Face.
For this experiment I've chosen two models, I was capped by the capabilities of the llm
crate provides a unified interface for loading and using Large Language Models (LLMs). The following models are supported:
On disk the quantised models would look like:
- open_llama_7b-f16.bin
gpt4all-j-q5_1.bin
Running the experiment
If everything goes well you should be able to see something like this in your console:
...
Loaded tensor 288/291
Loading of model complete
Model size = 12853.47 MB / num tensors = 291
open_llama_7b-f16.bin model (llama) has been started!
Elapsed: 96ms
Starting server at https://localhost:8080.
Starting the inference session is as easy as hitting the endpoint with your own prompt now.
curl -k -X POST -H "Content-Type: application/json" -d '{"prompt": "Say hello!"}' https://localhost:8080/api/generate
Conclusion:
Building a private, language model server with Rust and Actix offers a powerful combination of performance, safety, and expressiveness and is quite straightforward to use.
By carefully considering aspects like configuration management, security, error handling, and performance optimisations, we can create a robust and efficient system capable of serving AI models at scale.
As you continue to develop your server, consider implementing additional features such as model versioning, A/B testing capabilities, and advanced monitoring and logging. Remember that serving AI models often requires handling large amounts of data and computation, so always be prepared to scale your infrastructure as needed.
This expanded guide should provide a solid foundation for building and understanding a language model server using Rust and Actix. As with any complex system, continuous testing, monitoring, and refinement will be key to long-term success.
Top comments (2)
You can also check out the Candle Library developed by HuggingFace, which provides inferencing on many models from HuggingFace repos. This is one of the libraries that I have made using Candle: github.com/StarlightSearch/EmbedAn...
Looking good there, interesting approach!
Many thanks for the link! 🙏