#include "Dialect/TFDialect.h"
#include "Dialect/TFNodes.h"
+#include <plier/tf/Convert.h>
+
+#include <stdex/Memory.h>
+
+#include <loco/IR/Stride.h>
+#include <loco/IR/Padding2D.h>
+#include <loco/Service/ShapeInference.h>
+
+namespace
+{
+using plier::tf::DataLayout;
+
+void set_feature_enc(loco::FeatureEncode *feature_enc, DataLayout data_layout)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ enc->perm()->axis(loco::FeatureAxis::Count) = 0;
+ enc->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ enc->perm()->axis(loco::FeatureAxis::Height) = 2;
+ enc->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+ else
+ throw std::runtime_error("Not supported data layout");
+
+ feature_enc->encoder(std::move(enc));
+}
+
+void set_filter_enc(loco::FilterEncode *filter_enc)
+{
+ auto enc = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+
+ // In TensorFlow, Conv2dBackpropInput's filter is a 4-D tensor of following shape:
+ // [filter_height, filter_width, out_channels, in_channels] or HWOI or HWNC (in/out in loco sense)
+ enc->perm()->axis(loco::FilterAxis::Height) = 0;
+ enc->perm()->axis(loco::FilterAxis::Width) = 1;
+ enc->perm()->axis(loco::FilterAxis::Count) = 2;
+ enc->perm()->axis(loco::FilterAxis::Depth) = 3;
+
+ filter_enc->encoder(std::move(enc));
+}
+
+void set_feature_dec(loco::FeatureDecode *feature_dec, DataLayout data_layout)
+{
+ auto dec = stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ if (data_layout == DataLayout::NHWC)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ dec->perm()->axis(loco::FeatureAxis::Count) = 0;
+ dec->perm()->axis(loco::FeatureAxis::Depth) = 1;
+ dec->perm()->axis(loco::FeatureAxis::Height) = 2;
+ dec->perm()->axis(loco::FeatureAxis::Width) = 3;
+ }
+ else
+ throw std::runtime_error("Not supported data layout");
+
+ feature_dec->decoder(std::move(dec));
+}
+
+} // namespace
+
+namespace
+{
+
+loco::Stride<2> stride_2d_from_4d(const std::vector<int64_t> &strides_4d,
+ const DataLayout data_layout)
+{
+ assert(strides_4d.size() == 4);
+
+ loco::Stride<2> ret;
+ switch (data_layout)
+ {
+ case DataLayout::NHWC:
+ ret.vertical(strides_4d.at(1));
+ ret.horizontal(strides_4d.at(2));
+ break;
+ case DataLayout::NCHW:
+ ret.vertical(strides_4d.at(2));
+ ret.horizontal(strides_4d.at(3));
+ break;
+ default:
+ throw std::runtime_error("Not supported data layout");
+ }
+ return ret;
+}
+
+struct PlaneShape
+{
+ loco::Dimension vertical;
+ loco::Dimension horizontal;
+};
+
+class Padding2DInference final
+{
+public:
+ loco::Padding2D operator()(void);
+
+public:
+ PlaneShape &input() { return _input; }
+ PlaneShape &output() { return _output; }
+ loco::Stride<2> &stride() { return _stride; }
+ loco::Window<2> &window() { return _window; }
+ moco::tf::TFPadding &padding() { return _padding; }
+
+private:
+ PlaneShape _input;
+ PlaneShape _output;
+ loco::Stride<2> _stride;
+ loco::Window<2> _window;
+ moco::tf::TFPadding _padding;
+};
+
+loco::Padding2D Padding2DInference::operator()(void)
+{
+ // TODO Implement!
+ throw std::runtime_error("NYI");
+}
+
+/**
+ * @param[out] ret PlaneShape extracted from 'node' with given 'data_layout'
+ * @param[in] node
+ * @param[in] data_layout
+ *
+ * @return true on success
+ */
+bool set_plane_shape(PlaneShape &ret, const loco::Node *node, const DataLayout data_layout)
+{
+ if (!loco::shape_known(node))
+ return false;
+
+ auto tensor_shape = loco::shape_get(node).as<loco::TensorShape>();
+ assert(tensor_shape.rank() == 4);
+
+ switch (data_layout)
+ {
+ case DataLayout::NHWC:
+ ret.vertical = tensor_shape.dim(1).value();
+ ret.horizontal = tensor_shape.dim(2).value();
+ break;
+ case DataLayout::NCHW:
+ ret.vertical = tensor_shape.dim(2).value();
+ ret.horizontal = tensor_shape.dim(3).value();
+ break;
+ default:
+ throw std::runtime_error("Not supported data layout");
+ }
+
+ return true;
+}
+
+/**
+ * @param[out] ret 2D Window extracted from HW** filter node
+ * @param[in] filter_node
+ *
+ * @return true on success
+ */
+bool set_window(loco::Window<2> &ret, const loco::Node *filter_node)
+{
+ if (!loco::shape_known(filter_node))
+ return false;
+
+ auto tensor_shape = loco::shape_get(filter_node).as<loco::TensorShape>();
+ assert(tensor_shape.rank() == 4);
+
+ ret.vertical(tensor_shape.dim(0).value());
+ ret.horizontal(tensor_shape.dim(1).value());
+
+ return true;
+}
+
+} // namespace
+
namespace
{
bool canonicalize_conv2d_backprop_input(loco::Graph *graph,
moco::tf::TFConv2DBackpropInput *conv2d_backprop)
{
- // TODO Implement!
- return false;
+ /**
+ * @note This will replace TFConv2DBackpropInput node with canonical
+ * FeatureEncode + FilterEncode + TransposedConv2D + FeatureDecode
+ *
+ * Before
+ * input_sizes ----
+ * \
+ * filter -------- TFConv2DBackpropInput --- output(s)
+ * /
+ * out_backprop ---
+ *
+ * After
+ * input_sizes ----
+ * \
+ * filter -------- TFConv2DBackpropInput ---
+ * /
+ * out_backprop ---
+ *
+ * filter ------ FilterEncode ------ TransposedConv2D --- FeatureDecode --- output(s)
+ * (as ker) /
+ * out_backprop --- FeatureEncode ---
+ * (as ifm)
+ */
+
+ auto data_layout = plier::tf::as_data_layout(conv2d_backprop->data_layout());
+
+ // Nodes to replace
+ auto feature_enc = graph->nodes()->create<loco::FeatureEncode>();
+ auto filter_enc = graph->nodes()->create<loco::FilterEncode>();
+ auto tr_conv2d = graph->nodes()->create<loco::TransposedConv2D>();
+ auto feature_dec = graph->nodes()->create<loco::FeatureDecode>();
+
+ set_feature_enc(feature_enc, data_layout);
+ set_filter_enc(filter_enc);
+ set_feature_dec(feature_dec, data_layout);
+
+ // Attributes for new TransposedConv2D
+ loco::Stride<2> stride;
+ loco::Padding2D pad;
+
+ // Get attributes
+ {
+ stride = stride_2d_from_4d(conv2d_backprop->strides(), data_layout);
+
+ Padding2DInference infer_pad;
+
+ if (!set_plane_shape(infer_pad.input(), conv2d_backprop->out_backprop(), data_layout))
+ return false;
+ if (!set_plane_shape(infer_pad.output(), conv2d_backprop, data_layout))
+ return false;
+ if (!set_window(infer_pad.window(), conv2d_backprop->filter()))
+ return false;
+ infer_pad.stride() = stride;
+ infer_pad.padding() = conv2d_backprop->padding();
+
+ // Run padding infer_pad
+ pad = infer_pad();
+ }
+
+ // Set attributes
+ tr_conv2d->pad()->top(pad.top());
+ tr_conv2d->pad()->bottom(pad.bottom());
+ tr_conv2d->pad()->left(pad.left());
+ tr_conv2d->pad()->right(pad.right());
+
+ tr_conv2d->stride()->vertical(stride.vertical());
+ tr_conv2d->stride()->horizontal(stride.horizontal());
+
+ // Update graph
+ auto input_node = conv2d_backprop->out_backprop();
+ auto filter_node = conv2d_backprop->filter();
+
+ // Update connections
+ feature_enc->input(input_node);
+ filter_enc->input(filter_node);
+ tr_conv2d->ifm(feature_enc);
+ tr_conv2d->ker(filter_enc);
+ feature_dec->input(tr_conv2d);
+
+ // Replace old conv2d_backprop
+ replace(conv2d_backprop).with(feature_dec);
+
+ return true;
}
} // namespace