[locomotiv] Introduce Eltwise Binary Node Helper (#7461)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 16 Sep 2019 07:49:37 +0000 (16:49 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 16 Sep 2019 07:49:37 +0000 (16:49 +0900)
* [locomotiv] Introduce Eltwise Binary Node Helper

This commit introduces a shared helper method for for element-wise
binary nodes, and rewrites the "execute" method for these nodes using
this helper.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Add a blank line

compiler/locomotiv/src/Node/EltwiseAdd.cpp
compiler/locomotiv/src/Node/EltwiseDiv.cpp
compiler/locomotiv/src/Node/EltwiseMul.cpp
compiler/locomotiv/src/Node/EltwiseSub.cpp
compiler/locomotiv/src/NodeExecution.cpp
compiler/locomotiv/src/NodeExecution.h

index 6af1982..e9fb8e1 100644 (file)
@@ -37,6 +37,7 @@ namespace locomotiv
 
 void NodeExecution::execute(loco::EltwiseAdd *eltwise_add)
 {
+#if 0
   auto lhs_data = annot_data(eltwise_add->lhs());
   auto rhs_data = annot_data(eltwise_add->rhs());
 
@@ -75,6 +76,16 @@ void NodeExecution::execute(loco::EltwiseAdd *eltwise_add)
   erase_annot_data(eltwise_add);
   annot_data(eltwise_add, std::move(eltwise_add_data));
   annot_domain(eltwise_add, annot_domain(eltwise_add->lhs()));
+#endif
+
+  struct Func final : public BinaryFunc
+  {
+    float apply(float lhs, float rhs) const { return lhs + rhs; }
+  };
+
+  Func f;
+
+  eltwise_binary(eltwise_add, f);
 }
 
 } // namespace locomotiv
index 599e015..d081955 100644 (file)
@@ -37,6 +37,7 @@ namespace locomotiv
 
 void NodeExecution::execute(loco::EltwiseDiv *eltwise_div)
 {
+#if 0
   auto lhs_data = annot_data(eltwise_div->lhs());
   auto rhs_data = annot_data(eltwise_div->rhs());
 
@@ -76,6 +77,16 @@ void NodeExecution::execute(loco::EltwiseDiv *eltwise_div)
   erase_annot_data(eltwise_div);
   annot_data(eltwise_div, std::move(eltwise_div_data));
   annot_domain(eltwise_div, annot_domain(eltwise_div->lhs()));
+#endif
+
+  struct Func final : public BinaryFunc
+  {
+    float apply(float lhs, float rhs) const { return lhs / rhs; }
+  };
+
+  Func f;
+
+  eltwise_binary(eltwise_div, f);
 }
 
 } // namespace locomotiv
index 516225b..98be1cd 100644 (file)
@@ -37,6 +37,7 @@ namespace locomotiv
 
 void NodeExecution::execute(loco::EltwiseMul *eltwise_mul)
 {
+#if 0
   auto lhs_data = annot_data(eltwise_mul->lhs());
   auto rhs_data = annot_data(eltwise_mul->rhs());
 
@@ -75,6 +76,16 @@ void NodeExecution::execute(loco::EltwiseMul *eltwise_mul)
   erase_annot_data(eltwise_mul);
   annot_data(eltwise_mul, std::move(eltwise_mul_data));
   annot_domain(eltwise_mul, annot_domain(eltwise_mul->lhs()));
+#endif
+
+  struct Func final : public BinaryFunc
+  {
+    float apply(float lhs, float rhs) const { return lhs * rhs; }
+  };
+
+  Func f;
+
+  eltwise_binary(eltwise_mul, f);
 }
 
 } // namespace locomotiv
index d771774..826471a 100644 (file)
@@ -37,6 +37,7 @@ namespace locomotiv
 
 void NodeExecution::execute(loco::EltwiseSub *eltwise_sub)
 {
+#if 0
   auto lhs_data = annot_data(eltwise_sub->lhs());
   auto rhs_data = annot_data(eltwise_sub->rhs());
 
@@ -75,6 +76,16 @@ void NodeExecution::execute(loco::EltwiseSub *eltwise_sub)
   erase_annot_data(eltwise_sub);
   annot_data(eltwise_sub, std::move(eltwise_sub_data));
   annot_domain(eltwise_sub, annot_domain(eltwise_sub->lhs()));
+#endif
+
+  struct Func final : public BinaryFunc
+  {
+    float apply(float lhs, float rhs) const { return lhs - rhs; }
+  };
+
+  Func f;
+
+  eltwise_binary(eltwise_sub, f);
 }
 
 } // namespace locomotiv
index fe2a0d3..d88c03f 100644 (file)
@@ -40,6 +40,16 @@ namespace locomotiv
 float UnaryFunc::apply(float) const { throw std::runtime_error{"F32 is not supported yet"}; }
 int32_t UnaryFunc::apply(int32_t) const { throw std::runtime_error{"S32 is not supported yet"}; }
 
+float BinaryFunc::apply(float, float) const
+{
+  throw std::runtime_error{"F32 is not supported yet"};
+}
+
+int32_t BinaryFunc::apply(int32_t, int32_t) const
+{
+  throw std::runtime_error{"S32 is not supported yet"};
+}
+
 // TODO Use visitor pattern of loco when available
 void NodeExecution::run(loco::Node *node)
 {
@@ -101,4 +111,48 @@ void NodeExecution::eltwise_unary(loco::Node *node, const UnaryFunc &f)
   annot_domain(output_node, output_domain);
 }
 
+void NodeExecution::eltwise_binary(loco::Node *node, const BinaryFunc &f)
+{
+  auto lhs_node = node->arg(0);
+  auto rhs_node = node->arg(1);
+  auto lhs_data = annot_data(lhs_node);
+  auto rhs_data = annot_data(rhs_node);
+
+  validate(lhs_data && rhs_data, "Input not ready");
+  validate(annot_domain(lhs_node) == annot_domain(rhs_node), "Wrong input domain");
+  validate(lhs_data->dtype() == rhs_data->dtype(), "Wrong input type");
+  validate(*lhs_data->shape() == *rhs_data->shape(), "Wrong input shape");
+
+  auto out_node = node;
+  std::unique_ptr<NodeData> out_data = nullptr;
+
+  switch (lhs_data->dtype())
+  {
+    case loco::DataType::FLOAT32:
+    {
+      auto lhs_bufptr = lhs_data->as_f32_bufptr();
+      auto rhs_bufptr = rhs_data->as_f32_bufptr();
+      auto out_bufptr = make_buffer<float, LexicalLayout>(*lhs_data->shape());
+
+      auto *shape = lhs_data->shape();
+
+      for (IndexEnumerator e{*shape}; e.valid(); e.advance())
+      {
+        const auto &index = e.current();
+        out_bufptr.at(index) = f.apply(lhs_bufptr->at(index), rhs_bufptr->at(index));
+      }
+
+      out_data = make_data(out_bufptr);
+      break;
+    }
+    default:
+      throw std::runtime_error("NYI for this DataType");
+  }
+
+  assert(out_data != nullptr);
+  erase_annot_data(out_node);
+  annot_data(out_node, std::move(out_data));
+  annot_domain(out_node, annot_domain(lhs_node));
+}
+
 } // namespace locomotiv
index d72c2b8..363188d 100644 (file)
@@ -30,6 +30,15 @@ struct UnaryFunc
   virtual int32_t apply(int32_t) const;
 };
 
+// Q. How to support mixed precision binary operators?
+struct BinaryFunc
+{
+  virtual ~BinaryFunc() = default;
+
+  virtual float apply(float, float) const;
+  virtual int32_t apply(int32_t, int32_t) const;
+};
+
 /**
  * @brief Helper class for Session, responsible to process one node calculation.
  */
@@ -66,6 +75,7 @@ private:
 #undef NODE
 
   void eltwise_unary(loco::Node *node, const UnaryFunc &f);
+  void eltwise_binary(loco::Node *node, const BinaryFunc &f);
 };
 
 } // namespace locomotiv