From 1bc506e2a912887be0a34d9936e8978fad3c6b7e Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 17 Jul 2019 12:48:46 +0900 Subject: [PATCH] [exo/tflite] Visitor-based Type Inference (#4296) * [exo/tflite] Visitor-based Type Inference This commit revises the implemenation of TypeInference module using loco visitor interface. Signed-off-by: Jonghyun Park * Fix subobject-linkage warning * Explain why there is type context update --- contrib/exo-tflite/src/TypeInference.cpp | 120 +++++++++++++++---------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/contrib/exo-tflite/src/TypeInference.cpp b/contrib/exo-tflite/src/TypeInference.cpp index 48ac2eb..4e1318e 100644 --- a/contrib/exo-tflite/src/TypeInference.cpp +++ b/contrib/exo-tflite/src/TypeInference.cpp @@ -63,59 +63,84 @@ struct TypeContext std::unordered_map _node_to_type; }; -} // namespace +class TypeGetter final : public loco::CanonicalNodeMutableVisitor +{ +public: + TypeGetter(TypeContext &ctx) : gd{ctx} + { + // DO NOTHING + } + +public: +#define NODE(NAME) tflite::TensorType visit(loco::NAME *node) final; + NODE(ConstGen) + NODE(Pull) + NODE(Push) + NODE(FeatureEncode) + NODE(FeatureDecode) + NODE(FilterEncode) + NODE(MaxPool2D) + NODE(AvgPool2D) + NODE(Conv2D) + NODE(ReLU) + NODE(TensorConcat) + NODE(BiasEncode) + NODE(TensorBiasAdd) + NODE(FeatureBiasAdd) +#undef NODE + // TODO Add Here from the next Op + +private: + // TODO Find a better name + template tflite::TensorType get(loco::BiasAdd *); + +private: + // TODO Rename field variable + TypeContext &gd; +}; -tflite::TensorType getOpResultType(loco::ConstGen *node, TypeContext &) +tflite::TensorType TypeGetter::visit(loco::ConstGen *node) { return translateLocoTypeToTFLite(node->dtype()); } -tflite::TensorType getOpResultType(loco::Pull *node, TypeContext &) +tflite::TensorType TypeGetter::visit(loco::Pull *node) { return translateLocoTypeToTFLite(node->dtype()); } -tflite::TensorType getOpResultType(loco::Push *node, TypeContext &) -{ - return TypeInference::get(node->from()); -} +tflite::TensorType TypeGetter::visit(loco::Push *node) { return TypeInference::get(node->from()); } -tflite::TensorType getOpResultType(loco::ReLU *node, TypeContext &gd) -{ - return gd._node_to_type[node->input()]; -} +tflite::TensorType TypeGetter::visit(loco::ReLU *node) { return gd._node_to_type[node->input()]; } -tflite::TensorType getOpResultType(loco::MaxPool2D *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::MaxPool2D *node) { return gd._node_to_type[node->ifm()]; } -tflite::TensorType getOpResultType(loco::AvgPool2D *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::AvgPool2D *node) { return gd._node_to_type[node->ifm()]; } -tflite::TensorType getOpResultType(loco::Conv2D *node, TypeContext &gd) -{ - return gd._node_to_type[node->ifm()]; -} +tflite::TensorType TypeGetter::visit(loco::Conv2D *node) { return gd._node_to_type[node->ifm()]; } -tflite::TensorType getOpResultType(loco::FeatureEncode *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::FeatureEncode *node) { return gd._node_to_type[node->input()]; } -tflite::TensorType getOpResultType(loco::FeatureDecode *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::FeatureDecode *node) { return gd._node_to_type[node->input()]; } -tflite::TensorType getOpResultType(loco::FilterEncode *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::FilterEncode *node) { return gd._node_to_type[node->input()]; } -tflite::TensorType getOpResultType(loco::TensorConcat *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::TensorConcat *node) { tflite::TensorType lhs_type = gd._node_to_type[node->lhs()]; tflite::TensorType rhs_type = gd._node_to_type[node->rhs()]; @@ -126,13 +151,15 @@ tflite::TensorType getOpResultType(loco::TensorConcat *node, TypeContext &gd) return lhs_type; } -tflite::TensorType getOpResultType(loco::BiasEncode *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::BiasEncode *node) { return gd._node_to_type[node->input()]; } -template -tflite::TensorType getOpResultType(loco::BiasAdd *node, TypeContext &gd) +tflite::TensorType TypeGetter::visit(loco::TensorBiasAdd *node) { return get(node); } +tflite::TensorType TypeGetter::visit(loco::FeatureBiasAdd *node) { return get(node); } + +template tflite::TensorType TypeGetter::get(loco::BiasAdd *node) { tflite::TensorType value_type = gd._node_to_type[node->value()]; tflite::TensorType bias_type = gd._node_to_type[node->bias()]; @@ -143,6 +170,8 @@ tflite::TensorType getOpResultType(loco::BiasAdd *node, TypeContext &gd) return value_type; } +} // namespace + namespace { @@ -161,51 +190,22 @@ private: tflite::TensorType _type; }; -class TypeAnnotator final : public loco::CanonicalNodeMutableVisitor -{ -public: - TypeAnnotator() = default; - -public: -#define NODE(NAME) \ - void visit(loco::NAME *node) final \ - { \ - auto t = getOpResultType(node, _ctx); \ - node->annot(stdex::make_unique(t)); \ - _ctx._node_to_type[node] = t; \ - } - NODE(ConstGen) - NODE(Pull) - NODE(Push) - NODE(FeatureEncode) - NODE(FeatureDecode) - NODE(FilterEncode) - NODE(MaxPool2D) - NODE(AvgPool2D) - NODE(Conv2D) - NODE(ReLU) - NODE(TensorConcat) - NODE(BiasEncode) - NODE(TensorBiasAdd) - NODE(FeatureBiasAdd) -#undef NODE - -private: - // TODO Remove this variable - TypeContext _ctx; -}; - } // namespace void TypeInference::run(loco::Graph *g) { - TypeAnnotator type_annotator; + TypeContext ctx; + TypeGetter type_getter{ctx}; for (auto node : loco::postorder_traversal(loco::output_nodes(g))) { if (auto canonical_node = dynamic_cast(node)) { - canonical_node->accept(&type_annotator); + auto tflite_type = canonical_node->accept(&type_getter); + node->annot(stdex::make_unique(tflite_type)); + + // Update type context to allow type getter to retrieve the type of node's arguments + ctx._node_to_type[node] = tflite_type; } } } -- 2.7.4