[moco-tf] Apply plier-tf to AvgPoolCanonicalizer (#5921)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Fri, 26 Jul 2019 05:30:04 +0000 (14:30 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Fri, 26 Jul 2019 05:30:04 +0000 (14:30 +0900)
This commit applies plier-tf to AvgPoolCanonicalizer.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp

index 2d45260..4ffc74d 100644 (file)
 #include "Dialect/TFNodeImpl.h"
 
 #include <moco/Log.h>
+#include <plier/tf/Convert.h>
 
 #include <stdex/Memory.h>
 
 namespace
 {
 
-void set_feature_enc(loco::FeatureEncode *feature_enc, moco::tf::DataLayout data_layout)
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
 {
   auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
 
-  if (data_layout == moco::tf::DataLayout::NHWC)
+  if (data_layout == DataLayout::NHWC)
   {
     enc->perm()->axis(loco::FeatureAxis::Count) = 0;
     enc->perm()->axis(loco::FeatureAxis::Height) = 1;
     enc->perm()->axis(loco::FeatureAxis::Width) = 2;
     enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
   }
-  else if (data_layout == moco::tf::DataLayout::NCHW)
+  else if (data_layout == DataLayout::NCHW)
   {
     enc->perm()->axis(loco::FeatureAxis::Count) = 0;
     enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
@@ -57,18 +60,18 @@ void set_feature_enc(loco::FeatureEncode *feature_enc, moco::tf::DataLayout data
   feature_enc->encoder(std::move(enc));
 }
 
-void set_feature_dec(loco::FeatureDecode *feature_dec, moco::tf::DataLayout data_layout)
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
 {
   auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
 
-  if (data_layout == moco::tf::DataLayout::NHWC)
+  if (data_layout == DataLayout::NHWC)
   {
     dec->perm()->axis(loco::FeatureAxis::Count) = 0;
     dec->perm()->axis(loco::FeatureAxis::Height) = 1;
     dec->perm()->axis(loco::FeatureAxis::Width) = 2;
     dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
   }
-  else if (data_layout == moco::tf::DataLayout::NCHW)
+  else if (data_layout == DataLayout::NCHW)
   {
     dec->perm()->axis(loco::FeatureAxis::Count) = 0;
     dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
@@ -101,7 +104,7 @@ bool canonicalize_avgpool2d(loco::Graph *graph, moco::tf::TFAvgPool *node)
    *                 TFAvgPool is disconnected from other nodes
    */
 
-  auto data_layout = moco::tf::as_DataLayout(node->data_layout());
+  auto data_layout = plier::tf::as_data_layout(node->data_layout());
 
   auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
   auto avgPool2d_node = graph->nodes()->create<loco::AvgPool2D>();