Train, convert and predict with ONNX Runtime

This example demonstrates an end to end scenario starting with the training of a machine learned model to its use in its converted from.

Train a logistic regression

The first step consists in retrieving the iris datset.

from sklearn.datasets import load_iris

iris = load_iris()
X, y = iris.data, iris.target

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y)

Then we fit a model.

from sklearn.linear_model import LogisticRegression

clr = LogisticRegression()
clr.fit(X_train, y_train)
/home/runner/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:444: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
LogisticRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


We compute the prediction on the test set and we show the confusion matrix.

from sklearn.metrics import confusion_matrix

pred = clr.predict(X_test)
print(confusion_matrix(y_test, pred))
[[14  0  0]
 [ 0 11  0]
 [ 0  1 12]]

Conversion to ONNX format

We use module sklearn-onnx to convert the model into ONNX format.

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)
with open("logreg_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We load the model with ONNX Runtime and look at its input and output.

import onnxruntime as rt

sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers())

print("input name='{}' and shape={}".format(sess.get_inputs()[0].name, sess.get_inputs()[0].shape))
print("output name='{}' and shape={}".format(sess.get_outputs()[0].name, sess.get_outputs()[0].shape))
input name='float_input' and shape=[None, 4]
output name='output_label' and shape=[None]

We compute the predictions.

input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name

import numpy

pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(confusion_matrix(pred, pred_onx))
[[14  0  0]
 [ 0 12  0]
 [ 0  0 12]]

The prediction are perfectly identical.

Probabilities

Probabilities are needed to compute other relevant metrics such as the ROC Curve. Let’s see how to get them first with scikit-learn.

prob_sklearn = clr.predict_proba(X_test)
print(prob_sklearn[:3])
[[1.90908960e-01 8.02874408e-01 6.21663156e-03]
 [2.81698299e-02 9.12759445e-01 5.90707247e-02]
 [9.67801591e-01 3.21983324e-02 7.64754321e-08]]

And then with ONNX Runtime. The probabilies appear to be

prob_name = sess.get_outputs()[1].name
prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0]

import pprint

pprint.pprint(prob_rt[0:3])
[{0: 0.19090914726257324, 1: 0.8028742074966431, 2: 0.0062166303396224976},
 {0: 0.02816985361278057, 1: 0.9127594232559204, 2: 0.059070732444524765},
 {0: 0.9678016304969788, 1: 0.03219832107424736, 2: 7.647536648391906e-08}]

Let’s benchmark.

from timeit import Timer


def speed(inst, number=10, repeat=20):
    timer = Timer(inst, globals=globals())
    raw = numpy.array(timer.repeat(repeat, number=number))
    ave = raw.sum() / len(raw) / number
    mi, ma = raw.min() / number, raw.max() / number
    print("Average %1.3g min=%1.3g max=%1.3g" % (ave, mi, ma))
    return ave


print("Execution time for clr.predict")
speed("clr.predict(X_test)")

print("Execution time for ONNX Runtime")
speed("sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]")
Execution time for clr.predict
Average 4.44e-05 min=4.24e-05 max=5.4e-05
Execution time for ONNX Runtime
Average 2.21e-05 min=2.16e-05 max=2.71e-05

2.2087500000509407e-05

Let’s benchmark a scenario similar to what a webservice experiences: the model has to do one prediction at a time as opposed to a batch of prediction.

def loop(X_test, fct, n=None):
    nrow = X_test.shape[0]
    if n is None:
        n = nrow
    for i in range(0, n):
        im = i % nrow
        fct(X_test[im : im + 1])


print("Execution time for clr.predict")
speed("loop(X_test, clr.predict, 100)")


def sess_predict(x):
    return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict")
speed("loop(X_test, sess_predict, 100)")
Execution time for clr.predict
Average 0.00406 min=0.00404 max=0.00414
Execution time for sess_predict
Average 0.00104 min=0.00102 max=0.00107

0.00103513229499967

Let’s do the same for the probabilities.

print("Execution time for predict_proba")
speed("loop(X_test, clr.predict_proba, 100)")


def sess_predict_proba(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba, 100)")
Execution time for predict_proba
Average 0.0061 min=0.00608 max=0.00626
Execution time for sess_predict_proba
Average 0.00108 min=0.00107 max=0.0011

0.001081757285000009

This second comparison is better as ONNX Runtime, in this experience, computes the label and the probabilities in every case.

Benchmark with RandomForest

We first train and save a model in ONNX format.

from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier()
rf.fit(X_train, y_train)

initial_type = [("float_input", FloatTensorType([1, 4]))]
onx = convert_sklearn(rf, initial_types=initial_type)
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

We compare.

sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers())


def sess_predict_proba_rf(x):
    return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]


print("Execution time for predict_proba")
speed("loop(X_test, rf.predict_proba, 100)")

print("Execution time for sess_predict_proba")
speed("loop(X_test, sess_predict_proba_rf, 100)")
Execution time for predict_proba
Average 0.674 min=0.672 max=0.678
Execution time for sess_predict_proba
Average 0.0013 min=0.00129 max=0.00133

0.001301839815000676

Let’s see with different number of trees.

measures = []

for n_trees in range(5, 51, 5):
    print(n_trees)
    rf = RandomForestClassifier(n_estimators=n_trees)
    rf.fit(X_train, y_train)
    initial_type = [("float_input", FloatTensorType([1, 4]))]
    onx = convert_sklearn(rf, initial_types=initial_type)
    with open("rf_iris_%d.onnx" % n_trees, "wb") as f:
        f.write(onx.SerializeToString())
    sess = rt.InferenceSession("rf_iris_%d.onnx" % n_trees, providers=rt.get_available_providers())

    def sess_predict_proba_loop(x):
        return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0]

    tsk = speed("loop(X_test, rf.predict_proba, 100)", number=5, repeat=5)
    trt = speed("loop(X_test, sess_predict_proba_loop, 100)", number=5, repeat=5)
    measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt})

from pandas import DataFrame

df = DataFrame(measures)
ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True)
df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True)
ax.set_xlabel("Number of trees")
ax.set_ylabel("Prediction time (s)")
ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset")
ax.legend()
Speed comparison between scikit-learn and ONNX Runtime For a random forest on Iris dataset
5
Average 0.0491 min=0.049 max=0.0491
Average 0.00103 min=0.00102 max=0.00105
10
Average 0.0823 min=0.0822 max=0.0824
Average 0.00104 min=0.00102 max=0.00106
15
Average 0.115 min=0.115 max=0.115
Average 0.00103 min=0.00102 max=0.00106
20
Average 0.148 min=0.148 max=0.148
Average 0.00105 min=0.00104 max=0.00108
25
Average 0.181 min=0.181 max=0.182
Average 0.00107 min=0.00106 max=0.00109
30
Average 0.214 min=0.214 max=0.214
Average 0.00107 min=0.00106 max=0.00109
35
Average 0.247 min=0.247 max=0.248
Average 0.00109 min=0.00108 max=0.00111
40
Average 0.28 min=0.279 max=0.28
Average 0.00109 min=0.00108 max=0.00112
45
Average 0.312 min=0.311 max=0.312
Average 0.00112 min=0.00111 max=0.00115
50
Average 0.345 min=0.344 max=0.345
Average 0.00112 min=0.00111 max=0.00114

<matplotlib.legend.Legend object at 0x7f0fd5bd2a40>

Total running time of the script: ( 3 minutes 8.150 seconds)

Gallery generated by Sphinx-Gallery