--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_SERVICE_SHAPE_INFERENCE_RULE_H__
+#define __MOCO_SERVICE_SHAPE_INFERENCE_RULE_H__
+
+#include <loco/Service/ShapeInferenceRule.h>
+
+namespace moco
+{
+
+/**
+ * @brief Shape inference rule for TensorFlow dialect
+ */
+struct TFShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+ bool support(const API &ver) const final;
+ bool recognize(const loco::Dialect *) const final;
+ bool infer(const loco::Node *, loco::NodeShape &) const final;
+ void infer(const Context *, const loco::Node *, Sink *) const final;
+};
+
+} // namespace moco
+
+#endif // __MOCO_SERVICE_SHAPE_INFERENCE_RULE_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_SERVICE_TYPE_INFERENCE_RULE_H__
+#define __MOCO_SERVICE_TYPE_INFERENCE_RULE_H__
+
+#include <loco/Service/TypeInference.h>
+
+namespace moco
+{
+
+/**
+ * @brief Type Inference Rule for TFDialect
+ */
+struct TFTypeInferenceRule final : public loco::TypeInferenceRule
+{
+ bool recognize(const loco::Dialect *) const final;
+ bool infer(const loco::Node *, loco::DataType &) const final;
+};
+
+} // namespace moco
+
+#endif // __MOCO_SERVICE_TYPE_INFERENCE_RULE_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFShapeInferenceHelper.h"
+
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+
+namespace
+{
+
+// TODO Use codes in loco and remove duplicate broadcast_shape() and related
+/**
+ * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
+ *
+ * HOW TO USE:
+ *
+ * auto expanded_tensor_shape = expand(tensor_shape).to(N);
+ */
+class TensorShapeExpander
+{
+public:
+ TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ loco::TensorShape to(uint32_t output_rank)
+ {
+ auto const &input_shape = _shape;
+ uint32_t const input_rank = input_shape.rank();
+
+ assert(input_rank <= output_rank && "Cannot shrink rank");
+ uint32_t const axis_shift = output_rank - input_rank;
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(output_rank);
+ for (uint32_t axis = 0; axis < output_rank; ++axis)
+ {
+ output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
+ }
+
+ return output_shape;
+ }
+
+private:
+ const loco::TensorShape _shape;
+};
+
+/**
+ * @breif Expand shape x and y to same rank by align right and filling with 1
+ */
+void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
+{
+ auto x_rank = x.rank();
+ auto y_rank = y.rank();
+
+ if (x_rank == y_rank)
+ return;
+
+ TensorShapeExpander x_exp(x);
+ TensorShapeExpander y_exp(y);
+
+ auto xy_rank = std::max(x_rank, y_rank);
+
+ x = x_rank > y_rank ? x : x_exp.to(xy_rank);
+ y = y_rank > x_rank ? y : y_exp.to(xy_rank);
+}
+
+/**
+ * @breif Returns shape of expanded dimension of input x and y having same rank
+ */
+loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ assert(x.rank() == y.rank());
+
+ auto rank = x.rank();
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ assert(x.dim(axis).known() && y.dim(axis).known());
+
+ auto x_dim = x.dim(axis).value();
+ auto y_dim = y.dim(axis).value();
+
+ // each dimension of x and y should be same or one must be 1 if different
+ if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
+ throw std::runtime_error("Cannot produce expand_dimension of two shapes");
+
+ output_shape.dim(axis) = std::max(x_dim, y_dim);
+ }
+
+ return output_shape;
+}
+
+} // namespace
+
+namespace moco
+{
+
+loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ auto x_match = x;
+ auto y_match = y;
+
+ expand_rank(x_match, y_match);
+
+ auto output_shape = expand_dimension(x_match, y_match);
+
+ return output_shape;
+}
+
+} // namespace moco
+
+namespace moco
+{
+
+loco::NodeShape node_shape(const loco::Node *node)
+{
+ loco::NodeShape nodeshape; // default domain is Unknown
+
+ if (loco::shape_known(node))
+ {
+ nodeshape = loco::shape_get(node);
+ }
+
+ return nodeshape;
+}
+
+bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
+{
+ nodeshape = node_shape(node);
+ return (nodeshape.domain() != loco::Domain::Unknown);
+}
+
+loco::TensorShape as_tensor_shape(const loco::FeatureShape &feature_shape,
+ const TFDataLayout &data_layout)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(4);
+ if (data_layout == "NHWC")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.height();
+ tensor_shape.dim(2) = feature_shape.width();
+ tensor_shape.dim(3) = feature_shape.depth();
+ }
+ else if (data_layout == "NCHW")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.depth();
+ tensor_shape.dim(2) = feature_shape.height();
+ tensor_shape.dim(3) = feature_shape.width();
+ }
+ else
+ {
+ // TODO support for other data_layout if needed
+ throw std::runtime_error("as_tensor_shape: only supports NHWC or NCHW");
+ }
+
+ return tensor_shape;
+}
+
+loco::FeatureShape as_feature_shape(const loco::NodeShape &nodeshape,
+ const TFDataLayout &data_layout)
+{
+ if (nodeshape.domain() == loco::Domain::Feature)
+ return nodeshape.as<loco::FeatureShape>();
+
+ loco::FeatureShape feature_shape;
+
+ // only convert from tensor to feature
+ if (nodeshape.domain() != loco::Domain::Tensor)
+ {
+ throw std::runtime_error("as_feature_shape: domain is not tensor");
+ }
+
+ loco::TensorShape tensor_shape = nodeshape.as<loco::TensorShape>();
+
+ if (tensor_shape.rank() != 4)
+ {
+ throw std::runtime_error("as_feature_shape: rank is not 4");
+ }
+
+ if (data_layout == "NHWC")
+ {
+ feature_shape.count() = tensor_shape.dim(0);
+ feature_shape.height() = tensor_shape.dim(1);
+ feature_shape.width() = tensor_shape.dim(2);
+ feature_shape.depth() = tensor_shape.dim(3);
+ }
+ else if (data_layout == "NCHW")
+ {
+ feature_shape.count() = tensor_shape.dim(0);
+ feature_shape.depth() = tensor_shape.dim(1);
+ feature_shape.height() = tensor_shape.dim(2);
+ feature_shape.width() = tensor_shape.dim(3);
+ }
+ else
+ {
+ // TODO support for other data_layout if needed
+ throw std::runtime_error("as_feature_shape: only supports NHWC or NCHW");
+ }
+
+ return feature_shape;
+}
+
+} // namespace moco
+
+namespace moco
+{
+
+PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape)
+{
+ PlaneShape plane_shape;
+
+ plane_shape.height = feature_shape.height();
+ plane_shape.width = feature_shape.width();
+
+ return plane_shape;
+}
+
+FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
+{
+ return FeatureShapeUpdater{&feature_shape};
+}
+
+} // namespace moco
+
+namespace
+{
+
+/**
+ * @brief Class to represent TensorFlow "data_format" attr.
+ */
+enum class DataLayout
+{
+ NHWC,
+ NCHW,
+};
+
+DataLayout as_data_layout(const std::string &tf_layout_str)
+{
+ if (tf_layout_str == "NHWC")
+ return DataLayout::NHWC;
+ else if (tf_layout_str == "NCHW")
+ return DataLayout::NCHW;
+ else
+ throw std::runtime_error("unknown data layout");
+}
+
+} // namespace
+
+namespace moco
+{
+
+loco::Stride<2> stride_of(const TFStrides &strides, const TFDataLayout &datalayout)
+{
+ loco::Stride<2> stride;
+
+ auto data_layout = as_data_layout(datalayout);
+ if (data_layout == DataLayout::NHWC)
+ {
+ stride.vertical(strides[1]);
+ stride.horizontal(strides[2]);
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ stride.vertical(strides[2]);
+ stride.horizontal(strides[3]);
+ }
+
+ return stride;
+}
+
+loco::Window<2> window_of(const TFKSize &ksize, const TFDataLayout &datalayout)
+{
+ loco::Window<2> window;
+
+ auto data_layout = as_data_layout(datalayout);
+ if (data_layout == DataLayout::NHWC)
+ {
+ window.vertical(ksize[1]);
+ window.horizontal(ksize[2]);
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ window.vertical(ksize[2]);
+ window.horizontal(ksize[3]);
+ }
+
+ return window;
+}
+
+loco::Window<2> window_of(const loco::TensorShape &shape, const TFDataLayout &datalayout)
+{
+ loco::Window<2> window;
+
+ if (datalayout == "HWIO")
+ {
+ window.vertical(shape.dim(0).value());
+ window.horizontal(shape.dim(1).value());
+ }
+ else if (datalayout == "HWCM")
+ {
+ window.vertical(shape.dim(0).value());
+ window.horizontal(shape.dim(1).value());
+ }
+ else
+ {
+ // TODO add more datalayout supports if needed
+ assert(false);
+ }
+
+ return window;
+}
+
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MOCO_SERVICE_SHAPE_INFERENCE_HELPER_H__
+#define __MOCO_SERVICE_SHAPE_INFERENCE_HELPER_H__
+
+#include "moco/IR/TFNodeDecl.h" // for TFDataLayout
+
+#include <loco/IR/NodeShape.h>
+#include <loco/IR/Padding2D.h>
+#include <loco/IR/Stride.h>
+#include <loco/IR/Window.h>
+
+#include <cassert>
+
+namespace moco
+{
+
+/**
+ * @note Helper for return broadcasted shape for binary operators having
+ * different shape for input x and y
+ */
+loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y);
+
+} // namespace moco
+
+namespace moco
+{
+
+/**
+ * @brief Return true if node has shape inference data for checking shape
+ * inference is done or not
+ *
+ * @note Will be deprecated in near future
+ */
+bool shape_inference_done(const loco::Node *node);
+
+/**
+ * @note While in shape inference, Node maybe Canonical, TF dialect or other dialects
+ * This will provide common loco::NodeShape as shape information
+ */
+loco::NodeShape node_shape(const loco::Node *node);
+bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape);
+
+loco::TensorShape as_tensor_shape(const loco::FeatureShape &feature_shape,
+ const TFDataLayout &data_layout);
+
+loco::FeatureShape as_feature_shape(const loco::NodeShape &nodeshape,
+ const TFDataLayout &data_layout);
+
+} // namespace moco
+
+namespace moco
+{
+
+struct PlaneShape
+{
+ loco::Dimension height;
+ loco::Dimension width;
+};
+
+class FeatureShapeUpdater final
+{
+public:
+ FeatureShapeUpdater(loco::FeatureShape *ptr) : _feature_shape_ptr{ptr}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void with(const PlaneShape &plane_shape) const
+ {
+ _feature_shape_ptr->height() = plane_shape.height;
+ _feature_shape_ptr->width() = plane_shape.width;
+ }
+
+private:
+ loco::FeatureShape *_feature_shape_ptr;
+};
+
+PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape);
+
+FeatureShapeUpdater update(loco::FeatureShape &feature_shape);
+
+class PlaneInference
+{
+protected:
+ struct Parameters
+ {
+ PlaneShape input;
+ PlaneShape stride;
+ PlaneShape window;
+ PlaneShape dilation;
+ PlaneShape effective_window;
+ PlaneShape output;
+ };
+
+ void fill(Parameters &p, const PlaneShape &in)
+ {
+ p.input.height = in.height;
+ p.input.width = in.width;
+
+ p.stride.height = _stride.vertical();
+ p.stride.width = _stride.horizontal();
+
+ p.window.height = _window.vertical();
+ p.window.width = _window.horizontal();
+
+ // TODO support dilation
+ p.dilation.height = 1;
+ p.dilation.width = 1;
+
+ p.effective_window.height = p.dilation.height.value() * (p.window.height.value() - 1) + 1;
+ p.effective_window.width = p.dilation.width.value() * (p.window.width.value() - 1) + 1;
+ }
+
+ PlaneShape infer(const Parameters &p, const PlaneShape &)
+ {
+ PlaneShape res;
+
+ if (_padding == "VALID")
+ {
+ res.height =
+ (p.input.height.value() + p.stride.height.value() - p.effective_window.height.value()) /
+ p.stride.height.value();
+ res.width =
+ (p.input.width.value() + p.stride.width.value() - p.effective_window.width.value()) /
+ p.stride.width.value();
+ }
+ else if (_padding == "SAME")
+ {
+ res.height = (p.input.height.value() + p.stride.height.value() - 1) / p.stride.height.value();
+ res.width = (p.input.width.value() + p.stride.width.value() - 1) / p.stride.width.value();
+ }
+ else
+ assert(false);
+
+ return res;
+ }
+
+public:
+ PlaneShape operator()(const PlaneShape &in)
+ {
+ Parameters p;
+
+ fill(p, in);
+
+ return infer(p, in);
+ }
+
+public:
+ void padding(const TFPadding &value) { _padding = value; }
+ void window(const loco::Window<2> value) { _window = value; }
+ void stride(const loco::Stride<2> value) { _stride = value; }
+
+private:
+ TFPadding _padding;
+ loco::Window<2> _window;
+ loco::Stride<2> _stride;
+};
+
+class Padding2DInference final : public PlaneInference
+{
+public:
+ loco::Padding2D operator()(const PlaneShape &in)
+ {
+ Parameters p;
+
+ fill(p, in);
+
+ auto output = infer(p, in);
+
+ int64_t i_height = (int64_t)(output.height.value() - 1) * (int64_t)p.stride.height.value() +
+ (int64_t)p.effective_window.height.value() - (int64_t)p.input.height.value();
+ int64_t i_width = (int64_t)(output.width.value() - 1) * (int64_t)p.stride.width.value() +
+ (int64_t)p.effective_window.width.value() - (int64_t)p.input.width.value();
+
+ uint32_t pad_height = i_height >= 0 ? (uint32_t)i_height : 0U;
+ uint32_t pad_width = i_width >= 0 ? (uint32_t)i_width : 0U;
+
+ loco::Padding2D padding2d;
+
+ padding2d.top(pad_height / 2);
+ padding2d.bottom(pad_height - padding2d.top());
+ padding2d.left(pad_width / 2);
+ padding2d.right(pad_width - padding2d.left());
+
+ return padding2d;
+ }
+};
+
+} // namespace moco
+
+namespace moco
+{
+
+using TFStrides = std::vector<int64_t>;
+using TFKSize = std::vector<int64_t>;
+
+loco::Stride<2> stride_of(const TFStrides &strides, const TFDataLayout &datalayout);
+loco::Window<2> window_of(const TFKSize &ksize, const TFDataLayout &datalayout);
+loco::Window<2> window_of(const loco::TensorShape &shape, const TFDataLayout &datalayout);
+
+} // namespace moco
+
+#endif // __MOCO_SERVICE_SHAPE_INFERENCE_HELPER_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "moco/Service/TFShapeInferenceRule.h"
+
+#include "TFShapeInferenceHelper.h"
+
+#include "moco/IR/TFDialect.h"
+#include "moco/IR/TFNode.h"
+
+#include <loco/IR/NodeShape.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <cassert>
+
+namespace
+{
+
+class ShapeInferenceAlgorithm final : public moco::TFNodeVisitor<loco::NodeShape>
+{
+public:
+ ShapeInferenceAlgorithm(const loco::ShapeInferenceRule::Context *ctx) : _ctx{ctx}
+ {
+ // DO NOTHING
+ }
+
+private:
+ const loco::ShapeInferenceRule::Context *_ctx;
+
+private:
+ bool shape_known(const loco::Node *node) const { return _ctx->known(node); }
+ loco::NodeShape node_shape(const loco::Node *node) const { return _ctx->get(node); }
+
+private:
+ loco::NodeShape binary_node_shape(const moco::TFNode::Node *node)
+ {
+ // This helper works only for binary node.
+ assert(node->arity() == 2);
+
+ auto lhs_shape = node_shape(node->arg(0));
+ auto rhs_shape = node_shape(node->arg(1));
+
+ loco::TensorShape lhs_tensorshape = lhs_shape.as<loco::TensorShape>();
+ loco::TensorShape rhs_tensorshape = rhs_shape.as<loco::TensorShape>();
+ loco::TensorShape sum_tensorshape = moco::broadcast_shape(lhs_tensorshape, rhs_tensorshape);
+
+ loco::NodeShape sum_shape({sum_tensorshape});
+
+ return sum_shape;
+ }
+
+ loco::NodeShape node_shape_with_check(const moco::TFNode::Node *node)
+ {
+ auto nodeshape = node_shape(node);
+ assert(nodeshape.domain() == loco::Domain::Tensor);
+
+ return nodeshape;
+ }
+
+ bool valid_scala_value(moco::TFConst *node)
+ {
+ auto nodeshape = node_shape(node);
+ if (nodeshape.domain() != loco::Domain::Tensor)
+ {
+ return false;
+ }
+ if (node->dtype() != loco::DataType::S32)
+ {
+ return false;
+ }
+
+ auto tensor_shape = nodeshape.as<loco::TensorShape>();
+ if (!(tensor_shape.rank() == 0 || tensor_shape.rank() == 1))
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ int32_t scala_value(moco::TFConst *node)
+ {
+ auto nodeshape = node_shape(node);
+ assert(node->dtype() == loco::DataType::S32);
+
+ auto tensor_shape = nodeshape.as<loco::TensorShape>();
+ assert(tensor_shape.rank() == 0 || tensor_shape.rank() == 1);
+
+ return node->at<loco::DataType::S32>(0);
+ }
+
+public:
+ loco::NodeShape visit(const moco::TFAdd *node) final { return binary_node_shape(node); }
+
+ loco::NodeShape visit(const moco::TFAvgPool *node) final
+ {
+ auto value_shape = node_shape(node->value());
+ assert(value_shape.domain() != loco::Domain::Unknown);
+
+ moco::PlaneInference infer_plane_shape;
+
+ infer_plane_shape.padding(node->padding());
+ infer_plane_shape.stride(moco::stride_of(node->strides(), node->data_layout()));
+ infer_plane_shape.window(moco::window_of(node->ksize(), node->data_layout()));
+
+ auto input_feature_shape = moco::as_feature_shape(value_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ auto output_feature_shape = input_feature_shape;
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ moco::update(output_feature_shape).with(output_plane_shape);
+
+ return moco::as_tensor_shape(output_feature_shape, node->data_layout());
+ }
+
+ loco::NodeShape visit(const moco::TFBiasAdd *node) final
+ {
+ return node_shape_with_check(node->value());
+ }
+
+ loco::NodeShape visit(const moco::TFConcatV2 *node) final
+ {
+ // axis shape should be available
+ auto axis_node = node->axis();
+ auto axis_shape = node_shape(axis_node);
+ assert(axis_shape.domain() != loco::Domain::Unknown);
+
+ // check all input shapes and all ranks should be same
+ auto value_a = node->values(0);
+ auto value_a_shape = node_shape(value_a);
+ assert(value_a_shape.domain() == loco::Domain::Tensor);
+ auto value_a_tensor_shape = value_a_shape.as<loco::TensorShape>();
+ uint32_t a_rank = value_a_tensor_shape.rank();
+
+ uint32_t num_values = node->num_values();
+ for (uint32_t ni = 1; ni < num_values; ++ni)
+ {
+ auto value_b = node->values(ni);
+ auto value_b_shape = node_shape(value_b);
+ assert(value_b_shape.domain() == loco::Domain::Tensor);
+ auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
+ uint32_t b_rank = value_b_tensor_shape.rank();
+ assert(a_rank == b_rank);
+ }
+
+ int32_t axis_value = 0;
+ bool axis_available = false;
+ {
+ // check for axis is TFConst
+ auto tfconst = dynamic_cast<moco::TFConst *>(axis_node);
+ if (tfconst != nullptr)
+ {
+ if (valid_scala_value(tfconst))
+ {
+ axis_value = scala_value(tfconst);
+ axis_available = true;
+ }
+ }
+ }
+ assert(axis_available);
+
+ uint32_t axis_absolute = (axis_value >= 0) ? axis_value : (int32_t)a_rank + axis_value;
+ loco::TensorShape output_tensor_shape = value_a_tensor_shape;
+
+ for (uint32_t index = 0; index < a_rank; ++index)
+ {
+ if (value_a_tensor_shape.dim(index).known())
+ {
+ uint32_t dim = value_a_tensor_shape.dim(index).value();
+ uint32_t dim_acc = dim;
+
+ for (uint32_t ni = 1; ni < num_values; ++ni)
+ {
+ auto value_b = node->values(ni);
+ auto value_b_shape = node_shape(value_b);
+ assert(value_b_shape.domain() == loco::Domain::Tensor);
+ auto value_b_tensor_shape = value_b_shape.as<loco::TensorShape>();
+ assert(value_b_tensor_shape.dim(index).known());
+ if (index == axis_absolute)
+ dim_acc += value_b_tensor_shape.dim(index).value();
+ else
+ assert(dim == value_b_tensor_shape.dim(index).value());
+ }
+ output_tensor_shape.dim(index) = dim_acc;
+ }
+ else
+ output_tensor_shape.dim(index).unset();
+ }
+ return loco::NodeShape(output_tensor_shape);
+ }
+
+ loco::NodeShape visit(const moco::TFConst *node) final
+ {
+ loco::TensorShape output_tensor_shape;
+
+ uint32_t rank = node->rank();
+ output_tensor_shape.rank(rank);
+ for (uint32_t index = 0; index < rank; ++index)
+ {
+ if (node->dim(index).known())
+ output_tensor_shape.dim(index) = node->dim(index).value();
+ else
+ output_tensor_shape.dim(index).unset();
+ }
+
+ return loco::NodeShape(output_tensor_shape);
+ }
+
+ loco::NodeShape visit(const moco::TFConv2D *node) final
+ {
+ auto input_shape = moco::node_shape(node->input());
+ auto ker_shape = moco::node_shape(node->filter());
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWIO
+ auto node_stride = moco::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::window_of(ker_tensor_shape, "HWIO");
+
+ moco::PlaneInference infer_plane_shape;
+
+ infer_plane_shape.padding(node->padding());
+ infer_plane_shape.stride(node_stride);
+ infer_plane_shape.window(node_window);
+
+ auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ // output count is from input count, depth is from kernel 'O' which is dim(3)
+ auto output_feature_shape = input_feature_shape;
+ output_feature_shape.depth() = ker_tensor_shape.dim(3).value();
+
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ moco::update(output_feature_shape).with(output_plane_shape);
+
+ return moco::as_tensor_shape(output_feature_shape, node->data_layout());
+ }
+
+ loco::NodeShape visit(const moco::TFConv2DBackpropInput *node) final
+ {
+ // TFConv2DBackpropInput's first input, named 'input_sizes', actually contains shape of node
+ // output's feature map. We can get shape of TFConv2DBackpropInput by just copying this.
+ // TODO Support when 'input_sizes' is not TFConst, or support constant folding
+ auto input_sizes_node = dynamic_cast<moco::TFConst *>(node->input_sizes());
+ assert(input_sizes_node);
+
+ // Let's support S32 for time being
+ // TODO Support other integer types
+ assert(input_sizes_node->dtype() == loco::DataType::S32);
+ assert(input_sizes_node->size<loco::DataType::S32>() == 4);
+
+ // copy!
+ loco::TensorShape ofm_tensor_shape;
+ ofm_tensor_shape.rank(4);
+ for (uint32_t i = 0; i < 4; ++i)
+ {
+ int32_t dim = input_sizes_node->at<loco::DataType::S32>(i);
+ assert(dim > 0);
+ ofm_tensor_shape.dim(i) = (uint32_t)dim;
+ }
+
+ return loco::NodeShape(ofm_tensor_shape);
+ }
+
+ loco::NodeShape visit(const moco::TFDepthwiseConv2dNative *node) final
+ {
+ auto input_shape = moco::node_shape(node->input()); // NHWC
+ auto ker_shape = moco::node_shape(node->filter());
+ auto ker_tensor_shape = ker_shape.as<loco::TensorShape>(); // in HWCM
+ auto node_stride = moco::stride_of(node->strides(), node->data_layout());
+ auto node_window = moco::window_of(ker_tensor_shape, "HWCM");
+
+ moco::PlaneInference infer_plane_shape;
+
+ infer_plane_shape.padding(node->padding());
+ infer_plane_shape.stride(node_stride);
+ infer_plane_shape.window(node_window);
+
+ auto input_feature_shape = moco::as_feature_shape(input_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ // output count is from input count, depth is from kernel 'CM' which is dim(2) * dim(3)
+ auto output_feature_shape = input_feature_shape;
+ output_feature_shape.depth() =
+ loco::Dimension(ker_tensor_shape.dim(2).value() * ker_tensor_shape.dim(3).value());
+
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ moco::update(output_feature_shape).with(output_plane_shape);
+
+ return moco::as_tensor_shape(output_feature_shape, node->data_layout());
+ }
+
+ loco::NodeShape visit(const moco::TFFusedBatchNorm *node) final
+ {
+ return node_shape_with_check(node->input());
+ }
+
+ loco::NodeShape visit(const moco::TFIdentity *node) final
+ {
+ return node_shape_with_check(node->input());
+ }
+
+ loco::NodeShape visit(const moco::TFMaxPool *node) final
+ {
+ auto value_shape = node_shape(node->value());
+ assert(value_shape.domain() != loco::Domain::Unknown);
+
+ moco::PlaneInference infer_plane_shape;
+
+ infer_plane_shape.padding(node->padding());
+ infer_plane_shape.stride(moco::stride_of(node->strides(), node->data_layout()));
+ infer_plane_shape.window(moco::window_of(node->ksize(), node->data_layout()));
+
+ auto input_feature_shape = moco::as_feature_shape(value_shape, node->data_layout());
+ auto input_plane_shape = moco::make_plane_shape(input_feature_shape);
+ auto output_feature_shape = input_feature_shape;
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ moco::update(output_feature_shape).with(output_plane_shape);
+
+ return moco::as_tensor_shape(output_feature_shape, node->data_layout());
+ }
+
+ loco::NodeShape visit(const moco::TFMean *node) final
+ {
+ auto input_shape = node_shape(node->input());
+ auto reduction_indices = node->reduction_indices();
+
+ // Get constant values if reduction_indeces is const
+ std::vector<int32_t> reduction_values;
+ if (auto tfconst = dynamic_cast<moco::TFConst *>(reduction_indices))
+ {
+ assert(tfconst->dtype() == loco::DataType::S32);
+ auto const_size = tfconst->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < const_size; ++i)
+ {
+ int32_t axis = tfconst->at<loco::DataType::S32>(i);
+ if (axis < 0)
+ axis += input_shape.as<loco::TensorShape>().rank();
+ reduction_values.push_back(axis);
+ }
+ }
+ else
+ {
+ // we cannot find a valid reduction indices value
+ loco::NodeShape unknown;
+ return unknown;
+ }
+
+ loco::TensorShape output_shape;
+ auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ if (node->keep_dims())
+ {
+ output_shape.rank(input_tensor_shape.rank());
+ for (uint32_t i = 0; i < input_tensor_shape.rank(); ++i)
+ output_shape.dim(i) = input_tensor_shape.dim(i);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ output_shape.dim(reduction_values.at(i)) = 1;
+ }
+ else
+ {
+ std::vector<bool> check_reduce(input_tensor_shape.rank(), false);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ check_reduce.at(reduction_values.at(i)) = true;
+
+ uint32_t reduce_cnt = 0;
+ for (uint32_t i = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i))
+ ++reduce_cnt;
+
+ output_shape.rank(input_tensor_shape.rank() - reduce_cnt);
+ for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i) == false)
+ output_shape.dim(j++) = i;
+ }
+
+ return loco::NodeShape(output_shape);
+ }
+
+ loco::NodeShape visit(const moco::TFMul *node) final { return binary_node_shape(node); }
+
+ loco::NodeShape visit(const moco::TFPad *node) final
+ {
+ auto input_shape = node_shape(node->input());
+ assert(input_shape.domain() == loco::Domain::Tensor);
+
+ auto const_paddings = dynamic_cast<moco::TFConst *>(node->paddings());
+ assert(const_paddings);
+ assert(const_paddings->dtype() == loco::DataType::S32);
+ assert(const_paddings->rank() == 2);
+
+ loco::TensorShape input_tensor_shape = input_shape.as<loco::TensorShape>();
+ loco::TensorShape output_tensor_shape;
+
+ output_tensor_shape.rank(input_tensor_shape.rank());
+ for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+ {
+ output_tensor_shape.dim(axis) = input_tensor_shape.dim(axis).value() +
+ const_paddings->at<loco::DataType::S32>(axis * 2) +
+ const_paddings->at<loco::DataType::S32>(axis * 2 + 1);
+ }
+
+ return loco::NodeShape{output_tensor_shape};
+ }
+
+ loco::NodeShape visit(const moco::TFRealDiv *node) final { return binary_node_shape(node); }
+
+ loco::NodeShape visit(const moco::TFRelu *node) final
+ {
+ return node_shape_with_check(node->features());
+ }
+
+ loco::NodeShape visit(const moco::TFRelu6 *node) final
+ {
+ return node_shape_with_check(node->features());
+ }
+
+ loco::NodeShape visit(const moco::TFReshape *node) final
+ {
+ loco::NodeShape unknown;
+
+ // For now, we only consider Fixed Reshape, i.e. Reshape with determined
+ // 'shape' input. So here we only support case when 'shape' input of
+ // TFReshape is TFConst. If 'shape' input is not TFConst, another
+ // transform (e.g. constant folding) should be done beforehand to make
+ // it TFConst.
+ // TODO Support dynamic Reshape
+ // Note that 'shape()' here is 'shape' input, not node's shape information
+ auto const_shape_input = dynamic_cast<moco::TFConst *>(node->shape());
+ if (!const_shape_input)
+ {
+ // 'shape' input of TFReshape is not TFConst, we can not do shape inference
+ return unknown;
+ }
+
+ // 'Shape' input should be integer tensor of rank 1, e.g. [2, 3, 4] or [3, -1]
+ assert(const_shape_input->dtype() == loco::DataType::S32);
+ assert(const_shape_input->rank() == 1);
+
+ auto shape_rank = const_shape_input->dim(0).value();
+ assert(shape_rank > 0);
+
+ loco::TensorShape output_shape;
+ output_shape.rank(shape_rank);
+ for (uint32_t axis = 0; axis < shape_rank; ++axis)
+ {
+ auto shape_dim = const_shape_input->at<loco::DataType::S32>(axis);
+ if (shape_dim == -1)
+ {
+ // Reshape's new shape has wildcard dimension, i.e. dynamic reshape
+ return unknown;
+ }
+ assert(shape_dim >= 1);
+ output_shape.dim(axis) = shape_dim;
+ }
+
+ // TODO Compare 'tensor' input and validate coherency?
+ // Not sure this is appropriate stage for this task.
+
+ return loco::NodeShape(output_shape);
+ }
+
+ loco::NodeShape visit(const moco::TFRsqrt *node) final
+ {
+ return node_shape_with_check(node->x());
+ }
+
+ loco::NodeShape visit(const moco::TFShape *node) final
+ {
+ auto input_shape = node_shape(node->input());
+ auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ loco::TensorShape output_shape;
+
+ // Note that input shape becomes node(TFShape)'s value
+ output_shape.rank(1);
+ output_shape.dim(0) = input_tensor_shape.rank();
+
+ return loco::NodeShape(output_shape);
+ }
+
+ loco::NodeShape visit(const moco::TFSoftmax *node) final
+ {
+ return node_shape_with_check(node->logits());
+ }
+
+ loco::NodeShape visit(const moco::TFSqrt *node) final { return node_shape_with_check(node->x()); }
+
+ loco::NodeShape visit(const moco::TFSquaredDifference *node) final
+ {
+ return binary_node_shape(node);
+ }
+
+ loco::NodeShape visit(const moco::TFSqueeze *node) final
+ {
+ auto input_shape = node_shape(node->input());
+
+ // TODO Not sure Squeeze only get input as Tensor
+ // Note that tensor_shape() has assertion in it
+ auto input_tensor_shape = input_shape.as<loco::TensorShape>();
+
+ auto squeeze_dims_vec = node->squeeze_dims();
+ std::set<int64_t> squeeze_dims(squeeze_dims_vec.cbegin(), squeeze_dims_vec.cend());
+
+ loco::TensorShape output_shape;
+ uint32_t output_rank = 0;
+
+ if (squeeze_dims.empty())
+ {
+ // Remove all dimensions whose value is 1
+ for (uint32_t axis = 0; axis < input_tensor_shape.rank(); ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (dim != 1)
+ {
+ assert(dim > 1);
+ output_shape.rank(++output_rank);
+ output_shape.dim(output_rank - 1) = dim;
+ }
+ }
+ }
+ else
+ {
+ uint32_t input_rank = input_tensor_shape.rank();
+
+ // Sanity check for 'squeeze_dims'
+ auto is_valid_squeeze_dims = [&squeeze_dims, &input_rank]() {
+ if (!(squeeze_dims.size() < input_rank))
+ return false;
+ for (auto squeeze_dim : squeeze_dims)
+ {
+ if (!(squeeze_dim >= -(int64_t)input_rank))
+ return false;
+ if (!(squeeze_dim < (int64_t)input_rank))
+ return false;
+ }
+ return true;
+ };
+
+ if (!is_valid_squeeze_dims())
+ {
+ throw std::runtime_error("Fix shape for TFSqueeze: invalid squeeze dimension");
+ }
+
+ // Resolve negative squeeze dimension
+ std::set<int64_t> resolved_squeeze_dims;
+ for (auto squeeze_dim : squeeze_dims)
+ {
+ if (squeeze_dim < 0)
+ resolved_squeeze_dims.insert(squeeze_dim + (int64_t)input_rank);
+ else
+ resolved_squeeze_dims.insert(squeeze_dim);
+ }
+
+ // Remove squeeze dimensions only
+ for (uint32_t axis = 0; axis < input_rank; ++axis)
+ {
+ assert(input_tensor_shape.dim(axis).known());
+ auto dim = input_tensor_shape.dim(axis).value();
+ if (resolved_squeeze_dims.find((int64_t)axis) == resolved_squeeze_dims.cend())
+ {
+ // Not squeeze dim
+ output_shape.rank(++output_rank);
+ output_shape.dim(output_rank - 1) = dim;
+ }
+ else
+ {
+ // Is squeeze dim
+ assert(dim == 1);
+ // DO NOTHING
+ }
+ }
+ }
+
+ assert(output_shape.rank() > 0);
+
+ return loco::NodeShape(output_shape);
+ }
+
+ loco::NodeShape visit(const moco::TFStopGradient *node) final
+ {
+ return node_shape_with_check(node->input());
+ }
+
+ loco::NodeShape visit(const moco::TFSub *node) final { return binary_node_shape(node); }
+
+ loco::NodeShape visit(const moco::TFTanh *node) final { return node_shape_with_check(node->x()); }
+
+public:
+ loco::NodeShape visit(const moco::TFNode *) final
+ {
+ loco::NodeShape unknown;
+ return unknown;
+ }
+};
+
+} // namespace
+
+namespace
+{
+namespace compat
+{
+
+struct Context final : public loco::ShapeInferenceRule::Context
+{
+ bool known(const loco::Node *node) const final { return loco::shape_known(node); }
+ loco::NodeShape get(const loco::Node *node) const final { return loco::shape_get(node); }
+};
+
+class Sink final : public loco::ShapeInferenceRule::Sink
+{
+public:
+ enum Status
+ {
+ Unknown,
+ Okay,
+ Fail,
+ };
+
+public:
+ const Status &status(void) const { return _status; }
+ const loco::NodeShape &shape(void) const { return _shape; }
+
+public:
+ void okay(const loco::NodeShape &shape) final
+ {
+ _status = Okay;
+ _shape = shape;
+ }
+
+ void fail(void) final
+ {
+ // Notify failrue
+ _status = Fail;
+ }
+
+private:
+ Status _status = Unknown;
+ loco::NodeShape _shape;
+};
+
+} // namespace compat
+} // namespace
+
+namespace moco
+{
+
+bool TFShapeInferenceRule::support(const API &api) const
+{
+ return api == API::V1 or api == API::V2;
+}
+
+bool TFShapeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ // handle only TensorFlow dialect
+ return TFDialect::get() == d;
+}
+
+bool TFShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
+{
+ ::compat::Context ctx;
+ ::compat::Sink sink;
+
+ infer(&ctx, node, &sink);
+
+ assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail);
+
+ if (sink.status() == ::compat::Sink::Fail)
+ {
+ return false;
+ }
+
+ shape = sink.shape();
+
+ return true;
+}
+
+void TFShapeInferenceRule::infer(const Context *ctx, const loco::Node *node, Sink *sink) const
+{
+ assert(node->dialect() == TFDialect::get());
+ assert(dynamic_cast<const TFNode *>(node) != nullptr);
+
+ ShapeInferenceAlgorithm alg{ctx};
+ auto shape = dynamic_cast<const TFNode *>(node)->accept(&alg);
+
+ if (shape.domain() == loco::Domain::Unknown)
+ sink->fail();
+ else
+ sink->okay(shape);
+}
+
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "moco/Service/TFShapeInferenceRule.h"
+
+#include "TestHelper.h"
+
+#include "moco/IR/TFNodes.h"
+
+#include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <gtest/gtest.h>
+
+using namespace moco::test;
+
+namespace
+{
+
+moco::TFAvgPool *avgpool_network_simple1331(loco::Graph *graph)
+{
+ auto avgpool_node = graph->nodes()->create<moco::TFAvgPool>();
+
+ avgpool_node->data_layout("NHWC");
+ avgpool_node->ksize({1, 3, 3, 1});
+ avgpool_node->strides({1, 1, 1, 1});
+
+ // Dummy const node as ifm, just to fake TFShapeInferenceRule for TFAvgPool.
+ auto const_node = graph->nodes()->create<moco::TFConst>();
+ {
+ const_node->rank(4);
+ const_node->dim(0).set(1);
+ const_node->dim(1).set(3);
+ const_node->dim(2).set(3);
+ const_node->dim(3).set(1);
+ }
+ avgpool_node->value(const_node);
+
+ setup_output_node(graph, avgpool_node);
+
+ return avgpool_node;
+}
+
+} // namespace
+
+TEST(TFShapeInferenceRule, avgpool_same)
+{
+ moco::TFShapeInferenceRule shape_infer;
+ loco::Graph graph;
+
+ auto avgpool_node = avgpool_network_simple1331(&graph);
+ avgpool_node->padding("SAME");
+
+ bool cont = true;
+ while (cont)
+ {
+ cont = loco::apply(&shape_infer).to(&graph);
+ };
+
+ auto nodeshape = loco::shape_get(avgpool_node);
+ auto tshape = nodeshape.as<loco::TensorShape>();
+ ASSERT_EQ(tshape.rank(), 4);
+ ASSERT_EQ(tshape.dim(0).value(), 1);
+ ASSERT_EQ(tshape.dim(1).value(), 3);
+ ASSERT_EQ(tshape.dim(2).value(), 3);
+ ASSERT_EQ(tshape.dim(3).value(), 1);
+}
+
+TEST(TFShapeInferenceRule, avgpool_valid)
+{
+ moco::TFShapeInferenceRule shape_infer;
+ loco::Graph graph;
+
+ auto avgpool_node = avgpool_network_simple1331(&graph);
+ avgpool_node->padding("VALID");
+
+ bool cont = true;
+ while (cont)
+ {
+ cont = loco::apply(&shape_infer).to(&graph);
+ };
+
+ auto nodeshape = loco::shape_get(avgpool_node);
+ auto tshape = nodeshape.as<loco::TensorShape>();
+ ASSERT_EQ(tshape.rank(), 4);
+ ASSERT_EQ(tshape.dim(0).value(), 1);
+ ASSERT_EQ(tshape.dim(1).value(), 1);
+ ASSERT_EQ(tshape.dim(2).value(), 1);
+ ASSERT_EQ(tshape.dim(3).value(), 1);
+}
+
+namespace
+{
+
+void conv2d_test(const std::array<uint32_t, 4> ifm_shape, const std::array<uint32_t, 4> ker_shape,
+ const std::array<uint32_t, 2> stride_h_w, std::string padding,
+ const std::array<uint32_t, 4> expected_shape)
+{
+ moco::TFShapeInferenceRule shape_infer;
+ loco::Graph graph;
+
+ auto conv2d_node = graph.nodes()->create<moco::TFConv2D>();
+ conv2d_node->data_layout("NHWC");
+ conv2d_node->strides({1, stride_h_w[0], stride_h_w[1], 1});
+ conv2d_node->padding(padding);
+
+ auto ifm_node = graph.nodes()->create<moco::TFConst>();
+ {
+ ifm_node->rank(4);
+ ifm_node->dim(0).set(ifm_shape[0]);
+ ifm_node->dim(1).set(ifm_shape[1]);
+ ifm_node->dim(2).set(ifm_shape[2]);
+ ifm_node->dim(3).set(ifm_shape[3]);
+ }
+
+ auto ker_node = graph.nodes()->create<moco::TFConst>();
+ {
+ ker_node->rank(4);
+ ker_node->dim(0).set(ker_shape[0]);
+ ker_node->dim(1).set(ker_shape[1]);
+ ker_node->dim(2).set(ker_shape[2]);
+ ker_node->dim(3).set(ker_shape[3]);
+ }
+
+ conv2d_node->input(ifm_node);
+ conv2d_node->filter(ker_node);
+
+ setup_output_node(&graph, conv2d_node);
+
+ bool cont = true;
+ while (cont)
+ {
+ cont = loco::apply(&shape_infer).to(&graph);
+ };
+
+ auto nodeshape = loco::shape_get(conv2d_node);
+ auto tshape = nodeshape.as<loco::TensorShape>();
+ ASSERT_EQ(tshape.rank(), 4);
+ ASSERT_EQ(tshape.dim(0).value(), expected_shape[0]);
+ ASSERT_EQ(tshape.dim(1).value(), expected_shape[1]);
+ ASSERT_EQ(tshape.dim(2).value(), expected_shape[2]);
+ ASSERT_EQ(tshape.dim(3).value(), expected_shape[3]);
+}
+
+} // namespace
+
+/*
+ Testing "InceptionV3/InceptionV3/Conv2d_1a_3x3/Conv2D" Conv2D node in Inception_v3:
+ The result shape of this test is generated with the code below:
+
+ ifm = tf.constant(value=1.1, shape=[1, 299, 299, 3])
+ ker = tf.constant(value=1.1, shape=[3, 3, 3, 32])
+
+ out = tf.nn.conv2d(ifm, ker, strides = [1, 2, 2, 1], padding= 'VALID')
+
+ with tf.Session() as sess:
+ res = sess.run(out)
+ print(res.shape)
+ */
+TEST(TFShapeInferenceRule, conv2d_VALID)
+{
+ conv2d_test({1, 299, 299, 3}, // ifm
+ {3, 3, 3, 32}, // ker
+ {2, 2}, // strides
+ "VALID", // padding
+ {1, 149, 149, 32}); // expected shape after FixShape
+}
+
+/*
+ Testing "InceptionV3/InceptionV3/Conv2d_2b_3x3/Conv2D" Conv2D node in Inception_v3:
+ The result shape of this test is generated with the code below:
+
+ ifm = tf.constant(value=1.1, shape=[1, 147, 147, 32])
+ ker = tf.constant(value=1.1, shape=[3, 3, 32, 64])
+
+ out = tf.nn.conv2d(ifm, ker, strides = [1, 1, 1, 1], padding= 'SAME')
+
+ with tf.Session() as sess:
+ res = sess.run(out)
+ print(res.shape)
+ */
+TEST(TFShapeInferenceRule, conv2d_SAME)
+{
+ conv2d_test({1, 147, 147, 32}, // ifm
+ {3, 3, 32, 64}, // ker
+ {1, 1}, // strides
+ "SAME", // padding
+ {1, 147, 147, 64}); // expected shape after FixShape
+}
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "moco/Service/TFTypeInferenceRule.h"
+
+#include "moco/IR/TFDialect.h"
+#include "moco/IR/TFNodeVisitor.h"
+#include "moco/IR/TFNodes.h"
+
+#include "moco/IR/TFNodeImpl.h"
+
+#include <cassert>
+
+namespace
+{
+
+using namespace moco;
+
+struct TypeForwardAlgorithm final : public moco::TFNodeVisitor<loco::DataType>
+{
+ loco::DataType visit(const TFAdd *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFAvgPool *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFBiasAdd *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFConcatV2 *node) { return dtype_get(node->values(0)); }
+
+ loco::DataType visit(const TFConst *node) { return node->dtype(); }
+
+ loco::DataType visit(const TFConv2D *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFConv2DBackpropInput *node)
+ {
+ return dtype_get(node->out_backprop());
+ }
+ loco::DataType visit(const TFDepthwiseConv2dNative *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFFusedBatchNorm *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFIdentity *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFMaxPool *node) { return dtype_get(node->value()); }
+ loco::DataType visit(const TFMean *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFMul *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFPad *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFRealDiv *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFRelu *node) { return dtype_get(node->features()); }
+ loco::DataType visit(const TFRelu6 *node) { return dtype_get(node->features()); }
+ loco::DataType visit(const TFReshape *node) { return dtype_get(node->tensor()); }
+ loco::DataType visit(const TFRsqrt *node) { return dtype_get(node->x()); }
+
+ loco::DataType visit(const TFShape *node) { return node->dtype(); }
+
+ loco::DataType visit(const TFSoftmax *node) { return dtype_get(node->logits()); }
+ loco::DataType visit(const TFSqrt *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFSquaredDifference *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFSqueeze *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFStopGradient *node) { return dtype_get(node->input()); }
+ loco::DataType visit(const TFSub *node) { return dtype_get(node->x()); }
+ loco::DataType visit(const TFTanh *node) { return dtype_get(node->x()); }
+};
+
+} // namespace
+
+namespace moco
+{
+
+bool TFTypeInferenceRule::recognize(const loco::Dialect *d) const
+{
+ // This rule recognizes only "TFDialect" dialect!
+ return TFDialect::get() == d;
+}
+
+bool TFTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
+{
+ assert(node->dialect() == TFDialect::get());
+
+ TypeForwardAlgorithm alg;
+
+// clang-format off
+#define TENSORFLOW_NODE(OPCODE,CLASS) \
+ if (dynamic_cast<const moco::CLASS *>(node)) \
+ { \
+ auto tfnode = dynamic_cast<const moco::CLASS *>(node); \
+ dtype = tfnode->accept(&alg); \
+ assert(dtype != loco::DataType::Unknown); \
+ return true; \
+ }
+#include "moco/IR/TFNodes.lst"
+#undef TENSORFLOW_NODE
+ // clang-format on
+
+ return false;
+}
+
+} // namespace moco
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TEST_HELPER_H__
+#define __TEST_HELPER_H__
+
+#include <loco.h>
+
+namespace moco
+{
+namespace test
+{
+
+template <typename T> T *find_first_node_bytype(loco::Graph *g)
+{
+ T *first_node = nullptr;
+ loco::Graph::NodeContext *nodes = g->nodes();
+ uint32_t count = nodes->size();
+
+ for (uint32_t i = 0; i < count; ++i)
+ {
+ first_node = dynamic_cast<T *>(nodes->at(i));
+ if (first_node != nullptr)
+ break;
+ }
+
+ return first_node;
+}
+
+template <typename T> std::vector<T *> find_nodes_bytype(loco::Graph *g)
+{
+ std::vector<T *> find_nodes;
+ loco::Graph::NodeContext *nodes = g->nodes();
+ uint32_t count = nodes->size();
+
+ for (uint32_t i = 0; i < count; ++i)
+ {
+ auto node = dynamic_cast<T *>(nodes->at(i));
+ if (node != nullptr)
+ find_nodes.push_back(node);
+ }
+
+ return find_nodes;
+}
+
+/**
+ * @brief Append setup output of graph by adding loco::Push node
+ *
+ * @note This is subject to change when loco changes I/O treatment
+ */
+void setup_output_node(loco::Graph *graph, loco::Node *last_node);
+
+} // namespace test
+} // namespace moco
+
+#endif // __TEST_HELPER_H__
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+namespace moco
+{
+namespace test
+{
+
+void setup_output_node(loco::Graph *graph, loco::Node *last_node)
+{
+ // add push as output
+ auto push_node = graph->nodes()->create<loco::Push>();
+ push_node->from(last_node);
+
+ // set the graph output name and node object
+ auto graph_output = graph->outputs()->create();
+ graph_output->name("output");
+ graph_output->dtype(loco::DataType::FLOAT32);
+ loco::link(graph_output, push_node);
+}
+
+} // namespace test
+} // namespace moco