[mlir-reduce] Support parsing operations other than 'builtin.module' as top-level
authorrkayaith <rkayaith@gmail.com>
Mon, 19 Sep 2022 20:59:31 +0000 (16:59 -0400)
committerrkayaith <rkayaith@gmail.com>
Mon, 3 Oct 2022 20:15:53 +0000 (16:15 -0400)
This adds a `--no-implicit-module` option, which disables the insertion
of a top-level `builtin.module` during parsing. Although other ops can
now be parsed as top-level, the actual reduction passes are still
restricted to `builtin.module` as it didn't seem straightforward to
update them.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D134242

mlir/lib/Reducer/ReductionTreePass.cpp
mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
mlir/test/mlir-reduce/invalid.mlir [new file with mode: 0644]

index 152c8fb..a30b6d1 100644 (file)
@@ -217,7 +217,12 @@ void ReductionTreePass::runOnOperation() {
   Operation *topOperation = getOperation();
   while (topOperation->getParentOp() != nullptr)
     topOperation = topOperation->getParentOp();
-  ModuleOp module = cast<ModuleOp>(topOperation);
+  ModuleOp module = dyn_cast<ModuleOp>(topOperation);
+  if (!module) {
+    emitError(getOperation()->getLoc())
+        << "top-level op must be 'builtin.module'";
+    return signalPassFailure();
+  }
 
   SmallVector<Operation *, 8> workList;
   workList.push_back(getOperation());
index 4b5fb67..a3eab15 100644 (file)
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Tools/ParseUtilties.h"
 #include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/ToolOutputFile.h"
 
 using namespace mlir;
 
-// Parse and verify the input MLIR file.
-static LogicalResult loadModule(MLIRContext &context,
-                                OwningOpRef<ModuleOp> &module,
-                                StringRef inputFilename) {
-  module = parseSourceFile<ModuleOp>(inputFilename, &context);
-  if (!module)
-    return failure();
+// Parse and verify the input MLIR file. Returns null on error.
+OwningOpRef<Operation *> loadModule(MLIRContext &context,
+                                    StringRef inputFilename,
+                                    bool insertImplictModule) {
+  // Set up the input file.
+  std::string errorMessage;
+  auto file = openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
 
-  return success();
+  llvm::SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+  return parseSourceFileForTool(sourceMgr, &context, insertImplictModule);
 }
 
 LogicalResult mlir::mlirReduceMain(int argc, char **argv,
@@ -55,6 +63,12 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
       "o", llvm::cl::desc("Output filename for the reduced test case"),
       llvm::cl::init("-"), llvm::cl::cat(mlirReduceCategory));
 
+  static llvm::cl::opt<bool> noImplicitModule{
+      "no-implicit-module",
+      llvm::cl::desc(
+          "Disable implicit addition of a top-level module op during parsing"),
+      llvm::cl::init(false)};
+
   llvm::cl::HideUnrelatedOptions(mlirReduceCategory);
 
   llvm::InitLLVM y(argc, argv);
@@ -76,8 +90,9 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
   if (!output)
     return failure();
 
-  OwningOpRef<ModuleOp> moduleRef;
-  if (failed(loadModule(context, moduleRef, inputFilename)))
+  OwningOpRef<Operation *> opRef =
+      loadModule(context, inputFilename, !noImplicitModule);
+  if (!opRef)
     return failure();
 
   auto errorHandler = [&](const Twine &msg) {
@@ -85,16 +100,16 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
   };
 
   // Reduction pass pipeline.
-  PassManager pm(&context);
+  PassManager pm(&context, opRef.get()->getName().getStringRef());
   if (failed(parser.addToPipeline(pm, errorHandler)))
     return failure();
 
-  OwningOpRef<ModuleOp> m = moduleRef.get().clone();
+  OwningOpRef<Operation *> op = opRef.get()->clone();
 
-  if (failed(pm.run(m.get())))
+  if (failed(pm.run(op.get())))
     return failure();
 
-  m->print(output->os());
+  op.get()->print(output->os());
   output->keep();
 
   return success();
diff --git a/mlir/test/mlir-reduce/invalid.mlir b/mlir/test/mlir-reduce/invalid.mlir
new file mode 100644 (file)
index 0000000..3045cb0
--- /dev/null
@@ -0,0 +1,8 @@
+// UNSUPPORTED: system-windows
+// RUN: not mlir-reduce -opt-reduction-pass --no-implicit-module %s |& FileCheck %s --check-prefix=CHECK-PASS
+// RUN: not mlir-reduce -reduction-tree --no-implicit-module %s |& FileCheck %s --check-prefix=CHECK-TREE
+
+// The reduction passes are currently restricted to 'builtin.module'.
+// CHECK-PASS: error: Can't add pass '{{.+}}' restricted to 'builtin.module' on a PassManager intended to run on 'func.func'
+// CHECK-TREE: error: top-level op must be 'builtin.module'
+func.func private @foo()