[exo-tflite] Use loco shape inference framework (#6144)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 2 Aug 2019 08:22:12 +0000 (17:22 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 2 Aug 2019 08:22:12 +0000 (17:22 +0900)
This commit allows exo-tflite to use the shape inference framework
implemented in loco instead of exo-tflite internal shape inference
framework.

Please note that this feature is turned off by default.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/exo-tflite/CMakeLists.txt
compiler/exo-tflite/src/ExporterUtils.cpp
compiler/exo-tflite/src/ExporterUtils.h
compiler/exo-tflite/src/ShapeInference.cpp

index d36febd..5e23ac7 100644 (file)
@@ -32,6 +32,7 @@ target_include_directories(exo_tflite PRIVATE src)
 target_link_libraries(exo_tflite PUBLIC exo_tflite_fbs)
 target_link_libraries(exo_tflite PUBLIC loco)
 target_link_libraries(exo_tflite PRIVATE stdex)
+target_link_libraries(exo_tflite PRIVATE pepper_strcast)
 # Let's apply nncc common compile options
 #
 # NOTE This will enable strict compilation (warnings as error).
index fc3234e..daa09c9 100644 (file)
 
 #include "ExporterUtils.h"
 
+ShapeDescription to_shape_description(const loco::TensorShape &shape)
+{
+  ShapeDescription res;
+
+  res._rank_known = true;
+
+  res._dims.resize(shape.rank());
+  for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+  {
+    // All the dimensions SHOULD be known
+    assert(shape.dim(axis).known());
+    res._dims.at(axis) = shape.dim(axis).value();
+  }
+
+  return res;
+}
+
+ShapeDescription to_shape_description(const loco::NodeShape &shape)
+{
+  switch (shape.domain())
+  {
+    case loco::Domain::Tensor:
+      return to_shape_description(shape.as<loco::TensorShape>());
+    default:
+      break;
+  }
+
+  throw std::runtime_error{"Not implemented yet"};
+}
+
 uint32_t SerializedModelData::registerBuiltinOpcode(tflite::BuiltinOperator builtin_code)
 {
   auto it = _operator_codes.find(OpCode{builtin_code});
index b69c4e9..e37eb33 100644 (file)
@@ -21,6 +21,7 @@
 #include "loco.h"
 
 #include "loco/IR/PermutingCodec.h"
+#include "loco/IR/NodeShape.h"
 
 #include <unordered_map>
 
@@ -47,6 +48,9 @@ struct ShapeDescription
   bool _rank_known;
 };
 
+ShapeDescription to_shape_description(const loco::TensorShape &shape);
+ShapeDescription to_shape_description(const loco::NodeShape &shape);
+
 /**
  * @breif Record the information of T/F Lite SubGraph and its mapping to loco
  */
index 2891b55..4d37440 100644 (file)
 
 #include <loco/IR/CanonicalNode.h>
 #include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/CanonicalShapeInferenceRule.h>
 
+#include <pepper/strcast.h>
 #include <stdex/Memory.h>
 
 #include <type_traits>
 
+#include <cstdlib>
+
 namespace
 {
 
+// This Knob is a temporary workaround for incermental migration
+//
+// TODO Remove this workaround!
+struct Knob
+{
+  Knob()
+  {
+    // Off by default
+    auto s = std::getenv("EXOTFLITE_USE_LOCO_SHAPE_INFERENCE");
+    enable_loco_shape_inferene_framework = pepper::safe_strcast<int>(s, 0 /* DEFAULT */) != 0;
+  }
+
+  bool enable_loco_shape_inferene_framework = false;
+};
+
+Knob knob;
+
 template <typename T, typename If = typename std::enable_if<std::is_integral<T>::value, int>::type>
 T ceil_div(T dividend, T divisor)
 {
@@ -471,6 +493,13 @@ private:
 
 void ShapeInference::run(loco::Graph *g)
 {
+  if (knob.enable_loco_shape_inferene_framework)
+  {
+    loco::CanonicalShapeInferenceRule rule;
+    loco::apply(&rule).to(g);
+    return;
+  }
+
   ShapeContext shape_ctx;
   ShapeGetter shape_getter{shape_ctx};
 
@@ -487,6 +516,12 @@ void ShapeInference::run(loco::Graph *g)
 
 ShapeDescription ShapeInference::get(loco::Node *node)
 {
+  if (knob.enable_loco_shape_inferene_framework)
+  {
+    assert(loco::shape_known(node));
+    return to_shape_description(loco::shape_get(node));
+  }
+
   assert(node->annot<ShapeAnnotation>() != nullptr);
   return node->annot<ShapeAnnotation>()->shape();
 }