Building a High-Performance Text Embedding API with Rust, Axum, and ONNX

Text embeddings are the backbone of modern AI applications—from semantic search to recommendation systems. In this tutorial, we’ll build a production-ready embedding API in Rust that supports two models: a lightweight MiniLM model and Google’s EmbedGem…


This content originally appeared on DEV Community and was authored by Mayuresh

Text embeddings are the backbone of modern AI applications—from semantic search to recommendation systems. In this tutorial, we'll build a production-ready embedding API in Rust that supports two models: a lightweight MiniLM model and Google's EmbedGemma.

What We'll Build

A REST API with two endpoints:

  • /embed-mini - Fast embeddings using MiniLM (ONNX)
  • /generate-embedding - Embeddings using EmbeddingGemma (Candle)

Prerequisites

  • Rust installed (1.70+)
  • Basic understanding of async Rust
  • Familiarity with REST APIs

Step 1: Set Up Your Project

Create a new Rust project and add dependencies:

cargo new embedding-api
cd embedding-api

Add these dependencies to your Cargo.toml:

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
candle-core = "0.4"
ort = "1.16"
ndarray = "0.15"
tokenizers = "0.15"
reqwest = "0.11"
anyhow = "1.0"

Step 2: Define Your Data Structures

Create a new file src/embeddings.rs and define the request/response types:

use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
pub struct EmbeddingRequest {
    text: String,
}

#[derive(Serialize)]
pub struct EmbeddingResponse {
    embedding: Vec<f32>,
    dimension: usize,
}

#[derive(Serialize)]
struct ErrorResponse {
    error: String,
}

Step 3: Build the ONNX Embedder (MiniLM)

The MiniLM model runs via ONNX Runtime for optimal performance:

use ort::{Environment, ExecutionProvider, Session, SessionBuilder};
use tokenizers::Tokenizer;
use std::sync::Arc;

pub struct Embedder {
    session: Session,
    tokenizer: Arc<Tokenizer>,
}

impl Embedder {
    pub fn new(model_path: &str, tokenizer_path: &str) -> anyhow::Result<Self> {
        println!("Loading ONNX model from: {}", model_path);

        let environment = Environment::builder()
            .with_name("embedder")
            .with_execution_providers([ExecutionProvider::CPU(Default::default())])
            .build()?
            .into_arc();

        let session = SessionBuilder::new(&environment)?
            .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
            .with_intra_threads(4)?
            .with_model_from_file(model_path)?;

        let tokenizer = Tokenizer::from_file(tokenizer_path)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;

        println!("✅ ONNX model loaded!");

        Ok(Self {
            session,
            tokenizer: Arc::new(tokenizer),
        })
    }
}

Step 4: Implement the Embedding Logic

Add the embedding generation method to the Embedder:

impl Embedder {
    pub fn embedd(&self, text: String) -> anyhow::Result<Vec<f32>> {
        // 1. Tokenize the input
        let encoding = self.tokenizer
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;

        let ids = encoding.get_ids();
        let mask = encoding.get_attention_mask();
        let seq_len = ids.len();

        // 2. Prepare inputs as i64
        let input_ids: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
        let attention_mask: Vec<i64> = mask.iter().map(|&x| x as i64).collect();
        let token_type_ids = vec![0i64; seq_len];

        // 3. Create 2D arrays
        let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)?;
        let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask)?;
        let token_type_ids_arr = Array2::from_shape_vec((1, seq_len), token_type_ids)?;

        // 4. Convert to ORT values
        let input_ids_val = ort::Value::from_array(
            self.session.allocator(), 
            &CowArray::from(&input_ids_arr.into_dyn())
        )?;
        let attention_mask_val = ort::Value::from_array(
            self.session.allocator(), 
            &CowArray::from(&attention_mask_arr.into_dyn())
        )?;
        let token_type_ids_val = ort::Value::from_array(
            self.session.allocator(), 
            &CowArray::from(&token_type_ids_arr.into_dyn())
        )?;

        // 5. Run inference
        let outputs = self.session.run(vec![
            input_ids_val, 
            attention_mask_val, 
            token_type_ids_val
        ])?;

        // 6. Extract embeddings
        let embeddings_tensor = outputs[0].try_extract::<f32>()?;
        let embeddings = embeddings_tensor.view();

        // 7. Mean pooling
        let hidden_size = 384;
        let mut pooled = vec![0.0f32; hidden_size];

        for token_idx in 0..seq_len {
            for dim_idx in 0..hidden_size {
                pooled[dim_idx] += embeddings[[0, token_idx, dim_idx]];
            }
        }

        for val in pooled.iter_mut() {
            *val /= seq_len as f32;
        }

        // 8. Normalize to unit vector
        let length: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
        let normalized: Vec<f32> = pooled.iter().map(|x| x / length).collect();

        Ok(normalized)
    }
}

Key concepts:

  • Tokenization: Converts text to numerical IDs
  • Mean pooling: Averages token embeddings into a single vector
  • Normalization: Ensures embeddings are comparable via cosine similarity

Step 5: Build the EmbedGemma Handler

For the Candle-based approach using Google's EmbedGemma:

use candle_core::{Device, Tensor, safetensors};

pub async fn load_model() -> anyhow::Result<(HashMap<String, Tensor>, Tokenizer, Device)> {
    let device = Device::Cpu;
    let model_path = std::path::Path::new("models/embeddgemma");

    let tokenizer_file = model_path.join("tokenizer.json");
    let tokenizer = Tokenizer::from_file(&tokenizer_file)
        .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;

    let model_file = model_path.join("model.safetensors");
    let tensors = safetensors::load(model_file, &device)?;

    println!("Loaded {:?} tensors", tensors.len());
    Ok((tensors, tokenizer, device))
}

Step 6: Create the Embedding Generation Function

pub async fn generate_embedding_internal(
    state: &AppState,
    text: String,
) -> Result<Vec<f32>, String> {
    // Tokenize
    let tokens = state.tokenizer
        .encode(text, true)
        .map_err(|e| format!("Tokenization error: {}", e))?
        .get_ids()
        .to_vec();

    // Get embedding matrix
    let embed_weights = state.tensors
        .get("embed_tokens.weight")
        .ok_or("embed_tokens.weight not found")?;

    // Look up embeddings for each token
    let mut embeddings_vec = Vec::new();
    for &token_id in &tokens {
        let token_tensor = Tensor::new(&[token_id as u32], &state.device)
            .map_err(|e| format!("Failed to create token tensor: {}", e))?;

        let token_embed = embed_weights
            .index_select(&token_tensor, 0)
            .map_err(|e| format!("Embedding lookup error: {}", e))?;

        embeddings_vec.push(token_embed);
    }

    // Stack and pool
    let stacked = Tensor::stack(&embeddings_vec, 0)
        .map_err(|e| format!("Stacking error: {}", e))?;

    let pooled = stacked.mean(0)
        .map_err(|e| format!("Pooling error: {}", e))?;

    // Convert to Vec<f32>
    let embedding_vec: Vec<f32> = pooled
        .squeeze(0)
        .map_err(|e| format!("Squeeze error: {}", e))?
        .to_vec1::<f32>()
        .map_err(|e| format!("Tensor conversion error: {}", e))?;

    // Normalize
    let length: f32 = embedding_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
    let normalized: Vec<f32> = embedding_vec.iter().map(|x| x / length).collect();

    Ok(normalized)
}

Step 7: Create API Handlers

use axum::{Json, extract::State, response::IntoResponse};
use reqwest::StatusCode;

pub async fn embed_mini(
    State(state): State<AppState>,
    Json(request): Json<EmbedRequest>,
) -> impl IntoResponse {
    match state.embedder.embedd(request.text) {
        Ok(embedding) => {
            (
                StatusCode::OK,
                Json(EmbedResponse {
                    embedding,
                    dimension: embedding.len(),
                }),
            ).into_response()
        }
        Err(e) => {
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ErrorResponse {
                    error: format!("Embedding generation failed: {}", e),
                }),
            ).into_response()
        }
    }
}

pub async fn generate_embedding(
    State(state): State<AppState>,
    Json(request): Json<EmbeddingRequest>,
) -> impl IntoResponse {
    match generate_embedding_internal(&state, request.text).await {
        Ok(embedding) => {
            (
                StatusCode::OK,
                Json(EmbeddingResponse {
                    embedding,
                    dimension: embedding.len(),
                }),
            ).into_response()
        }
        Err(e) => {
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ErrorResponse {
                    error: format!("Embedding generation failed: {}", e),
                }),
            ).into_response()
        }
    }
}

Step 8: Set Up Your Main Application

In src/main.rs:

use axum::{Router, routing::post};
use std::sync::Arc;

#[derive(Clone)]
pub struct AppState {
    embedder: Arc<Embedder>,
    tensors: HashMap<String, Tensor>,
    tokenizer: Arc<Tokenizer>,
    device: Device,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // Load models
    let embedder = Embedder::new(
        "models/minilm/model.onnx",
        "models/minilm/tokenizer.json"
    )?;

    let (tensors, tokenizer, device) = load_model().await?;

    let state = AppState {
        embedder: Arc::new(embedder),
        tensors,
        tokenizer: Arc::new(tokenizer),
        device,
    };

    // Create router
    let app = Router::new()
        .route("/embed-mini", post(embed_mini))
        .route("/generate-embedding", post(generate_embedding))
        .with_state(state);

    // Start server
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
    println!("🚀 Server running on http://localhost:3000");

    axum::serve(listener, app).await?;
    Ok(())
}

Step 9: Download Your Models

Create a models directory and download:

MiniLM (ONNX):

mkdir -p models/minilm
# Download from Hugging Face: sentence-transformers/all-MiniLM-L6-v2

EmbedGemma:

mkdir -p models/embeddgemma
# Download from Hugging Face: google/embeddinggemma

Step 10: Test Your API

Start the server:

cargo run --release

Test with curl:

# MiniLM endpoint
curl -X POST http://localhost:3000/embed-mini \
  -H "Content-Type: application/json" \
  -d '{"text": "Hello, world!"}'

# EmbedGemma endpoint
curl -X POST http://localhost:3000/generate-embedding \
  -H "Content-Type: application/json" \
  -d '{"text": "Rust is amazing!"}'

Performance Tips

  1. Use release mode for production: cargo build --release
  2. Adjust thread count in ONNX builder based on your CPU
  3. Add caching for frequently requested embeddings
  4. Consider GPU support for larger models using CUDA execution provider

Next Steps

  • Add batch processing support
  • Implement model caching strategies
  • Add metrics and monitoring
  • Deploy with Docker
  • Add authentication

Conclusion

You've built a production-ready embedding API in Rust! This setup gives you:

  • Fast inference with ONNX Runtime
  • Flexible model support (MiniLM, EmbedGemma)
  • Type-safe request handling
  • Easy integration with downstream applications

The normalized embeddings are ready for semantic search, clustering, or any similarity-based tasks.

Connect with me on LinkedIn

Questions? Drop them in the comments below!


This content originally appeared on DEV Community and was authored by Mayuresh


Print Share Comment Cite Upload Translate Updates
APA

Mayuresh | Sciencx (2025-10-18T03:12:41+00:00) Building a High-Performance Text Embedding API with Rust, Axum, and ONNX. Retrieved from https://www.scien.cx/2025/10/18/building-a-high-performance-text-embedding-api-with-rust-axum-and-onnx/

MLA
" » Building a High-Performance Text Embedding API with Rust, Axum, and ONNX." Mayuresh | Sciencx - Saturday October 18, 2025, https://www.scien.cx/2025/10/18/building-a-high-performance-text-embedding-api-with-rust-axum-and-onnx/
HARVARD
Mayuresh | Sciencx Saturday October 18, 2025 » Building a High-Performance Text Embedding API with Rust, Axum, and ONNX., viewed ,<https://www.scien.cx/2025/10/18/building-a-high-performance-text-embedding-api-with-rust-axum-and-onnx/>
VANCOUVER
Mayuresh | Sciencx - » Building a High-Performance Text Embedding API with Rust, Axum, and ONNX. [Internet]. [Accessed ]. Available from: https://www.scien.cx/2025/10/18/building-a-high-performance-text-embedding-api-with-rust-axum-and-onnx/
CHICAGO
" » Building a High-Performance Text Embedding API with Rust, Axum, and ONNX." Mayuresh | Sciencx - Accessed . https://www.scien.cx/2025/10/18/building-a-high-performance-text-embedding-api-with-rust-axum-and-onnx/
IEEE
" » Building a High-Performance Text Embedding API with Rust, Axum, and ONNX." Mayuresh | Sciencx [Online]. Available: https://www.scien.cx/2025/10/18/building-a-high-performance-text-embedding-api-with-rust-axum-and-onnx/. [Accessed: ]
rf:citation
» Building a High-Performance Text Embedding API with Rust, Axum, and ONNX | Mayuresh | Sciencx | https://www.scien.cx/2025/10/18/building-a-high-performance-text-embedding-api-with-rust-axum-and-onnx/ |

Please log in to upload a file.




There are no updates yet.
Click the Upload button above to add an update.

You must be logged in to translate posts. Please log in or register.