#include <loco/Service/ShapeInference.h>
#include <loco/Service/CanonicalShapeInferenceRule.h>
-#include <pepper/strcast.h>
-#include <stdex/Memory.h>
-
-#include <type_traits>
-
-#include <cstdlib>
-
-namespace
-{
-
-// This Knob is a temporary workaround for incermental migration
-//
-// TODO Remove this workaround!
-struct Knob
-{
- Knob()
- {
- // On by default
- auto s = std::getenv("EXOTFLITE_USE_LOCO_SHAPE_INFERENCE");
- enable_loco_shape_inferene_framework = pepper::safe_strcast<int>(s, 1 /* DEFAULT */) != 0;
- }
-
- bool enable_loco_shape_inferene_framework = false;
-};
-
-Knob knob;
-
-template <typename T, typename If = typename std::enable_if<std::is_integral<T>::value, int>::type>
-T ceil_div(T dividend, T divisor)
-{
- assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
- return ((dividend - 1) / divisor) + 1;
-}
-
-/**
- * @brief Record the (tensor) shape of each loco node
- */
-struct ShapeContext
-{
- std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
-};
-
-int32_t decodeShapeDimension(const loco::Dimension &dim)
-{
- if (!dim.known())
- return -1;
- return dim.value();
-}
-
-loco::Dimension encodeShapeDimension(const int32_t &value)
-{
- if (value == -1)
- return loco::Dimension();
- return {static_cast<uint32_t>(value)};
-}
-
-class ShapeGetter final : public loco::CanonicalNodeMutableVisitor<ShapeDescription>
-{
-public:
- ShapeGetter(ShapeContext &ctx) : gd{ctx}
- {
- // DO NOTHING
- }
-
-public:
-#define NODE(NAME) ShapeDescription visit(loco::NAME *node) final;
- NODE(ConstGen)
- NODE(Pull)
- NODE(Push)
- NODE(FeatureEncode)
- NODE(FeatureDecode)
- NODE(FilterEncode)
- NODE(MaxPool2D)
- NODE(AvgPool2D)
- NODE(Conv2D)
- NODE(ReLU)
- NODE(TensorConcat)
- NODE(BiasEncode)
- NODE(TensorBiasAdd)
- NODE(FeatureBiasAdd)
- NODE(EltwiseAdd)
- NODE(EltwiseMul)
- NODE(EltwiseSub)
- NODE(EltwiseDiv)
-#undef NODE
- // TODO Put all the visit method implementations inside this class declaration
- ShapeDescription visit(loco::ReLU6 *node) { return gd._node_to_shape[node->input()]; }
-
-private:
- ShapeContext &gd;
-};
-
-ShapeDescription ShapeGetter::visit(loco::Pull *node)
-{
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.reserve(node->rank());
- for (uint32_t i = 0; i < node->rank(); ++i)
- {
- shape._dims.push_back(decodeShapeDimension(node->dim(i)));
- }
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::Push *node) { return gd._node_to_shape[node->from()]; }
-
-ShapeDescription ShapeGetter::visit(loco::ConstGen *node)
-{
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.reserve(node->rank());
- for (uint32_t i = 0; i < node->rank(); ++i)
- {
- shape._dims.push_back(decodeShapeDimension(node->dim(i)));
- }
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::MaxPool2D *node)
-{
- loco::Node *pred = node->ifm();
- const ShapeDescription &pred_shape = gd._node_to_shape[pred];
- if (!pred_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.resize(4);
- shape._dims[0] = pred_shape._dims[0];
- shape._dims[3] = pred_shape._dims[3];
- tflite::Padding padding = getOpPadding(node->pad());
- switch (padding)
- {
- case tflite::Padding_SAME:
- {
- auto height = static_cast<uint32_t>(pred_shape._dims[1]);
- auto width = static_cast<uint32_t>(pred_shape._dims[2]);
-
- int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
- int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
- shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_res_height;
- shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_res_width;
- break;
- }
- case tflite::Padding_VALID:
- {
- auto padded_h = static_cast<uint32_t>(pred_shape._dims[1] - (node->window()->vertical() - 1));
- auto padded_w =
- static_cast<uint32_t>(pred_shape._dims[2] - (node->window()->horizontal() - 1));
-
- int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
- int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
- shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_height;
- shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_width;
- break;
- }
- default:
- assert(false && "unknown padding type");
- }
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::AvgPool2D *node)
-{
- const ShapeDescription &ifm_shape = gd._node_to_shape[node->ifm()];
- assert(ifm_shape._rank_known);
-
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.resize(4);
- shape._dims[0] = ifm_shape._dims[0]; // copy batch
- shape._dims[3] = ifm_shape._dims[3]; // copy channel
-
- tflite::Padding padding = getOpPadding(node->pad());
- switch (padding)
- {
- case tflite::Padding_SAME:
- {
- auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
- auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
-
- int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
- int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
- break;
- }
- case tflite::Padding_VALID:
- {
- auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (node->window()->vertical() - 1));
- auto padded_w =
- static_cast<uint32_t>(ifm_shape._dims[2] - (node->window()->horizontal() - 1));
-
- int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
- int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
- break;
- }
- default:
- assert(false && "unknown padding type");
- }
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::Conv2D *node)
-{
- loco::Node *ifm = node->ifm();
- const ShapeDescription &ifm_shape = gd._node_to_shape[ifm];
- if (!ifm_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
- assert(ker);
- const ShapeDescription &ker_shape = gd._node_to_shape[ker];
- if (!ker_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- ShapeDescription shape;
- shape._rank_known = true;
- shape._dims.resize(4);
- shape._dims[0] = ifm_shape._dims[0];
- shape._dims[3] = ker_shape._dims[0];
- tflite::Padding padding = getOpPadding(node->pad());
- switch (padding)
- {
- case tflite::Padding_SAME:
- {
- auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
- auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
-
- int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
- int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
- break;
- }
- case tflite::Padding_VALID:
- {
- auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (ker_shape._dims[1] - 1));
- auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (ker_shape._dims[2] - 1));
-
- int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
- int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
- shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
- shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
- break;
- }
- default:
- assert(false && "unknown padding type");
- }
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::ReLU *node) { return gd._node_to_shape[node->input()]; }
-
-ShapeDescription ShapeGetter::visit(loco::FeatureEncode *node)
-{
- const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
- if (!pred_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- loco::TensorShape tensor_shape;
- uint32_t num_dims = pred_shape._dims.size();
- tensor_shape.rank(num_dims);
- for (uint32_t i = 0; i < num_dims; ++i)
- {
- tensor_shape.dim(i) = encodeShapeDimension(pred_shape._dims[i]);
- }
- loco::FeatureShape feature_shape = node->encoder()->shape(tensor_shape);
- shape._dims.resize(4);
- shape._dims[0] = decodeShapeDimension(feature_shape.count());
- shape._dims[1] = decodeShapeDimension(feature_shape.height());
- shape._dims[2] = decodeShapeDimension(feature_shape.width());
- shape._dims[3] = decodeShapeDimension(feature_shape.depth());
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::FeatureDecode *node)
-{
- const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
- if (!pred_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- loco::FeatureShape feature_shape;
- feature_shape.count() = encodeShapeDimension(pred_shape._dims[0]);
- feature_shape.height() = encodeShapeDimension(pred_shape._dims[1]);
- feature_shape.width() = encodeShapeDimension(pred_shape._dims[2]);
- feature_shape.depth() = encodeShapeDimension(pred_shape._dims[3]);
- loco::TensorShape tensor_shape = node->decoder()->shape(feature_shape);
- shape._dims.resize(4);
- for (uint32_t i = 0; i < 4; ++i)
- {
- shape._dims[i] = decodeShapeDimension(tensor_shape.dim(i));
- }
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::FilterEncode *node)
-{
- const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
- if (!input_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
- ShapeDescription shape;
- shape._rank_known = true;
- loco::TensorShape tensor_shape;
- uint32_t num_dims = input_shape._dims.size();
- tensor_shape.rank(num_dims);
- for (uint32_t i = 0; i < num_dims; ++i)
- {
- tensor_shape.dim(i) = encodeShapeDimension(input_shape._dims[i]);
- }
- loco::FilterShape filter_shape = node->encoder()->shape(tensor_shape);
- shape._dims.resize(4);
- shape._dims[0] = decodeShapeDimension(filter_shape.count());
- shape._dims[1] = decodeShapeDimension(filter_shape.height());
- shape._dims[2] = decodeShapeDimension(filter_shape.width());
- shape._dims[3] = decodeShapeDimension(filter_shape.depth());
- return shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::TensorConcat *node)
-{
- const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
- if (!lhs_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
- if (!rhs_shape._rank_known)
- {
- // return unknown shape
- return {};
- }
-
- ShapeDescription ret;
-
- assert(lhs_shape._dims.size() == rhs_shape._dims.size());
- ret._dims.resize(lhs_shape._dims.size());
-
- uint32_t axis = node->axis();
-
- for (uint32_t i = 0; i < lhs_shape._dims.size(); ++i)
- {
- if (i == axis)
- {
- ret._dims[i] = lhs_shape._dims[i] + rhs_shape._dims[i];
- }
- else
- {
- assert(lhs_shape._dims[i] == rhs_shape._dims[i]);
- ret._dims[i] = lhs_shape._dims[i];
- }
- }
- ret._rank_known = true;
-
- return ret;
-}
-
-ShapeDescription ShapeGetter::visit(loco::BiasEncode *node)
-{
- const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
-
- // Bias should be rank 1
- assert(input_shape._dims.size() == 1);
-
- return input_shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::BiasAdd<loco::Domain::Tensor> *node)
-{
- const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
- const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
-
- // For TFlite, only supports last bias add axis. Unless, broadcasting is not performed as
- // expected.
- assert(node->axis() == value_shape._dims.size() - 1);
-
- // Bias should be rank 1
- assert(bias_shape._dims.size() == 1);
-
- // Channel count coherency for proper broadcast
- assert(bias_shape._dims[0] == value_shape._dims[node->axis()]);
-
- return value_shape;
-}
-
-// TODO Reduce code duplication
-ShapeDescription ShapeGetter::visit(loco::FeatureBiasAdd *node)
-{
- const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
- const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
-
- // Bias should be rank 1
- assert(bias_shape._dims.size() == 1);
-
- // Channel count coherency for proper broadcast
- // Feature in T/F Lite uses NHWC layout
- assert(bias_shape._dims[0] == value_shape._dims[3]);
-
- return value_shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::EltwiseAdd *node)
-{
- const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
- const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
-
- assert(lhs_shape._dims == rhs_shape._dims);
-
- return lhs_shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::EltwiseMul *node)
-{
- const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
- const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
-
- assert(lhs_shape._dims == rhs_shape._dims);
-
- return lhs_shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::EltwiseSub *node)
-{
- const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
- const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
-
- assert(lhs_shape._dims == rhs_shape._dims);
-
- return lhs_shape;
-}
-
-ShapeDescription ShapeGetter::visit(loco::EltwiseDiv *node)
-{
- const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
- const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
-
- assert(lhs_shape._dims == rhs_shape._dims);
-
- return lhs_shape;
-}
-} // namespace
-
-namespace
-{
-
-class ShapeAnnotation : public loco::NodeAnnotation
-{
-public:
- ShapeAnnotation(const ShapeDescription &shape) : _shape{shape}
- {
- // DO NOTHING
- }
-
-public:
- const ShapeDescription &shape(void) const { return _shape; }
-
-private:
- ShapeDescription _shape;
-};
-
-} // namespace
-
void ShapeInference::run(loco::Graph *g)
{
- if (knob.enable_loco_shape_inferene_framework)
+ // TODO Adjust indentation level
{
loco::CanonicalShapeInferenceRule rule;
loco::apply(&rule).to(g);
return;
}
-
- ShapeContext shape_ctx;
- ShapeGetter shape_getter{shape_ctx};
-
- for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
- {
- if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
- {
- auto shape = canonical_node->accept(&shape_getter);
- node->annot(stdex::make_unique<ShapeAnnotation>(shape));
- shape_ctx._node_to_shape[node] = shape;
- }
- }
}
ShapeDescription ShapeInference::get(loco::Node *node)
{
- if (knob.enable_loco_shape_inferene_framework)
+ // TODO Adjust indentation level
{
assert(loco::shape_known(node));
return to_shape_description(loco::shape_get(node));
}
-
- assert(node->annot<ShapeAnnotation>() != nullptr);
- return node->annot<ShapeAnnotation>()->shape();
}