}
/**
+ * @brief Checks that list of includes contains all and only desired headers
+ * @param artifact_headers List of headers stored in ArtifaceModule
+ * @param expected_headers Reference set of desired headers
+ * @param message Message to print in case of check failure
+ */
+void checkHeadersSetsEqual(const list<string> &artifact_headers,
+ const set<string>& expected_headers,
+ const char* message) {
+ set<string> artifact_set(artifact_headers.begin(), artifact_headers.end());
+ ASSERT_EQ(artifact_set, expected_headers) << message;
+}
+
+/**
* @brief Check that artifact DOM has all needed includes
* @param m Root module of DOM
*/
void checkDomIncludes(const ArtifactModule& m) {
- // TODO
+ // check system includes, like '#include <vector>'
+ checkHeadersSetsEqual(m.headerSysIncludes(), {"fstream"}, "header includes diverged");
+
+ checkHeadersSetsEqual(m.sourceIncludes(), {}, "source includes diverged");
+
+ // check ordinary includes, like '#include "artifact_data.h"'
+ checkHeadersSetsEqual(m.headerIncludes(), {"arm_compute/core/Types.h",
+ "arm_compute/runtime/BlobLifetimeManager.h",
+ "arm_compute/runtime/CL/CLBufferAllocator.h",
+ "arm_compute/runtime/CL/CLFunctions.h",
+ "arm_compute/runtime/CL/CLScheduler.h",
+ "arm_compute/runtime/MemoryManagerOnDemand.h",
+ "arm_compute/runtime/PoolManager.h"},
+ "system header includes diverged");
+
+ checkHeadersSetsEqual(m.sourceSysIncludes(), {}, "system source includes diverged");
}
/**
checkArtifactClass(*cls, layers, tensors);
}
+/**
+ * @brief Creates TensorVariant with specified shape
+ * @param shape Desired shape of TV
+ * @return TensorVariant with specified shape
+ */
TensorVariant createTensorVariant(const Shape& shape) {
size_t data_size = shape.numElements();
float* data = new float[data_size];
const ArtifactModule& m = dom_gen.generate(&g);
checkDomStructure(m, {}, {});
-
- ArtifactGeneratorCppCode code_gen(std::cerr);
- m.accept(&code_gen);
- ArtifactGeneratorCppDecl decl_gen(std::cerr);
- m.accept(&decl_gen);
-
}