From 524565d8f8b22629c58e5c33a7766766e14b42d0 Mon Sep 17 00:00:00 2001 From: Nils Faltermann Date: Wed, 15 Nov 2023 14:56:29 +0100 Subject: [PATCH] fixing syst shifts --- src/ml.cxx | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/ml.cxx b/src/ml.cxx index 22d69675..d3d844f4 100644 --- a/src/ml.cxx +++ b/src/ml.cxx @@ -156,6 +156,7 @@ namespace sofie { std::string modelHeaderFile = modelName + std::string(".hxx"); + if (!std::filesystem::exists(modelHeaderFile)) { Logger::get("SOFIEGenerator") ->debug("generating model code..."); model.Generate(); @@ -171,9 +172,11 @@ namespace sofie { Logger::get("SOFIEGenerator") ->debug("compiling model code..."); CompileModelForRDF(modelHeaderFile, input_vec.size()); - - Logger::get("SOFIEGenerator") - ->debug("evaluating model code..."); + } + else { + Logger::get("SOFIEGenerator") + ->debug("model already compiled, skipping"); + } std::string sofie_func_str = "sofie_functor(rdfslot_, "; sofie_func_str += input_vec[0]; @@ -199,7 +202,11 @@ ROOT::RDF::RNode KerasEvaluate(ROOT::RDF::RNode df, Logger::get("KerasEvaluate") ->debug("finished loading model"); - auto df2 = df.Define(outputname, SOFIEGenerator(input_vec, model, modelFilePath)); + auto eval_func = SOFIEGenerator(input_vec, model, modelFilePath); + Logger::get("KerasEvaluate") + ->debug("evaluating model code..."); + + auto df2 = df.Define(outputname, eval_func); return df2; @@ -223,7 +230,11 @@ ROOT::RDF::RNode PyTorchEvaluate(ROOT::RDF::RNode df, Logger::get("PyTorchEvaluate") ->debug("finished loading model"); - auto df2 = df.Define(outputname, SOFIEGenerator(input_vec, model, modelFilePath)); + auto eval_func = SOFIEGenerator(input_vec, model, modelFilePath); + Logger::get("PyTorchEvaluate") + ->debug("evaluating model code..."); + + auto df2 = df.Define(outputname, eval_func); return df2; @@ -245,7 +256,11 @@ ROOT::RDF::RNode PyTorchEvaluate(ROOT::RDF::RNode df, // Logger::get("ONNXEvaluate") // ->debug("finished loading model"); -// auto df2 = df.Define(outputname, SOFIEGenerator(input_vec, model, modelFilePath)); +// auto eval_func = SOFIEGenerator(input_vec, model, modelFilePath); +// Logger::get("ONNXEvaluate") +// ->debug("evaluating model code..."); + +// auto df2 = df.Define(outputname, eval_func); // return df2;