nnc: add DOT format support to model dumpers (#1359)
authorVitaliy Cherepanov/AI Tools Lab /SRR/Engineer/삼성전자 <v.cherepanov@samsung.com>
Mon, 10 Sep 2018 15:00:02 +0000 (18:00 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Mon, 10 Sep 2018 15:00:02 +0000 (18:00 +0300)
usage for dumpers:
    --dot - for dot format
    --dump - for simple format

Signed-off-by: Vitaliy Cherepanov <v.cherepanov@samsung.com>
contrib/nnc/examples/caffe_frontend/model_dump.cpp
contrib/nnc/examples/tflite_frontend/sanity_check.cpp

index a2a23d6..a201fa5 100644 (file)
@@ -1,34 +1,63 @@
 #include <iostream>
 
 #include "support/CommandLine.h"
+#include "support/PluginException.h"
 #include "option/Options.h"
 #include "caffe_importer.h"
+#include "core/modelIR/graph.h"
+#include "core/modelIR/ir_dot_dumper.h"
+#include "core/modelIR/ShapeInference.h"
 
 using namespace nncc::contrib;
+using namespace nncc::contrib::clopt;
+using namespace nncc::contrib::core::dumper;
 
-int main(int argc, const char** argv)
-{
-    clopt::CommandLine::getParser()->parseCommandLine(argc, argv);
-    std::string modelName = clopt::inputFile;
-
-    if ( modelName.empty() )
-    {
-        modelName = "mobilenet.caffemodel";
-    }
+enum Format {FormatDot, FormatDump};
 
-    nncc::contrib::frontend::caffe::CaffeImporter importer{modelName};
+static Option<bool> isDumpFormat(optname("--dump"),
+                                 overview("if not setted DOT format will be used"),
+                                 false,
+                                 optional(true));
 
-    bool success = importer.import();
-
-    if (success)
-    {
-        importer.dump();
-        importer.createIR();
-    }
-    else
+int main(int argc, const char **argv)
+{
+  clopt::CommandLine::getParser()->parseCommandLine(argc, argv, false);
+  std::string model = clopt::inputFile;
+
+  nncc::contrib::frontend::caffe::CaffeImporter importer{model};
+
+  if (!importer.import())
+  {
+    std::cout << "Could not load model \"" << model << "\"" << std::endl;
+    return -1;
+  }
+
+  Format format = isDumpFormat ? FormatDump : FormatDot;
+
+  switch (format)
+  {
+  case FormatDump:
+    importer.dump();
+    break;
+  case FormatDot:
+    try
     {
-        std::cout << "Could not load model \"" << modelName << "\"" << std::endl;
+      IrDotDumper dotDumper;
+      ShapeInference inf;
+      auto g = static_cast<Graph *>(importer.createIR());
+      g->accept(&inf);
+      g->accept(&dotDumper);
+
+      dotDumper.writeDot(std::cout);
+    } catch (PluginException &e) {
+      std::cout << "Error: " << e.what() << std::endl;
+      return -1;
     }
+    break;
+  default:
+    std::cout << "Error: Unsuported format" << std::endl;
+    return -1;
+  }
 
-    return 0;
+  return 0;
 }
index 581fe9d..905d20e 100644 (file)
@@ -1,32 +1,62 @@
 #include <iostream>
 
 #include "support/CommandLine.h"
+#include "support/PluginException.h"
 #include "option/Options.h"
 #include "tflite_v3_importer.h"
+#include "core/modelIR/graph.h"
+#include "core/modelIR/ir_dot_dumper.h"
+#include "core/modelIR/ShapeInference.h"
 
 using namespace nncc::contrib;
+using namespace nncc::contrib::clopt;
+using namespace nncc::contrib::core::dumper;
+
+enum Format {FormatDot, FormatDump};
+
+static Option<bool> isDumpFormat(optname("--dump"),
+                                 overview("if not setted DOT format will be used"),
+                                 false,
+                                 optional(true));
 
 int main(int argc, const char **argv)
 {
-  clopt::CommandLine::getParser()->parseCommandLine(argc, argv);
-  std::string modelName = clopt::inputFile;
+  clopt::CommandLine::getParser()->parseCommandLine(argc, argv, false);
+  std::string model = clopt::inputFile;
+
+  nncc::contrib::frontend::tflite::v3::TfliteImporter importer{model};
 
-  if (modelName.empty())
+  if (!importer.import())
   {
-    modelName = "mobilenet_v1.0.tflite";
+    std::cout << "Could not load model \"" << model << "\"" << std::endl;
+    return -1;
   }
 
-  nncc::contrib::frontend::tflite::v3::TfliteImporter importer{modelName};
+  Format format = isDumpFormat ? FormatDump : FormatDot;
 
-  bool success = importer.import();
-
-  if (success)
+  switch (format)
   {
+  case FormatDump:
     importer.dump();
-  }
-  else
-  {
-    std::cout << "Could not load model \"" << modelName << "\"" << std::endl;
+    break;
+  case FormatDot:
+    try
+    {
+      IrDotDumper dotDumper;
+      ShapeInference inf;
+      auto g = static_cast<Graph *>(importer.createIR());
+      g->accept(&inf);
+      g->accept(&dotDumper);
+
+      dotDumper.writeDot(std::cout);
+    } catch (PluginException &e) {
+      std::cout << "Error: " << e.what() << std::endl;
+      return -1;
+    }
+    break;
+  default:
+    std::cout << "Error: Unsuported format" << std::endl;
+    return -1;
   }
 
   return 0;