[nnc] Add check for include generation correctness in acl unit tests (#2722)
authorEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 11 Jan 2019 20:56:53 +0000 (23:56 +0300)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 11 Jan 2019 20:56:53 +0000 (23:56 +0300)
- Add checks for correct include DOM generation
- Fix system headers setter in ArtifactModule

Signed-off-by: Efimov Alexander <a.efimov@samsung.com>
contrib/nnc/passes/acl_soft_backend/ArtifactModel.h
contrib/nnc/unittests/acl_backend/DOMToText.cpp
contrib/nnc/unittests/acl_backend/MIRToDOM.cpp

index eae9d84..8329128 100644 (file)
@@ -700,7 +700,7 @@ public:
 
   void addHeaderInclude(const std::string& name) { _headerIncludes.push_back(name); }
   void addSourceInclude(const std::string& name) { _sourceIncludes.push_back(name); }
-  void addHeaderSysInclude(const std::string& name) { _headerIncludes.push_back(name); }
+  void addHeaderSysInclude(const std::string& name) { _headerSysIncludes.push_back(name); }
   void addSourceSysInclude(const std::string& name) { _sourceSysIncludes.push_back(name); }
 
   const std::string& name() const { return _name; }
index d071a01..cda455b 100644 (file)
@@ -478,7 +478,7 @@ TEST(acl_backend_dom_to_text, ArtifactModule) {
   ASSERT_EQ(code_out.str(), ref_data);
 
   // test header code generation
-  const char* ref_decl_data = "#include \"foo.h\"\n#include \"vector\"\n\nclass Class {\npublic:\n  Class();\n\nprivate:\n};\n";
+  const char* ref_decl_data = "#include <vector>\n\n#include \"foo.h\"\n\nclass Class {\npublic:\n  Class();\n\nprivate:\n};\n";
   m.accept(&decl_gen);
 
   ASSERT_EQ(decl_out.str(), ref_decl_data);
index 0e18453..4d7c26e 100644 (file)
@@ -86,11 +86,39 @@ void fillGraph(Graph& g, const OpConstructor& op_constr, const vector<Shape>& in
 }
 
 /**
+ * @brief Checks that list of includes contains all and only desired headers
+ * @param artifact_headers List of headers stored in ArtifaceModule
+ * @param expected_headers Reference set of desired headers
+ * @param message Message to print in case of check failure
+ */
+void checkHeadersSetsEqual(const list<string> &artifact_headers,
+                           const set<string>& expected_headers,
+                           const char* message) {
+  set<string> artifact_set(artifact_headers.begin(), artifact_headers.end());
+  ASSERT_EQ(artifact_set, expected_headers) << message;
+}
+
+/**
  * @brief Check that artifact DOM has all needed includes
  * @param m Root module of DOM
  */
 void checkDomIncludes(const ArtifactModule& m) {
-  // TODO
+  // check system includes, like '#include  <vector>'
+  checkHeadersSetsEqual(m.headerSysIncludes(), {"fstream"}, "header includes diverged");
+
+  checkHeadersSetsEqual(m.sourceIncludes(), {}, "source includes diverged");
+
+  // check ordinary includes, like '#include  "artifact_data.h"'
+  checkHeadersSetsEqual(m.headerIncludes(), {"arm_compute/core/Types.h",
+                                             "arm_compute/runtime/BlobLifetimeManager.h",
+                                             "arm_compute/runtime/CL/CLBufferAllocator.h",
+                                             "arm_compute/runtime/CL/CLFunctions.h",
+                                             "arm_compute/runtime/CL/CLScheduler.h",
+                                             "arm_compute/runtime/MemoryManagerOnDemand.h",
+                                             "arm_compute/runtime/PoolManager.h"},
+                        "system header includes diverged");
+
+  checkHeadersSetsEqual(m.sourceSysIncludes(), {}, "system source includes diverged");
 }
 
 /**
@@ -159,6 +187,11 @@ void checkDomStructure(const ArtifactModule& m,
   checkArtifactClass(*cls, layers, tensors);
 }
 
+/**
+ * @brief Creates TensorVariant with specified shape
+ * @param shape Desired shape of TV
+ * @return TensorVariant with specified shape
+ */
 TensorVariant createTensorVariant(const Shape& shape) {
   size_t data_size = shape.numElements();
   float* data = new float[data_size];
@@ -521,10 +554,4 @@ TEST(acl_backend_mir_to_dom, transpose) {
   const ArtifactModule& m = dom_gen.generate(&g);
 
   checkDomStructure(m, {}, {});
-
-  ArtifactGeneratorCppCode code_gen(std::cerr);
-  m.accept(&code_gen);
-  ArtifactGeneratorCppDecl decl_gen(std::cerr);
-  m.accept(&decl_gen);
-
 }