Porting ML Models to Java Using ONNX

Hello everyone, my name is Evgeniy Munin. I am a Senior ML Engineer in Ad Tech in the betting platform for Web advertising and the author of the TG channel ML Advertising. Today I will tell you how we use ML models in advertising platforms with a JVM backend.

How do betting platforms work?

The chain of advertising delivery by request from advertisers through betting platform intermediaries to the user

The chain of advertising delivery by request from advertisers through betting platform intermediaries to the user

It all starts from the moment when a user enters a site that contains advertising slots. At this moment, the publisher who owns the site logs the user's entry and sends a request to the aggregator, the so-called Prebid. It in turn contacts the Supply Side platform, which holds an auction to sell the advertising slot. Auction participants evaluate the request to see if it meets the criteria of the clients' advertising campaigns, and if so, what is its value to the advertiser. After that, the participants send a response with a bid back. The winner of the auction puts their client's creative on the purchased advertising space for display.

Betting platforms process user requests, filter them and make recommendations close to real time. Often, the entire process from the request to the placement of the creative on the site should take no more than 150 ms, i.e. the time it takes for a user with a good Internet connection to load a web page. To meet this requirement, the back-end platforms are usually written in languages ​​for the JVM (I have seen in my practice in Java, Scala) or also in Rust.

So, what is the problem here, you ask me. And the problem is that the platform, processing the incoming request from the user, must apply ML models to it, which, for example, filter fraud, cut off low-quality traffic or tweak monetization. These ML models, which are supposed to be used on betting platforms, are in most cases written in Python frameworks: Sklearn, PyTorch, etc. (except perhaps SparkML, which has wrappers in both Python and Scala). Libraries do not often bother with issues of compatibility with other languages.

To resolve this issue, a unified format was developed. Open Neural Network Exchange (abbreviated ONNX), into which ML models from different libraries can be written and made available for use on platforms, including under the JVM.

Today we will look at the example of a simple logistic regression model on Sklearn, which predicts the probability of a bid on an advertising platform. In order to wrap it in ONNX format, we will use the library ONNXRuntime and its Java binding to run the model on the JVM.

Let's write a logistic regression model

First, we need to write the model pipeline itself. Here we will use sklearn.pipeline, which will include two stages:

  • FeatureHasher to hash the input features as follows: hash_feat = MurMurHash3(feat) % hash_size

  • LogisticRegression a model that predicts the likelihood of a bid being placed on an ad query

class GammaPipeline(Pipeline):
  def __init__(self):
    super().__init__([
        ("hasher", FeatureHasher(n_features=2**5, input_type="string", dtype=np.float32)),
        ("logreg",
            LogisticRegression(
                max_iter=10,
                random_state=42,
                penalty="l2",
                solver="lbfgs",
                C=1.0,
                tol=1e-4,
            ),
        ),
    ], verbose=True)

Once we have defined the pipeline, let's create a sample of some data.

data = {
    "timestamp": [datetime.datetime.now() for _ in range(3)],
    "feature1": ["val10", "val11", "val12"],
    "feature2": ["val20", "val21", "val22"],
    "target": [True, False, False],
}

df = pd.DataFrame(data)

x = df.drop(columns=["target"])
y = df["target"]

Let's fit the pipeline and display its structure

pipeline = GammaPipeline()
pipeline.fit(x.values, y)
Pipeline structure

Pipeline structure

Serializing the ONNX predictor

Now that the pipeline is trained, we will cast it into ONNX format. Here we will need to define initial_type features and their number in the sample. It is also worth disabling zipmap for logistic regression class labels, since we expect the output of the model not as a list of dictionaries with class labels, but simply as a vector of probabilities.

initial_type = [
    ('input', StringTensorType([None, len(x[0])])),
]

options = {LogisticRegression: {'zipmap': False}}

onnx_model = convert_sklearn(
    pipeline,
    initial_types=initial_type,
    options=options
)

with open(path_data + "models/onnx_log_reg.onnx", "wb") as f:
  f.write(onnx_model.SerializeToString())

Once the model is written to ONNX, we can visualize it using netron.app.

What the pipeline looks like under the hood. Everything to the left of the LinearClassifier is the feature hashing stage

What the pipeline looks like under the hood. Everything to the left of the LinearClassifier is the feature hashing stage

Reading ONNX predictor and running on Java

Now that the model is written in ONNX format, we can deserialize it on the JVM. In this example, I used Java 17.

First, let's create OrtEnvironmentlet's open a session for prediction using OrtSession class and pass the path to the .onnx predictor. When the predictor is read, we run the inference using session.run. Since the model expects a structure of the form Map<String, OnnxTensor>its column keys must match the name of the input columns in the pipeline, and the columns themselves must be of type OnnxTensor .

public class OnnxModelRunner {
    private OrtSession session;
    private OrtEnvironment environment;

    public OnnxModelRunner(String modelPath) throws OrtException {
        environment = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        session = environment.createSession(modelPath, options);
    }

    public OrtSession.Result runModel(String[][] bidders) throws OrtException {
        OnnxTensor inputTensor = OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), bidders);
        return session.run(Collections.singletonMap("input", inputTensor));
    }
}

After we have defined the OnnxModelRunner class, we can run the prediction. Let's prepare the features and create the object modelRunner and run the inference. Also here we will write the extraction of probabilities at the output of the model.

public class ApplicationOnnx {

    public static void main(String[] args) throws OrtException {

        String[][] bidders = {
            {"val11", "val21"},
            {"val12", "val22"},
            {"val13", "val23"}
        };

        OnnxModelRunner modelRunner = new OnnxModelRunner("onnx_log_reg_v1_2.onnx");

        OrtSession.Result results = modelRunner.runModel(bidders);

        StreamSupport.stream(results.spliterator(), false)
                .filter(onnxItem -> Objects.equals(onnxItem.getKey(), "probabilities"))
                .forEach(onnxItem -> {
                    OnnxValue onnxValue = onnxItem.getValue();
                    OnnxTensor tensor = (OnnxTensor) onnxValue;
                    try {
                        float[][] probas = (float[][]) tensor.getValue();
                        System.out.println(
                                "    tensor.getValue(): " + tensor.getValue() +
                                        "\n    probas: " + Arrays.deepToString(probas)
                        );
                    } catch (OrtException e) {
                        throw new RuntimeException(e);
                    }
                });
    }
}

In OrtSession.Result we have a Map of two elements:

for labels

OnnxTensor(info=TensorInfo(javaType=INT64,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,shape=[3]))

for probabilities

OnnxTensor(info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[3, 2]))

Let's filter the map by key probabilitiesselecting only probabilities. The values ​​have the type OnnxValuewhich we first cast in OnnxTensor and then into an array of numbers float[][]The model's dispersed output looks like this:

Output: probabilities: OnnxTensor(info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[3, 2]))
    tensor.getValue(): [[F@6438a396
    probas: [[0.29334846, 0.70665157], [0.853327, 0.14667302], [0.853327, 0.14667302]]

Finally

We looked at a simple example of how to port a Sklearn pipeline to ONNX, then read it and run predictions in Java. Naturally, depending on the complexity of the tasks, new questions will arise, for example:

  • How to write a pipeline with custom transformations in ONNX, and how to set shape_calculator And converter

  • What if you want to preserve the type of different input columns, instead of converting them all to string, float?

  • How to make ONNX model able to run on GPU prediction

But this example is already enough to solve the problem of compatibility of ML frameworks and platforms.

If you are interested in the topic of ML models in advertising platforms, and how to roll models into production, then visit my TG channel ML Advertising!

Thank you for reading to the end!

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *