From bd4e4d23daabf08a39980455760dba5544e30572 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 26 Jul 2019 14:30:04 +0900 Subject: [PATCH] [moco-tf] Apply plier-tf to AvgPoolCanonicalizer (#5921) This commit applies plier-tf to AvgPoolCanonicalizer. Signed-off-by: Cheongyo Bahk --- .../src/Canonicalization/AvgPoolCanonicalizer.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp index 2d45260..4ffc74d 100644 --- a/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp +++ b/compiler/moco-tf/src/Canonicalization/AvgPoolCanonicalizer.cpp @@ -29,24 +29,27 @@ #include "Dialect/TFNodeImpl.h" #include +#include #include namespace { -void set_feature_enc(loco::FeatureEncode *feature_enc, moco::tf::DataLayout data_layout) +using plier::tf::DataLayout; + +void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout) { auto enc = stdex::make_unique>(); - if (data_layout == moco::tf::DataLayout::NHWC) + if (data_layout == DataLayout::NHWC) { enc->perm()->axis(loco::FeatureAxis::Count) = 0; enc->perm()->axis(loco::FeatureAxis::Height) = 1; enc->perm()->axis(loco::FeatureAxis::Width) = 2; enc->perm()->axis(loco::FeatureAxis::Depth) = 3; } - else if (data_layout == moco::tf::DataLayout::NCHW) + else if (data_layout == DataLayout::NCHW) { enc->perm()->axis(loco::FeatureAxis::Count) = 0; enc->perm()->axis(loco::FeatureAxis::Depth) = 1; @@ -57,18 +60,18 @@ void set_feature_enc(loco::FeatureEncode *feature_enc, moco::tf::DataLayout data feature_enc->encoder(std::move(enc)); } -void set_feature_dec(loco::FeatureDecode *feature_dec, moco::tf::DataLayout data_layout) +void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout) { auto dec = stdex::make_unique>(); - if (data_layout == moco::tf::DataLayout::NHWC) + if (data_layout == DataLayout::NHWC) { dec->perm()->axis(loco::FeatureAxis::Count) = 0; dec->perm()->axis(loco::FeatureAxis::Height) = 1; dec->perm()->axis(loco::FeatureAxis::Width) = 2; dec->perm()->axis(loco::FeatureAxis::Depth) = 3; } - else if (data_layout == moco::tf::DataLayout::NCHW) + else if (data_layout == DataLayout::NCHW) { dec->perm()->axis(loco::FeatureAxis::Count) = 0; dec->perm()->axis(loco::FeatureAxis::Depth) = 1; @@ -101,7 +104,7 @@ bool canonicalize_avgpool2d(loco::Graph *graph, moco::tf::TFAvgPool *node) * TFAvgPool is disconnected from other nodes */ - auto data_layout = moco::tf::as_DataLayout(node->data_layout()); + auto data_layout = plier::tf::as_data_layout(node->data_layout()); auto feature_enc = graph->nodes()->create(); auto avgPool2d_node = graph->nodes()->create(); -- 2.7.4