#include "OperationExporter.h"
#include "ExporterUtils.h"
#include "TypeInference.h"
+#include "ShapeInference.h"
using namespace flatbuffers;
using namespace tflite;
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInference.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+#include <stdex/Memory.h>
+
+#include <type_traits>
+
+namespace
+{
+
+template <typename T, typename If = typename std::enable_if<std::is_integral<T>::value, int>::type>
+T ceil_div(T dividend, T divisor)
+{
+ assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
+ return (dividend + divisor - 1) / divisor;
+}
+
+/**
+ * @brief Record the (tensor) shape of each loco node
+ */
+struct ShapeContext
+{
+ std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
+};
+
+} // namespace
+
+int32_t decodeShapeDimension(const loco::Dimension &dim)
+{
+ if (!dim.known())
+ return -1;
+ return dim.value();
+}
+
+loco::Dimension encodeShapeDimension(const int32_t &value)
+{
+ if (value == -1)
+ return loco::Dimension();
+ return {static_cast<uint32_t>(value)};
+}
+
+ShapeDescription getOpResultShape(loco::Pull *node, ShapeContext &)
+{
+ ShapeDescription shape;
+ shape._rank_known = true;
+ shape._dims.reserve(node->rank());
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ {
+ shape._dims.push_back(decodeShapeDimension(node->dim(i)));
+ }
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::Push *node, ShapeContext &gd)
+{
+ return gd._node_to_shape[node->from()];
+}
+
+ShapeDescription getOpResultShape(loco::ConstGen *node, ShapeContext &)
+{
+ ShapeDescription shape;
+ shape._rank_known = true;
+ shape._dims.reserve(node->rank());
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ {
+ shape._dims.push_back(decodeShapeDimension(node->dim(i)));
+ }
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::MaxPool2D *node, ShapeContext &gd)
+{
+ loco::Node *pred = node->ifm();
+ const ShapeDescription &pred_shape = gd._node_to_shape[pred];
+ if (!pred_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+ ShapeDescription shape;
+ shape._rank_known = true;
+ shape._dims.resize(4);
+ shape._dims[0] = pred_shape._dims[0];
+ shape._dims[3] = pred_shape._dims[3];
+ tflite::Padding padding = getOpPadding(node->pad());
+ switch (padding)
+ {
+ case tflite::Padding_SAME:
+ {
+ auto height = static_cast<uint32_t>(pred_shape._dims[1]);
+ auto width = static_cast<uint32_t>(pred_shape._dims[2]);
+
+ int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
+ int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
+
+ shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_res_height;
+ shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_res_width;
+ break;
+ }
+ case tflite::Padding_VALID:
+ {
+ auto padded_h = static_cast<uint32_t>(pred_shape._dims[1] - (node->window()->vertical() - 1));
+ auto padded_w = static_cast<uint32_t>(pred_shape._dims[2] - (node->window()->horizontal() - 1));
+
+ int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
+ int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
+
+ shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_height;
+ shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_width;
+ break;
+ }
+ default:
+ assert(false && "unknown padding type");
+ }
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::AvgPool2D *node, ShapeContext &gd)
+{
+ const ShapeDescription &ifm_shape = gd._node_to_shape[node->ifm()];
+ assert(ifm_shape._rank_known);
+
+ ShapeDescription shape;
+ shape._rank_known = true;
+ shape._dims.resize(4);
+ shape._dims[0] = ifm_shape._dims[0]; // copy batch
+ shape._dims[3] = ifm_shape._dims[3]; // copy channel
+
+ tflite::Padding padding = getOpPadding(node->pad());
+ switch (padding)
+ {
+ case tflite::Padding_SAME:
+ {
+ auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
+ auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
+
+ int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
+ int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
+
+ shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
+ shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
+ break;
+ }
+ case tflite::Padding_VALID:
+ {
+ auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (node->window()->vertical() - 1));
+ auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (node->window()->horizontal() - 1));
+
+ int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
+ int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
+
+ shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
+ shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
+ break;
+ }
+ default:
+ assert(false && "unknown padding type");
+ }
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::Conv2D *node, ShapeContext &gd)
+{
+ loco::Node *ifm = node->ifm();
+ const ShapeDescription &ifm_shape = gd._node_to_shape[ifm];
+ if (!ifm_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+
+ auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
+ assert(ker);
+ const ShapeDescription &ker_shape = gd._node_to_shape[ker];
+ if (!ker_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+
+ ShapeDescription shape;
+ shape._rank_known = true;
+ shape._dims.resize(4);
+ shape._dims[0] = ifm_shape._dims[0];
+ shape._dims[3] = ker_shape._dims[0];
+ tflite::Padding padding = getOpPadding(node->pad());
+ switch (padding)
+ {
+ case tflite::Padding_SAME:
+ {
+ auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
+ auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
+
+ int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
+ int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
+
+ shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
+ shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
+ break;
+ }
+ case tflite::Padding_VALID:
+ {
+ auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (ker_shape._dims[1] - 1));
+ auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (ker_shape._dims[2] - 1));
+
+ int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
+ int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
+
+ shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
+ shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
+ break;
+ }
+ default:
+ assert(false && "unknown padding type");
+ }
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::ReLU *node, ShapeContext &gd)
+{
+ return gd._node_to_shape[node->input()];
+}
+
+ShapeDescription getOpResultShape(loco::FeatureEncode *node, ShapeContext &gd)
+{
+ const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
+ if (!pred_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+ ShapeDescription shape;
+ shape._rank_known = true;
+ loco::TensorShape tensor_shape;
+ uint32_t num_dims = pred_shape._dims.size();
+ tensor_shape.rank(num_dims);
+ for (uint32_t i = 0; i < num_dims; ++i)
+ {
+ tensor_shape.dim(i) = encodeShapeDimension(pred_shape._dims[i]);
+ }
+ loco::FeatureShape feature_shape = node->encoder()->shape(tensor_shape);
+ shape._dims.resize(4);
+ shape._dims[0] = decodeShapeDimension(feature_shape.count());
+ shape._dims[1] = decodeShapeDimension(feature_shape.height());
+ shape._dims[2] = decodeShapeDimension(feature_shape.width());
+ shape._dims[3] = decodeShapeDimension(feature_shape.depth());
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::FeatureDecode *node, ShapeContext &gd)
+{
+ const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
+ if (!pred_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+ ShapeDescription shape;
+ shape._rank_known = true;
+ loco::FeatureShape feature_shape;
+ feature_shape.count() = encodeShapeDimension(pred_shape._dims[0]);
+ feature_shape.height() = encodeShapeDimension(pred_shape._dims[1]);
+ feature_shape.width() = encodeShapeDimension(pred_shape._dims[2]);
+ feature_shape.depth() = encodeShapeDimension(pred_shape._dims[3]);
+ loco::TensorShape tensor_shape = node->decoder()->shape(feature_shape);
+ shape._dims.resize(4);
+ for (uint32_t i = 0; i < 4; ++i)
+ {
+ shape._dims[i] = decodeShapeDimension(tensor_shape.dim(i));
+ }
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::FilterEncode *node, ShapeContext &gd)
+{
+ const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
+ if (!input_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+ ShapeDescription shape;
+ shape._rank_known = true;
+ loco::TensorShape tensor_shape;
+ uint32_t num_dims = input_shape._dims.size();
+ tensor_shape.rank(num_dims);
+ for (uint32_t i = 0; i < num_dims; ++i)
+ {
+ tensor_shape.dim(i) = encodeShapeDimension(input_shape._dims[i]);
+ }
+ loco::FilterShape filter_shape = node->encoder()->shape(tensor_shape);
+ shape._dims.resize(4);
+ shape._dims[0] = decodeShapeDimension(filter_shape.count());
+ shape._dims[1] = decodeShapeDimension(filter_shape.height());
+ shape._dims[2] = decodeShapeDimension(filter_shape.width());
+ shape._dims[3] = decodeShapeDimension(filter_shape.depth());
+ return shape;
+}
+
+ShapeDescription getOpResultShape(loco::TensorConcat *node, ShapeContext &gd)
+{
+ const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+ if (!lhs_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+
+ const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+ if (!rhs_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+
+ ShapeDescription ret;
+
+ assert(lhs_shape._dims.size() == rhs_shape._dims.size());
+ ret._dims.resize(lhs_shape._dims.size());
+
+ uint32_t axis = node->axis();
+
+ for (uint32_t i = 0; i < lhs_shape._dims.size(); ++i)
+ {
+ if (i == axis)
+ {
+ ret._dims[i] = lhs_shape._dims[i] + rhs_shape._dims[i];
+ }
+ else
+ {
+ assert(lhs_shape._dims[i] == rhs_shape._dims[i]);
+ ret._dims[i] = lhs_shape._dims[i];
+ }
+ }
+ ret._rank_known = true;
+
+ return ret;
+}
+
+ShapeDescription getOpResultShape(loco::BiasEncode *node, ShapeContext &gd)
+{
+ const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
+
+ // Bias should be rank 1
+ assert(input_shape._dims.size() == 1);
+
+ return input_shape;
+}
+
+ShapeDescription getOpResultShape(loco::BiasAdd<loco::Domain::Tensor> *node, ShapeContext &gd)
+{
+ const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
+ const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
+
+ // For TFlite, only supports last bias add axis. Unless, broadcasting is not performed as
+ // expected.
+ assert(node->axis() == value_shape._dims.size() - 1);
+
+ // Bias should be rank 1
+ assert(bias_shape._dims.size() == 1);
+
+ // Channel count coherency for proper broadcast
+ assert(bias_shape._dims[0] == value_shape._dims[node->axis()]);
+
+ return value_shape;
+}
+
+// TODO Reduce code duplication
+ShapeDescription getOpResultShape(loco::FeatureBiasAdd *node, ShapeContext &gd)
+{
+ const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
+ const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
+
+ // Bias should be rank 1
+ assert(bias_shape._dims.size() == 1);
+
+ // Channel count coherency for proper broadcast
+ // Feature in T/F Lite uses NHWC layout
+ assert(bias_shape._dims[0] == value_shape._dims[3]);
+
+ return value_shape;
+}
+
+namespace
+{
+
+class ShapeAnnotation : public loco::NodeAnnotation
+{
+public:
+ ShapeAnnotation(const ShapeDescription &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const ShapeDescription &shape(void) const { return _shape; }
+
+private:
+ ShapeDescription _shape;
+};
+
+class ShapeAnnotator final : public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+ ShapeAnnotator() = default;
+
+public:
+#define NODE(NAME) \
+ void visit(loco::NAME *node) final \
+ { \
+ auto s = getOpResultShape(node, _ctx); \
+ node->annot(stdex::make_unique<ShapeAnnotation>(s)); \
+ _ctx._node_to_shape[node] = s; \
+ }
+ NODE(ConstGen)
+ NODE(Pull)
+ NODE(Push)
+ NODE(FeatureEncode)
+ NODE(FeatureDecode)
+ NODE(FilterEncode)
+ NODE(MaxPool2D)
+ NODE(AvgPool2D)
+ NODE(Conv2D)
+ NODE(ReLU)
+ NODE(TensorConcat)
+ NODE(BiasEncode)
+ NODE(TensorBiasAdd)
+ NODE(FeatureBiasAdd)
+#undef NODE
+
+private:
+ // TODO Remove this variable
+ ShapeContext _ctx;
+};
+
+} // namespace
+
+void ShapeInference::run(loco::Graph *g)
+{
+ ShapeAnnotator shape_annotator;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+ {
+ canonical_node->accept(&shape_annotator);
+ }
+ }
+}
+
+ShapeDescription ShapeInference::get(loco::Node *node)
+{
+ assert(node->annot<ShapeAnnotation>() != nullptr);
+ return node->annot<ShapeAnnotation>()->shape();
+}
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __SHAPE_INFERENCE_H__
+#define __SHAPE_INFERENCE_H__
+
+#include "ExporterUtils.h"
+
+#include <loco/IR/Nodes.h>
+
+/**
+ * @brief Annotate the shape of each node as a node annotation
+ *
+ * HOW TO USE
+ *
+ * ShapeInference::run(g);
+ *
+ * ShapeInference::get(g->nodes()->at(..));
+ */
+struct ShapeInference
+{
+ static void run(loco::Graph *g);
+
+ static ShapeDescription get(loco::Node *node);
+};
+
+#endif // __SHAPE_INFERENCE_H__
#include "TFLExporterImpl.h"
#include "TypeInference.h"
+#include "ShapeInference.h"
#include "TensorExporter.h"
#include "OperationExporter.h"
#include "ExporterUtils.h"
#include "TensorExporter.h"
#include "TypeInference.h"
+#include "ShapeInference.h"
// TODO Fix include style
#include "loco/IR/Algorithm.h"
}
}
-template <typename T, typename If = typename std::enable_if<std::is_integral<T>::value, int>::type>
-T ceil_div(T dividend, T divisor)
-{
- assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
- return (dividend + divisor - 1) / divisor;
-}
-
/**
* @brief Record the data type of each loco node
*/
assert(node->annot<TypeAnnotation>() != nullptr);
return node->annot<TypeAnnotation>()->type();
}
-
-namespace
-{
-
-/**
- * @brief Record the (tensor) shape of each loco node
- */
-struct ShapeContext
-{
- std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
-};
-
-} // namespace
-
-int32_t decodeShapeDimension(const loco::Dimension &dim)
-{
- if (!dim.known())
- return -1;
- return dim.value();
-}
-
-loco::Dimension encodeShapeDimension(const int32_t &value)
-{
- if (value == -1)
- return loco::Dimension();
- return {static_cast<uint32_t>(value)};
-}
-
-ShapeDescription getOpResultShape(loco::Pull *node, ShapeContext &)
-{
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.reserve(node->rank());
- for (uint32_t i = 0; i < node->rank(); ++i)
- {
- shape._dims.push_back(decodeShapeDimension(node->dim(i)));
- }
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::Push *node, ShapeContext &gd)
-{
- return gd._node_to_shape[node->from()];
-}
-
-ShapeDescription getOpResultShape(loco::ConstGen *node, ShapeContext &)
-{
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.reserve(node->rank());
- for (uint32_t i = 0; i < node->rank(); ++i)
- {
- shape._dims.push_back(decodeShapeDimension(node->dim(i)));
- }
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::MaxPool2D *node, ShapeContext &gd)
-{
- loco::Node *pred = node->ifm();
- const ShapeDescription &pred_shape = gd._node_to_shape[pred];
- if (!pred_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.resize(4);
- shape._dims[0] = pred_shape._dims[0];
- shape._dims[3] = pred_shape._dims[3];
- tflite::Padding padding = getOpPadding(node->pad());
- switch (padding)
- {
- case tflite::Padding_SAME:
- {
- auto height = static_cast<uint32_t>(pred_shape._dims[1]);
- auto width = static_cast<uint32_t>(pred_shape._dims[2]);
-
- int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
- int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
- shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_res_height;
- shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_res_width;
- break;
- }
- case tflite::Padding_VALID:
- {
- auto padded_h = static_cast<uint32_t>(pred_shape._dims[1] - (node->window()->vertical() - 1));
- auto padded_w = static_cast<uint32_t>(pred_shape._dims[2] - (node->window()->horizontal() - 1));
-
- int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
- int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
- shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_height;
- shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_width;
- break;
- }
- default:
- assert(false && "unknown padding type");
- }
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::AvgPool2D *node, ShapeContext &gd)
-{
- const ShapeDescription &ifm_shape = gd._node_to_shape[node->ifm()];
- assert(ifm_shape._rank_known);
-
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.resize(4);
- shape._dims[0] = ifm_shape._dims[0]; // copy batch
- shape._dims[3] = ifm_shape._dims[3]; // copy channel
-
- tflite::Padding padding = getOpPadding(node->pad());
- switch (padding)
- {
- case tflite::Padding_SAME:
- {
- auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
- auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
-
- int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
- int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
- break;
- }
- case tflite::Padding_VALID:
- {
- auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (node->window()->vertical() - 1));
- auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (node->window()->horizontal() - 1));
-
- int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
- int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
- break;
- }
- default:
- assert(false && "unknown padding type");
- }
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::Conv2D *node, ShapeContext &gd)
-{
- loco::Node *ifm = node->ifm();
- const ShapeDescription &ifm_shape = gd._node_to_shape[ifm];
- if (!ifm_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
- assert(ker);
- const ShapeDescription &ker_shape = gd._node_to_shape[ker];
- if (!ker_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.resize(4);
- shape._dims[0] = ifm_shape._dims[0];
- shape._dims[3] = ker_shape._dims[0];
- tflite::Padding padding = getOpPadding(node->pad());
- switch (padding)
- {
- case tflite::Padding_SAME:
- {
- auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
- auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
-
- int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
- int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
- break;
- }
- case tflite::Padding_VALID:
- {
- auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (ker_shape._dims[1] - 1));
- auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (ker_shape._dims[2] - 1));
-
- int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
- int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
- break;
- }
- default:
- assert(false && "unknown padding type");
- }
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::ReLU *node, ShapeContext &gd)
-{
- return gd._node_to_shape[node->input()];
-}
-
-ShapeDescription getOpResultShape(loco::FeatureEncode *node, ShapeContext &gd)
-{
- const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
- if (!pred_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- loco::TensorShape tensor_shape;
- uint32_t num_dims = pred_shape._dims.size();
- tensor_shape.rank(num_dims);
- for (uint32_t i = 0; i < num_dims; ++i)
- {
- tensor_shape.dim(i) = encodeShapeDimension(pred_shape._dims[i]);
- }
- loco::FeatureShape feature_shape = node->encoder()->shape(tensor_shape);
- shape._dims.resize(4);
- shape._dims[0] = decodeShapeDimension(feature_shape.count());
- shape._dims[1] = decodeShapeDimension(feature_shape.height());
- shape._dims[2] = decodeShapeDimension(feature_shape.width());
- shape._dims[3] = decodeShapeDimension(feature_shape.depth());
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::FeatureDecode *node, ShapeContext &gd)
-{
- const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
- if (!pred_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- loco::FeatureShape feature_shape;
- feature_shape.count() = encodeShapeDimension(pred_shape._dims[0]);
- feature_shape.height() = encodeShapeDimension(pred_shape._dims[1]);
- feature_shape.width() = encodeShapeDimension(pred_shape._dims[2]);
- feature_shape.depth() = encodeShapeDimension(pred_shape._dims[3]);
- loco::TensorShape tensor_shape = node->decoder()->shape(feature_shape);
- shape._dims.resize(4);
- for (uint32_t i = 0; i < 4; ++i)
- {
- shape._dims[i] = decodeShapeDimension(tensor_shape.dim(i));
- }
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::FilterEncode *node, ShapeContext &gd)
-{
- const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
- if (!input_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- loco::TensorShape tensor_shape;
- uint32_t num_dims = input_shape._dims.size();
- tensor_shape.rank(num_dims);
- for (uint32_t i = 0; i < num_dims; ++i)
- {
- tensor_shape.dim(i) = encodeShapeDimension(input_shape._dims[i]);
- }
- loco::FilterShape filter_shape = node->encoder()->shape(tensor_shape);
- shape._dims.resize(4);
- shape._dims[0] = decodeShapeDimension(filter_shape.count());
- shape._dims[1] = decodeShapeDimension(filter_shape.height());
- shape._dims[2] = decodeShapeDimension(filter_shape.width());
- shape._dims[3] = decodeShapeDimension(filter_shape.depth());
- return shape;
-}
-
-ShapeDescription getOpResultShape(loco::TensorConcat *node, ShapeContext &gd)
-{
- const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
- if (!lhs_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
- if (!rhs_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- ShapeDescription ret;
-
- assert(lhs_shape._dims.size() == rhs_shape._dims.size());
- ret._dims.resize(lhs_shape._dims.size());
-
- uint32_t axis = node->axis();
-
- for (uint32_t i = 0; i < lhs_shape._dims.size(); ++i)
- {
- if (i == axis)
- {
- ret._dims[i] = lhs_shape._dims[i] + rhs_shape._dims[i];
- }
- else
- {
- assert(lhs_shape._dims[i] == rhs_shape._dims[i]);
- ret._dims[i] = lhs_shape._dims[i];
- }
- }
- ret._rank_known = true;
-
- return ret;
-}
-
-ShapeDescription getOpResultShape(loco::BiasEncode *node, ShapeContext &gd)
-{
- const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
-
- // Bias should be rank 1
- assert(input_shape._dims.size() == 1);
-
- return input_shape;
-}
-
-ShapeDescription getOpResultShape(loco::BiasAdd<loco::Domain::Tensor> *node, ShapeContext &gd)
-{
- const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
- const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
-
- // For TFlite, only supports last bias add axis. Unless, broadcasting is not performed as
- // expected.
- assert(node->axis() == value_shape._dims.size() - 1);
-
- // Bias should be rank 1
- assert(bias_shape._dims.size() == 1);
-
- // Channel count coherency for proper broadcast
- assert(bias_shape._dims[0] == value_shape._dims[node->axis()]);
-
- return value_shape;
-}
-
-// TODO Reduce code duplication
-ShapeDescription getOpResultShape(loco::FeatureBiasAdd *node, ShapeContext &gd)
-{
- const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
- const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
-
- // Bias should be rank 1
- assert(bias_shape._dims.size() == 1);
-
- // Channel count coherency for proper broadcast
- // Feature in T/F Lite uses NHWC layout
- assert(bias_shape._dims[0] == value_shape._dims[3]);
-
- return value_shape;
-}
-
-namespace
-{
-
-class ShapeAnnotation : public loco::NodeAnnotation
-{
-public:
- ShapeAnnotation(const ShapeDescription &shape) : _shape{shape}
- {
- // DO NOTHING
- }
-
-public:
- const ShapeDescription &shape(void) const { return _shape; }
-
-private:
- ShapeDescription _shape;
-};
-
-class ShapeAnnotator final : public loco::CanonicalNodeMutableVisitor<void>
-{
-public:
- ShapeAnnotator() = default;
-
-public:
-#define NODE(NAME) \
- void visit(loco::NAME *node) final \
- { \
- auto s = getOpResultShape(node, _ctx); \
- node->annot(stdex::make_unique<ShapeAnnotation>(s)); \
- _ctx._node_to_shape[node] = s; \
- }
- NODE(ConstGen)
- NODE(Pull)
- NODE(Push)
- NODE(FeatureEncode)
- NODE(FeatureDecode)
- NODE(FilterEncode)
- NODE(MaxPool2D)
- NODE(AvgPool2D)
- NODE(Conv2D)
- NODE(ReLU)
- NODE(TensorConcat)
- NODE(BiasEncode)
- NODE(TensorBiasAdd)
- NODE(FeatureBiasAdd)
-#undef NODE
-
-private:
- // TODO Remove this variable
- ShapeContext _ctx;
-};
-
-} // namespace
-
-void ShapeInference::run(loco::Graph *g)
-{
- ShapeAnnotator shape_annotator;
-
- for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
- {
- if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
- {
- canonical_node->accept(&shape_annotator);
- }
- }
-}
-
-ShapeDescription ShapeInference::get(loco::Node *node)
-{
- assert(node->annot<ShapeAnnotation>() != nullptr);
- return node->annot<ShapeAnnotation>()->shape();
-}
static tflite::TensorType get(loco::Node *node);
};
-/**
- * @brief Annotate the shape of each node as a node annotation
- *
- * HOW TO USE
- *
- * ShapeInference::run(g);
- *
- * ShapeInference::get(g->nodes()->at(..));
- */
-struct ShapeInference
-{
- static void run(loco::Graph *g);
-
- static ShapeDescription get(loco::Node *node);
-};
-
#endif // __TYPE_INFERENCE_H__