[mlir][NFC] Update the Builtin dialect to use "Both" accessors
authorRiver Riddle <riddleriver@gmail.com>
Tue, 8 Mar 2022 03:13:02 +0000 (19:13 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 8 Mar 2022 20:25:32 +0000 (12:25 -0800)
Differential Revision: https://reviews.llvm.org/D121189

14 files changed:
flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
mlir/include/mlir/IR/BuiltinDialect.td
mlir/include/mlir/IR/BuiltinOps.td
mlir/include/mlir/IR/FunctionInterfaces.h
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/test/lib/IR/TestPrintInvalid.cpp
mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp

index 4c2ff9a..a32dbee 100644 (file)
@@ -67,12 +67,9 @@ public:
   matchAndRewrite(mlir::FuncOp op,
                   mlir::PatternRewriter &rewriter) const override {
     rewriter.startRootUpdate(op);
-    auto result = fir::NameUniquer::deconstruct(op.sym_name());
-    if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
-      auto newName = mangleExternalName(result);
-      op.sym_nameAttr(rewriter.getStringAttr(newName));
-      SymbolTable::setSymbolName(op, newName);
-    }
+    auto result = fir::NameUniquer::deconstruct(op.getSymName());
+    if (fir::NameUniquer::isExternalFacingUniquedName(result))
+      op.setSymNameAttr(rewriter.getStringAttr(mangleExternalName(result)));
     rewriter.finalizeRootUpdate(op);
     return success();
   }
@@ -165,7 +162,7 @@ void ExternalNameConversionPass::runOnOperation() {
   });
 
   target.addDynamicallyLegalOp<mlir::FuncOp>([](mlir::FuncOp op) {
-    return !fir::NameUniquer::needExternalNameMangling(op.sym_name());
+    return !fir::NameUniquer::needExternalNameMangling(op.getSymName());
   });
 
   target.addDynamicallyLegalOp<fir::GlobalOp>([](fir::GlobalOp op) {
index a3d0f0c..838a551 100644 (file)
@@ -34,6 +34,7 @@ def Builtin_Dialect : Dialect {
 
   public:
   }];
+  let emitAccessorPrefix = kEmitAccessorPrefix_Both;
 }
 
 #endif // BUILTIN_BASE
index c447c79..866e594 100644 (file)
@@ -76,7 +76,7 @@ def FuncOp : Builtin_Op<"func", [
   }];
 
   let arguments = (ins SymbolNameAttr:$sym_name,
-                       TypeAttr:$type,
+                       TypeAttrOf<FunctionType>:$type,
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
@@ -110,12 +110,6 @@ def FuncOp : Builtin_Op<"func", [
     /// compatible.
     void cloneInto(FuncOp dest, BlockAndValueMapping &mapper);
 
-    /// Returns the type of this function.
-    /// FIXME: We should drive this via the ODS `type` param.
-    FunctionType getType() { 
-      return getTypeAttr().getValue().cast<FunctionType>();
-    }
-
     //===------------------------------------------------------------------===//
     // CallableOpInterface
     //===------------------------------------------------------------------===//
@@ -144,7 +138,7 @@ def FuncOp : Builtin_Op<"func", [
     LogicalResult verifyType() {
       auto type = getTypeAttr().getValue();
       if (!type.isa<FunctionType>())
-        return emitOpError("requires '" + getTypeAttrName() +
+        return emitOpError("requires '" + FunctionOpInterface::getTypeAttrName() +
                            "' attribute of function type");
       return success();
     }
@@ -188,16 +182,16 @@ def ModuleOp : Builtin_Op<"module", [
 
   let arguments = (ins OptionalAttr<SymbolNameAttr>:$sym_name,
                        OptionalAttr<StrAttr>:$sym_visibility);
-  let regions = (region SizedRegion<1>:$body);
+  let regions = (region SizedRegion<1>:$bodyRegion);
 
-  let assemblyFormat = "($sym_name^)? attr-dict-with-keyword $body";
+  let assemblyFormat = "($sym_name^)? attr-dict-with-keyword $bodyRegion";
   let builders = [OpBuilder<(ins CArg<"Optional<StringRef>", "{}">:$name)>];
   let extraClassDeclaration = [{
     /// Construct a module from the given location with an optional name.
     static ModuleOp create(Location loc, Optional<StringRef> name = llvm::None);
 
     /// Return the name of this module if present.
-    Optional<StringRef> getName() { return sym_name(); }
+    Optional<StringRef> getName() { return getSymName(); }
 
     //===------------------------------------------------------------------===//
     // SymbolOpInterface Methods
index f4ce672..e916457 100644 (file)
@@ -208,7 +208,7 @@ template <typename ConcreteOp>
 LogicalResult verifyTrait(ConcreteOp op) {
   if (!op.getTypeAttr())
     return op.emitOpError("requires a type attribute '")
-           << ConcreteOp::getTypeAttrName() << '\'';
+           << function_interface_impl::getTypeAttrName() << '\'';
 
   if (failed(op.verifyType()))
     return failure();
index 8a72766..3310569 100644 (file)
@@ -1268,6 +1268,11 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
   let constBuilderCall = "::mlir::TypeAttr::get($0)";
 }
 
+class TypeAttrOf<Type ty>
+   : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.description> {
+  let constBuilderCall = "::mlir::TypeAttr::get($0)";
+}
+
 // The mere presence of unit attributes has a meaning.  Therefore, unit
 // attributes are always treated as optional and accessors to them return
 // "true" if the attribute is present and "false" otherwise.
index 9381ea1..a2e8c2c 100644 (file)
@@ -492,18 +492,18 @@ struct UnrealizedConversionCastOpLowering
   matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SmallVector<Type> convertedTypes;
-    if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(),
+    if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
                                               convertedTypes)) &&
-        convertedTypes == adaptor.inputs().getTypes()) {
-      rewriter.replaceOp(op, adaptor.inputs());
+        convertedTypes == adaptor.getInputs().getTypes()) {
+      rewriter.replaceOp(op, adaptor.getInputs());
       return success();
     }
 
     convertedTypes.clear();
-    if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(),
+    if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
                                               convertedTypes)) &&
-        convertedTypes == op.outputs().getType()) {
-      rewriter.replaceOp(op, adaptor.inputs());
+        convertedTypes == op.getOutputs().getType()) {
+      rewriter.replaceOp(op, adaptor.getInputs());
       return success();
     }
     return failure();
index 3e75908..5196817 100644 (file)
@@ -37,15 +37,15 @@ struct UnrealizedConversionCastPassthrough
     auto users = op->getUsers();
     if (!llvm::all_of(users, [&](Operation *user) {
           if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
-            return other.getResultTypes() == op.inputs().getTypes() &&
-                   other.inputs() == op.outputs();
+            return other.getResultTypes() == op.getInputs().getTypes() &&
+                   other.getInputs() == op.getOutputs();
           return false;
         })) {
       return rewriter.notifyMatchFailure(op, "live unrealized conversion cast");
     }
 
     for (Operation *user : users)
-      rewriter.replaceOp(user, op.inputs());
+      rewriter.replaceOp(user, op.getInputs());
 
     rewriter.eraseOp(op);
     return success();
index 62ddf2f..f6629c4 100644 (file)
@@ -463,8 +463,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
 
   ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>();
 
-  ArrayRef<Type> computeFuncInputTypes =
-      computeFunc.func.type().cast<FunctionType>().getInputs();
+  ArrayRef<Type> computeFuncInputTypes = computeFunc.func.getType().getInputs();
 
   // Compared to the parallel compute function async dispatch function takes
   // additional !async.group argument. Also instead of a single `blockIndex` it
@@ -541,7 +540,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
       operands[1] = midIndex;
       operands[2] = end;
 
-      executeBuilder.create<func::CallOp>(executeLoc, func.sym_name(),
+      executeBuilder.create<func::CallOp>(executeLoc, func.getSymName(),
                                           func.getCallableResults(), operands);
       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
     };
@@ -562,7 +561,7 @@ static FuncOp createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
   SmallVector<Value> computeFuncOperands = {blockStart};
   computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end());
 
-  b.create<func::CallOp>(computeFunc.func.sym_name(),
+  b.create<func::CallOp>(computeFunc.func.getSymName(),
                          computeFunc.func.getCallableResults(),
                          computeFuncOperands);
   b.create<func::ReturnOp>(ValueRange());
@@ -609,7 +608,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
     SmallVector<Value> operands = {c0, blockSize};
     appendBlockComputeOperands(operands);
 
-    b.create<func::CallOp>(parallelComputeFunction.func.sym_name(),
+    b.create<func::CallOp>(parallelComputeFunction.func.getSymName(),
                            parallelComputeFunction.func.getCallableResults(),
                            operands);
     b.create<scf::YieldOp>();
@@ -628,7 +627,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
     SmallVector<Value> operands = {group, c0, blockCount, blockSize};
     appendBlockComputeOperands(operands);
 
-    b.create<func::CallOp>(asyncDispatchFunction.sym_name(),
+    b.create<func::CallOp>(asyncDispatchFunction.getSymName(),
                            asyncDispatchFunction.getCallableResults(),
                            operands);
 
@@ -687,7 +686,7 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
     // Call parallel compute function inside the async.execute region.
     auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
                                   Location executeLoc, ValueRange executeArgs) {
-      executeBuilder.create<func::CallOp>(executeLoc, compute.sym_name(),
+      executeBuilder.create<func::CallOp>(executeLoc, compute.getSymName(),
                                           compute.getCallableResults(),
                                           computeFuncOperands(iv));
       executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
@@ -704,7 +703,7 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
   b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder);
 
   // Call parallel compute function for the first block in the caller thread.
-  b.create<func::CallOp>(compute.sym_name(), compute.getCallableResults(),
+  b.create<func::CallOp>(compute.getSymName(), compute.getCallableResults(),
                          computeFuncOperands(c0));
 
   // Wait for the completion of all async compute operations.
index 3f0d089..2a88a68 100644 (file)
@@ -180,7 +180,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   // `async.await` op lowering will create resume blocks for async
   // continuations, and will conditionally branch to cleanup or suspend blocks.
 
-  for (Block &block : func.body().getBlocks()) {
+  for (Block &block : func.getBody().getBlocks()) {
     if (&block == entryBlock || &block == cleanupBlock ||
         &block == suspendBlock)
       continue;
@@ -677,7 +677,7 @@ funcsToCoroutines(ModuleOp module,
     // this dict between the passes is ugly.
     if (isAllowedToBlock(func) ||
         outlinedFunctions.find(func) == outlinedFunctions.end()) {
-      for (Operation &op : func.body().getOps()) {
+      for (Operation &op : func.getBody().getOps()) {
         if (dyn_cast<AwaitOp>(op) || dyn_cast<AwaitAllOp>(op)) {
           funcWorklist.push_back(func);
           break;
index 5daa764..a8abea2 100644 (file)
@@ -149,7 +149,7 @@ getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) {
 /// Return nullptr if there is no such unique ReturnOp.
 static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
   func::ReturnOp returnOp;
-  for (Block &b : funcOp.body()) {
+  for (Block &b : funcOp.getBody()) {
     if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
       if (returnOp)
         return nullptr;
@@ -460,7 +460,7 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
   // 3. Rewrite the bbArgs.
   // Iterate on the original `numArgs` and replace them in order.
   // This guarantees the argument order still matches after the rewrite.
-  Block &frontBlock = funcOp.body().front();
+  Block &frontBlock = funcOp.getBody().front();
   unsigned numArgs = frontBlock.getNumArguments();
   for (unsigned idx = 0; idx < numArgs; ++idx) {
     auto bbArg = frontBlock.getArgument(0);
@@ -527,7 +527,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   // For each FuncOp, the number of CallOpInterface it contains.
   DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
   WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
-    if (!funcOp.body().empty()) {
+    if (!funcOp.getBody().empty()) {
       func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
       if (!returnOp)
         return funcOp->emitError()
@@ -624,7 +624,7 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
       argumentTypes.push_back(desiredMemrefType);
 
       // If funcOp's body is not empty, change the bbArg type and propagate.
-      if (!funcOp.body().empty()) {
+      if (!funcOp.getBody().empty()) {
         BlockArgument bbArg = funcOp.getArgument(argNumber);
         bbArg.setType(desiredMemrefType);
         OpBuilder b(bbArg.getContext());
@@ -886,7 +886,7 @@ struct CallOpInterface
 
     // 4. Create the new CallOp.
     Operation *newCallOp = rewriter.create<func::CallOp>(
-        callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
+        callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
     newCallOp->setAttrs(callOp->getAttrs());
     // Get replacement values for non-tensor / non-equivalent results.
     for (unsigned i = 0; i < replacementValues.size(); ++i) {
@@ -1009,7 +1009,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
   // Analyze ops.
   for (FuncOp funcOp : moduleState.orderedFuncOps) {
     // No body => no analysis.
-    if (funcOp.body().empty())
+    if (funcOp.getBody().empty())
       continue;
 
     // Now analyzing function.
@@ -1037,7 +1037,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
   // Bufferize function bodies.
   for (FuncOp funcOp : moduleState.orderedFuncOps) {
     // No body => no analysis.
-    if (funcOp.body().empty())
+    if (funcOp.getBody().empty())
       continue;
 
     if (failed(bufferizeOp(funcOp, state)))
index 67531ba..15def75 100644 (file)
@@ -283,7 +283,7 @@ public:
 
     IRRewriter rewriter(func.getContext());
 
-    propagateShapesInRegion(func.body());
+    propagateShapesInRegion(func.getBody());
 
     // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with
     // the FuncOp type.
index af94c5e..fbfa61c 100644 (file)
@@ -101,7 +101,8 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                    ArrayRef<DictionaryAttr> argAttrs) {
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
-  state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+  state.addAttribute(function_interface_impl::getTypeAttrName(),
+                     TypeAttr::get(type));
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
 
@@ -287,8 +288,8 @@ LogicalResult ModuleOp::verify() {
 LogicalResult
 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
                                  SmallVectorImpl<OpFoldResult> &foldResults) {
-  OperandRange operands = inputs();
-  ResultRange results = outputs();
+  OperandRange operands = getInputs();
+  ResultRange results = getOutputs();
 
   if (operands.getType() == results.getType()) {
     foldResults.append(operands.begin(), operands.end());
index f62758f..9fec736 100644 (file)
@@ -30,7 +30,7 @@ struct TestPrintInvalidPass
 
   void runOnOperation() override {
     Location loc = getOperation().getLoc();
-    OpBuilder builder(getOperation().body());
+    OpBuilder builder(getOperation().getBodyRegion());
     auto funcOp = builder.create<FuncOp>(
         loc, "test", FunctionType::get(getOperation().getContext(), {}, {}));
     funcOp.addEntryBlock();
index 6f84273..28e5700 100644 (file)
@@ -41,7 +41,7 @@ protected:
 
   // Create ValueShapeRange on the arith.addi operation.
   ValueShapeRange addiRange() {
-    auto &fnBody = mapFn.body();
+    auto &fnBody = mapFn.getBody();
     return std::next(fnBody.front().begin())->getOperands();
   }