Revert "Revert "Reorder MLIRContext location in BuiltinAttributes.h""
authorTres Popp <tpopp@google.com>
Mon, 8 Feb 2021 08:44:03 +0000 (09:44 +0100)
committerTres Popp <tpopp@google.com>
Mon, 8 Feb 2021 09:39:58 +0000 (10:39 +0100)
This reverts commit 511dd4f4383b1c2873beac4dbea2df302f1f9d0c along with
a couple fixes.

Original message:
Now the context is the first, rather than the last input.

This better matches the rest of the infrastructure and makes
it easier to move these types to being declaratively specified.

Phabricator: https://reviews.llvm.org/D96111

34 files changed:
debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.cpp
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Lower/FIRBuilder.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/SymbolInterfaces.td
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/StructsGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp

index 629ef16..2633e4b 100644 (file)
@@ -34,8 +34,8 @@ mlir::Attribute UnitAttr = mlir::UnitAttr::get(&Context);
 mlir::Attribute FloatAttr = mlir::FloatAttr::get(FloatType, 1.0);
 mlir::Attribute IntegerAttr = mlir::IntegerAttr::get(IntegerType, 10);
 mlir::Attribute TypeAttr = mlir::TypeAttr::get(IndexType);
-mlir::Attribute ArrayAttr = mlir::ArrayAttr::get({UnitAttr}, &Context);
-mlir::Attribute StringAttr = mlir::StringAttr::get("foo", &Context);
+mlir::Attribute ArrayAttr = mlir::ArrayAttr::get(&Context, {UnitAttr});
+mlir::Attribute StringAttr = mlir::StringAttr::get(&Context, "foo");
 mlir::Attribute ElementsAttr = mlir::DenseElementsAttr::get(
     VectorType.cast<mlir::ShapedType>(), llvm::ArrayRef<float>{2.0f, 3.0f});
 
index 8f3670b..cde5372 100644 (file)
@@ -267,27 +267,27 @@ class fir_AllocatableOp<string mnemonic, list<OpTrait> traits = []> :
     static constexpr llvm::StringRef inType() { return "in_type"; }
     static constexpr llvm::StringRef lenpName() { return "len_param_count"; }
     mlir::Type getAllocatedType();
-    
+
     bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; }
-    
+
     unsigned numLenParams() {
       if (auto val = (*this)->getAttrOfType<mlir::IntegerAttr>(lenpName()))
         return val.getInt();
       return 0;
     }
-    
+
     operand_range getLenParams() {
       return {operand_begin(), operand_begin() + numLenParams()};
     }
-    
+
     unsigned numShapeOperands() {
       return operand_end() - operand_begin() + numLenParams();
     }
-    
+
     operand_range getShapeOperands() {
       return {operand_begin() + numLenParams(), operand_end()};
     }
-    
+
     static mlir::Type getRefTy(mlir::Type ty);
 
     /// Get the input type of the allocation
@@ -1131,7 +1131,7 @@ def fir_EmboxCharOp : fir_Op<"emboxchar", [NoSideEffect]> {
   }];
 
   let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len);
-  
+
   let results = (outs fir_BoxCharType);
 
   let assemblyFormat = [{
@@ -1563,7 +1563,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
     p.printFunctionalType((*this)->getOperandTypes(),
         (*this)->getResultTypes());
   }];
-  
+
   let verifier = [{
     auto refTy = ref().getType();
     if (fir::isa_ref_type(refTy)) {
@@ -1598,7 +1598,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
-             
+
   let extraClassDeclaration = [{
     static constexpr llvm::StringRef baseType() { return "base_type"; }
     mlir::Type getBaseType();
@@ -1686,7 +1686,7 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> {
 
   let printer = [{
     p << getOperationName() << ' '
-      << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue() 
+      << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
       << ", " << (*this)->getAttr(typeAttrName());
     if (getNumOperands()) {
       p << '(';
@@ -2007,7 +2007,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
       CArg<"ValueRange", "llvm::None">:$iterArgs,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
-  
+
   let extraClassDeclaration = [{
     mlir::Block *getBody() { return &region().front(); }
     mlir::Value getIterateVar() { return getBody()->getArgument(1); }
@@ -2276,11 +2276,11 @@ def fir_ConstfOp : fir_Op<"constf", [NoSideEffect]> {
   }];
 
   let arguments = (ins FirRealAttr:$constant);
-  
+
   let results = (outs fir_RealType:$res);
 
   let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)";
-  
+
   let verifier = [{
     if (!getType().isa<fir::RealType>())
       return emitOpError("must be a !fir.real type");
@@ -2357,7 +2357,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
   }];
 
   let results = (outs fir_ComplexType);
-  
+
   let parser = [{
     fir::RealAttr realp;
     fir::RealAttr imagp;
@@ -2455,7 +2455,7 @@ def fir_CmpcOp : fir_Op<"cmpc",
 
 def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
   let summary = "convert a symbol to an SSA value";
-  
+
   let description = [{
     Convert a symbol (a function or global reference) to an SSA-value to be
     used in other Operations.
@@ -2474,7 +2474,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
 
 def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> {
   let summary = "encapsulates all Fortran scalar type conversions";
-  
+
   let description = [{
     Generalized type conversion. Convert the ssa value from type T to type U.
     Not all pairs of types have conversions. When types T and U are the same
@@ -2705,7 +2705,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
     mlir::Type resultType() {
       return fir::AllocaOp::wrapResultType(getType());
     }
-    
+
     /// Return the initializer attribute if it exists, or a null attribute.
     Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); }
 
@@ -2728,9 +2728,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
     }
 
     mlir::FlatSymbolRefAttr getSymbol() {
-      return mlir::FlatSymbolRefAttr::get(
+      return mlir::FlatSymbolRefAttr::get(getContext(),
           (*this)->getAttrOfType<mlir::StringAttr>(
-              mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext());
+              mlir::SymbolTable::getSymbolAttrName()).getValue());
     }
   }];
 }
@@ -2772,7 +2772,7 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> {
   }];
 
   let printer = [{
-    p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) 
+    p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
       << ", " << (*this)->getAttr(intAttrName());
   }];
 
index 3f470d6..0a8473b 100644 (file)
@@ -173,7 +173,7 @@ mlir::Value Fortran::lower::FirOpBuilder::createConvert(mlir::Location loc,
 
 fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit(
     mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) {
-  auto strAttr = mlir::StringAttr::get(data, getContext());
+  auto strAttr = mlir::StringAttr::get(getContext(), data);
   auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext());
   mlir::NamedAttribute dataAttr(valTag, strAttr);
   auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext());
index 3883ce2..8523a83 100644 (file)
@@ -107,7 +107,7 @@ private:
                                              ModuleOp module) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
-      return SymbolRefAttr::get("printf", context);
+      return SymbolRefAttr::get(context, "printf");
 
     // Create a function declaration for printf, the signature is:
     //   * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ private:
     PatternRewriter::InsertionGuard insertGuard(rewriter);
     rewriter.setInsertionPointToStart(module.getBody());
     rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
-    return SymbolRefAttr::get("printf", context);
+    return SymbolRefAttr::get(context, "printf");
   }
 
   /// Return a value representing an access into a global string with the given
index 3883ce2..8523a83 100644 (file)
@@ -107,7 +107,7 @@ private:
                                              ModuleOp module) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
-      return SymbolRefAttr::get("printf", context);
+      return SymbolRefAttr::get(context, "printf");
 
     // Create a function declaration for printf, the signature is:
     //   * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ private:
     PatternRewriter::InsertionGuard insertGuard(rewriter);
     rewriter.setInsertionPointToStart(module.getBody());
     rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
-    return SymbolRefAttr::get("printf", context);
+    return SymbolRefAttr::get(context, "printf");
   }
 
   /// Return a value representing an access into a global string with the given
index 794417e..b903c09 100644 (file)
@@ -31,7 +31,7 @@ inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
-  auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
   return indexingMaps == maps;
 }
 
@@ -42,7 +42,7 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
-  auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
   return indexingMaps == maps;
 }
 
index 34e7e8c..571c912 100644 (file)
@@ -69,7 +69,7 @@ public:
   using Base::Base;
   using ValueType = ArrayRef<Attribute>;
 
-  static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
+  static ArrayAttr get(MLIRContext *context, ArrayRef<Attribute> value);
 
   ArrayRef<Attribute> getValue() const;
   Attribute operator[](unsigned idx) const;
@@ -126,8 +126,8 @@ public:
   /// attributes. This method assumes that the provided list is unordered. If
   /// the caller can guarantee that the attributes are ordered by name,
   /// getWithSorted should be used instead.
-  static DictionaryAttr get(ArrayRef<NamedAttribute> value,
-                            MLIRContext *context);
+  static DictionaryAttr get(MLIRContext *context,
+                            ArrayRef<NamedAttribute> value);
 
   /// Construct a dictionary with an array of values that is known to already be
   /// sorted by name and uniqued.
@@ -250,7 +250,7 @@ public:
   using Attribute::Attribute;
   using ValueType = bool;
 
-  static BoolAttr get(bool value, MLIRContext *context);
+  static BoolAttr get(MLIRContext *context, bool value);
 
   /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
   /// avoid bringing in all of IntegerAttrs methods.
@@ -292,8 +292,8 @@ public:
   using Base::Base;
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
-  static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
-                        MLIRContext *context);
+  static OpaqueAttr get(MLIRContext *context, Identifier dialect,
+                        StringRef attrData, Type type);
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
   /// If the given identifier is not a valid namespace for a dialect, then a
@@ -325,7 +325,7 @@ public:
   using ValueType = StringRef;
 
   /// Get an instance of a StringAttr with the given string.
-  static StringAttr get(StringRef bytes, MLIRContext *context);
+  static StringAttr get(MLIRContext *context, StringRef bytes);
 
   /// Get an instance of a StringAttr with the given string and Type.
   static StringAttr get(StringRef bytes, Type type);
@@ -348,13 +348,12 @@ public:
   using Base::Base;
 
   /// Construct a symbol reference for the given value name.
-  static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
+  static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
 
   /// Construct a symbol reference for the given value name, and a set of nested
   /// references that are further resolve to a nested symbol.
-  static SymbolRefAttr get(StringRef value,
-                           ArrayRef<FlatSymbolRefAttr> references,
-                           MLIRContext *ctx);
+  static SymbolRefAttr get(MLIRContext *ctx, StringRef value,
+                           ArrayRef<FlatSymbolRefAttr> references);
 
   /// Returns the name of the top level symbol reference, i.e. the root of the
   /// reference path.
@@ -377,8 +376,8 @@ public:
   using ValueType = StringRef;
 
   /// Construct a symbol reference for the given value name.
-  static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
-    return SymbolRefAttr::get(value, ctx);
+  static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) {
+    return SymbolRefAttr::get(ctx, value);
   }
 
   /// Returns the name of the held symbol reference.
index be8a689..c2eec87 100644 (file)
@@ -569,7 +569,7 @@ void FunctionLike<ConcreteType>::setArgAttrs(
   if (attributes.empty())
     return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
   Operation *op = this->getOperation();
-  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
+  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
@@ -646,7 +646,7 @@ void FunctionLike<ConcreteType>::setResultAttrs(
   if (attributes.empty())
     return (void)this->getOperation()->removeAttr(nameOut);
   Operation *op = this->getOperation();
-  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
+  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
index 45b9c49..70cd55d 100644 (file)
@@ -315,7 +315,7 @@ public:
     attrs = newAttrs;
   }
   void setAttrs(ArrayRef<NamedAttribute> newAttrs) {
-    setAttrs(DictionaryAttr::get(newAttrs, getContext()));
+    setAttrs(DictionaryAttr::get(getContext(), newAttrs));
   }
 
   /// Return the specified attribute if present, null otherwise.
index c5f252e..a7b1fd8 100644 (file)
@@ -44,7 +44,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
       /*defaultImplementation=*/[{
         this->getOperation()->setAttr(
             mlir::SymbolTable::getSymbolAttrName(),
-            StringAttr::get(name, this->getOperation()->getContext()));
+            StringAttr::get(this->getOperation()->getContext(), name));
       }]
     >,
     InterfaceMethod<"Gets the visibility of this symbol.",
index 90ed9cb..9e61e3a 100644 (file)
@@ -42,9 +42,9 @@ bool mlirAttributeIsAArray(MlirAttribute attr) {
 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
                                MlirAttribute const *elements) {
   SmallVector<Attribute, 8> attrs;
-  return wrap(ArrayAttr::get(
-      unwrapList(static_cast<size_t>(numElements), elements, attrs),
-      unwrap(ctx)));
+  return wrap(
+      ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
+                                             elements, attrs)));
 }
 
 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
@@ -71,7 +71,7 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
     attributes.emplace_back(
         Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
         unwrap(elements[i].attribute));
-  return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
+  return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
 }
 
 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
@@ -137,7 +137,7 @@ bool mlirAttributeIsABool(MlirAttribute attr) {
 }
 
 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
-  return wrap(BoolAttr::get(value, unwrap(ctx)));
+  return wrap(BoolAttr::get(unwrap(ctx), value));
 }
 
 bool mlirBoolAttrGetValue(MlirAttribute attr) {
@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
                                 intptr_t dataLength, const char *data,
                                 MlirType type) {
-  return wrap(
-      OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
-                      StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
+  return wrap(OpaqueAttr::get(
+      unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
+      StringRef(data, dataLength), unwrap(type)));
 }
 
 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
@@ -185,7 +185,7 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
 }
 
 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
-  return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
+  return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
 }
 
 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
@@ -211,7 +211,7 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
   refs.reserve(numReferences);
   for (intptr_t i = 0; i < numReferences; ++i)
     refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
-  return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
+  return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
 }
 
 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
@@ -241,7 +241,7 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
 }
 
 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
-  return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
+  return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
 }
 
 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
index 447b005..1b9e361 100644 (file)
@@ -148,7 +148,7 @@ StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation(
   auto blob = convertModuleToBlob(llvmModule, loc, name);
   if (!blob)
     return {};
-  return StringAttr::get({blob->data(), blob->size()}, loc->getContext());
+  return StringAttr::get(loc->getContext(), {blob->data(), blob->size()});
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
index 887d3e7..5b62ca4 100644 (file)
@@ -177,12 +177,12 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
   // Set SPIR-V binary shader data as an attribute.
   vulkanLaunchCallOp->setAttr(
       kSPIRVBlobAttrName,
-      StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
+      StringAttr::get(loc->getContext(), {binary.data(), binary.size()}));
 
   // Set entry point name as an attribute.
   vulkanLaunchCallOp->setAttr(
       kSPIRVEntryPointAttrName,
-      StringAttr::get(launchOp.getKernelName(), loc->getContext()));
+      StringAttr::get(loc->getContext(), launchOp.getKernelName()));
 
   launchOp.erase();
 }
index 87026e4..29cf422 100644 (file)
@@ -687,8 +687,8 @@ public:
         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
     structValue = rewriter.create<LLVM::InsertValueOp>(
         loc, structType, structValue, executionMode,
-        ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
-                       context));
+        ArrayAttr::get(context,
+                       {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
 
     // Insert extra operands if they exist into execution mode info struct.
     for (unsigned i = 0, e = values.size(); i < e; ++i) {
@@ -696,9 +696,9 @@ public:
       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
       structValue = rewriter.create<LLVM::InsertValueOp>(
           loc, structType, structValue, entry,
-          ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
-                          rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
-                         context));
+          ArrayAttr::get(context,
+                         {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
+                          rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
     }
     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
     rewriter.eraseOp(op);
@@ -1297,17 +1297,17 @@ public:
     switch (funcOp.function_control()) {
 #define DISPATCH(functionControl, llvmAttr)                                    \
   case functionControl:                                                        \
-    newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context));    \
+    newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
     break;
 
       DISPATCH(spirv::FunctionControl::Inline,
-               StringAttr::get("alwaysinline", context));
+               StringAttr::get(context, "alwaysinline"));
       DISPATCH(spirv::FunctionControl::DontInline,
-               StringAttr::get("noinline", context));
+               StringAttr::get(context, "noinline"));
       DISPATCH(spirv::FunctionControl::Pure,
-               StringAttr::get("readonly", context));
+               StringAttr::get(context, "readonly"));
       DISPATCH(spirv::FunctionControl::Const,
-               StringAttr::get("readnone", context));
+               StringAttr::get(context, "readnone"));
 
 #undef DISPATCH
 
index 794f4a5..ea0a425 100644 (file)
@@ -4016,7 +4016,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
     if (failed(applyPartialConversion(m, target, std::move(patterns))))
       signalPassFailure();
     m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
-               StringAttr::get(this->dataLayout, m.getContext()));
+               StringAttr::get(m.getContext(), this->dataLayout));
   }
 };
 } // end namespace
index 9e88250..683de81 100644 (file)
@@ -762,7 +762,7 @@ public:
     if (positionAttrs.size() > 1) {
       auto oneDVectorType = reducedVectorTypeBack(vectorType);
       auto nMinusOnePositionAttrs =
-          ArrayAttr::get(positionAttrs.drop_back(), context);
+          ArrayAttr::get(context, positionAttrs.drop_back());
       extracted = rewriter.create<LLVM::ExtractValueOp>(
           loc, typeConverter->convertType(oneDVectorType), extracted,
           nMinusOnePositionAttrs);
@@ -871,7 +871,7 @@ public:
     if (positionAttrs.size() > 1) {
       oneDVectorType = reducedVectorTypeBack(destVectorType);
       auto nMinusOnePositionAttrs =
-          ArrayAttr::get(positionAttrs.drop_back(), context);
+          ArrayAttr::get(context, positionAttrs.drop_back());
       extracted = rewriter.create<LLVM::ExtractValueOp>(
           loc, typeConverter->convertType(oneDVectorType), extracted,
           nMinusOnePositionAttrs);
@@ -887,7 +887,7 @@ public:
     // Potential insertion of resulting 1-D vector into array.
     if (positionAttrs.size() > 1) {
       auto nMinusOnePositionAttrs =
-          ArrayAttr::get(positionAttrs.drop_back(), context);
+          ArrayAttr::get(context, positionAttrs.drop_back());
       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
                                                       adaptor.dest(), inserted,
                                                       nMinusOnePositionAttrs);
index c1d0820..6ccb59a 100644 (file)
@@ -53,7 +53,7 @@ LogicalResult setMappingAttr(scf::ParallelOp ploopOp,
   }
   ArrayRef<Attribute> mappingAsAttrs(mapping.data(), mapping.size());
   ploopOp->setAttr(getMappingAttrName(),
-                   ArrayAttr::get(mappingAsAttrs, ploopOp.getContext()));
+                   ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs));
   return success();
 }
 } // namespace gpu
index a3960ae..e966687 100644 (file)
@@ -225,7 +225,7 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
     if (genericAttrNamesSet.count(attr.first.strref()) > 0)
       genericAttrs.push_back(attr);
   if (!genericAttrs.empty()) {
-    auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
+    auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
     p << genericDictAttr;
   }
 
@@ -833,7 +833,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
   // Handle the corner case of the result being a rank 0 shaped type. Return an
   // emtpy ArrayAttr.
   if (mapsConsumer.empty() && !mapsProducer.empty())
-    return ArrayAttr::get(ArrayRef<Attribute>(), context);
+    return ArrayAttr::get(context, ArrayRef<Attribute>());
   if (mapsProducer.empty() || mapsConsumer.empty() ||
       mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
       mapsProducer.size() != mapsConsumer[0].getNumDims())
@@ -854,7 +854,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
         numLhsDims, /*numSymbols =*/0, reassociations, context)));
     reassociations.clear();
   }
-  return ArrayAttr::get(reassociationMaps, context);
+  return ArrayAttr::get(context, reassociationMaps);
 }
 
 namespace {
index 8db4824..c7b7640 100644 (file)
@@ -137,11 +137,11 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
   // wrong, so abort.
   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
     return nullptr;
-  return ArrayAttr::get(
-      llvm::to_vector<4>(llvm::map_range(
-          newIndexingMaps,
-          [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
-      context);
+  return ArrayAttr::get(context,
+                        llvm::to_vector<4>(llvm::map_range(
+                            newIndexingMaps, [](AffineMap map) -> Attribute {
+                              return AffineMapAttr::get(map);
+                            })));
 }
 
 /// Modify the region of indexed generic op to drop arguments corresponding to
@@ -220,7 +220,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
 
     rewriter.startRootUpdate(op);
     op.indexing_mapsAttr(newIndexingMapAttr);
-    op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
+    op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
     (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
     rewriter.finalizeRootUpdate(op);
     return success();
@@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
       RankedTensorType::get(newShape, type.getElementType()),
       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
                      newIndexExprs, context),
-      ArrayAttr::get(reassociationMaps, context)};
+      ArrayAttr::get(context, reassociationMaps)};
   return info;
 }
 
index cac0ae0..b893f2b 100644 (file)
@@ -77,9 +77,9 @@ LinalgOp mlir::linalg::interchange(LinalgOp op,
   applyPermutationToVector(itTypesVector, interchangeVector);
 
   op->setAttr(getIndexingMapsAttrName(),
-              ArrayAttr::get(newIndexingMaps, context));
+              ArrayAttr::get(context, newIndexingMaps));
   op->setAttr(getIteratorTypesAttrName(),
-              ArrayAttr::get(itTypesVector, context));
+              ArrayAttr::get(context, itTypesVector));
 
   return op;
 }
index 9b62b42..4ce29b4 100644 (file)
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
   });
   for (auto &var : interfaceVarSet) {
     interfaceVars.push_back(SymbolRefAttr::get(
-        cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+        funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
   }
   return success();
 }
index 0902b29..65ebc54 100644 (file)
@@ -338,7 +338,7 @@ OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
       return a;
   }
   // If this is reached, all inputs were statically known passing.
-  return BoolAttr::get(true, getContext());
+  return BoolAttr::get(getContext(), true);
 }
 
 static LogicalResult verify(AssumingAllOp op) {
@@ -482,10 +482,10 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
   // Both operands are not needed if one is a scalar.
   if (operands[0] &&
       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
   if (operands[1] &&
       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
 
   if (operands[0] && operands[1]) {
     auto lhsShape = llvm::to_vector<6>(
@@ -494,7 +494,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
     SmallVector<int64_t, 6> resultShape;
     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
-      return BoolAttr::get(true, getContext());
+      return BoolAttr::get(getContext(), true);
   }
 
   // Lastly, see if folding can be completed based on what constraints are known
@@ -506,7 +506,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
     return nullptr;
 
   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
 
   // Because a failing witness result here represents an eventual assertion
   // failure, we do not replace it with a constant witness.
@@ -526,7 +526,7 @@ void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
   if (llvm::all_of(operands,
                    [&](Attribute a) { return a && a == operands[0]; }))
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
 
   // Because a failing witness result here represents an eventual assertion
   // failure, we do not try to replace it with a constant witness. Similarly, we
@@ -573,14 +573,14 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
 
 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
   if (lhs() == rhs())
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
   if (lhs == nullptr)
     return {};
   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
   if (rhs == nullptr)
     return {};
-  return BoolAttr::get(lhs == rhs, getContext());
+  return BoolAttr::get(getContext(), lhs == rhs);
 }
 
 //===----------------------------------------------------------------------===//
index c085c1c..ca2e273 100644 (file)
@@ -844,7 +844,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
 
   if (lhs() == rhs()) {
     auto val = applyCmpPredicateToEqualOperands(getPredicate());
-    return BoolAttr::get(val, getContext());
+    return BoolAttr::get(getContext(), val);
   }
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
@@ -853,7 +853,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
-  return BoolAttr::get(val, getContext());
+  return BoolAttr::get(getContext(), val);
 }
 
 //===----------------------------------------------------------------------===//
index f20b713..9fe8cf2 100644 (file)
@@ -247,7 +247,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
     if (traitAttrsSet.count(attr.first.strref()) > 0)
       attrs.push_back(attr);
 
-  auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
+  auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
   p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
   p << op.rhs() << ", " << op.acc();
   if (op.masks().size() == 2)
@@ -1445,7 +1445,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
   });
-  return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
+  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
 }
 
 static LogicalResult verify(InsertStridedSliceOp op) {
index 8a5206e..bafeccb 100644 (file)
@@ -92,11 +92,11 @@ NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
 
 BoolAttr Builder::getBoolAttr(bool value) {
-  return BoolAttr::get(value, context);
+  return BoolAttr::get(context, value);
 }
 
 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
-  return DictionaryAttr::get(value, context);
+  return DictionaryAttr::get(context, value);
 }
 
 IntegerAttr Builder::getIndexAttr(int64_t value) {
@@ -200,11 +200,11 @@ FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
 }
 
 StringAttr Builder::getStringAttr(StringRef bytes) {
-  return StringAttr::get(bytes, context);
+  return StringAttr::get(context, bytes);
 }
 
 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
-  return ArrayAttr::get(value, context);
+  return ArrayAttr::get(context, value);
 }
 
 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
@@ -214,12 +214,12 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
   return getSymbolRefAttr(symName.getValue());
 }
 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
-  return SymbolRefAttr::get(value, getContext());
+  return SymbolRefAttr::get(getContext(), value);
 }
 SymbolRefAttr
 Builder::getSymbolRefAttr(StringRef value,
                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
-  return SymbolRefAttr::get(value, nestedReferences, getContext());
+  return SymbolRefAttr::get(getContext(), value, nestedReferences);
 }
 
 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
index 162bed9..58a5b33 100644 (file)
@@ -35,7 +35,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
 // ArrayAttr
 //===----------------------------------------------------------------------===//
 
-ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) {
   return Base::get(context, value);
 }
 
@@ -134,8 +134,8 @@ DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
   return findDuplicateElement(array);
 }
 
-DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
-                                   MLIRContext *context) {
+DictionaryAttr DictionaryAttr::get(MLIRContext *context,
+                                   ArrayRef<NamedAttribute> value) {
   if (value.empty())
     return DictionaryAttr::getEmpty(context);
   assert(llvm::all_of(value,
@@ -267,13 +267,12 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
 // SymbolRefAttr
 //===----------------------------------------------------------------------===//
 
-FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
   return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
 }
 
-SymbolRefAttr SymbolRefAttr::get(StringRef value,
-                                 ArrayRef<FlatSymbolRefAttr> nestedReferences,
-                                 MLIRContext *ctx) {
+SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
+                                 ArrayRef<FlatSymbolRefAttr> nestedReferences) {
   return Base::get(ctx, value, nestedReferences);
 }
 
@@ -294,7 +293,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
 
 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
   if (type.isSignlessInteger(1))
-    return BoolAttr::get(value.getBoolValue(), type.getContext());
+    return BoolAttr::get(type.getContext(), value.getBoolValue());
   return Base::get(type.getContext(), type, value);
 }
 
@@ -377,8 +376,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
 // OpaqueAttr
 //===----------------------------------------------------------------------===//
 
-OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
-                           MLIRContext *context) {
+OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
+                           StringRef attrData, Type type) {
   return Base::get(context, dialect, attrData, type);
 }
 
@@ -409,7 +408,7 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
 // StringAttr
 //===----------------------------------------------------------------------===//
 
-StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) {
   return get(bytes, NoneType::get(context));
 }
 
index 469aa31..db383c6 100644 (file)
@@ -166,7 +166,7 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
     newAttrs.insert(attr);
   for (auto &attr : getAttrs())
     newAttrs.insert(attr);
-  dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext()));
+  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
 
   // Clone the body.
   getBody().cloneInto(&dest.getBody(), mapper);
index dbfa1bd..8d13a9c 100644 (file)
@@ -872,7 +872,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
     storage->setType(NoneType::get(ctx));
 }
 
-BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
+BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
 }
 
index b4fe9f8..be31268 100644 (file)
@@ -76,7 +76,7 @@ Operation *Operation::create(Location location, OperationName name,
                              ArrayRef<NamedAttribute> attributes,
                              BlockRange successors, unsigned numRegions) {
   return create(location, name, resultTypes, operands,
-                DictionaryAttr::get(attributes, location.getContext()),
+                DictionaryAttr::get(location.getContext(), attributes),
                 successors, numRegions);
 }
 
index b198600..70133d2 100644 (file)
@@ -46,7 +46,7 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
   assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
   MLIRContext *ctx = symbol->getContext();
 
-  auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
+  auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName);
   results.push_back(leafRef);
 
   // Early exit for when 'within' is the parent of 'symbol'.
@@ -67,13 +67,13 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
         getNameIfSymbol(symbolTableOp, symbolNameId);
     if (!symbolTableName)
       return failure();
-    results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
+    results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs));
 
     symbolTableOp = symbolTableOp->getParentOp();
     if (symbolTableOp == within)
       break;
     nestedRefs.insert(nestedRefs.begin(),
-                      FlatSymbolRefAttr::get(*symbolTableName, ctx));
+                      FlatSymbolRefAttr::get(ctx, *symbolTableName));
   } while (true);
   return success();
 }
@@ -203,7 +203,7 @@ StringRef SymbolTable::getSymbolName(Operation *symbol) {
 /// Sets the name of the given symbol operation.
 void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
   symbol->setAttr(getSymbolAttrName(),
-                  StringAttr::get(name, symbol->getContext()));
+                  StringAttr::get(symbol->getContext(), name));
 }
 
 /// Returns the visibility of the given symbol operation.
@@ -235,7 +235,7 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
          "unknown symbol visibility kind");
 
   StringRef visName = vis == Visibility::Private ? "private" : "nested";
-  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
+  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
 }
 
 /// Returns the nearest symbol table from a given operation `from`. Returns
@@ -603,7 +603,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
       // doesn't support parent references.
       if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
           symbol->getParentOp())
-        return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
+        return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}};
       return {};
     }
 
@@ -659,7 +659,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
 template <typename IRUnit>
 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
                                                        IRUnit *limit) {
-  return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
+  return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}};
 }
 
 /// Returns true if the given reference 'SubRef' is a sub reference of the
@@ -825,11 +825,11 @@ static Attribute rebuildAttrAfterRAUW(
   if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
     auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
     updateAttrs(make_second_range(newAttrs));
-    return DictionaryAttr::get(newAttrs, dictAttr.getContext());
+    return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
   }
   auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
   updateAttrs(newAttrs);
-  return ArrayAttr::get(newAttrs, container.getContext());
+  return ArrayAttr::get(container.getContext(), newAttrs);
 }
 
 /// Generates a new symbol reference attribute with a new leaf reference.
@@ -839,8 +839,8 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
     return newLeafAttr;
   auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
   nestedRefs.back() = newLeafAttr;
-  return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
-                            oldAttr.getContext());
+  return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(),
+                            nestedRefs);
 }
 
 /// The implementation of SymbolTable::replaceAllSymbolUses below.
@@ -867,7 +867,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
 
   // Generate a new attribute to replace the given attribute.
   MLIRContext *ctx = limit->getContext();
-  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
+  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol);
   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
     SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
     auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
@@ -883,13 +883,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
       if (useRef != scope.symbol) {
         if (scope.symbol.isa<FlatSymbolRefAttr>()) {
           replacementRef =
-              SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
+              SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences());
         } else {
           auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
           nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
               newLeafAttr;
           replacementRef =
-              SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
+              SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs);
         }
       }
 
index 859e8e2..98f7417 100644 (file)
@@ -148,7 +148,7 @@ Attribute Parser::parseAttribute(Type type) {
       return Attribute();
 
     return type ? StringAttr::get(val, type)
-                : StringAttr::get(val, getContext());
+                : StringAttr::get(getContext(), val);
   }
 
   // Parse a symbol reference attribute.
@@ -176,7 +176,7 @@ Attribute Parser::parseAttribute(Type type) {
 
       std::string nameStr = getToken().getSymbolReference();
       consumeToken(Token::at_identifier);
-      nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+      nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
     }
 
     return builder.getSymbolRefAttr(nameStr, nestedRefs);
index 2f0b337..52ce37e 100644 (file)
@@ -742,7 +742,8 @@ void OpEmitter::genAttrGetters() {
 
       body << "  ::mlir::MLIRContext* ctx = getContext();\n";
       body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
-      body << "  return ::mlir::DictionaryAttr::get({\n";
+      body << "  return ::mlir::DictionaryAttr::get(";
+      body << "  ctx, {\n";
       interleave(
           derivedAttrs, body,
           [&](const NamedAttribute &namedAttr) {
@@ -755,7 +756,7 @@ void OpEmitter::genAttrGetters() {
                  << "}";
           },
           ",\n");
-      body << "\n    }, ctx);";
+      body << "});";
     }
   }
 }
index 5595986..52f5223 100644 (file)
@@ -150,7 +150,7 @@ static void emitFactoryDef(llvm::StringRef structName,
   }
 
   const char *getEndInfo = R"(
-  ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context);
+  ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields);
   return dict.dyn_cast<{0}>();
 }
 )";
index 0dd9ef9..ef0bdd8 100644 (file)
@@ -67,7 +67,7 @@ TEST(StructsGenTest, ClassofExtraFalse) {
   newValues.push_back(wrongAttr);
 
   // Make a new DictionaryAttr and validate.
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }
 
@@ -88,7 +88,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) {
   auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
   newValues.push_back(wrongAttr);
 
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }
 
@@ -113,7 +113,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
   auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
   newValues.push_back(wrongAttr);
 
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }
 
@@ -130,7 +130,7 @@ TEST(StructsGenTest, ClassofMissingFalse) {
       expectedValues.begin() + 1, expectedValues.end());
 
   // Make a new DictionaryAttr and validate it is not a validate TestStruct.
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }