[exo] introducing DomainConverter (#8745)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 5 Nov 2019 00:14:19 +0000 (09:14 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Tue, 5 Nov 2019 00:14:19 +0000 (09:14 +0900)
This adds DomainConverter and InputHandler, which are helpers used while converting a canonical node in Feature domain into a TFL node (Tensor domain).

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo/src/GraphBlock.h

index 158e6d2..2bce015 100644 (file)
 #ifndef __GRAPH_BLOCK_H__
 #define __GRAPH_BLOCK_H__
 
+#include "Check.h"
+
 #include <loco.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <functional>
 
 namespace exo
 {
@@ -55,6 +60,128 @@ enum class DepthwiseFilterLayout
 template <DepthwiseFilterLayout T>
 loco::DepthwiseFilterDecode *make_dw_filter_decode(loco::Node *input_for_decode);
 
+} // exo
+
+//
+// DomainConverter
+//
+
+/**
+ * Some canonical nodes can have input of various loco::Domain, e.g., loco::Domain::Tensor,
+ * loco::Domain::Feature, etc. However, TFL node accepts only loco::Domain::Tensor.
+ * So, When converting such canonical node to TFL node and input(s) of a canonical node are not
+ * loco::Domain::Tensor, additional nodes need to be inserted.
+ *
+ * The following two classes helps this insertion.
+ *
+ * For example, in case of loco::Relu conversion,
+ *
+ * Before:
+ *
+ *    A (output: feature) -- loco::ReLU --- B (input:feature)
+ *
+ * After:
+ *
+ *    A -- loco::FeatureDecode -- locoex::TFLRelu -- loco::FeatureEncode --- B
+ *
+ *                  loco::ReLU (dead node)
+ */
+
+namespace exo
+{
+
+/**
+ * @brief Handles input(s) while converting a canonical node to TFL node(s).
+ *        This class informs DomainConverter how to handle inputs of a specific canonical node.
+ */
+template <class CanonicalT, class TFLT> class InputHandler
+{
+public:
+  /**
+   * @brief Assign origin's inputs to replacer's inputs.
+   *        (This is called when origin belongs in Tensor domain.)
+   */
+  virtual void handover(CanonicalT *origin, TFLT *replacer) = 0;
+
+  /**
+   * @brief Returns the list of inputs that needs to have FeatureDecode as its input.
+   *        (This is called when origin belongs in Feature domain.)
+   */
+  virtual std::vector<loco::Node *> getInputsToConvert(CanonicalT *origin) = 0;
+
+  /// @brief Set the inputs of replacer to new_inputs
+  virtual void set(TFLT *replacer, std::vector<loco::Node *> new_inputs) = 0;
+
+  /// @brief Set the inputs to nullptr
+  virtual void nullify(CanonicalT *origin) = 0;
+};
+
+/**
+ * @brief Class to handle domain conversion while converting a canonical node to TFL node(s)
+ */
+template <class CanonicalT, class TFLT> class DomainConverter
+{
+public:
+  template <FeatureLayout FeatureLayoutT>
+  TFLT *convert(CanonicalT *origin, InputHandler<CanonicalT, TFLT> &input_handler);
+};
+
+/**
+ * @brief Performs domain conversion
+ *
+ * 1. if origin belong to loco::Domain::Tensor, and replace origin to a TFL node.
+ * 2. if origin belong to loco::Domain::Feature, insert loco::FeatureDecode for input(s) and
+ *    insert loco::FeatureEncode for output. Then replace origin to a TFL node.
+ *
+ * @return new TFL node; nullptr if shape of origin cannot be known
+ */
+template <class CanonicalT, class TFLT>
+template <FeatureLayout FeatureLayoutT>
+TFLT *DomainConverter<CanonicalT, TFLT>::convert(CanonicalT *origin,
+                                                 InputHandler<CanonicalT, TFLT> &input_handler)
+{
+  static_assert(FeatureLayoutT == FeatureLayout::NHWC);
+
+  if (!loco::shape_known(origin))
+  {
+    return nullptr;
+  }
+
+  auto tfl_node = origin->graph()->nodes()->template create<TFLT>();
+
+  // when the input is Tensor, just replace canonical node to TFL node.
+  if (loco::shape_get(origin).domain() == loco::Domain::Tensor)
+  {
+    input_handler.handover(origin, tfl_node);
+
+    loco::replace(origin).with(tfl_node);
+    input_handler.nullify(origin);
+
+    return tfl_node;
+  }
+  else if (loco::shape_get(origin).domain() == loco::Domain::Feature)
+  {
+    std::vector<loco::Node *> feature_decodes;
+
+    for (auto input : input_handler.getInputsToConvert(origin))
+    {
+      auto dec = make_feature_decode<FeatureLayoutT>(input);
+      feature_decodes.emplace_back(dec);
+    }
+
+    input_handler.set(tfl_node, feature_decodes);
+
+    auto enc = make_feature_encode<FeatureLayoutT>(tfl_node);
+
+    loco::replace(origin).with(enc);
+    input_handler.nullify(origin);
+
+    return tfl_node;
+  }
+  else
+    EXO_THROW("Not yet supported loco::Domain");
+}
+
 } // namespace exo
 
 #endif //__GRAPH_BLOCK_H__