high-performance NLP pipelines in Rust

The main feature of rust-bert is that it fits perfectly into Rust.

Setting up the environment

To use rust-bert you will need to install Libtorch is a C++ API from PyTorch that the library uses for deep learning. While you can choose automatic installation, manual setup will give you complete control, especially if you need to use the GPU.

Manual installation of Libtorch:

  1. Loading Libtorch:

    wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcu124.zip
  2. Unboxing:

    unzip libtorch-cxx11-abi-shared-with-deps-2.4.0+cu124.zip -d /path/to/libtorch
  3. Setting environment variables:

    Linux/macOS:

    export LIBTORCH=/path/to/libtorch
    export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH

    Windows (PowerShell):

    $Env:LIBTORCH="C:\path\to\libtorch"
    $Env:Path += ";C:\path\to\libtorch\lib"

If you don’t want to bother with manual configuration, you can use the automatic download of Libtorch via the flag download-libtorch. Convenient for CPU-only versions, but for CUDA you will need to specify the version via a variable TORCH_CUDA_VERSION.

Setting up automatic download:

[dependencies]
rust-bert = { version = "0.23", features = ["download-libtorch"] }
export TORCH_CUDA_VERSION=cu124
cargo build

Rust-bert also supports model cachingdownloaded and saved locally in ~/.cache/.rustbert. If you need to configure a different path, use the variable RUSTBERT_CACHE.

Key Features

Models for extracting answers to questions

One of the most useful features is extracting answers to questions. You provide the question and context, the model finds the answer indicating the exact coordinates in the text.

use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
fn main() -> anyhow::Result<()> {
    let qa_model = QuestionAnsweringModel::new(Default::default())?;
    let question = "Где живет Эми?";
    let context = "Эми живет в Амстердаме.";
    let answers = qa_model.predict(&[QaInput { question, context }], 1, 32);
    println!("{:?}", answers);
    Ok(())
}

The model finds “Amsterdam” with high accuracy.

Text translation models

Rust-bert supports translation models such as Marian And M2M100.

use rust_bert::pipelines::translation::{Language, TranslationModelBuilder};
fn main() -> anyhow::Result<()> {
    let model = TranslationModelBuilder::new()
        .with_source_languages(vec![Language::English])
        .with_target_languages(vec![Language::Russian])
        .create_model()?;
    
    let input_text = "This is a test sentence.";
    let output = model.translate(&[input_text], None, Language::Russian)?;
    for sentence in output {
        println!("{}", sentence);
    }
    Ok(())
}

Abstract summation

Abstract summarization allows you to reduce long texts to short but meaningful content.

use rust_bert::pipelines::summarization::SummarizationModel;
fn main() -> anyhow::Result<()> {
    let summarization_model = SummarizationModel::new(Default::default())?;
    let input = ["Ученые обнаружили воду в атмосфере планеты K2-18b..."];
    let output = summarization_model.summarize(&input);
    println!("{:?}", output);
    Ok(())
}

Text generation

The library supports GPT-2 And GPT.

use rust_bert::pipelines::text_generation::TextGenerationModel;
fn main() -> anyhow::Result<()> {
    let model = TextGenerationModel::new(Default::default())?;
    let input_context = "В один прекрасный день,";
    let output = model.generate(&[input_context], None);
    println!("{:?}", output);
    Ok(())
}

Neural network classifiers

Zero-shot classification allows models to classify text without the need for additional training for specific classes.

use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
fn main() -> anyhow::Result<()> {
    let model = ZeroShotClassificationModel::new(Default::default())?;
    let input_sentence = "Сегодня солнечная погода.";
    let candidate_labels = vec!["weather", "sports", "politics"];
    let output = model.predict(&[input_sentence], &candidate_labels, None, 128);
    println!("{:?}", output);
    Ok(())
}

Working with custom models and ONNX

If standard models are not suitable, rust-bert allows you to download custom models from PyTorch And ONNX. Export from PyTorch is done using torch.saveafter which the weights can be converted for use in rust-bert.

Example of loading a custom model:

use rust_bert::bert::BertModel;
use tch::nn::VarStore;

fn main() -> anyhow::Result<()> {
    let mut vs = VarStore::new(tch::Device::Cpu);
    let model = BertModel::new(&vs.root(), Default::default());
    vs.load("path/to/bert_model_weights.bin")?;
    Ok(())
}

Export to ONNX via Hugging Face Optimum:

from transformers import BertModel
from optimum.onnxruntime import ORTModelForQuestionAnswering
from optimum.exporters import TasksManager

model = BertModel.from_pretrained("bert-base-uncased")
TasksManager.export(model, "onnx", save_dir="path_to_save_onnx")

ONNX allows you to optimize inference on GPU and CPU, making the model faster and more efficient for high-load applications.

More details from rust-bert read here.

On October 23, there will be an open lesson “TSMixter, a modern architecture for time series decomposition from Google.” In this lesson, we will look at the modern TSMixter model from Google, which can automatically decompose a time series into complex components and build a forecast based on them. We'll learn how to load and work with this model, and compare it to more complex transform models such as NBEATS, NHITS, iTransformers, PatchTST, and TimesNet. You can sign up for a lesson on the course page “Machine Learning. Professional”.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *