[layer] Support reshape layer
authorParichay Kapoor <pk.kapoor@samsung.com>
Tue, 19 Oct 2021 07:10:44 +0000 (16:10 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 20 Oct 2021 12:17:45 +0000 (21:17 +0900)
This patch provides support for reshape layer and basic unittests.
The flatten layer is also updated to use reshape layer internally.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
13 files changed:
api/ccapi/include/layer.h
jni/Android.mk
nntrainer/app_context.cpp
nntrainer/layers/common_properties.cpp
nntrainer/layers/common_properties.h
nntrainer/layers/flatten_layer.cpp
nntrainer/layers/flatten_layer.h
nntrainer/layers/layer_node.cpp
nntrainer/layers/meson.build
nntrainer/layers/reshape_layer.cpp [new file with mode: 0644]
nntrainer/layers/reshape_layer.h [new file with mode: 0644]
test/unittest/layers/meson.build
test/unittest/layers/unittest_layers_reshape.cpp [new file with mode: 0644]

index 4f965be..625ee07 100644 (file)
@@ -67,6 +67,7 @@ enum LayerType {
   LAYER_BACKBONE_TFLITE,                   /**< Backbone using TFLite */
   LAYER_ATTENTION,                         /**< Attention Layer type */
   LAYER_CONV1D,                            /**< Convolution 1D Layer type */
+  LAYER_RESHAPE,                           /**< Reshape Layer type */
   LAYER_LOSS_MSE = 500,             /**< Mean Squared Error Loss Layer type */
   LAYER_LOSS_CROSS_ENTROPY_SIGMOID, /**< Cross Entropy with Sigmoid Loss Layer
                                        type */
@@ -245,6 +246,14 @@ Flatten(const std::vector<std::string> &properties = {}) {
 }
 
 /**
+ * @brief Helper function to create reshape layer
+ */
+inline std::unique_ptr<Layer>
+Reshape(const std::vector<std::string> &properties = {}) {
+  return createLayer(LayerType::LAYER_RESHAPE, properties);
+}
+
+/**
  * @brief Helper function to create addition layer
  */
 inline std::unique_ptr<Layer>
index 8f70da6..e2b41a2 100644 (file)
@@ -160,6 +160,7 @@ NNTRAINER_SRCS := $(NNTRAINER_ROOT)/nntrainer/models/neuralnet.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/pooling2d_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/activation_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/flatten_layer.cpp \
+                  $(NNTRAINER_ROOT)/nntrainer/layers/reshape_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/addition_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/attention_layer.cpp \
                   $(NNTRAINER_ROOT)/nntrainer/layers/concat_layer.cpp \
index 659810e..4fdd2b4 100644 (file)
@@ -224,6 +224,8 @@ static void add_default_object(AppContext &ac) {
                      Pooling2DLayer::type, LayerType::LAYER_POOLING2D);
   ac.registerFactory(nntrainer::createLayer<FlattenLayer>, FlattenLayer::type,
                      LayerType::LAYER_FLATTEN);
+  ac.registerFactory(nntrainer::createLayer<ReshapeLayer>, ReshapeLayer::type,
+                     LayerType::LAYER_RESHAPE);
   ac.registerFactory(nntrainer::createLayer<ActivationLayer>,
                      ActivationLayer::type, LayerType::LAYER_ACTIVATION);
   ac.registerFactory(nntrainer::createLayer<AdditionLayer>, AdditionLayer::type,
index 4eeb577..8f793a6 100644 (file)
@@ -13,6 +13,7 @@
 #include <common_properties.h>
 
 #include <nntrainer_error.h>
+#include <nntrainer_log.h>
 #include <tensor_dim.h>
 
 #include <regex>
@@ -303,6 +304,19 @@ bool WeightRegularizer::isValid(
 }
 
 FlipDirection::FlipDirection(FlipDirectionInfo::Enum value) { set(value); }
+
+void GenericShape::set(const TensorDim &value) {
+  TensorDim ret = value;
+  ret.setDynDimFlag(0b1000);
+  if (ret.batch() != 1) {
+    ml_logw("Batch size set with dimension %u is ignored."
+            "Use batchsize property for the model to update batchsize.",
+            ret.batch());
+    ret.batch(1);
+  }
+  Property<TensorDim>::set(ret);
+}
+
 } // namespace props
 
 static const std::vector<std::pair<char, std::string>>
index 1fa5c94..9eaaba4 100644 (file)
@@ -960,6 +960,42 @@ public:
   using prop_tag = uint_prop_tag; /**< property type */
 };
 
+/**
+ * @brief generic shape property which saves a single tensor shape
+ * (practically, std::array<GenericShape> is used)
+ *
+ * @note batch dimension is ignored with this dimension. Setting of batch must
+ * be done with the model.
+ *
+ */
+class GenericShape : public Property<TensorDim> {
+
+public:
+  static constexpr const char *key =
+    "generic_shape";                   /**< unique key to access */
+  using prop_tag = dimension_prop_tag; /**< property type */
+
+  /**
+   * @brief Input shape setter
+   *
+   * @param value value to set
+   */
+  void set(const TensorDim &value) override;
+};
+
+/**
+ * @brief target shape property which saves a single tensor shape
+ * (practically, std::array<TargetShape> is used)
+ *
+ */
+class TargetShape : public GenericShape {
+
+public:
+  static constexpr const char *key =
+    "target_shape";                    /**< unique key to access */
+  using prop_tag = dimension_prop_tag; /**< property type */
+};
+
 } // namespace props
 } // namespace nntrainer
 
index 43c1bbc..9122c91 100644 (file)
@@ -21,36 +21,15 @@ namespace nntrainer {
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
 void FlattenLayer::finalize(InitLayerContext &context) {
-  if (context.getNumInputs() != 1) {
-    throw std::invalid_argument("input_shape keyword is only for one input");
-  }
+  ReshapeLayer::setProperty({"target_shape=-1"});
+  /** @note the output dimension is in invalid state till finalize of
+   * reshape_layer is finished */
+  ReshapeLayer::finalize(context);
 
-  TensorDim out_dim;
   const TensorDim &in_dim = context.getInputDimensions()[0];
   if (in_dim.channel() == 1 && in_dim.height() == 1) {
     ml_logw("Warning: the flatten layer is redundant");
   }
-
-  out_dim.batch(in_dim.batch());
-  out_dim.channel(1);
-  out_dim.height(1);
-  out_dim.width(in_dim.getFeatureLen());
-
-  context.setOutputDimensions({out_dim});
-}
-
-void FlattenLayer::forwarding(RunLayerContext &context, bool training) {
-  if (!context.executeInPlace()) {
-    context.getOutput(SINGLE_INOUT_IDX)
-      .copyData(context.getInput(SINGLE_INOUT_IDX));
-  }
-}
-
-void FlattenLayer::calcDerivative(RunLayerContext &context) {
-  if (!context.executeInPlace()) {
-    context.getOutgoingDerivative(SINGLE_INOUT_IDX)
-      .copyData(context.getIncomingDerivative(SINGLE_INOUT_IDX));
-  }
 }
 
 void FlattenLayer::setProperty(const std::vector<std::string> &values) {
index bc6a5c3..9331802 100644 (file)
@@ -15,7 +15,7 @@
 #define __FLATTEN_LAYER_H__
 #ifdef __cplusplus
 
-#include <layer_devel.h>
+#include <reshape_layer.h>
 
 namespace nntrainer {
 
@@ -23,12 +23,12 @@ namespace nntrainer {
  * @class   Flatten Layer
  * @brief   Flatten Layer
  */
-class FlattenLayer : public Layer {
+class FlattenLayer : public ReshapeLayer {
 public:
   /**
    * @brief     Constructor of Flatten Layer
    */
-  FlattenLayer() : Layer() {}
+  FlattenLayer() : ReshapeLayer() {}
 
   /**
    * @brief     Destructor of Flatten Layer
@@ -53,37 +53,11 @@ public:
   void finalize(InitLayerContext &context) override;
 
   /**
-   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
-   */
-  void forwarding(RunLayerContext &context, bool training) override;
-
-  /**
-   * @copydoc Layer::calcDerivative(RunLayerContext &context)
-   */
-  void calcDerivative(RunLayerContext &context) override;
-
-  /**
    * @copydoc Layer::setProperty(const std::vector<std::string> &values)
    */
   void setProperty(const std::vector<std::string> &values) override;
 
   /**
-   * @copydoc bool supportBackwarding() const
-   */
-  bool supportBackwarding() const override { return true; };
-
-  /**
-   * @copydoc Layer::supportInPlace()
-   */
-  bool supportInPlace() const override { return true; }
-
-  /**
-   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
-   */
-  void exportTo(Exporter &exporter,
-                const ExportMethods &method) const override {}
-
-  /**
    * @copydoc Layer::getType()
    */
   const std::string getType() const override { return FlattenLayer::type; };
index db066f5..91ee497 100644 (file)
@@ -90,28 +90,11 @@ public:
  * (practically, std::array<InputShape> is used)
  *
  */
-class InputShape : public Property<TensorDim> {
+class InputShape : public GenericShape {
 
 public:
   static constexpr const char *key = "input_shape"; /**< unique key to access */
   using prop_tag = dimension_prop_tag;              /**< property type */
-
-  /**
-   * @brief Input shape setter
-   *
-   * @param value value to set
-   */
-  void set(const TensorDim &value) override {
-    TensorDim ret = value;
-    ret.setDynDimFlag(0b1000);
-    if (ret.batch() != 1) {
-      ml_logw("Batch size set with input dimension %u is ignored."
-              "Use batchsize property for the model to update batchsize.",
-              ret.batch());
-      ret.batch(1);
-    }
-    Property<TensorDim>::set(ret);
-  }
 };
 
 /**
index 68e43bc..5f27908 100644 (file)
@@ -30,7 +30,8 @@ layer_sources = [
   'gru.cpp',
   'dropout.cpp',
   'centroid_knn.cpp',
-  'layer_context.cpp'
+  'layer_context.cpp',
+  'reshape_layer.cpp'
 ]
 
 layer_headers = [
diff --git a/nntrainer/layers/reshape_layer.cpp b/nntrainer/layers/reshape_layer.cpp
new file mode 100644 (file)
index 0000000..86e8ce5
--- /dev/null
@@ -0,0 +1,72 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
+ *
+ * @file   flatten_layer.cpp
+ * @date   16 June 2020
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Jijoong Moon <jijoong.moon@samsung.com>
+ * @bug           No known bugs except for NYI items
+ * @brief  This is Flatten Layer Class for Neural Network
+ *
+ * @todo Update flatten to work in-place properly.
+ */
+
+#include <nntrainer_error.h>
+#include <nntrainer_log.h>
+#include <node_exporter.h>
+#include <reshape_layer.h>
+
+namespace nntrainer {
+
+static constexpr size_t SINGLE_INOUT_IDX = 0;
+
+void ReshapeLayer::finalize(InitLayerContext &context) {
+  if (context.getNumInputs() != 1) {
+    throw std::invalid_argument("Reshape only supports 1 input for now");
+  }
+
+  const TensorDim &in_dim = context.getInputDimensions()[0];
+
+  auto &target_shape = std::get<props::TargetShape>(reshape_props);
+  if (target_shape.empty())
+    throw std::invalid_argument(
+      "Reshape layer must be provided with target shape");
+  TensorDim out_dim = target_shape.get();
+
+  /** flatten sets the dimension to 1 to indicate to flatten the rest of the
+   * dimensions */
+  if ((int)out_dim.getDataLen() == -1) {
+    out_dim.height(1);
+    out_dim.channel(1);
+    out_dim.width(in_dim.getFeatureLen());
+  }
+
+  out_dim.batch(in_dim.batch());
+
+  context.setOutputDimensions({out_dim});
+}
+
+void ReshapeLayer::forwarding(RunLayerContext &context, bool training) {
+  if (!context.executeInPlace()) {
+    context.getOutput(SINGLE_INOUT_IDX)
+      .copyData(context.getInput(SINGLE_INOUT_IDX));
+  }
+}
+
+void ReshapeLayer::calcDerivative(RunLayerContext &context) {
+  if (!context.executeInPlace()) {
+    context.getOutgoingDerivative(SINGLE_INOUT_IDX)
+      .copyData(context.getIncomingDerivative(SINGLE_INOUT_IDX));
+  }
+}
+
+void ReshapeLayer::setProperty(const std::vector<std::string> &values) {
+  auto remain_props = loadProperties(values, reshape_props);
+  if (!remain_props.empty()) {
+    std::string msg = "[ReshapeLayer] Unknown Layer Properties count " +
+                      std::to_string(remain_props.size());
+    throw exception::not_supported(msg);
+  }
+}
+} /* namespace nntrainer */
diff --git a/nntrainer/layers/reshape_layer.h b/nntrainer/layers/reshape_layer.h
new file mode 100644 (file)
index 0000000..0248f6b
--- /dev/null
@@ -0,0 +1,102 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
+ *
+ * @file   reshape_layer.h
+ * @date   16 June 2020
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is Reshape Layer Class for Neural Network
+ *
+ */
+
+#ifndef __RESHAPE_LAYER_H__
+#define __RESHAPE_LAYER_H__
+#ifdef __cplusplus
+
+#include <common_properties.h>
+#include <layer_devel.h>
+
+namespace nntrainer {
+
+/**
+ * @class   Reshape Layer
+ * @brief   Reshape Layer
+ */
+class ReshapeLayer : public Layer {
+public:
+  /**
+   * @brief     Constructor of Reshape Layer
+   */
+  ReshapeLayer() : Layer() {}
+
+  /**
+   * @brief     Destructor of Reshape Layer
+   */
+  ~ReshapeLayer() = default;
+
+  /**
+   *  @brief  Move constructor of ReshapeLayer.
+   *  @param[in] ReshapeLayer &&
+   */
+  ReshapeLayer(ReshapeLayer &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs ReshapeLayer to be moved.
+   */
+  ReshapeLayer &operator=(ReshapeLayer &&rhs) = default;
+
+  /**
+   * @copydoc Layer::finalize(InitLayerContext &context)
+   */
+  void finalize(InitLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
+   */
+  void forwarding(RunLayerContext &context, bool training) override;
+
+  /**
+   * @copydoc Layer::calcDerivative(RunLayerContext &context)
+   */
+  void calcDerivative(RunLayerContext &context) override;
+
+  /**
+   * @copydoc Layer::setProperty(const std::vector<std::string> &values)
+   */
+  void setProperty(const std::vector<std::string> &values) override;
+
+  /**
+   * @copydoc bool supportBackwarding() const
+   */
+  bool supportBackwarding() const override { return true; };
+
+  /**
+   * @copydoc Layer::supportInPlace()
+   */
+  bool supportInPlace() const override { return true; }
+
+  /**
+   * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method)
+   */
+  void exportTo(Exporter &exporter,
+                const ExportMethods &method) const override {}
+
+  /**
+   * @copydoc Layer::getType()
+   */
+  const std::string getType() const override { return ReshapeLayer::type; };
+
+  inline static const std::string type = "reshape";
+
+private:
+  std::tuple<props::TargetShape>
+    reshape_props; /**< reshape properties : target_shape after reshape */
+};
+
+} // namespace nntrainer
+
+#endif /* __cplusplus */
+#endif /* __RESHAPE_LAYER_H__ */
index 2b0d2c3..353d1a3 100644 (file)
@@ -52,6 +52,7 @@ test_target = [
   'unittest_layers_permute.cpp',
   'unittest_layers_attention.cpp',
   'unittest_layers_dropout.cpp',
+  'unittest_layers_reshape.cpp',
 ]
 
 if get_option('enable-tflite-backbone')
diff --git a/test/unittest/layers/unittest_layers_reshape.cpp b/test/unittest/layers/unittest_layers_reshape.cpp
new file mode 100644 (file)
index 0000000..d59763b
--- /dev/null
@@ -0,0 +1,24 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
+ *
+ * @file unittest_layers_flatten.cpp
+ * @date 19 October 2021
+ * @brief Reshape Layer Test
+ * @see        https://github.com/nnstreamer/nntrainer
+ * @author Parichay Kapoor <pk.kapoor@samsung.com>
+ * @bug No known bugs except for NYI items
+ */
+#include <tuple>
+
+#include <gtest/gtest.h>
+
+#include <layers_common_tests.h>
+#include <reshape_layer.h>
+
+auto semantic_reshape = LayerSemanticsParamType(
+  nntrainer::createLayer<nntrainer::ReshapeLayer>,
+  nntrainer::ReshapeLayer::type, {"target_shape=-1"}, 0, false, 1);
+
+INSTANTIATE_TEST_CASE_P(Reshape, LayerSemantics,
+                        ::testing::Values(semantic_reshape));