[exo/tflite] Extract ShapeInference (#4255)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 15 Jul 2019 04:47:11 +0000 (13:47 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 15 Jul 2019 04:47:11 +0000 (13:47 +0900)
This commit introduces ShapeInference.h and ShapeInference.cpp and
extracts all the declarations and implementations related with
ShapeInference from TypeInference.h & TypeInference.cpp.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/exo-tflite/src/OperationExporter.cpp
contrib/exo-tflite/src/ShapeInference.cpp [new file with mode: 0644]
contrib/exo-tflite/src/ShapeInference.h [new file with mode: 0644]
contrib/exo-tflite/src/TFLExporterImpl.cpp
contrib/exo-tflite/src/TensorExporter.cpp
contrib/exo-tflite/src/TypeInference.cpp
contrib/exo-tflite/src/TypeInference.h

index 700cf63..5707e55 100644 (file)
@@ -17,6 +17,7 @@
 #include "OperationExporter.h"
 #include "ExporterUtils.h"
 #include "TypeInference.h"
+#include "ShapeInference.h"
 
 using namespace flatbuffers;
 using namespace tflite;
diff --git a/contrib/exo-tflite/src/ShapeInference.cpp b/contrib/exo-tflite/src/ShapeInference.cpp
new file mode 100644 (file)
index 0000000..9574c9e
--- /dev/null
@@ -0,0 +1,473 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeInference.h"
+
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+#include <stdex/Memory.h>
+
+#include <type_traits>
+
+namespace
+{
+
+template <typename T, typename If = typename std::enable_if<std::is_integral<T>::value, int>::type>
+T ceil_div(T dividend, T divisor)
+{
+  assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
+  return (dividend + divisor - 1) / divisor;
+}
+
+/**
+ * @brief Record the (tensor) shape of each loco node
+ */
+struct ShapeContext
+{
+  std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
+};
+
+} // namespace
+
+int32_t decodeShapeDimension(const loco::Dimension &dim)
+{
+  if (!dim.known())
+    return -1;
+  return dim.value();
+}
+
+loco::Dimension encodeShapeDimension(const int32_t &value)
+{
+  if (value == -1)
+    return loco::Dimension();
+  return {static_cast<uint32_t>(value)};
+}
+
+ShapeDescription getOpResultShape(loco::Pull *node, ShapeContext &)
+{
+  ShapeDescription shape;
+  shape._rank_known = true;
+  shape._dims.reserve(node->rank());
+  for (uint32_t i = 0; i < node->rank(); ++i)
+  {
+    shape._dims.push_back(decodeShapeDimension(node->dim(i)));
+  }
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::Push *node, ShapeContext &gd)
+{
+  return gd._node_to_shape[node->from()];
+}
+
+ShapeDescription getOpResultShape(loco::ConstGen *node, ShapeContext &)
+{
+  ShapeDescription shape;
+  shape._rank_known = true;
+  shape._dims.reserve(node->rank());
+  for (uint32_t i = 0; i < node->rank(); ++i)
+  {
+    shape._dims.push_back(decodeShapeDimension(node->dim(i)));
+  }
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::MaxPool2D *node, ShapeContext &gd)
+{
+  loco::Node *pred = node->ifm();
+  const ShapeDescription &pred_shape = gd._node_to_shape[pred];
+  if (!pred_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+  ShapeDescription shape;
+  shape._rank_known = true;
+  shape._dims.resize(4);
+  shape._dims[0] = pred_shape._dims[0];
+  shape._dims[3] = pred_shape._dims[3];
+  tflite::Padding padding = getOpPadding(node->pad());
+  switch (padding)
+  {
+  case tflite::Padding_SAME:
+  {
+    auto height = static_cast<uint32_t>(pred_shape._dims[1]);
+    auto width = static_cast<uint32_t>(pred_shape._dims[2]);
+
+    int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
+    int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
+
+    shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_res_height;
+    shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_res_width;
+    break;
+  }
+  case tflite::Padding_VALID:
+  {
+    auto padded_h = static_cast<uint32_t>(pred_shape._dims[1] - (node->window()->vertical() - 1));
+    auto padded_w = static_cast<uint32_t>(pred_shape._dims[2] - (node->window()->horizontal() - 1));
+
+    int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
+    int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
+
+    shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_height;
+    shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_width;
+    break;
+  }
+  default:
+    assert(false && "unknown padding type");
+  }
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::AvgPool2D *node, ShapeContext &gd)
+{
+  const ShapeDescription &ifm_shape = gd._node_to_shape[node->ifm()];
+  assert(ifm_shape._rank_known);
+
+  ShapeDescription shape;
+  shape._rank_known = true;
+  shape._dims.resize(4);
+  shape._dims[0] = ifm_shape._dims[0]; // copy batch
+  shape._dims[3] = ifm_shape._dims[3]; // copy channel
+
+  tflite::Padding padding = getOpPadding(node->pad());
+  switch (padding)
+  {
+  case tflite::Padding_SAME:
+  {
+    auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
+    auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
+
+    int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
+    int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
+
+    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
+    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
+    break;
+  }
+  case tflite::Padding_VALID:
+  {
+    auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (node->window()->vertical() - 1));
+    auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (node->window()->horizontal() - 1));
+
+    int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
+    int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
+
+    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
+    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
+    break;
+  }
+  default:
+    assert(false && "unknown padding type");
+  }
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::Conv2D *node, ShapeContext &gd)
+{
+  loco::Node *ifm = node->ifm();
+  const ShapeDescription &ifm_shape = gd._node_to_shape[ifm];
+  if (!ifm_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+
+  auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
+  assert(ker);
+  const ShapeDescription &ker_shape = gd._node_to_shape[ker];
+  if (!ker_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+
+  ShapeDescription shape;
+  shape._rank_known = true;
+  shape._dims.resize(4);
+  shape._dims[0] = ifm_shape._dims[0];
+  shape._dims[3] = ker_shape._dims[0];
+  tflite::Padding padding = getOpPadding(node->pad());
+  switch (padding)
+  {
+  case tflite::Padding_SAME:
+  {
+    auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
+    auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
+
+    int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
+    int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
+
+    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
+    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
+    break;
+  }
+  case tflite::Padding_VALID:
+  {
+    auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (ker_shape._dims[1] - 1));
+    auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (ker_shape._dims[2] - 1));
+
+    int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
+    int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
+
+    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
+    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
+    break;
+  }
+  default:
+    assert(false && "unknown padding type");
+  }
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::ReLU *node, ShapeContext &gd)
+{
+  return gd._node_to_shape[node->input()];
+}
+
+ShapeDescription getOpResultShape(loco::FeatureEncode *node, ShapeContext &gd)
+{
+  const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
+  if (!pred_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+  ShapeDescription shape;
+  shape._rank_known = true;
+  loco::TensorShape tensor_shape;
+  uint32_t num_dims = pred_shape._dims.size();
+  tensor_shape.rank(num_dims);
+  for (uint32_t i = 0; i < num_dims; ++i)
+  {
+    tensor_shape.dim(i) = encodeShapeDimension(pred_shape._dims[i]);
+  }
+  loco::FeatureShape feature_shape = node->encoder()->shape(tensor_shape);
+  shape._dims.resize(4);
+  shape._dims[0] = decodeShapeDimension(feature_shape.count());
+  shape._dims[1] = decodeShapeDimension(feature_shape.height());
+  shape._dims[2] = decodeShapeDimension(feature_shape.width());
+  shape._dims[3] = decodeShapeDimension(feature_shape.depth());
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::FeatureDecode *node, ShapeContext &gd)
+{
+  const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
+  if (!pred_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+  ShapeDescription shape;
+  shape._rank_known = true;
+  loco::FeatureShape feature_shape;
+  feature_shape.count() = encodeShapeDimension(pred_shape._dims[0]);
+  feature_shape.height() = encodeShapeDimension(pred_shape._dims[1]);
+  feature_shape.width() = encodeShapeDimension(pred_shape._dims[2]);
+  feature_shape.depth() = encodeShapeDimension(pred_shape._dims[3]);
+  loco::TensorShape tensor_shape = node->decoder()->shape(feature_shape);
+  shape._dims.resize(4);
+  for (uint32_t i = 0; i < 4; ++i)
+  {
+    shape._dims[i] = decodeShapeDimension(tensor_shape.dim(i));
+  }
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::FilterEncode *node, ShapeContext &gd)
+{
+  const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
+  if (!input_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+  ShapeDescription shape;
+  shape._rank_known = true;
+  loco::TensorShape tensor_shape;
+  uint32_t num_dims = input_shape._dims.size();
+  tensor_shape.rank(num_dims);
+  for (uint32_t i = 0; i < num_dims; ++i)
+  {
+    tensor_shape.dim(i) = encodeShapeDimension(input_shape._dims[i]);
+  }
+  loco::FilterShape filter_shape = node->encoder()->shape(tensor_shape);
+  shape._dims.resize(4);
+  shape._dims[0] = decodeShapeDimension(filter_shape.count());
+  shape._dims[1] = decodeShapeDimension(filter_shape.height());
+  shape._dims[2] = decodeShapeDimension(filter_shape.width());
+  shape._dims[3] = decodeShapeDimension(filter_shape.depth());
+  return shape;
+}
+
+ShapeDescription getOpResultShape(loco::TensorConcat *node, ShapeContext &gd)
+{
+  const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+  if (!lhs_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+
+  const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+  if (!rhs_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+
+  ShapeDescription ret;
+
+  assert(lhs_shape._dims.size() == rhs_shape._dims.size());
+  ret._dims.resize(lhs_shape._dims.size());
+
+  uint32_t axis = node->axis();
+
+  for (uint32_t i = 0; i < lhs_shape._dims.size(); ++i)
+  {
+    if (i == axis)
+    {
+      ret._dims[i] = lhs_shape._dims[i] + rhs_shape._dims[i];
+    }
+    else
+    {
+      assert(lhs_shape._dims[i] == rhs_shape._dims[i]);
+      ret._dims[i] = lhs_shape._dims[i];
+    }
+  }
+  ret._rank_known = true;
+
+  return ret;
+}
+
+ShapeDescription getOpResultShape(loco::BiasEncode *node, ShapeContext &gd)
+{
+  const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
+
+  // Bias should be rank 1
+  assert(input_shape._dims.size() == 1);
+
+  return input_shape;
+}
+
+ShapeDescription getOpResultShape(loco::BiasAdd<loco::Domain::Tensor> *node, ShapeContext &gd)
+{
+  const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
+  const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
+
+  // For TFlite, only supports last bias add axis. Unless, broadcasting is not performed as
+  // expected.
+  assert(node->axis() == value_shape._dims.size() - 1);
+
+  // Bias should be rank 1
+  assert(bias_shape._dims.size() == 1);
+
+  // Channel count coherency for proper broadcast
+  assert(bias_shape._dims[0] == value_shape._dims[node->axis()]);
+
+  return value_shape;
+}
+
+// TODO Reduce code duplication
+ShapeDescription getOpResultShape(loco::FeatureBiasAdd *node, ShapeContext &gd)
+{
+  const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
+  const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
+
+  // Bias should be rank 1
+  assert(bias_shape._dims.size() == 1);
+
+  // Channel count coherency for proper broadcast
+  // Feature in T/F Lite uses NHWC layout
+  assert(bias_shape._dims[0] == value_shape._dims[3]);
+
+  return value_shape;
+}
+
+namespace
+{
+
+class ShapeAnnotation : public loco::NodeAnnotation
+{
+public:
+  ShapeAnnotation(const ShapeDescription &shape) : _shape{shape}
+  {
+    // DO NOTHING
+  }
+
+public:
+  const ShapeDescription &shape(void) const { return _shape; }
+
+private:
+  ShapeDescription _shape;
+};
+
+class ShapeAnnotator final : public loco::CanonicalNodeMutableVisitor<void>
+{
+public:
+  ShapeAnnotator() = default;
+
+public:
+#define NODE(NAME)                                       \
+  void visit(loco::NAME *node) final                     \
+  {                                                      \
+    auto s = getOpResultShape(node, _ctx);               \
+    node->annot(stdex::make_unique<ShapeAnnotation>(s)); \
+    _ctx._node_to_shape[node] = s;                       \
+  }
+  NODE(ConstGen)
+  NODE(Pull)
+  NODE(Push)
+  NODE(FeatureEncode)
+  NODE(FeatureDecode)
+  NODE(FilterEncode)
+  NODE(MaxPool2D)
+  NODE(AvgPool2D)
+  NODE(Conv2D)
+  NODE(ReLU)
+  NODE(TensorConcat)
+  NODE(BiasEncode)
+  NODE(TensorBiasAdd)
+  NODE(FeatureBiasAdd)
+#undef NODE
+
+private:
+  // TODO Remove this variable
+  ShapeContext _ctx;
+};
+
+} // namespace
+
+void ShapeInference::run(loco::Graph *g)
+{
+  ShapeAnnotator shape_annotator;
+
+  for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+  {
+    if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
+    {
+      canonical_node->accept(&shape_annotator);
+    }
+  }
+}
+
+ShapeDescription ShapeInference::get(loco::Node *node)
+{
+  assert(node->annot<ShapeAnnotation>() != nullptr);
+  return node->annot<ShapeAnnotation>()->shape();
+}
diff --git a/contrib/exo-tflite/src/ShapeInference.h b/contrib/exo-tflite/src/ShapeInference.h
new file mode 100644 (file)
index 0000000..86c0c59
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __SHAPE_INFERENCE_H__
+#define __SHAPE_INFERENCE_H__
+
+#include "ExporterUtils.h"
+
+#include <loco/IR/Nodes.h>
+
+/**
+ * @brief Annotate the shape of each node as a node annotation
+ *
+ * HOW TO USE
+ *
+ *   ShapeInference::run(g);
+ *
+ *   ShapeInference::get(g->nodes()->at(..));
+ */
+struct ShapeInference
+{
+  static void run(loco::Graph *g);
+
+  static ShapeDescription get(loco::Node *node);
+};
+
+#endif // __SHAPE_INFERENCE_H__
index 63fd151..ad7b6f9 100644 (file)
@@ -17,6 +17,7 @@
 #include "TFLExporterImpl.h"
 
 #include "TypeInference.h"
+#include "ShapeInference.h"
 #include "TensorExporter.h"
 #include "OperationExporter.h"
 #include "ExporterUtils.h"
index 7d0e3c9..70223b8 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "TensorExporter.h"
 #include "TypeInference.h"
+#include "ShapeInference.h"
 
 // TODO Fix include style
 #include "loco/IR/Algorithm.h"
index fdf6034..48ac2eb 100644 (file)
@@ -55,13 +55,6 @@ tflite::TensorType translateLocoTypeToTFLite(loco::DataType dtype)
   }
 }
 
-template <typename T, typename If = typename std::enable_if<std::is_integral<T>::value, int>::type>
-T ceil_div(T dividend, T divisor)
-{
-  assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
-  return (dividend + divisor - 1) / divisor;
-}
-
 /**
  * @brief Record the data type of each loco node
  */
@@ -222,445 +215,3 @@ tflite::TensorType TypeInference::get(loco::Node *node)
   assert(node->annot<TypeAnnotation>() != nullptr);
   return node->annot<TypeAnnotation>()->type();
 }
-
-namespace
-{
-
-/**
- * @brief Record the (tensor) shape of each loco node
- */
-struct ShapeContext
-{
-  std::unordered_map<loco::Node *, ShapeDescription> _node_to_shape;
-};
-
-} // namespace
-
-int32_t decodeShapeDimension(const loco::Dimension &dim)
-{
-  if (!dim.known())
-    return -1;
-  return dim.value();
-}
-
-loco::Dimension encodeShapeDimension(const int32_t &value)
-{
-  if (value == -1)
-    return loco::Dimension();
-  return {static_cast<uint32_t>(value)};
-}
-
-ShapeDescription getOpResultShape(loco::Pull *node, ShapeContext &)
-{
-  ShapeDescription shape;
-  shape._rank_known = true;
-  shape._dims.reserve(node->rank());
-  for (uint32_t i = 0; i < node->rank(); ++i)
-  {
-    shape._dims.push_back(decodeShapeDimension(node->dim(i)));
-  }
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::Push *node, ShapeContext &gd)
-{
-  return gd._node_to_shape[node->from()];
-}
-
-ShapeDescription getOpResultShape(loco::ConstGen *node, ShapeContext &)
-{
-  ShapeDescription shape;
-  shape._rank_known = true;
-  shape._dims.reserve(node->rank());
-  for (uint32_t i = 0; i < node->rank(); ++i)
-  {
-    shape._dims.push_back(decodeShapeDimension(node->dim(i)));
-  }
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::MaxPool2D *node, ShapeContext &gd)
-{
-  loco::Node *pred = node->ifm();
-  const ShapeDescription &pred_shape = gd._node_to_shape[pred];
-  if (!pred_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-  ShapeDescription shape;
-  shape._rank_known = true;
-  shape._dims.resize(4);
-  shape._dims[0] = pred_shape._dims[0];
-  shape._dims[3] = pred_shape._dims[3];
-  tflite::Padding padding = getOpPadding(node->pad());
-  switch (padding)
-  {
-  case tflite::Padding_SAME:
-  {
-    auto height = static_cast<uint32_t>(pred_shape._dims[1]);
-    auto width = static_cast<uint32_t>(pred_shape._dims[2]);
-
-    int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
-    int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
-    shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_res_height;
-    shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_res_width;
-    break;
-  }
-  case tflite::Padding_VALID:
-  {
-    auto padded_h = static_cast<uint32_t>(pred_shape._dims[1] - (node->window()->vertical() - 1));
-    auto padded_w = static_cast<uint32_t>(pred_shape._dims[2] - (node->window()->horizontal() - 1));
-
-    int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
-    int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
-    shape._dims[1] = pred_shape._dims[1] == -1 ? -1 : proposed_height;
-    shape._dims[2] = pred_shape._dims[2] == -1 ? -1 : proposed_width;
-    break;
-  }
-  default:
-    assert(false && "unknown padding type");
-  }
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::AvgPool2D *node, ShapeContext &gd)
-{
-  const ShapeDescription &ifm_shape = gd._node_to_shape[node->ifm()];
-  assert(ifm_shape._rank_known);
-
-  ShapeDescription shape;
-  shape._rank_known = true;
-  shape._dims.resize(4);
-  shape._dims[0] = ifm_shape._dims[0]; // copy batch
-  shape._dims[3] = ifm_shape._dims[3]; // copy channel
-
-  tflite::Padding padding = getOpPadding(node->pad());
-  switch (padding)
-  {
-  case tflite::Padding_SAME:
-  {
-    auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
-    auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
-
-    int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
-    int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
-    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
-    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
-    break;
-  }
-  case tflite::Padding_VALID:
-  {
-    auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (node->window()->vertical() - 1));
-    auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (node->window()->horizontal() - 1));
-
-    int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
-    int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
-    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
-    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
-    break;
-  }
-  default:
-    assert(false && "unknown padding type");
-  }
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::Conv2D *node, ShapeContext &gd)
-{
-  loco::Node *ifm = node->ifm();
-  const ShapeDescription &ifm_shape = gd._node_to_shape[ifm];
-  if (!ifm_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-
-  auto *ker = dynamic_cast<loco::FilterEncode *>(node->ker());
-  assert(ker);
-  const ShapeDescription &ker_shape = gd._node_to_shape[ker];
-  if (!ker_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-
-  ShapeDescription shape;
-  shape._rank_known = true;
-  shape._dims.resize(4);
-  shape._dims[0] = ifm_shape._dims[0];
-  shape._dims[3] = ker_shape._dims[0];
-  tflite::Padding padding = getOpPadding(node->pad());
-  switch (padding)
-  {
-  case tflite::Padding_SAME:
-  {
-    auto height = static_cast<uint32_t>(ifm_shape._dims[1]);
-    auto width = static_cast<uint32_t>(ifm_shape._dims[2]);
-
-    int32_t proposed_res_height = ceil_div(height, node->stride()->vertical());
-    int32_t proposed_res_width = ceil_div(width, node->stride()->horizontal());
-
-    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_res_height;
-    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_res_width;
-    break;
-  }
-  case tflite::Padding_VALID:
-  {
-    auto padded_h = static_cast<uint32_t>(ifm_shape._dims[1] - (ker_shape._dims[1] - 1));
-    auto padded_w = static_cast<uint32_t>(ifm_shape._dims[2] - (ker_shape._dims[2] - 1));
-
-    int32_t proposed_height = ceil_div(padded_h, node->stride()->vertical());
-    int32_t proposed_width = ceil_div(padded_w, node->stride()->horizontal());
-
-    shape._dims[1] = ifm_shape._dims[1] == -1 ? -1 : proposed_height;
-    shape._dims[2] = ifm_shape._dims[2] == -1 ? -1 : proposed_width;
-    break;
-  }
-  default:
-    assert(false && "unknown padding type");
-  }
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::ReLU *node, ShapeContext &gd)
-{
-  return gd._node_to_shape[node->input()];
-}
-
-ShapeDescription getOpResultShape(loco::FeatureEncode *node, ShapeContext &gd)
-{
-  const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
-  if (!pred_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-  ShapeDescription shape;
-  shape._rank_known = true;
-  loco::TensorShape tensor_shape;
-  uint32_t num_dims = pred_shape._dims.size();
-  tensor_shape.rank(num_dims);
-  for (uint32_t i = 0; i < num_dims; ++i)
-  {
-    tensor_shape.dim(i) = encodeShapeDimension(pred_shape._dims[i]);
-  }
-  loco::FeatureShape feature_shape = node->encoder()->shape(tensor_shape);
-  shape._dims.resize(4);
-  shape._dims[0] = decodeShapeDimension(feature_shape.count());
-  shape._dims[1] = decodeShapeDimension(feature_shape.height());
-  shape._dims[2] = decodeShapeDimension(feature_shape.width());
-  shape._dims[3] = decodeShapeDimension(feature_shape.depth());
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::FeatureDecode *node, ShapeContext &gd)
-{
-  const ShapeDescription &pred_shape = gd._node_to_shape[node->input()];
-  if (!pred_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-  ShapeDescription shape;
-  shape._rank_known = true;
-  loco::FeatureShape feature_shape;
-  feature_shape.count() = encodeShapeDimension(pred_shape._dims[0]);
-  feature_shape.height() = encodeShapeDimension(pred_shape._dims[1]);
-  feature_shape.width() = encodeShapeDimension(pred_shape._dims[2]);
-  feature_shape.depth() = encodeShapeDimension(pred_shape._dims[3]);
-  loco::TensorShape tensor_shape = node->decoder()->shape(feature_shape);
-  shape._dims.resize(4);
-  for (uint32_t i = 0; i < 4; ++i)
-  {
-    shape._dims[i] = decodeShapeDimension(tensor_shape.dim(i));
-  }
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::FilterEncode *node, ShapeContext &gd)
-{
-  const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
-  if (!input_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-  ShapeDescription shape;
-  shape._rank_known = true;
-  loco::TensorShape tensor_shape;
-  uint32_t num_dims = input_shape._dims.size();
-  tensor_shape.rank(num_dims);
-  for (uint32_t i = 0; i < num_dims; ++i)
-  {
-    tensor_shape.dim(i) = encodeShapeDimension(input_shape._dims[i]);
-  }
-  loco::FilterShape filter_shape = node->encoder()->shape(tensor_shape);
-  shape._dims.resize(4);
-  shape._dims[0] = decodeShapeDimension(filter_shape.count());
-  shape._dims[1] = decodeShapeDimension(filter_shape.height());
-  shape._dims[2] = decodeShapeDimension(filter_shape.width());
-  shape._dims[3] = decodeShapeDimension(filter_shape.depth());
-  return shape;
-}
-
-ShapeDescription getOpResultShape(loco::TensorConcat *node, ShapeContext &gd)
-{
-  const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
-  if (!lhs_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-
-  const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
-  if (!rhs_shape._rank_known)
-  {
-    // return unknown shape
-    return {};
-  }
-
-  ShapeDescription ret;
-
-  assert(lhs_shape._dims.size() == rhs_shape._dims.size());
-  ret._dims.resize(lhs_shape._dims.size());
-
-  uint32_t axis = node->axis();
-
-  for (uint32_t i = 0; i < lhs_shape._dims.size(); ++i)
-  {
-    if (i == axis)
-    {
-      ret._dims[i] = lhs_shape._dims[i] + rhs_shape._dims[i];
-    }
-    else
-    {
-      assert(lhs_shape._dims[i] == rhs_shape._dims[i]);
-      ret._dims[i] = lhs_shape._dims[i];
-    }
-  }
-  ret._rank_known = true;
-
-  return ret;
-}
-
-ShapeDescription getOpResultShape(loco::BiasEncode *node, ShapeContext &gd)
-{
-  const ShapeDescription &input_shape = gd._node_to_shape[node->input()];
-
-  // Bias should be rank 1
-  assert(input_shape._dims.size() == 1);
-
-  return input_shape;
-}
-
-ShapeDescription getOpResultShape(loco::BiasAdd<loco::Domain::Tensor> *node, ShapeContext &gd)
-{
-  const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
-  const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
-
-  // For TFlite, only supports last bias add axis. Unless, broadcasting is not performed as
-  // expected.
-  assert(node->axis() == value_shape._dims.size() - 1);
-
-  // Bias should be rank 1
-  assert(bias_shape._dims.size() == 1);
-
-  // Channel count coherency for proper broadcast
-  assert(bias_shape._dims[0] == value_shape._dims[node->axis()]);
-
-  return value_shape;
-}
-
-// TODO Reduce code duplication
-ShapeDescription getOpResultShape(loco::FeatureBiasAdd *node, ShapeContext &gd)
-{
-  const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
-  const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
-
-  // Bias should be rank 1
-  assert(bias_shape._dims.size() == 1);
-
-  // Channel count coherency for proper broadcast
-  // Feature in T/F Lite uses NHWC layout
-  assert(bias_shape._dims[0] == value_shape._dims[3]);
-
-  return value_shape;
-}
-
-namespace
-{
-
-class ShapeAnnotation : public loco::NodeAnnotation
-{
-public:
-  ShapeAnnotation(const ShapeDescription &shape) : _shape{shape}
-  {
-    // DO NOTHING
-  }
-
-public:
-  const ShapeDescription &shape(void) const { return _shape; }
-
-private:
-  ShapeDescription _shape;
-};
-
-class ShapeAnnotator final : public loco::CanonicalNodeMutableVisitor<void>
-{
-public:
-  ShapeAnnotator() = default;
-
-public:
-#define NODE(NAME)                                       \
-  void visit(loco::NAME *node) final                     \
-  {                                                      \
-    auto s = getOpResultShape(node, _ctx);               \
-    node->annot(stdex::make_unique<ShapeAnnotation>(s)); \
-    _ctx._node_to_shape[node] = s;                       \
-  }
-  NODE(ConstGen)
-  NODE(Pull)
-  NODE(Push)
-  NODE(FeatureEncode)
-  NODE(FeatureDecode)
-  NODE(FilterEncode)
-  NODE(MaxPool2D)
-  NODE(AvgPool2D)
-  NODE(Conv2D)
-  NODE(ReLU)
-  NODE(TensorConcat)
-  NODE(BiasEncode)
-  NODE(TensorBiasAdd)
-  NODE(FeatureBiasAdd)
-#undef NODE
-
-private:
-  // TODO Remove this variable
-  ShapeContext _ctx;
-};
-
-} // namespace
-
-void ShapeInference::run(loco::Graph *g)
-{
-  ShapeAnnotator shape_annotator;
-
-  for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
-  {
-    if (auto canonical_node = dynamic_cast<loco::CanonicalNode *>(node))
-    {
-      canonical_node->accept(&shape_annotator);
-    }
-  }
-}
-
-ShapeDescription ShapeInference::get(loco::Node *node)
-{
-  assert(node->annot<ShapeAnnotation>() != nullptr);
-  return node->annot<ShapeAnnotation>()->shape();
-}
index 43d75cb..848549e 100644 (file)
@@ -38,20 +38,4 @@ struct TypeInference
   static tflite::TensorType get(loco::Node *node);
 };
 
-/**
- * @brief Annotate the shape of each node as a node annotation
- *
- * HOW TO USE
- *
- *   ShapeInference::run(g);
- *
- *   ShapeInference::get(g->nodes()->at(..));
- */
-struct ShapeInference
-{
-  static void run(loco::Graph *g);
-
-  static ShapeDescription get(loco::Node *node);
-};
-
 #endif // __TYPE_INFERENCE_H__