#include "Dialect/TFNodeImpl.h"
#include <moco/Log.h>
+#include <plier/tf/Convert.h>
#include <stdex/Memory.h>
namespace
{
-void set_feature_enc(loco::FeatureEncode *feature_enc, moco::tf::DataLayout data_layout)
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
{
auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
- if (data_layout == moco::tf::DataLayout::NHWC)
+ if (data_layout == DataLayout::NHWC)
{
enc->perm()->axis(loco::FeatureAxis::Count) = 0;
enc->perm()->axis(loco::FeatureAxis::Height) = 1;
enc->perm()->axis(loco::FeatureAxis::Width) = 2;
enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
}
- else if (data_layout == moco::tf::DataLayout::NCHW)
+ else if (data_layout == DataLayout::NCHW)
{
enc->perm()->axis(loco::FeatureAxis::Count) = 0;
enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
feature_enc->encoder(std::move(enc));
}
-void set_feature_dec(loco::FeatureDecode *feature_dec, moco::tf::DataLayout data_layout)
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
{
auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
- if (data_layout == moco::tf::DataLayout::NHWC)
+ if (data_layout == DataLayout::NHWC)
{
dec->perm()->axis(loco::FeatureAxis::Count) = 0;
dec->perm()->axis(loco::FeatureAxis::Height) = 1;
dec->perm()->axis(loco::FeatureAxis::Width) = 2;
dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
}
- else if (data_layout == moco::tf::DataLayout::NCHW)
+ else if (data_layout == DataLayout::NCHW)
{
dec->perm()->axis(loco::FeatureAxis::Count) = 0;
dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
* TFAvgPool is disconnected from other nodes
*/
- auto data_layout = moco::tf::as_DataLayout(node->data_layout());
+ auto data_layout = plier::tf::as_data_layout(node->data_layout());
auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
auto avgPool2d_node = graph->nodes()->create<loco::AvgPool2D>();