[layers] flatten layer consider target_shape
authorseongwoo <mhs4670go@naver.com>
Thu, 12 May 2022 07:59:31 +0000 (16:59 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 18 May 2022 10:36:22 +0000 (19:36 +0900)
This patch makes a flatten layer consider target_shape and use it during exporting.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: seongwoo <mhs4670go@naver.com>
nntrainer/layers/flatten_layer.cpp
nntrainer/layers/reshape_layer.cpp
nntrainer/layers/reshape_layer.h

index cb11adc..a6fc1f8 100644 (file)
@@ -23,19 +23,24 @@ namespace nntrainer {
 static constexpr size_t SINGLE_INOUT_IDX = 0;
 
 void FlattenLayer::finalize(InitLayerContext &context) {
-  ReshapeLayer::setProperty({"target_shape=-1"});
+  const TensorDim &in_dim = context.getInputDimensions()[0];
+
+  std::string target_shape =
+    "target_shape=1:1:" + std::to_string(in_dim.getFeatureLen());
+  ReshapeLayer::setProperty({target_shape});
+
   /** @note the output dimension is in invalid state till finalize of
    * reshape_layer is finished */
   ReshapeLayer::finalize(context);
 
-  const TensorDim &in_dim = context.getInputDimensions()[0];
   if (in_dim.channel() == 1 && in_dim.height() == 1) {
     ml_logw("Warning: the flatten layer is redundant");
   }
 }
 
 void FlattenLayer::setProperty(const std::vector<std::string> &values) {
-  if (!values.empty()) {
+  auto remain_props = loadProperties(values, reshape_props);
+  if (!remain_props.empty()) {
     std::string msg = "[FlattenLayer] Unknown Layer Properties count " +
                       std::to_string(values.size());
     throw exception::not_supported(msg);
@@ -43,6 +48,8 @@ void FlattenLayer::setProperty(const std::vector<std::string> &values) {
 }
 
 void FlattenLayer::exportTo(Exporter &exporter,
-                            const ExportMethods &method) const {}
+                            const ExportMethods &method) const {
+  exporter.saveResult(reshape_props, method, this);
+}
 
 } /* namespace nntrainer */
index e4783ed..a32a387 100644 (file)
@@ -34,8 +34,6 @@ void ReshapeLayer::finalize(InitLayerContext &context) {
       "Reshape layer must be provided with target shape");
   TensorDim out_dim = target_shape.get();
 
-  /** flatten sets the dimension to 1 to indicate to flatten the rest of the
-   * dimensions */
   if ((int)out_dim.getDataLen() == -1) {
     out_dim.height(1);
     out_dim.channel(1);
index c9c155b..cd56cda 100644 (file)
@@ -90,7 +90,7 @@ public:
 
   inline static const std::string type = "reshape";
 
-private:
+protected:
   std::tuple<props::TargetShape>
     reshape_props; /**< reshape properties : target_shape after reshape */
 };