From 7248dd22176527fd02e80585a79843a32805a36f Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 2 Aug 2019 17:22:12 +0900 Subject: [PATCH] [exo-tflite] Use loco shape inference framework (#6144) This commit allows exo-tflite to use the shape inference framework implemented in loco instead of exo-tflite internal shape inference framework. Please note that this feature is turned off by default. Signed-off-by: Jonghyun Park --- compiler/exo-tflite/CMakeLists.txt | 1 + compiler/exo-tflite/src/ExporterUtils.cpp | 30 +++++++++++++++++++++++++ compiler/exo-tflite/src/ExporterUtils.h | 4 ++++ compiler/exo-tflite/src/ShapeInference.cpp | 35 ++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+) diff --git a/compiler/exo-tflite/CMakeLists.txt b/compiler/exo-tflite/CMakeLists.txt index d36febd..5e23ac7 100644 --- a/compiler/exo-tflite/CMakeLists.txt +++ b/compiler/exo-tflite/CMakeLists.txt @@ -32,6 +32,7 @@ target_include_directories(exo_tflite PRIVATE src) target_link_libraries(exo_tflite PUBLIC exo_tflite_fbs) target_link_libraries(exo_tflite PUBLIC loco) target_link_libraries(exo_tflite PRIVATE stdex) +target_link_libraries(exo_tflite PRIVATE pepper_strcast) # Let's apply nncc common compile options # # NOTE This will enable strict compilation (warnings as error). diff --git a/compiler/exo-tflite/src/ExporterUtils.cpp b/compiler/exo-tflite/src/ExporterUtils.cpp index fc3234e..daa09c9 100644 --- a/compiler/exo-tflite/src/ExporterUtils.cpp +++ b/compiler/exo-tflite/src/ExporterUtils.cpp @@ -16,6 +16,36 @@ #include "ExporterUtils.h" +ShapeDescription to_shape_description(const loco::TensorShape &shape) +{ + ShapeDescription res; + + res._rank_known = true; + + res._dims.resize(shape.rank()); + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + // All the dimensions SHOULD be known + assert(shape.dim(axis).known()); + res._dims.at(axis) = shape.dim(axis).value(); + } + + return res; +} + +ShapeDescription to_shape_description(const loco::NodeShape &shape) +{ + switch (shape.domain()) + { + case loco::Domain::Tensor: + return to_shape_description(shape.as()); + default: + break; + } + + throw std::runtime_error{"Not implemented yet"}; +} + uint32_t SerializedModelData::registerBuiltinOpcode(tflite::BuiltinOperator builtin_code) { auto it = _operator_codes.find(OpCode{builtin_code}); diff --git a/compiler/exo-tflite/src/ExporterUtils.h b/compiler/exo-tflite/src/ExporterUtils.h index b69c4e9..e37eb33 100644 --- a/compiler/exo-tflite/src/ExporterUtils.h +++ b/compiler/exo-tflite/src/ExporterUtils.h @@ -21,6 +21,7 @@ #include "loco.h" #include "loco/IR/PermutingCodec.h" +#include "loco/IR/NodeShape.h" #include @@ -47,6 +48,9 @@ struct ShapeDescription bool _rank_known; }; +ShapeDescription to_shape_description(const loco::TensorShape &shape); +ShapeDescription to_shape_description(const loco::NodeShape &shape); + /** * @breif Record the information of T/F Lite SubGraph and its mapping to loco */ diff --git a/compiler/exo-tflite/src/ShapeInference.cpp b/compiler/exo-tflite/src/ShapeInference.cpp index 2891b55..4d37440 100644 --- a/compiler/exo-tflite/src/ShapeInference.cpp +++ b/compiler/exo-tflite/src/ShapeInference.cpp @@ -18,14 +18,36 @@ #include #include +#include +#include +#include #include #include +#include + namespace { +// This Knob is a temporary workaround for incermental migration +// +// TODO Remove this workaround! +struct Knob +{ + Knob() + { + // Off by default + auto s = std::getenv("EXOTFLITE_USE_LOCO_SHAPE_INFERENCE"); + enable_loco_shape_inferene_framework = pepper::safe_strcast(s, 0 /* DEFAULT */) != 0; + } + + bool enable_loco_shape_inferene_framework = false; +}; + +Knob knob; + template ::value, int>::type> T ceil_div(T dividend, T divisor) { @@ -471,6 +493,13 @@ private: void ShapeInference::run(loco::Graph *g) { + if (knob.enable_loco_shape_inferene_framework) + { + loco::CanonicalShapeInferenceRule rule; + loco::apply(&rule).to(g); + return; + } + ShapeContext shape_ctx; ShapeGetter shape_getter{shape_ctx}; @@ -487,6 +516,12 @@ void ShapeInference::run(loco::Graph *g) ShapeDescription ShapeInference::get(loco::Node *node) { + if (knob.enable_loco_shape_inferene_framework) + { + assert(loco::shape_known(node)); + return to_shape_description(loco::shape_get(node)); + } + assert(node->annot() != nullptr); return node->annot()->shape(); } -- 2.7.4