Use TranslateFromMLIRRegistration for SPIRV roundtrip (NFC)
authorMehdi Amini <joker.eph@gmail.com>
Sun, 23 Aug 2020 00:40:16 +0000 (00:40 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 23 Aug 2020 00:40:50 +0000 (00:40 +0000)
This is aligning it with the other "translation" which operates on a MLIR input.

mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp

index 42b458d..d0aa277 100644 (file)
@@ -115,23 +115,17 @@ void registerToSPIRVTranslation() {
 // Round-trip registration
 //===----------------------------------------------------------------------===//
 
-static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
-                                     bool emitDebugInfo, raw_ostream &output,
-                                     MLIRContext *context) {
-  // Parse an MLIR module from the source manager.
-  auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context));
-  if (!srcModule)
-    return failure();
-
+static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
+                                     raw_ostream &output) {
   SmallVector<uint32_t, 0> binary;
-
-  auto spirvModules = srcModule->getOps<spirv::ModuleOp>();
+  MLIRContext *context = srcModule.getContext();
+  auto spirvModules = srcModule.getOps<spirv::ModuleOp>();
 
   if (spirvModules.begin() == spirvModules.end())
-    return srcModule->emitError("found no 'spv.module' op");
+    return srcModule.emitError("found no 'spv.module' op");
 
   if (std::next(spirvModules.begin()) != spirvModules.end())
-    return srcModule->emitError("found more than one 'spv.module' op");
+    return srcModule.emitError("found more than one 'spv.module' op");
 
   if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
     return failure();
@@ -152,21 +146,16 @@ static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr,
 
 namespace mlir {
 void registerTestRoundtripSPIRV() {
-  TranslateRegistration roundtrip(
-      "test-spirv-roundtrip", [](llvm::SourceMgr &sourceMgr,
-                                 raw_ostream &output, MLIRContext *context) {
-        return roundTripModule(sourceMgr, /*emitDebugInfo=*/false, output,
-                               context);
+  TranslateFromMLIRRegistration roundtrip(
+      "test-spirv-roundtrip", [](ModuleOp module, raw_ostream &output) {
+        return roundTripModule(module, /*emitDebugInfo=*/false, output);
       });
 }
 
 void registerTestRoundtripDebugSPIRV() {
-  TranslateRegistration roundtrip(
-      "test-spirv-roundtrip-debug",
-      [](llvm::SourceMgr &sourceMgr, raw_ostream &output,
-         MLIRContext *context) {
-        return roundTripModule(sourceMgr, /*emitDebugInfo=*/true, output,
-                               context);
+  TranslateFromMLIRRegistration roundtrip(
+      "test-spirv-roundtrip-debug", [](ModuleOp module, raw_ostream &output) {
+        return roundTripModule(module, /*emitDebugInfo=*/true, output);
       });
 }
 } // namespace mlir