[ RNNT LOSS ] Add skeleton code for rnnt loss
authorjijoong.moon <jijoong.moon@samsung.com>
Thu, 22 Jul 2021 06:35:30 +0000 (15:35 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 27 Jul 2021 07:01:58 +0000 (16:01 +0900)
This commit includes:
  . Sekeleton Code for RNNT Loss as an Custom Loss Layer
  . Sementic Unit Test

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
Applications/Custom/LayerPlugin/layer_plugin_rnnt_loss_test.cpp [new file with mode: 0644]
Applications/Custom/LayerPlugin/meson.build
Applications/Custom/meson.build
Applications/Custom/rnnt_loss.cpp [new file with mode: 0644]
Applications/Custom/rnnt_loss.h [new file with mode: 0644]

diff --git a/Applications/Custom/LayerPlugin/layer_plugin_rnnt_loss_test.cpp b/Applications/Custom/LayerPlugin/layer_plugin_rnnt_loss_test.cpp
new file mode 100644 (file)
index 0000000..ce6c1b9
--- /dev/null
@@ -0,0 +1,31 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
+ *
+ * @file   layer_plugin_rnnt_loss_test.cpp
+ * @date   10 June 2021
+ * @brief  This file contains the execution part of rnnt loss layer in
+ * LayerPlugin example
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jijoong Moon <jijoong.moon@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <layer_plugin_common_test.h>
+#include <layers_common_tests.h>
+#include <rnnt_loss.h>
+
+INSTANTIATE_TEST_CASE_P(
+  RNNTLossLayer, LayerPluginCommonTest,
+  ::testing::Values(std::make_tuple("librnnt_loss_layer.so", "rnnt_loss")));
+
+auto semantic_rnnt =
+  LayerSemanticsParamType(nntrainer::createLayer<custom::RNNTLossLayer>,
+                          custom::RNNTLossLayer::type, {}, 0, false);
+
+INSTANTIATE_TEST_CASE_P(RNNTLossLayer, LayerSemantics,
+                        ::testing::Values(semantic_rnnt));
index 9ac80aa..34bb2c0 100644 (file)
@@ -18,6 +18,16 @@ mae_loss_layer  = shared_library('mae_loss_layer',
   cpp_args: '-DPLUGGABLE'
 )
 
+# build command for librnnt_loss_layer.so
+rnnt_loss_layer  = shared_library('rnnt_loss_layer',
+  layer_example_rnnt_src,
+  dependencies: [nntrainer_dep, nntrainer_ccapi_dep],
+  include_directories: layer_example_inc,
+  install: true,
+  install_dir: application_install_dir,
+  cpp_args: '-DPLUGGABLE'
+)
+
 pow_layer_dep = declare_dependency(
   link_with: pow_layer,
   include_directories: layer_example_inc
@@ -28,11 +38,17 @@ mae_loss_layer_dep = declare_dependency(
   include_directories: layer_example_inc
   )
 
+rnnt_loss_layer_dep = declare_dependency(
+  link_with: rnnt_loss_layer,
+  include_directories: layer_example_inc
+  )
+
 if get_option('enable-test')
   test_target = [
     'layer_plugin_common_test.cpp',
     'layer_plugin_pow_test.cpp',
     'layer_plugin_mae_loss_test.cpp',
+    'layer_plugin_rnnt_loss_test.cpp'
   ]
 
   exe = executable(
@@ -44,6 +60,7 @@ if get_option('enable-test')
     nntrainer_testutil_dep,
     pow_layer_dep,
     mae_loss_layer_dep,
+    rnnt_loss_layer_dep,    
     nntrainer_layer_common_standalone_tests_dep
   ],
   install: get_option('enable-test'),
index 2db636b..4aa6f8c 100644 (file)
@@ -1,6 +1,7 @@
 
 layer_example_pow_src = files('pow.cpp')
 layer_example_mae_src = files('mae_loss.cpp')
+layer_example_rnnt_src = files('rnnt_loss.cpp')
 
 layer_example_inc = include_directories('./')
 
diff --git a/Applications/Custom/rnnt_loss.cpp b/Applications/Custom/rnnt_loss.cpp
new file mode 100644 (file)
index 0000000..89bb64c
--- /dev/null
@@ -0,0 +1,73 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jijoong Moon <jijoong.moon@samsung.com>
+ *
+ * @file   rnnt_loss.cpp
+ * @date   22 July 2021
+ * @brief  This file contains the rnnt loss
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jijoong Moon <jijoong.moon@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+#include "rnnt_loss.h"
+
+#include <cmath>
+
+#include <tensor.h>
+
+constexpr const float EPSILON_ = 1e-7;
+namespace custom {
+
+void RNNTLossLayer::finalize(nntrainer::InitLayerContext &context) {
+  context.setOutputDimensions(context.getInputDimensions());
+  // NYI
+}
+
+void RNNTLossLayer::setProperty(const std::vector<std::string> &values) {
+  if (!values.empty()) {
+    std::string msg = "[RNNTLossLayer] Unknown Layer Properties count " +
+                      std::to_string(values.size());
+    throw std::invalid_argument(msg);
+  }
+}
+
+void RNNTLossLayer::forwarding(nntrainer::RunLayerContext &context,
+                               bool training) {
+  // nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX);
+  // nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
+  // NYI
+}
+
+void RNNTLossLayer::calcDerivative(nntrainer::RunLayerContext &context) {
+  // nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX);
+  // nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX);
+
+  // nntrainer::Tensor &deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
+  // NYI
+}
+
+#ifdef PLUGGABLE
+
+nntrainer::Layer *create_rnnt_loss_layer() {
+  auto layer = new RNNTLossLayer();
+  std::cout << "rnnt loss layer created\n";
+  return layer;
+}
+
+void destory_rnnt_loss_layer(nntrainer::Layer *layer) {
+  std::cout << "rnnt loss layer destroyed\n";
+  delete layer;
+}
+
+/**
+ * @note ml_train_layer_pluggable defines the entry point for nntrainer to
+ * register a plugin layer
+ */
+extern "C" {
+nntrainer::LayerPluggable ml_train_layer_pluggable{create_rnnt_loss_layer,
+                                                   destory_rnnt_loss_layer};
+}
+
+#endif
+} // namespace custom
diff --git a/Applications/Custom/rnnt_loss.h b/Applications/Custom/rnnt_loss.h
new file mode 100644 (file)
index 0000000..b693ad6
--- /dev/null
@@ -0,0 +1,87 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Jijoong Moon <jijoong.moon@samsung.com>
+ *
+ * @file   rnnt_loss.h
+ * @date   22 July 2021
+ * @brief  This file contains the rnnt loss
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jijoong Moon <jijoong.moon@samsung.com>
+ * @bug    No known bugs except for NYI items
+ *
+ */
+#ifndef __RNNT_LOSS_LAYER_H__
+#define __RNNT_LOSS_LAYER_H__
+#include <string>
+
+#include <layer_context.h>
+#include <layer_devel.h>
+#include <node_exporter.h>
+
+namespace custom {
+
+/**
+ * @brief A rnnt loss layer which calculates rnnt loss
+ *
+ */
+class RNNTLossLayer final : public nntrainer::Layer {
+public:
+  /**
+   * @brief Construct a new RNNT Loss Layer object
+   *
+   */
+  RNNTLossLayer() : Layer() {}
+
+  /**
+   * @brief Destroy the RNNT Loss Layer object
+   *
+   */
+  ~RNNTLossLayer() = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(nntrainer::InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(nntrainer::RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(nntrainer::RunLayerContext &context) override;
+
+  /**
+   * @copydoc bool supportBackwarding() const
+   */
+  bool supportBackwarding() const override { return true; };
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
+   */
+  void exportTo(nntrainer::Exporter &exporter,
+                const nntrainer::ExportMethods &method) const override {}
+
+  /**
+   * @copydoc Layer::requireLabel()
+   */
+  bool requireLabel() const { return true; }
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override { return RNNTLossLayer::type; }
+
+  inline static const std::string type = "rnnt_loss";
+};
+
+} // namespace custom
+
+#endif /* __RNNT_LOSS_LAYER_H__ */