[exo] Fuse Instance Norm logic (#9103)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Fri, 22 Nov 2019 07:45:20 +0000 (16:45 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Fri, 22 Nov 2019 07:45:20 +0000 (16:45 +0900)
This commit implements logic for Instance Norm fusion.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/exo/src/Pass/FuseInstanceNormPass.cpp

index e79d7d9..0d3428c 100644 (file)
@@ -330,10 +330,53 @@ bool InstanceNormPattern::matched()
   return true;
 }
 
+/**
+ * Instance norm pattern would be fused like following diagram:
+ *
+ *    [In] --------------------------- CircleInstanceNorm --- [Out]
+ *                                     / /
+ *    const_as_gamma --- TFLReshape --- /
+ *                                     /
+ *    const_as_beta ---- TFLReshape ---
+ *
+ * Note
+ *  - 'const_as_gamma' and 'const_as_beta' are from original graph
+ *  - Value of 'const_as_epsilon' would be copied to CircleInstanceNorm's attribute
+ *  - TFLReshape is added as CircleInstanceNorm only accept 1D tensor
+ *  - 'TFLConst --- TFLReshape' is expected to be fused in constant folding for Reshape
+ */
 void fuse_instance_norm(const InstanceNormPattern &p)
 {
   assert(p.matched());
-  // TODO implement
+
+  auto graph = p.add_as_terminal->graph();
+
+  // Make reshape for gamma & beta
+  auto reshape_gamma = graph->nodes()->create<locoex::TFLReshape>();
+  auto reshape_beta = graph->nodes()->create<locoex::TFLReshape>();
+  {
+    auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>();
+    uint32_t ifm_channel_depth = ifm_shape.dim(3).value();
+
+    int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)};
+
+    reshape_gamma->tensor(p.const_as_gamma);
+    reshape_beta->tensor(p.const_as_beta);
+
+    locoex::set_new_shape(reshape_gamma, new_shape, 1);
+    locoex::set_new_shape(reshape_beta, new_shape, 1);
+  }
+
+  // Make Instance Norm to replace
+  auto instance_norm = graph->nodes()->create<locoex::CircleInstanceNorm>();
+  instance_norm->input(p.ifm);
+  instance_norm->gamma(reshape_gamma);
+  instance_norm->beta(reshape_beta);
+  float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
+  instance_norm->epsilon(epsilon);
+  instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
+
+  replace(p.add_as_terminal).with(instance_norm);
 }
 
 } // namespace