From 6a5588fba307562a372a702f60125e18369de8a1 Mon Sep 17 00:00:00 2001 From: Maarten Van Veghel Date: Tue, 9 Dec 2025 16:30:38 +0100 Subject: [PATCH] add test for dynamic model input example like for FT or isolation --- .../python/Hlt2Conf/lines/test/spruce_test.py | 60 ++++++++++++++++++- ...est_ml_selections_onnxruntime_line_test.py | 2 + ...est_ml_selections_pytorch_model_to_onnx.py | 38 ++++++++++++ 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/Hlt/Hlt2Conf/python/Hlt2Conf/lines/test/spruce_test.py b/Hlt/Hlt2Conf/python/Hlt2Conf/lines/test/spruce_test.py index fb7e97ee984..be007682577 100644 --- a/Hlt/Hlt2Conf/python/Hlt2Conf/lines/test/spruce_test.py +++ b/Hlt/Hlt2Conf/python/Hlt2Conf/lines/test/spruce_test.py @@ -19,7 +19,12 @@ from Functors.math import in_range from GaudiKernel.SystemOfUnits import GeV, MeV, mm from GaudiKernel.SystemOfUnits import micrometer as um from Moore.config import Hlt2Line, SpruceLine -from PyConf.Algorithms import MLServiceAlg, MLServiceAlg_Relation1D, ParticleTaggerAlg +from PyConf.Algorithms import ( + MLServiceAlg, + MLServiceAlg_2DTo1D_Relation1D, + MLServiceAlg_Relation1D, + ParticleTaggerAlg, +) from RecoConf.algorithms_thor import ParticleCombiner, ParticleFilter from RecoConf.event_filters import require_gec, require_pvs from RecoConf.reconstruction_objects import make_pvs, upfront_reconstruction @@ -460,6 +465,8 @@ def filter_v0s_mlservice_relations(v0s, pions, fileloc): """ Version where you have a more complicated setup (than just one (main) Particle) of input needed for your input features. Examples are flavour tagging or isolation. + In this case one evaluates the model for every relation individually. + For inference on a set of relations, see below. """ relations = ParticleTaggerAlg( Input=v0s, @@ -491,6 +498,47 @@ def filter_v0s_mlservice_relations(v0s, pions, fileloc): ) +def filter_v0s_mlservice_dynamic_input_with_relations(v0s, pions, fileloc): + """ + Version where you have a more complicated setup (than just one (main) Particle) + of input needed for your input features. Examples are flavour tagging or isolation. + This case doesn't process for every relation (in the input relation table) like above, + but where one has a Particle object (FROM) that could be related to an unknown number + of Particles (TO) and one wants to infer something from those relations with a fixed + set of features per TO object to an output (e.g. of a classifier) per FROM object. + """ + relations = ParticleTaggerAlg( + Input=v0s, + TaggingContainer=pions, + ).OutputRelations + # Relation table is FROM (v0), TO (pions), so first argument is FROM, i.e. FORWARDARG0 + mva_inputs = F.GATHER( + Inputs=[ + F.PT @ F.FORWARDARG0, + F.OWNPVIPCHI2 @ F.FORWARDARG0, + F.CHI2DOF @ F.FORWARDARG0, + F.MAXSDOCA @ F.FORWARDARG0, + ] + ) + mva_alg = MLServiceAlg_2DTo1D_Relation1D( + MLService="ONNXRuntimeSvc", + InputFrom=v0s, + InputRelations=relations, + MVAInput=mva_inputs, + ModelFile=fileloc, + ) + mva_val = ( + F.VALUE_OR(0) + @ F.ELEMENT_AT_INDEX(Index=0) + @ F.MAP_TO_RELATED(Relations=mva_alg.OutputRelations) + ) + return ParticleFilter( + v0s, + Cut=F.FILTER(mva_val > 0.5), + name="ParticleFilter_ONNX_MLServiceAlg_2DTo1D_Relation1D", + ) + + def make_v0_ml_filter_test_line( name, filter_v0s, model_file="file://onnxruntime_test.onnx", prescale=1 ): @@ -537,3 +585,13 @@ def Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_Relations_line( return make_v0_ml_filter_test_line( name=name, filter_v0s=filter_v0s_mlservice_relations ) + + +def Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_dynamic_input_with_Relations_line( + name="SpruceTest_Ks2PiPi_filter_ONNX_with_MLSeriveAlg_dynamic_input_with_relations", +): + return make_v0_ml_filter_test_line( + name=name, + filter_v0s=filter_v0s_mlservice_dynamic_input_with_relations, + model_file="file://onnxruntime_dynamic_axes_test.onnx", + ) diff --git a/Hlt/Hlt2Conf/tests/options/test_ml_selections_onnxruntime_line_test.py b/Hlt/Hlt2Conf/tests/options/test_ml_selections_onnxruntime_line_test.py index 05b90f1dfb7..543c5cb76b4 100644 --- a/Hlt/Hlt2Conf/tests/options/test_ml_selections_onnxruntime_line_test.py +++ b/Hlt/Hlt2Conf/tests/options/test_ml_selections_onnxruntime_line_test.py @@ -13,6 +13,7 @@ per example testing as Sprucing lines on topo{2,3} persistreco hlt2 output (use """ from Hlt2Conf.lines.test.spruce_test import ( + Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_dynamic_input_with_Relations_line, Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_line, Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_Relations_line, Test_v0s_with_ML_ONNXRuntime_MVA_functor_line, @@ -25,6 +26,7 @@ def make_lines(): Test_v0s_with_ML_ONNXRuntime_MVA_functor_line(), Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_line(), Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_Relations_line(), + Test_v0s_with_ML_ONNXRuntime_MLServiceAlg_dynamic_input_with_Relations_line(), ] diff --git a/Hlt/Hlt2Conf/tests/options/test_ml_selections_pytorch_model_to_onnx.py b/Hlt/Hlt2Conf/tests/options/test_ml_selections_pytorch_model_to_onnx.py index e048d7505a2..24d7e0572a1 100644 --- a/Hlt/Hlt2Conf/tests/options/test_ml_selections_pytorch_model_to_onnx.py +++ b/Hlt/Hlt2Conf/tests/options/test_ml_selections_pytorch_model_to_onnx.py @@ -78,6 +78,26 @@ for epoch in range(n_epochs): predicted_classes = (preds > 0.5).float() accuracy = (predicted_classes == y_train).float().mean() + +# dynamic version of the dummy classifier +class SimpleObjectSetClassifier(nn.Module): + def __init__(self, base_model): + super().__init__() + self.base_model = base_model + + def forward(self, x): + # x: [B, N, 4] + B, N, F = x.shape + x = x.view(B * N, F) + # evaluate dummy classifier for dynamic number of objects + logits = self.base_model(x) + logits = logits.view(B, N, 1) + # final classifier is the average output from all objects + return logits.mean(dim=1) + + +dynamic_model = SimpleObjectSetClassifier(model) + # Export to ONNX model.eval() dummy_input = torch.randn(1, 4) # Batch size 1, 4 input features @@ -94,3 +114,21 @@ torch.onnx.export( output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) + +dynamic_model.eval() +dummy_dynamic_input = torch.randn(1, 2, 4) +onnx_dynamic_file = "onnxruntime_dynamic_axes_test.onnx" +torch.onnx.export( + dynamic_model, + dummy_dynamic_input, + onnx_dynamic_file, + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size", 1: "objects_size"}, + "output": {0: "batch_size"}, + }, +) -- GitLab