From b2d91cdcd94998feea8d79ba528db5fe6300e2b6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 5 Nov 2019 09:14:19 +0900 Subject: [PATCH] [exo] introducing DomainConverter (#8745) This adds DomainConverter and InputHandler, which are helpers used while converting a canonical node in Feature domain into a TFL node (Tensor domain). Signed-off-by: Hyun Sik Yoon --- compiler/exo/src/GraphBlock.h | 127 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/compiler/exo/src/GraphBlock.h b/compiler/exo/src/GraphBlock.h index 158e6d2..2bce015 100644 --- a/compiler/exo/src/GraphBlock.h +++ b/compiler/exo/src/GraphBlock.h @@ -17,7 +17,12 @@ #ifndef __GRAPH_BLOCK_H__ #define __GRAPH_BLOCK_H__ +#include "Check.h" + #include +#include + +#include namespace exo { @@ -55,6 +60,128 @@ enum class DepthwiseFilterLayout template loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode); +} // exo + +// +// DomainConverter +// + +/** + * Some canonical nodes can have input of various loco::Domain, e.g., loco::Domain::Tensor, + * loco::Domain::Feature, etc. However, TFL node accepts only loco::Domain::Tensor. + * So, When converting such canonical node to TFL node and input(s) of a canonical node are not + * loco::Domain::Tensor, additional nodes need to be inserted. + * + * The following two classes helps this insertion. + * + * For example, in case of loco::Relu conversion, + * + * Before: + * + * A (output: feature) -- loco::ReLU --- B (input:feature) + * + * After: + * + * A -- loco::FeatureDecode -- locoex::TFLRelu -- loco::FeatureEncode --- B + * + * loco::ReLU (dead node) + */ + +namespace exo +{ + +/** + * @brief Handles input(s) while converting a canonical node to TFL node(s). + * This class informs DomainConverter how to handle inputs of a specific canonical node. + */ +template class InputHandler +{ +public: + /** + * @brief Assign origin's inputs to replacer's inputs. + * (This is called when origin belongs in Tensor domain.) + */ + virtual void handover(CanonicalT *origin, TFLT *replacer) = 0; + + /** + * @brief Returns the list of inputs that needs to have FeatureDecode as its input. + * (This is called when origin belongs in Feature domain.) + */ + virtual std::vector getInputsToConvert(CanonicalT *origin) = 0; + + /// @brief Set the inputs of replacer to new_inputs + virtual void set(TFLT *replacer, std::vector new_inputs) = 0; + + /// @brief Set the inputs to nullptr + virtual void nullify(CanonicalT *origin) = 0; +}; + +/** + * @brief Class to handle domain conversion while converting a canonical node to TFL node(s) + */ +template class DomainConverter +{ +public: + template + TFLT *convert(CanonicalT *origin, InputHandler &input_handler); +}; + +/** + * @brief Performs domain conversion + * + * 1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node. + * 2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and + * insert loco::FeatureEncode for output. Then replace origin to a TFL node. + * + * @return new TFL node; nullptr if shape of origin cannot be known + */ +template +template +TFLT *DomainConverter::convert(CanonicalT *origin, + InputHandler &input_handler) +{ + static_assert(FeatureLayoutT == FeatureLayout::NHWC); + + if (!loco::shape_known(origin)) + { + return nullptr; + } + + auto tfl_node = origin->graph()->nodes()->template create(); + + // when the input is Tensor, just replace canonical node to TFL node. + if (loco::shape_get(origin).domain() == loco::Domain::Tensor) + { + input_handler.handover(origin, tfl_node); + + loco::replace(origin).with(tfl_node); + input_handler.nullify(origin); + + return tfl_node; + } + else if (loco::shape_get(origin).domain() == loco::Domain::Feature) + { + std::vector feature_decodes; + + for (auto input : input_handler.getInputsToConvert(origin)) + { + auto dec = make_feature_decode(input); + feature_decodes.emplace_back(dec); + } + + input_handler.set(tfl_node, feature_decodes); + + auto enc = make_feature_encode(tfl_node); + + loco::replace(origin).with(enc); + input_handler.nullify(origin); + + return tfl_node; + } + else + EXO_THROW("Not yet supported loco::Domain"); +} + } // namespace exo #endif //__GRAPH_BLOCK_H__ -- 2.7.4