Skip to content

Commit

Permalink
fixing syst shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
nfaltermann committed Nov 15, 2023
1 parent a804bb7 commit 524565d
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/ml.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ namespace sofie {

std::string modelHeaderFile = modelName + std::string(".hxx");

if (!std::filesystem::exists(modelHeaderFile)) {
Logger::get("SOFIEGenerator")
->debug("generating model code...");
model.Generate();
Expand All @@ -171,9 +172,11 @@ namespace sofie {
Logger::get("SOFIEGenerator")
->debug("compiling model code...");
CompileModelForRDF(modelHeaderFile, input_vec.size());

Logger::get("SOFIEGenerator")
->debug("evaluating model code...");
}
else {
Logger::get("SOFIEGenerator")
->debug("model already compiled, skipping");
}

std::string sofie_func_str = "sofie_functor(rdfslot_, ";
sofie_func_str += input_vec[0];
Expand All @@ -199,7 +202,11 @@ ROOT::RDF::RNode KerasEvaluate(ROOT::RDF::RNode df,
Logger::get("KerasEvaluate")
->debug("finished loading model");

auto df2 = df.Define(outputname, SOFIEGenerator(input_vec, model, modelFilePath));
auto eval_func = SOFIEGenerator(input_vec, model, modelFilePath);
Logger::get("KerasEvaluate")
->debug("evaluating model code...");

auto df2 = df.Define(outputname, eval_func);

return df2;

Expand All @@ -223,7 +230,11 @@ ROOT::RDF::RNode PyTorchEvaluate(ROOT::RDF::RNode df,
Logger::get("PyTorchEvaluate")
->debug("finished loading model");

auto df2 = df.Define(outputname, SOFIEGenerator(input_vec, model, modelFilePath));
auto eval_func = SOFIEGenerator(input_vec, model, modelFilePath);
Logger::get("PyTorchEvaluate")
->debug("evaluating model code...");

auto df2 = df.Define(outputname, eval_func);

return df2;

Expand All @@ -245,7 +256,11 @@ ROOT::RDF::RNode PyTorchEvaluate(ROOT::RDF::RNode df,
// Logger::get("ONNXEvaluate")
// ->debug("finished loading model");

// auto df2 = df.Define(outputname, SOFIEGenerator(input_vec, model, modelFilePath));
// auto eval_func = SOFIEGenerator(input_vec, model, modelFilePath);
// Logger::get("ONNXEvaluate")
// ->debug("evaluating model code...");

// auto df2 = df.Define(outputname, eval_func);

// return df2;

Expand Down

0 comments on commit 524565d

Please sign in to comment.