Revert "Revert "Reorder MLIRContext location in BuiltinAttributes.h""
authorTres Popp <tpopp@google.com>
Mon, 8 Feb 2021 08:44:03 +0000 (09:44 +0100)
committerTres Popp <tpopp@google.com>
Mon, 8 Feb 2021 09:39:58 +0000 (10:39 +0100)
This reverts commit 511dd4f4383b1c2873beac4dbea2df302f1f9d0c along with
a couple fixes.

Original message:
Now the context is the first, rather than the last input.

This better matches the rest of the infrastructure and makes
it easier to move these types to being declaratively specified.

Phabricator: https://reviews.llvm.org/D96111

34 files changed:
debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.cpp
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Lower/FIRBuilder.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/SymbolInterfaces.td
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/StructsGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp

index 629ef1668c8f1919fd001b2964da9040c2b7b867..2633e4b19ebcea1b06aca5419be3a0b748ea9dca 100644 (file)
@@ -34,8 +34,8 @@ mlir::Attribute UnitAttr = mlir::UnitAttr::get(&Context);
 mlir::Attribute FloatAttr = mlir::FloatAttr::get(FloatType, 1.0);
 mlir::Attribute IntegerAttr = mlir::IntegerAttr::get(IntegerType, 10);
 mlir::Attribute TypeAttr = mlir::TypeAttr::get(IndexType);
-mlir::Attribute ArrayAttr = mlir::ArrayAttr::get({UnitAttr}, &Context);
-mlir::Attribute StringAttr = mlir::StringAttr::get("foo", &Context);
+mlir::Attribute ArrayAttr = mlir::ArrayAttr::get(&Context, {UnitAttr});
+mlir::Attribute StringAttr = mlir::StringAttr::get(&Context, "foo");
 mlir::Attribute ElementsAttr = mlir::DenseElementsAttr::get(
     VectorType.cast<mlir::ShapedType>(), llvm::ArrayRef<float>{2.0f, 3.0f});
 
index 8f3670b29d748c52a8d1dda1e70c7c7a324ca672..cde53725b4a4f2dba068989c423427ee0c3d97c5 100644 (file)
@@ -267,27 +267,27 @@ class fir_AllocatableOp<string mnemonic, list<OpTrait> traits = []> :
     static constexpr llvm::StringRef inType() { return "in_type"; }
     static constexpr llvm::StringRef lenpName() { return "len_param_count"; }
     mlir::Type getAllocatedType();
-    
+
     bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; }
-    
+
     unsigned numLenParams() {
       if (auto val = (*this)->getAttrOfType<mlir::IntegerAttr>(lenpName()))
         return val.getInt();
       return 0;
     }
-    
+
     operand_range getLenParams() {
       return {operand_begin(), operand_begin() + numLenParams()};
     }
-    
+
     unsigned numShapeOperands() {
       return operand_end() - operand_begin() + numLenParams();
     }
-    
+
     operand_range getShapeOperands() {
       return {operand_begin() + numLenParams(), operand_end()};
     }
-    
+
     static mlir::Type getRefTy(mlir::Type ty);
 
     /// Get the input type of the allocation
@@ -1131,7 +1131,7 @@ def fir_EmboxCharOp : fir_Op<"emboxchar", [NoSideEffect]> {
   }];
 
   let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len);
-  
+
   let results = (outs fir_BoxCharType);
 
   let assemblyFormat = [{
@@ -1563,7 +1563,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
     p.printFunctionalType((*this)->getOperandTypes(),
         (*this)->getResultTypes());
   }];
-  
+
   let verifier = [{
     auto refTy = ref().getType();
     if (fir::isa_ref_type(refTy)) {
@@ -1598,7 +1598,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
-             
+
   let extraClassDeclaration = [{
     static constexpr llvm::StringRef baseType() { return "base_type"; }
     mlir::Type getBaseType();
@@ -1686,7 +1686,7 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> {
 
   let printer = [{
     p << getOperationName() << ' '
-      << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue() 
+      << (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
       << ", " << (*this)->getAttr(typeAttrName());
     if (getNumOperands()) {
       p << '(';
@@ -2007,7 +2007,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
       CArg<"ValueRange", "llvm::None">:$iterArgs,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
-  
+
   let extraClassDeclaration = [{
     mlir::Block *getBody() { return &region().front(); }
     mlir::Value getIterateVar() { return getBody()->getArgument(1); }
@@ -2276,11 +2276,11 @@ def fir_ConstfOp : fir_Op<"constf", [NoSideEffect]> {
   }];
 
   let arguments = (ins FirRealAttr:$constant);
-  
+
   let results = (outs fir_RealType:$res);
 
   let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)";
-  
+
   let verifier = [{
     if (!getType().isa<fir::RealType>())
       return emitOpError("must be a !fir.real type");
@@ -2357,7 +2357,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
   }];
 
   let results = (outs fir_ComplexType);
-  
+
   let parser = [{
     fir::RealAttr realp;
     fir::RealAttr imagp;
@@ -2455,7 +2455,7 @@ def fir_CmpcOp : fir_Op<"cmpc",
 
 def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
   let summary = "convert a symbol to an SSA value";
-  
+
   let description = [{
     Convert a symbol (a function or global reference) to an SSA-value to be
     used in other Operations.
@@ -2474,7 +2474,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
 
 def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> {
   let summary = "encapsulates all Fortran scalar type conversions";
-  
+
   let description = [{
     Generalized type conversion. Convert the ssa value from type T to type U.
     Not all pairs of types have conversions. When types T and U are the same
@@ -2705,7 +2705,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
     mlir::Type resultType() {
       return fir::AllocaOp::wrapResultType(getType());
     }
-    
+
     /// Return the initializer attribute if it exists, or a null attribute.
     Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); }
 
@@ -2728,9 +2728,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
     }
 
     mlir::FlatSymbolRefAttr getSymbol() {
-      return mlir::FlatSymbolRefAttr::get(
+      return mlir::FlatSymbolRefAttr::get(getContext(),
           (*this)->getAttrOfType<mlir::StringAttr>(
-              mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext());
+              mlir::SymbolTable::getSymbolAttrName()).getValue());
     }
   }];
 }
@@ -2772,7 +2772,7 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> {
   }];
 
   let printer = [{
-    p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) 
+    p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
       << ", " << (*this)->getAttr(intAttrName());
   }];
 
index 3f470d61c28668bae805c0df88427253024d8558..0a8473b73268bb23fa56a7d89fea8ced84a8d622 100644 (file)
@@ -173,7 +173,7 @@ mlir::Value Fortran::lower::FirOpBuilder::createConvert(mlir::Location loc,
 
 fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit(
     mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) {
-  auto strAttr = mlir::StringAttr::get(data, getContext());
+  auto strAttr = mlir::StringAttr::get(getContext(), data);
   auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext());
   mlir::NamedAttribute dataAttr(valTag, strAttr);
   auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext());
index 3883ce2ed0c8b5e4ad73115ee223c05b9ef62810..8523a83711927a010015f18385c860961a676ef5 100644 (file)
@@ -107,7 +107,7 @@ private:
                                              ModuleOp module) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
-      return SymbolRefAttr::get("printf", context);
+      return SymbolRefAttr::get(context, "printf");
 
     // Create a function declaration for printf, the signature is:
     //   * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ private:
     PatternRewriter::InsertionGuard insertGuard(rewriter);
     rewriter.setInsertionPointToStart(module.getBody());
     rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
-    return SymbolRefAttr::get("printf", context);
+    return SymbolRefAttr::get(context, "printf");
   }
 
   /// Return a value representing an access into a global string with the given
index 3883ce2ed0c8b5e4ad73115ee223c05b9ef62810..8523a83711927a010015f18385c860961a676ef5 100644 (file)
@@ -107,7 +107,7 @@ private:
                                              ModuleOp module) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
-      return SymbolRefAttr::get("printf", context);
+      return SymbolRefAttr::get(context, "printf");
 
     // Create a function declaration for printf, the signature is:
     //   * `i32 (i8*, ...)`
@@ -120,7 +120,7 @@ private:
     PatternRewriter::InsertionGuard insertGuard(rewriter);
     rewriter.setInsertionPointToStart(module.getBody());
     rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
-    return SymbolRefAttr::get("printf", context);
+    return SymbolRefAttr::get(context, "printf");
   }
 
   /// Return a value representing an access into a global string with the given
index 794417e996521fdfc2a1cffa6e706ddf28731438..b903c0928d1b425f1023489d99dc615beb601b5e 100644 (file)
@@ -31,7 +31,7 @@ inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
-  auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
   return indexingMaps == maps;
 }
 
@@ -42,7 +42,7 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
   auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
   auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
   auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
-  auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
+  auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
   return indexingMaps == maps;
 }
 
index 34e7e8cfce12cb6fefb042eab3c762cf9eedab84..571c9126f1631c05a314c40f427e9fa45d711a40 100644 (file)
@@ -69,7 +69,7 @@ public:
   using Base::Base;
   using ValueType = ArrayRef<Attribute>;
 
-  static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
+  static ArrayAttr get(MLIRContext *context, ArrayRef<Attribute> value);
 
   ArrayRef<Attribute> getValue() const;
   Attribute operator[](unsigned idx) const;
@@ -126,8 +126,8 @@ public:
   /// attributes. This method assumes that the provided list is unordered. If
   /// the caller can guarantee that the attributes are ordered by name,
   /// getWithSorted should be used instead.
-  static DictionaryAttr get(ArrayRef<NamedAttribute> value,
-                            MLIRContext *context);
+  static DictionaryAttr get(MLIRContext *context,
+                            ArrayRef<NamedAttribute> value);
 
   /// Construct a dictionary with an array of values that is known to already be
   /// sorted by name and uniqued.
@@ -250,7 +250,7 @@ public:
   using Attribute::Attribute;
   using ValueType = bool;
 
-  static BoolAttr get(bool value, MLIRContext *context);
+  static BoolAttr get(MLIRContext *context, bool value);
 
   /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
   /// avoid bringing in all of IntegerAttrs methods.
@@ -292,8 +292,8 @@ public:
   using Base::Base;
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
-  static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
-                        MLIRContext *context);
+  static OpaqueAttr get(MLIRContext *context, Identifier dialect,
+                        StringRef attrData, Type type);
 
   /// Get or create a new OpaqueAttr with the provided dialect and string data.
   /// If the given identifier is not a valid namespace for a dialect, then a
@@ -325,7 +325,7 @@ public:
   using ValueType = StringRef;
 
   /// Get an instance of a StringAttr with the given string.
-  static StringAttr get(StringRef bytes, MLIRContext *context);
+  static StringAttr get(MLIRContext *context, StringRef bytes);
 
   /// Get an instance of a StringAttr with the given string and Type.
   static StringAttr get(StringRef bytes, Type type);
@@ -348,13 +348,12 @@ public:
   using Base::Base;
 
   /// Construct a symbol reference for the given value name.
-  static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
+  static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
 
   /// Construct a symbol reference for the given value name, and a set of nested
   /// references that are further resolve to a nested symbol.
-  static SymbolRefAttr get(StringRef value,
-                           ArrayRef<FlatSymbolRefAttr> references,
-                           MLIRContext *ctx);
+  static SymbolRefAttr get(MLIRContext *ctx, StringRef value,
+                           ArrayRef<FlatSymbolRefAttr> references);
 
   /// Returns the name of the top level symbol reference, i.e. the root of the
   /// reference path.
@@ -377,8 +376,8 @@ public:
   using ValueType = StringRef;
 
   /// Construct a symbol reference for the given value name.
-  static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
-    return SymbolRefAttr::get(value, ctx);
+  static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) {
+    return SymbolRefAttr::get(ctx, value);
   }
 
   /// Returns the name of the held symbol reference.
index be8a68979203073c9a8af7864b529981642271d4..c2eec87272407c4ef72a4731f595654e4eddbf70 100644 (file)
@@ -569,7 +569,7 @@ void FunctionLike<ConcreteType>::setArgAttrs(
   if (attributes.empty())
     return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
   Operation *op = this->getOperation();
-  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
+  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
@@ -646,7 +646,7 @@ void FunctionLike<ConcreteType>::setResultAttrs(
   if (attributes.empty())
     return (void)this->getOperation()->removeAttr(nameOut);
   Operation *op = this->getOperation();
-  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
+  op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
 }
 
 template <typename ConcreteType>
index 45b9c490fd212904395fe32a8c93cf2c5df18608..70cd55dbbb138895af8d8a14c5ebdf7c87174d1b 100644 (file)
@@ -315,7 +315,7 @@ public:
     attrs = newAttrs;
   }
   void setAttrs(ArrayRef<NamedAttribute> newAttrs) {
-    setAttrs(DictionaryAttr::get(newAttrs, getContext()));
+    setAttrs(DictionaryAttr::get(getContext(), newAttrs));
   }
 
   /// Return the specified attribute if present, null otherwise.
index c5f252e45a2018b4f79fabff5a11c69bc76003fe..a7b1fd8cfe6433529f062266f66565d0804b129b 100644 (file)
@@ -44,7 +44,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
       /*defaultImplementation=*/[{
         this->getOperation()->setAttr(
             mlir::SymbolTable::getSymbolAttrName(),
-            StringAttr::get(name, this->getOperation()->getContext()));
+            StringAttr::get(this->getOperation()->getContext(), name));
       }]
     >,
     InterfaceMethod<"Gets the visibility of this symbol.",
index 90ed9cb0ad02d1be62842f48c7b5299380a333cd..9e61e3a9d6e05ac9974381a0eb56d72c20731298 100644 (file)
@@ -42,9 +42,9 @@ bool mlirAttributeIsAArray(MlirAttribute attr) {
 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
                                MlirAttribute const *elements) {
   SmallVector<Attribute, 8> attrs;
-  return wrap(ArrayAttr::get(
-      unwrapList(static_cast<size_t>(numElements), elements, attrs),
-      unwrap(ctx)));
+  return wrap(
+      ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
+                                             elements, attrs)));
 }
 
 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
@@ -71,7 +71,7 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
     attributes.emplace_back(
         Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
         unwrap(elements[i].attribute));
-  return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
+  return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
 }
 
 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
@@ -137,7 +137,7 @@ bool mlirAttributeIsABool(MlirAttribute attr) {
 }
 
 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
-  return wrap(BoolAttr::get(value, unwrap(ctx)));
+  return wrap(BoolAttr::get(unwrap(ctx), value));
 }
 
 bool mlirBoolAttrGetValue(MlirAttribute attr) {
@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
                                 intptr_t dataLength, const char *data,
                                 MlirType type) {
-  return wrap(
-      OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
-                      StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
+  return wrap(OpaqueAttr::get(
+      unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
+      StringRef(data, dataLength), unwrap(type)));
 }
 
 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
@@ -185,7 +185,7 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
 }
 
 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
-  return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
+  return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
 }
 
 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
@@ -211,7 +211,7 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
   refs.reserve(numReferences);
   for (intptr_t i = 0; i < numReferences; ++i)
     refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
-  return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
+  return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
 }
 
 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
@@ -241,7 +241,7 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
 }
 
 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
-  return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
+  return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
 }
 
 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
index 447b00567776d5199ab554a76a6768d8564fbc85..1b9e36180114604cf2e30de021027857f3aba3e2 100644 (file)
@@ -148,7 +148,7 @@ StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation(
   auto blob = convertModuleToBlob(llvmModule, loc, name);
   if (!blob)
     return {};
-  return StringAttr::get({blob->data(), blob->size()}, loc->getContext());
+  return StringAttr::get(loc->getContext(), {blob->data(), blob->size()});
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
index 887d3e798af74629dc56f50c73e7e4c91ea0822b..5b62ca455dea305b67c8399759d60856a40e6f1b 100644 (file)
@@ -177,12 +177,12 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
   // Set SPIR-V binary shader data as an attribute.
   vulkanLaunchCallOp->setAttr(
       kSPIRVBlobAttrName,
-      StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
+      StringAttr::get(loc->getContext(), {binary.data(), binary.size()}));
 
   // Set entry point name as an attribute.
   vulkanLaunchCallOp->setAttr(
       kSPIRVEntryPointAttrName,
-      StringAttr::get(launchOp.getKernelName(), loc->getContext()));
+      StringAttr::get(loc->getContext(), launchOp.getKernelName()));
 
   launchOp.erase();
 }
index 87026e4483e6eaf35fb3e91e59e57b33fea98c83..29cf42205a563408b8ea18b8a66d9fa29dd5d4c8 100644 (file)
@@ -687,8 +687,8 @@ public:
         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
     structValue = rewriter.create<LLVM::InsertValueOp>(
         loc, structType, structValue, executionMode,
-        ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
-                       context));
+        ArrayAttr::get(context,
+                       {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
 
     // Insert extra operands if they exist into execution mode info struct.
     for (unsigned i = 0, e = values.size(); i < e; ++i) {
@@ -696,9 +696,9 @@ public:
       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
       structValue = rewriter.create<LLVM::InsertValueOp>(
           loc, structType, structValue, entry,
-          ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
-                          rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
-                         context));
+          ArrayAttr::get(context,
+                         {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
+                          rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
     }
     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
     rewriter.eraseOp(op);
@@ -1297,17 +1297,17 @@ public:
     switch (funcOp.function_control()) {
 #define DISPATCH(functionControl, llvmAttr)                                    \
   case functionControl:                                                        \
-    newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context));    \
+    newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr}));    \
     break;
 
       DISPATCH(spirv::FunctionControl::Inline,
-               StringAttr::get("alwaysinline", context));
+               StringAttr::get(context, "alwaysinline"));
       DISPATCH(spirv::FunctionControl::DontInline,
-               StringAttr::get("noinline", context));
+               StringAttr::get(context, "noinline"));
       DISPATCH(spirv::FunctionControl::Pure,
-               StringAttr::get("readonly", context));
+               StringAttr::get(context, "readonly"));
       DISPATCH(spirv::FunctionControl::Const,
-               StringAttr::get("readnone", context));
+               StringAttr::get(context, "readnone"));
 
 #undef DISPATCH
 
index 794f4a5d6c1e1126a1fb103b12386b9ce2ea5648..ea0a4259637cb2036818d74a7f480ac482d863dd 100644 (file)
@@ -4016,7 +4016,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
     if (failed(applyPartialConversion(m, target, std::move(patterns))))
       signalPassFailure();
     m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
-               StringAttr::get(this->dataLayout, m.getContext()));
+               StringAttr::get(m.getContext(), this->dataLayout));
   }
 };
 } // end namespace
index 9e88250e2cab14f921f8e219191c1384530ce2ef..683de815a54e468858ee694fce32fc481a1704f9 100644 (file)
@@ -762,7 +762,7 @@ public:
     if (positionAttrs.size() > 1) {
       auto oneDVectorType = reducedVectorTypeBack(vectorType);
       auto nMinusOnePositionAttrs =
-          ArrayAttr::get(positionAttrs.drop_back(), context);
+          ArrayAttr::get(context, positionAttrs.drop_back());
       extracted = rewriter.create<LLVM::ExtractValueOp>(
           loc, typeConverter->convertType(oneDVectorType), extracted,
           nMinusOnePositionAttrs);
@@ -871,7 +871,7 @@ public:
     if (positionAttrs.size() > 1) {
       oneDVectorType = reducedVectorTypeBack(destVectorType);
       auto nMinusOnePositionAttrs =
-          ArrayAttr::get(positionAttrs.drop_back(), context);
+          ArrayAttr::get(context, positionAttrs.drop_back());
       extracted = rewriter.create<LLVM::ExtractValueOp>(
           loc, typeConverter->convertType(oneDVectorType), extracted,
           nMinusOnePositionAttrs);
@@ -887,7 +887,7 @@ public:
     // Potential insertion of resulting 1-D vector into array.
     if (positionAttrs.size() > 1) {
       auto nMinusOnePositionAttrs =
-          ArrayAttr::get(positionAttrs.drop_back(), context);
+          ArrayAttr::get(context, positionAttrs.drop_back());
       inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
                                                       adaptor.dest(), inserted,
                                                       nMinusOnePositionAttrs);
index c1d0820e1cc72524d7061ae28da09a9765b564ab..6ccb59aff35a3f11fe9d79149f8d948416491434 100644 (file)
@@ -53,7 +53,7 @@ LogicalResult setMappingAttr(scf::ParallelOp ploopOp,
   }
   ArrayRef<Attribute> mappingAsAttrs(mapping.data(), mapping.size());
   ploopOp->setAttr(getMappingAttrName(),
-                   ArrayAttr::get(mappingAsAttrs, ploopOp.getContext()));
+                   ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs));
   return success();
 }
 } // namespace gpu
index a3960ae94b2706e746d2162cad4149c746d8648a..e96668779401865ef4d398c715af0bfeb91c92e7 100644 (file)
@@ -225,7 +225,7 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
     if (genericAttrNamesSet.count(attr.first.strref()) > 0)
       genericAttrs.push_back(attr);
   if (!genericAttrs.empty()) {
-    auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
+    auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
     p << genericDictAttr;
   }
 
@@ -833,7 +833,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
   // Handle the corner case of the result being a rank 0 shaped type. Return an
   // emtpy ArrayAttr.
   if (mapsConsumer.empty() && !mapsProducer.empty())
-    return ArrayAttr::get(ArrayRef<Attribute>(), context);
+    return ArrayAttr::get(context, ArrayRef<Attribute>());
   if (mapsProducer.empty() || mapsConsumer.empty() ||
       mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
       mapsProducer.size() != mapsConsumer[0].getNumDims())
@@ -854,7 +854,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
         numLhsDims, /*numSymbols =*/0, reassociations, context)));
     reassociations.clear();
   }
-  return ArrayAttr::get(reassociationMaps, context);
+  return ArrayAttr::get(context, reassociationMaps);
 }
 
 namespace {
index 8db4824cbbd2cc5dc4287c9ebb820fa795255025..c7b76404b2f88ecaa84b1ccf47d055d2ddd70ade 100644 (file)
@@ -137,11 +137,11 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
   // wrong, so abort.
   if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
     return nullptr;
-  return ArrayAttr::get(
-      llvm::to_vector<4>(llvm::map_range(
-          newIndexingMaps,
-          [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
-      context);
+  return ArrayAttr::get(context,
+                        llvm::to_vector<4>(llvm::map_range(
+                            newIndexingMaps, [](AffineMap map) -> Attribute {
+                              return AffineMapAttr::get(map);
+                            })));
 }
 
 /// Modify the region of indexed generic op to drop arguments corresponding to
@@ -220,7 +220,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
 
     rewriter.startRootUpdate(op);
     op.indexing_mapsAttr(newIndexingMapAttr);
-    op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
+    op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
     (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
     rewriter.finalizeRootUpdate(op);
     return success();
@@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
       RankedTensorType::get(newShape, type.getElementType()),
       AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
                      newIndexExprs, context),
-      ArrayAttr::get(reassociationMaps, context)};
+      ArrayAttr::get(context, reassociationMaps)};
   return info;
 }
 
index cac0ae0d081c44602432fa52103bbb2fd1743213..b893f2ba672115c2d2352d3ab54143edf5ff5505 100644 (file)
@@ -77,9 +77,9 @@ LinalgOp mlir::linalg::interchange(LinalgOp op,
   applyPermutationToVector(itTypesVector, interchangeVector);
 
   op->setAttr(getIndexingMapsAttrName(),
-              ArrayAttr::get(newIndexingMaps, context));
+              ArrayAttr::get(context, newIndexingMaps));
   op->setAttr(getIteratorTypesAttrName(),
-              ArrayAttr::get(itTypesVector, context));
+              ArrayAttr::get(context, itTypesVector));
 
   return op;
 }
index 9b62b4289c7732ee753f5301ad96f8e02eca73c9..4ce29b4a83972c4f55406033c7fb166cab3e478e 100644 (file)
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
   });
   for (auto &var : interfaceVarSet) {
     interfaceVars.push_back(SymbolRefAttr::get(
-        cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+        funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
   }
   return success();
 }
index 0902b297ddd35f9117c311a48070333554ac688f..65ebc54aeeb367051a945c01bf1b2c534e553063 100644 (file)
@@ -338,7 +338,7 @@ OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
       return a;
   }
   // If this is reached, all inputs were statically known passing.
-  return BoolAttr::get(true, getContext());
+  return BoolAttr::get(getContext(), true);
 }
 
 static LogicalResult verify(AssumingAllOp op) {
@@ -482,10 +482,10 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
   // Both operands are not needed if one is a scalar.
   if (operands[0] &&
       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
   if (operands[1] &&
       operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
 
   if (operands[0] && operands[1]) {
     auto lhsShape = llvm::to_vector<6>(
@@ -494,7 +494,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
         operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
     SmallVector<int64_t, 6> resultShape;
     if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
-      return BoolAttr::get(true, getContext());
+      return BoolAttr::get(getContext(), true);
   }
 
   // Lastly, see if folding can be completed based on what constraints are known
@@ -506,7 +506,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
     return nullptr;
 
   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
 
   // Because a failing witness result here represents an eventual assertion
   // failure, we do not replace it with a constant witness.
@@ -526,7 +526,7 @@ void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
 OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
   if (llvm::all_of(operands,
                    [&](Attribute a) { return a && a == operands[0]; }))
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
 
   // Because a failing witness result here represents an eventual assertion
   // failure, we do not try to replace it with a constant witness. Similarly, we
@@ -573,14 +573,14 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
 
 OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
   if (lhs() == rhs())
-    return BoolAttr::get(true, getContext());
+    return BoolAttr::get(getContext(), true);
   auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
   if (lhs == nullptr)
     return {};
   auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
   if (rhs == nullptr)
     return {};
-  return BoolAttr::get(lhs == rhs, getContext());
+  return BoolAttr::get(getContext(), lhs == rhs);
 }
 
 //===----------------------------------------------------------------------===//
index c085c1cd33a719df67736396a96ad0f6ce1abd2b..ca2e2731df0398161763a429e5a8133387c8fea7 100644 (file)
@@ -844,7 +844,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
 
   if (lhs() == rhs()) {
     auto val = applyCmpPredicateToEqualOperands(getPredicate());
-    return BoolAttr::get(val, getContext());
+    return BoolAttr::get(getContext(), val);
   }
 
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
@@ -853,7 +853,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
     return {};
 
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
-  return BoolAttr::get(val, getContext());
+  return BoolAttr::get(getContext(), val);
 }
 
 //===----------------------------------------------------------------------===//
index f20b713e8e77adc2eca0562f1c5334bbe2d0d838..9fe8cf23c162089e3340a3ea320a07411a00c17a 100644 (file)
@@ -247,7 +247,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
     if (traitAttrsSet.count(attr.first.strref()) > 0)
       attrs.push_back(attr);
 
-  auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
+  auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
   p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
   p << op.rhs() << ", " << op.acc();
   if (op.masks().size() == 2)
@@ -1445,7 +1445,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
   });
-  return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
+  return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
 }
 
 static LogicalResult verify(InsertStridedSliceOp op) {
index 8a5206eb0b1c8c9d75bdf5dbabbfd63a28cf21ed..bafeccbd53ea8c98d788d12fd7595d7816d63e68 100644 (file)
@@ -92,11 +92,11 @@ NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
 
 BoolAttr Builder::getBoolAttr(bool value) {
-  return BoolAttr::get(value, context);
+  return BoolAttr::get(context, value);
 }
 
 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
-  return DictionaryAttr::get(value, context);
+  return DictionaryAttr::get(context, value);
 }
 
 IntegerAttr Builder::getIndexAttr(int64_t value) {
@@ -200,11 +200,11 @@ FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
 }
 
 StringAttr Builder::getStringAttr(StringRef bytes) {
-  return StringAttr::get(bytes, context);
+  return StringAttr::get(context, bytes);
 }
 
 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
-  return ArrayAttr::get(value, context);
+  return ArrayAttr::get(context, value);
 }
 
 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
@@ -214,12 +214,12 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
   return getSymbolRefAttr(symName.getValue());
 }
 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
-  return SymbolRefAttr::get(value, getContext());
+  return SymbolRefAttr::get(getContext(), value);
 }
 SymbolRefAttr
 Builder::getSymbolRefAttr(StringRef value,
                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
-  return SymbolRefAttr::get(value, nestedReferences, getContext());
+  return SymbolRefAttr::get(getContext(), value, nestedReferences);
 }
 
 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
index 162bed96e3f4f2aa7e333848be6aa2be0c482dc6..58a5b3370364fd4f63f10a150e8b505b876f504a 100644 (file)
@@ -35,7 +35,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
 // ArrayAttr
 //===----------------------------------------------------------------------===//
 
-ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
+ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) {
   return Base::get(context, value);
 }
 
@@ -134,8 +134,8 @@ DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
   return findDuplicateElement(array);
 }
 
-DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
-                                   MLIRContext *context) {
+DictionaryAttr DictionaryAttr::get(MLIRContext *context,
+                                   ArrayRef<NamedAttribute> value) {
   if (value.empty())
     return DictionaryAttr::getEmpty(context);
   assert(llvm::all_of(value,
@@ -267,13 +267,12 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
 // SymbolRefAttr
 //===----------------------------------------------------------------------===//
 
-FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
   return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
 }
 
-SymbolRefAttr SymbolRefAttr::get(StringRef value,
-                                 ArrayRef<FlatSymbolRefAttr> nestedReferences,
-                                 MLIRContext *ctx) {
+SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
+                                 ArrayRef<FlatSymbolRefAttr> nestedReferences) {
   return Base::get(ctx, value, nestedReferences);
 }
 
@@ -294,7 +293,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
 
 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
   if (type.isSignlessInteger(1))
-    return BoolAttr::get(value.getBoolValue(), type.getContext());
+    return BoolAttr::get(type.getContext(), value.getBoolValue());
   return Base::get(type.getContext(), type, value);
 }
 
@@ -377,8 +376,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
 // OpaqueAttr
 //===----------------------------------------------------------------------===//
 
-OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
-                           MLIRContext *context) {
+OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
+                           StringRef attrData, Type type) {
   return Base::get(context, dialect, attrData, type);
 }
 
@@ -409,7 +408,7 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
 // StringAttr
 //===----------------------------------------------------------------------===//
 
-StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
+StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) {
   return get(bytes, NoneType::get(context));
 }
 
index 469aa310140c75fb15e02f388ad46946df686463..db383c691c7cdc5fdde8c0ad849d59600fe6dbc4 100644 (file)
@@ -166,7 +166,7 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
     newAttrs.insert(attr);
   for (auto &attr : getAttrs())
     newAttrs.insert(attr);
-  dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext()));
+  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
 
   // Clone the body.
   getBody().cloneInto(&dest.getBody(), mapper);
index dbfa1bdf6f7e41835496ad7df97e95b198298579..8d13a9c4af329f3139e17afc0dc5a366ff186a03 100644 (file)
@@ -872,7 +872,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
     storage->setType(NoneType::get(ctx));
 }
 
-BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
+BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
 }
 
index b4fe9f854dda792b849f28c2d2f74daa98d6913c..be312689cebb57c182472b4af47f4d53370d0857 100644 (file)
@@ -76,7 +76,7 @@ Operation *Operation::create(Location location, OperationName name,
                              ArrayRef<NamedAttribute> attributes,
                              BlockRange successors, unsigned numRegions) {
   return create(location, name, resultTypes, operands,
-                DictionaryAttr::get(attributes, location.getContext()),
+                DictionaryAttr::get(location.getContext(), attributes),
                 successors, numRegions);
 }
 
index b198600e92422a07d4249c091118c76367049032..70133d22482ffd51f8f4310e419442f3c9136a89 100644 (file)
@@ -46,7 +46,7 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
   assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
   MLIRContext *ctx = symbol->getContext();
 
-  auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
+  auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName);
   results.push_back(leafRef);
 
   // Early exit for when 'within' is the parent of 'symbol'.
@@ -67,13 +67,13 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
         getNameIfSymbol(symbolTableOp, symbolNameId);
     if (!symbolTableName)
       return failure();
-    results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
+    results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs));
 
     symbolTableOp = symbolTableOp->getParentOp();
     if (symbolTableOp == within)
       break;
     nestedRefs.insert(nestedRefs.begin(),
-                      FlatSymbolRefAttr::get(*symbolTableName, ctx));
+                      FlatSymbolRefAttr::get(ctx, *symbolTableName));
   } while (true);
   return success();
 }
@@ -203,7 +203,7 @@ StringRef SymbolTable::getSymbolName(Operation *symbol) {
 /// Sets the name of the given symbol operation.
 void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
   symbol->setAttr(getSymbolAttrName(),
-                  StringAttr::get(name, symbol->getContext()));
+                  StringAttr::get(symbol->getContext(), name));
 }
 
 /// Returns the visibility of the given symbol operation.
@@ -235,7 +235,7 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
          "unknown symbol visibility kind");
 
   StringRef visName = vis == Visibility::Private ? "private" : "nested";
-  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
+  symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
 }
 
 /// Returns the nearest symbol table from a given operation `from`. Returns
@@ -603,7 +603,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
       // doesn't support parent references.
       if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
           symbol->getParentOp())
-        return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
+        return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}};
       return {};
     }
 
@@ -659,7 +659,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
 template <typename IRUnit>
 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
                                                        IRUnit *limit) {
-  return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
+  return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}};
 }
 
 /// Returns true if the given reference 'SubRef' is a sub reference of the
@@ -825,11 +825,11 @@ static Attribute rebuildAttrAfterRAUW(
   if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
     auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
     updateAttrs(make_second_range(newAttrs));
-    return DictionaryAttr::get(newAttrs, dictAttr.getContext());
+    return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
   }
   auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
   updateAttrs(newAttrs);
-  return ArrayAttr::get(newAttrs, container.getContext());
+  return ArrayAttr::get(container.getContext(), newAttrs);
 }
 
 /// Generates a new symbol reference attribute with a new leaf reference.
@@ -839,8 +839,8 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
     return newLeafAttr;
   auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
   nestedRefs.back() = newLeafAttr;
-  return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
-                            oldAttr.getContext());
+  return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(),
+                            nestedRefs);
 }
 
 /// The implementation of SymbolTable::replaceAllSymbolUses below.
@@ -867,7 +867,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
 
   // Generate a new attribute to replace the given attribute.
   MLIRContext *ctx = limit->getContext();
-  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
+  FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol);
   for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
     SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
     auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
@@ -883,13 +883,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
       if (useRef != scope.symbol) {
         if (scope.symbol.isa<FlatSymbolRefAttr>()) {
           replacementRef =
-              SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
+              SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences());
         } else {
           auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
           nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
               newLeafAttr;
           replacementRef =
-              SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
+              SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs);
         }
       }
 
index 859e8e279917a9cd17a556506a42b932a64d0f70..98f74174e5a3ef81656917519d84c3e34d536066 100644 (file)
@@ -148,7 +148,7 @@ Attribute Parser::parseAttribute(Type type) {
       return Attribute();
 
     return type ? StringAttr::get(val, type)
-                : StringAttr::get(val, getContext());
+                : StringAttr::get(getContext(), val);
   }
 
   // Parse a symbol reference attribute.
@@ -176,7 +176,7 @@ Attribute Parser::parseAttribute(Type type) {
 
       std::string nameStr = getToken().getSymbolReference();
       consumeToken(Token::at_identifier);
-      nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+      nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
     }
 
     return builder.getSymbolRefAttr(nameStr, nestedRefs);
index 2f0b3379d152bdabe5b71e9b4fcf1970b9d56251..52ce37eb79ab65e833660389e04b98a0e14a1107 100644 (file)
@@ -742,7 +742,8 @@ void OpEmitter::genAttrGetters() {
 
       body << "  ::mlir::MLIRContext* ctx = getContext();\n";
       body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
-      body << "  return ::mlir::DictionaryAttr::get({\n";
+      body << "  return ::mlir::DictionaryAttr::get(";
+      body << "  ctx, {\n";
       interleave(
           derivedAttrs, body,
           [&](const NamedAttribute &namedAttr) {
@@ -755,7 +756,7 @@ void OpEmitter::genAttrGetters() {
                  << "}";
           },
           ",\n");
-      body << "\n    }, ctx);";
+      body << "});";
     }
   }
 }
index 5595986e4016b280cf72caa93837ce122461e409..52f52238701795861d0138307365a1eabd5c877b 100644 (file)
@@ -150,7 +150,7 @@ static void emitFactoryDef(llvm::StringRef structName,
   }
 
   const char *getEndInfo = R"(
-  ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context);
+  ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields);
   return dict.dyn_cast<{0}>();
 }
 )";
index 0dd9ef9de3e678def820c830c474726541a9c5de..ef0bdd81ee3a5d3da88e0ec0c34d495edcc90c31 100644 (file)
@@ -67,7 +67,7 @@ TEST(StructsGenTest, ClassofExtraFalse) {
   newValues.push_back(wrongAttr);
 
   // Make a new DictionaryAttr and validate.
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }
 
@@ -88,7 +88,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) {
   auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
   newValues.push_back(wrongAttr);
 
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }
 
@@ -113,7 +113,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
   auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
   newValues.push_back(wrongAttr);
 
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }
 
@@ -130,7 +130,7 @@ TEST(StructsGenTest, ClassofMissingFalse) {
       expectedValues.begin() + 1, expectedValues.end());
 
   // Make a new DictionaryAttr and validate it is not a validate TestStruct.
-  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }