During serialization do a walk of ops in module to find spv.module.
authorMahesh Ravishankar <ravishankarm@google.com>
Fri, 6 Dec 2019 22:26:34 +0000 (14:26 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Dec 2019 22:27:03 +0000 (14:27 -0800)
During lowering, spv.module might be within other modules (for example
gpu kernel module). Walk the module op to find spirv module to
serialize.

PiperOrigin-RevId: 284262550

mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp

index ffd62bd..655f559 100644 (file)
@@ -86,15 +86,16 @@ LogicalResult serializeModule(ModuleOp module, llvm::raw_ostream &output) {
 
   SmallVector<uint32_t, 0> binary;
 
-  auto spirvModules = module.getOps<spirv::ModuleOp>();
+  SmallVector<spirv::ModuleOp, 1> spirvModules;
+  module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); });
 
-  if (spirvModules.begin() == spirvModules.end())
+  if (spirvModules.empty())
     return module.emitError("found no 'spv.module' op");
 
-  if (std::next(spirvModules.begin()) != spirvModules.end())
+  if (spirvModules.size() != 1)
     return module.emitError("found more than one 'spv.module' op");
 
-  if (failed(spirv::serialize(*spirvModules.begin(), binary)))
+  if (failed(spirv::serialize(spirvModules[0], binary)))
     return failure();
 
   output.write(reinterpret_cast<char *>(binary.data()),