From: 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Thu, 11 Jul 2019 06:56:11 +0000 (+0900) Subject: [moco/tf] TFConv2D to Conv2D Canonicalizer (#4202) X-Git-Tag: nncc_backup~116 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f4bbc86763ef842a9992875272015fa96f31d8e1;p=platform%2Fcore%2Fml%2Fnnfw.git [moco/tf] TFConv2D to Conv2D Canonicalizer (#4202) This will implement TFConv2D to Conv2D Canonicalizer Signed-off-by: SaeHie Park --- diff --git a/contrib/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp b/contrib/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp index 61da9f9..13f9c56 100644 --- a/contrib/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp +++ b/contrib/moco-tf/src/Canonicalization/Conv2DCanonicalizer.cpp @@ -16,19 +16,151 @@ #include "Conv2DCanonicalizer.h" +#include "Annotations/PadData.h" +#include "Annotations/StrideData.h" + +#include "Knob.h" + #include "Dialect/TFDialect.h" #include "Dialect/TFNodes.h" #include "Dialect/TFNodeVisitor.h" #include "Dialect/TFNodeImpl.h" +#include + +#include + namespace { +void set_feature_enc(loco::FeatureEncode *feature_enc, moco::tf::DataLayout data_layout) +{ + auto enc = stdex::make_unique>(); + + if (data_layout == moco::tf::DataLayout::NHWC) + { + enc->perm()->axis(loco::FeatureAxis::Count) = 0; + enc->perm()->axis(loco::FeatureAxis::Height) = 1; + enc->perm()->axis(loco::FeatureAxis::Width) = 2; + enc->perm()->axis(loco::FeatureAxis::Depth) = 3; + } + else if (data_layout == moco::tf::DataLayout::NCHW) + { + enc->perm()->axis(loco::FeatureAxis::Count) = 0; + enc->perm()->axis(loco::FeatureAxis::Depth) = 1; + enc->perm()->axis(loco::FeatureAxis::Height) = 2; + enc->perm()->axis(loco::FeatureAxis::Width) = 3; + } + + feature_enc->encoder(std::move(enc)); +} + +void set_filter_enc(loco::FilterEncode *filter_enc, moco::tf::DataLayout data_layout) +{ + auto enc = stdex::make_unique>(); + + // In TensorFlow, conv2d filter is a 4-D tensor of following shape: + // [filter_height, filter_width, in_channels, out_channels] -> HWIO (HWCN) + enc->perm()->axis(loco::FilterAxis::Height) = 0; + enc->perm()->axis(loco::FilterAxis::Width) = 1; + enc->perm()->axis(loco::FilterAxis::Depth) = 2; + enc->perm()->axis(loco::FilterAxis::Count) = 3; + + filter_enc->encoder(std::move(enc)); +} + +void set_feature_dec(loco::FeatureDecode *feature_dec, moco::tf::DataLayout data_layout) +{ + auto dec = stdex::make_unique>(); + + if (data_layout == moco::tf::DataLayout::NHWC) + { + dec->perm()->axis(loco::FeatureAxis::Count) = 0; + dec->perm()->axis(loco::FeatureAxis::Height) = 1; + dec->perm()->axis(loco::FeatureAxis::Width) = 2; + dec->perm()->axis(loco::FeatureAxis::Depth) = 3; + } + else if (data_layout == moco::tf::DataLayout::NCHW) + { + dec->perm()->axis(loco::FeatureAxis::Count) = 0; + dec->perm()->axis(loco::FeatureAxis::Depth) = 1; + dec->perm()->axis(loco::FeatureAxis::Height) = 2; + dec->perm()->axis(loco::FeatureAxis::Width) = 3; + } + + feature_dec->decoder(std::move(dec)); +} + bool canonicalize_conv2d(loco::Graph *graph, moco::tf::TFConv2D *node) { - std::runtime_error("NYI canonicalize_conv2d"); + if (!moco::tf::get()) + return false; + + LOGGER(l); + + /** + * @note This will replace TFCon2D node with Canonical FeatureEncode + + * FilterEncode + Conv2D + FeatureDecode + * + * Before + * A -- TFConv2D - C + * B -/ + * + * After + * - TFConv2D - + * A -- FeatureEncode - Conv2D - FeatureDecode - C + * B -- FilterEncode -/ + * + * Where + * A : ifm of TFConv2D + * B : ker of TFConv2D + * C : a node that uses TFConv2D as an input + * TFConv2D is disconnected from other nodes + */ + + auto data_layout = moco::tf::as_DataLayout(node->data_layout()); + + auto feature_enc = graph->nodes()->create(); + auto filter_enc = graph->nodes()->create(); + auto conv2d = graph->nodes()->create(); + auto feature_dec = graph->nodes()->create(); + + set_feature_enc(feature_enc, data_layout); + set_filter_enc(filter_enc, data_layout); + set_feature_dec(feature_dec, data_layout); + + // Set Conv2D attributes from TFConv2D + auto pad_data = node->annot(); + assert(pad_data != nullptr); + + conv2d->pad()->top(pad_data->pad()->top()); + conv2d->pad()->bottom(pad_data->pad()->bottom()); + conv2d->pad()->left(pad_data->pad()->left()); + conv2d->pad()->right(pad_data->pad()->right()); + + auto stride_data = node->annot(); + assert(stride_data != nullptr); + + conv2d->stride()->vertical(stride_data->stride()->vertical()); + conv2d->stride()->horizontal(stride_data->stride()->horizontal()); + + // update graph + auto node_A = node->ifm(); + auto node_B = node->ker(); + + // update connections + feature_enc->input(node_A); + filter_enc->input(node_B); + conv2d->ifm(feature_enc); + conv2d->ker(filter_enc); + feature_dec->input(conv2d); + + // replace and disconnect old node + replace(node).with(feature_dec); + node->ifm(nullptr); + node->ker(nullptr); - return false; + return true; } } // namespace