[moco/tf] Check canonicalization (#4257)
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 15 Jul 2019 07:05:44 +0000 (16:05 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 15 Jul 2019 07:05:44 +0000 (16:05 +0900)
* [moco/tf] Check canonicalization

This will add a check routine if canonicalization has really done its task

Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* extract helper function

contrib/moco-tf/src/Canonicalizer.cpp

index 5938f3d..bf51ce2 100644 (file)
 #include "Canonicalization/BiasAddCanonicalizer.h"
 #include "Canonicalization/Conv2DCanonicalizer.h"
 
+#include "Dialect/TFDialect.h"
+#include "Dialect/TFNodes.h"
+
 #include <stdex/Memory.h>
 
+namespace
+{
+
+/**
+ * @brief Return true if graph has TFDialect nodes
+ */
+bool has_tf_nodes(loco::Graph *g)
+{
+  auto active_nodes = loco::active_nodes(loco::output_nodes(g));
+  for (auto node : active_nodes)
+  {
+    if (node->dialect() == moco::tf::TFDialect::get())
+    {
+      if (moco::tf::get<moco::tf::Knob::CanonicalizeBiasAdd>())
+      {
+        auto tfnode = dynamic_cast<moco::tf::TFBiasAdd *>(node);
+        if (tfnode != nullptr)
+          return true;
+      }
+      if (moco::tf::get<moco::tf::Knob::CanonicalizeConv2D>())
+      {
+        auto tfnode = dynamic_cast<moco::tf::TFConv2D *>(node);
+        if (tfnode != nullptr)
+          return true;
+      }
+    }
+  }
+  return false;
+}
+
+} // namespace
+
 namespace moco
 {
 namespace tf
@@ -41,6 +76,9 @@ void Canonicalizer::canonicalize(loco::Graph *g) const
 
   moco::tf::PhaseRunner<moco::tf::PhaseStrategy::Saturate> phase_runner{g};
   phase_runner.run(phase);
+
+  // Assert if graph has TF dialect nodes
+  assert(!has_tf_nodes(g));
 }
 
 } // namespace tf