[exo-tflite] Broadcast conversion for Div, Sub (#7591)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 19 Sep 2019 01:31:42 +0000 (10:31 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 19 Sep 2019 01:31:42 +0000 (10:31 +0900)
This will enable TensorBroadcastConverter to handle Div and Sub node

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.cpp

index 84f74a2..b852131 100644 (file)
@@ -50,7 +50,14 @@ struct Collector final : public locoex::TFLNodeMutableVisitor<void>
     }
   }
 
-  // TODO ADD TFLDiv
+  void visit(locoex::TFLDiv *node) final
+  {
+    if (auto tbc = input_as_tbc<locoex::TFLDiv>(node))
+    {
+      NodePair pair(tbc, node);
+      candidates.insert(pair);
+    }
+  }
 
   void visit(locoex::TFLMul *node) final
   {
@@ -61,7 +68,14 @@ struct Collector final : public locoex::TFLNodeMutableVisitor<void>
     }
   }
 
-  // TODO ADD TFLSub
+  void visit(locoex::TFLSub *node) final
+  {
+    if (auto tbc = input_as_tbc<locoex::TFLSub>(node))
+    {
+      NodePair pair(tbc, node);
+      candidates.insert(pair);
+    }
+  }
 
   void visit(locoex::TFLNode *) final { return; }
 
@@ -133,13 +147,21 @@ bool TensorBroadcastConverter::run(loco::Graph *graph)
         jump_connection<locoex::TFLAdd>(tensorbroadcast, tfladd);
         changed = true;
       }
-      // TODO ADD TFLDiv
+      else if (auto tfldiv = dynamic_cast<locoex::TFLDiv *>(pair.second))
+      {
+        jump_connection<locoex::TFLDiv>(tensorbroadcast, tfldiv);
+        changed = true;
+      }
       else if (auto tflmul = dynamic_cast<locoex::TFLMul *>(pair.second))
       {
         jump_connection<locoex::TFLMul>(tensorbroadcast, tflmul);
         changed = true;
       }
-      // TODO ADD TFLSub
+      else if (auto tflsub = dynamic_cast<locoex::TFLSub *>(pair.second))
+      {
+        jump_connection<locoex::TFLSub>(tensorbroadcast, tflsub);
+        changed = true;
+      }
       else
       {
         assert(false);