Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / exec / Execution.cc
index 7d5b406..1384c9f 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "exec/Execution.h"
 
+#include "train/TrainableExecutors.h"
+
 #include "util/logging.h"
 
 namespace onert
@@ -151,6 +153,35 @@ void Execution::waitFinish()
 
 bool Execution::isFinished(void) const { return finished; }
 
+#ifdef ONERT_TRAIN
+void Execution::train(uint32_t training_step)
+{
+  auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get());
+  if (!execs)
+  {
+    throw std::runtime_error{"Supported only TrainableExecutors"};
+  }
+
+  VERBOSE(Execution) << "Start training" << std::endl;
+
+  execs->train(_io_desc, training_step);
+  finished = true;
+
+  VERBOSE(Execution) << "training finished" << std::endl;
+}
+
+float Execution::getLoss(const ir::IOIndex &ind)
+{
+  auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get());
+  if (!execs)
+  {
+    throw std::runtime_error{"Supported only TrainableExecutors"};
+  }
+
+  return execs->getLoss(ind);
+}
+#endif // ONERT_TRAIN
+
 ir::Shape Execution::getInputShape(ir::IOIndex ind) const
 {
   auto itr = _io_desc.dynamic_input_shapes.find(ind);
@@ -180,5 +211,16 @@ ir::Shape Execution::getOutputShape(ir::IOIndex ind) const
   return output_desc->info.shape();
 }
 
+size_t Execution::getInputTotalSize(ir::IOIndex ind) const
+{
+  // TODO Support dynamic shape
+  return _executors->inputInfo(ind).total_size();
+}
+
+size_t Execution::getOutputTotalSize(ir::IOIndex ind) const
+{
+  return _executors->outputInfo(ind).total_size();
+}
+
 } // namespace exec
 } // namespace onert