From c2c83e97c3ac98ddf5bd685cbfba3f620f59fa51 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 8 Feb 2021 09:44:03 +0100 Subject: [PATCH] Revert "Revert "Reorder MLIRContext location in BuiltinAttributes.h"" This reverts commit 511dd4f4383b1c2873beac4dbea2df302f1f9d0c along with a couple fixes. Original message: Now the context is the first, rather than the last input. This better matches the rest of the infrastructure and makes it easier to move these types to being declaratively specified. Phabricator: https://reviews.llvm.org/D96111 --- .../llvm-prettyprinters/gdb/mlir-support.cpp | 4 +- .../include/flang/Optimizer/Dialect/FIROps.td | 40 +++++++++---------- flang/lib/Lower/FIRBuilder.cpp | 2 +- mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 4 +- mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 4 +- .../mlir/Dialect/Utils/StructuredOpsUtils.h | 4 +- mlir/include/mlir/IR/BuiltinAttributes.h | 25 ++++++------ mlir/include/mlir/IR/FunctionSupport.h | 4 +- mlir/include/mlir/IR/Operation.h | 2 +- mlir/include/mlir/IR/SymbolInterfaces.td | 2 +- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 22 +++++----- .../GPUCommon/ConvertKernelFuncToBlob.cpp | 2 +- ...ConvertGPULaunchFuncToVulkanLaunchFunc.cpp | 4 +- .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 20 +++++----- .../StandardToLLVM/StandardToLLVM.cpp | 2 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 6 +-- .../GPU/Transforms/ParallelLoopMapper.cpp | 2 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 +-- .../Linalg/Transforms/DropUnitDims.cpp | 14 +++---- .../Dialect/Linalg/Transforms/Interchange.cpp | 4 +- .../Transforms/LowerABIAttributesPass.cpp | 2 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 16 ++++---- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 4 +- mlir/lib/Dialect/Vector/VectorOps.cpp | 4 +- mlir/lib/IR/Builders.cpp | 12 +++--- mlir/lib/IR/BuiltinAttributes.cpp | 21 +++++----- mlir/lib/IR/BuiltinDialect.cpp | 2 +- mlir/lib/IR/MLIRContext.cpp | 2 +- mlir/lib/IR/Operation.cpp | 2 +- mlir/lib/IR/SymbolTable.cpp | 28 ++++++------- mlir/lib/Parser/AttributeParser.cpp | 4 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 5 ++- mlir/tools/mlir-tblgen/StructsGen.cpp | 2 +- mlir/unittests/TableGen/StructsGenTest.cpp | 8 ++-- 34 files changed, 142 insertions(+), 143 deletions(-) diff --git a/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.cpp b/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.cpp index 629ef1668c8f..2633e4b19ebc 100644 --- a/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.cpp +++ b/debuginfo-tests/llvm-prettyprinters/gdb/mlir-support.cpp @@ -34,8 +34,8 @@ mlir::Attribute UnitAttr = mlir::UnitAttr::get(&Context); mlir::Attribute FloatAttr = mlir::FloatAttr::get(FloatType, 1.0); mlir::Attribute IntegerAttr = mlir::IntegerAttr::get(IntegerType, 10); mlir::Attribute TypeAttr = mlir::TypeAttr::get(IndexType); -mlir::Attribute ArrayAttr = mlir::ArrayAttr::get({UnitAttr}, &Context); -mlir::Attribute StringAttr = mlir::StringAttr::get("foo", &Context); +mlir::Attribute ArrayAttr = mlir::ArrayAttr::get(&Context, {UnitAttr}); +mlir::Attribute StringAttr = mlir::StringAttr::get(&Context, "foo"); mlir::Attribute ElementsAttr = mlir::DenseElementsAttr::get( VectorType.cast(), llvm::ArrayRef{2.0f, 3.0f}); diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 8f3670b29d74..cde53725b4a4 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -267,27 +267,27 @@ class fir_AllocatableOp traits = []> : static constexpr llvm::StringRef inType() { return "in_type"; } static constexpr llvm::StringRef lenpName() { return "len_param_count"; } mlir::Type getAllocatedType(); - + bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } - + unsigned numLenParams() { if (auto val = (*this)->getAttrOfType(lenpName())) return val.getInt(); return 0; } - + operand_range getLenParams() { return {operand_begin(), operand_begin() + numLenParams()}; } - + unsigned numShapeOperands() { return operand_end() - operand_begin() + numLenParams(); } - + operand_range getShapeOperands() { return {operand_begin() + numLenParams(), operand_end()}; } - + static mlir::Type getRefTy(mlir::Type ty); /// Get the input type of the allocation @@ -1131,7 +1131,7 @@ def fir_EmboxCharOp : fir_Op<"emboxchar", [NoSideEffect]> { }]; let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len); - + let results = (outs fir_BoxCharType); let assemblyFormat = [{ @@ -1563,7 +1563,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { p.printFunctionalType((*this)->getOperandTypes(), (*this)->getResultTypes()); }]; - + let verifier = [{ auto refTy = ref().getType(); if (fir::isa_ref_type(refTy)) { @@ -1598,7 +1598,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { CArg<"ArrayRef", "{}">:$attrs)>, OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attrs)>]; - + let extraClassDeclaration = [{ static constexpr llvm::StringRef baseType() { return "base_type"; } mlir::Type getBaseType(); @@ -1686,7 +1686,7 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> { let printer = [{ p << getOperationName() << ' ' - << (*this)->getAttrOfType(fieldAttrName()).getValue() + << (*this)->getAttrOfType(fieldAttrName()).getValue() << ", " << (*this)->getAttr(typeAttrName()); if (getNumOperands()) { p << '('; @@ -2007,7 +2007,7 @@ def fir_IterWhileOp : region_Op<"iterate_while", CArg<"ValueRange", "llvm::None">:$iterArgs, CArg<"ArrayRef", "{}">:$attributes)> ]; - + let extraClassDeclaration = [{ mlir::Block *getBody() { return ®ion().front(); } mlir::Value getIterateVar() { return getBody()->getArgument(1); } @@ -2276,11 +2276,11 @@ def fir_ConstfOp : fir_Op<"constf", [NoSideEffect]> { }]; let arguments = (ins FirRealAttr:$constant); - + let results = (outs fir_RealType:$res); let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)"; - + let verifier = [{ if (!getType().isa()) return emitOpError("must be a !fir.real type"); @@ -2357,7 +2357,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { }]; let results = (outs fir_ComplexType); - + let parser = [{ fir::RealAttr realp; fir::RealAttr imagp; @@ -2455,7 +2455,7 @@ def fir_CmpcOp : fir_Op<"cmpc", def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> { let summary = "convert a symbol to an SSA value"; - + let description = [{ Convert a symbol (a function or global reference) to an SSA-value to be used in other Operations. @@ -2474,7 +2474,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> { def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> { let summary = "encapsulates all Fortran scalar type conversions"; - + let description = [{ Generalized type conversion. Convert the ssa value from type T to type U. Not all pairs of types have conversions. When types T and U are the same @@ -2705,7 +2705,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { mlir::Type resultType() { return fir::AllocaOp::wrapResultType(getType()); } - + /// Return the initializer attribute if it exists, or a null attribute. Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); } @@ -2728,9 +2728,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { } mlir::FlatSymbolRefAttr getSymbol() { - return mlir::FlatSymbolRefAttr::get( + return mlir::FlatSymbolRefAttr::get(getContext(), (*this)->getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext()); + mlir::SymbolTable::getSymbolAttrName()).getValue()); } }]; } @@ -2772,7 +2772,7 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> { }]; let printer = [{ - p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) + p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) << ", " << (*this)->getAttr(intAttrName()); }]; diff --git a/flang/lib/Lower/FIRBuilder.cpp b/flang/lib/Lower/FIRBuilder.cpp index 3f470d61c286..0a8473b73268 100644 --- a/flang/lib/Lower/FIRBuilder.cpp +++ b/flang/lib/Lower/FIRBuilder.cpp @@ -173,7 +173,7 @@ mlir::Value Fortran::lower::FirOpBuilder::createConvert(mlir::Location loc, fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit( mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) { - auto strAttr = mlir::StringAttr::get(data, getContext()); + auto strAttr = mlir::StringAttr::get(getContext(), data); auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext()); mlir::NamedAttribute dataAttr(valTag, strAttr); auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext()); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 3883ce2ed0c8..8523a8371192 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -107,7 +107,7 @@ private: ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` @@ -120,7 +120,7 @@ private: PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create(module.getLoc(), "printf", llvmFnType); - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); } /// Return a value representing an access into a global string with the given diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 3883ce2ed0c8..8523a8371192 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -107,7 +107,7 @@ private: ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` @@ -120,7 +120,7 @@ private: PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create(module.getLoc(), "printf", llvmFnType); - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); } /// Return a value representing an access into a global string with the given diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index 794417e99652..b903c0928d1b 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -31,7 +31,7 @@ inline bool isRowMajorMatmul(ArrayAttr indexingMaps) { auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); return indexingMaps == maps; } @@ -42,7 +42,7 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) { auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); return indexingMaps == maps; } diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index 34e7e8cfce12..571c9126f163 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -69,7 +69,7 @@ public: using Base::Base; using ValueType = ArrayRef; - static ArrayAttr get(ArrayRef value, MLIRContext *context); + static ArrayAttr get(MLIRContext *context, ArrayRef value); ArrayRef getValue() const; Attribute operator[](unsigned idx) const; @@ -126,8 +126,8 @@ public: /// attributes. This method assumes that the provided list is unordered. If /// the caller can guarantee that the attributes are ordered by name, /// getWithSorted should be used instead. - static DictionaryAttr get(ArrayRef value, - MLIRContext *context); + static DictionaryAttr get(MLIRContext *context, + ArrayRef value); /// Construct a dictionary with an array of values that is known to already be /// sorted by name and uniqued. @@ -250,7 +250,7 @@ public: using Attribute::Attribute; using ValueType = bool; - static BoolAttr get(bool value, MLIRContext *context); + static BoolAttr get(MLIRContext *context, bool value); /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to /// avoid bringing in all of IntegerAttrs methods. @@ -292,8 +292,8 @@ public: using Base::Base; /// Get or create a new OpaqueAttr with the provided dialect and string data. - static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context); + static OpaqueAttr get(MLIRContext *context, Identifier dialect, + StringRef attrData, Type type); /// Get or create a new OpaqueAttr with the provided dialect and string data. /// If the given identifier is not a valid namespace for a dialect, then a @@ -325,7 +325,7 @@ public: using ValueType = StringRef; /// Get an instance of a StringAttr with the given string. - static StringAttr get(StringRef bytes, MLIRContext *context); + static StringAttr get(MLIRContext *context, StringRef bytes); /// Get an instance of a StringAttr with the given string and Type. static StringAttr get(StringRef bytes, Type type); @@ -348,13 +348,12 @@ public: using Base::Base; /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx); + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value); /// Construct a symbol reference for the given value name, and a set of nested /// references that are further resolve to a nested symbol. - static SymbolRefAttr get(StringRef value, - ArrayRef references, - MLIRContext *ctx); + static SymbolRefAttr get(MLIRContext *ctx, StringRef value, + ArrayRef references); /// Returns the name of the top level symbol reference, i.e. the root of the /// reference path. @@ -377,8 +376,8 @@ public: using ValueType = StringRef; /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) { - return SymbolRefAttr::get(value, ctx); + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) { + return SymbolRefAttr::get(ctx, value); } /// Returns the name of the held symbol reference. diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index be8a68979203..c2eec8727240 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -569,7 +569,7 @@ void FunctionLike::setArgAttrs( if (attributes.empty()) return (void)static_cast(this)->removeAttr(nameOut); Operation *op = this->getOperation(); - op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); + op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); } template @@ -646,7 +646,7 @@ void FunctionLike::setResultAttrs( if (attributes.empty()) return (void)this->getOperation()->removeAttr(nameOut); Operation *op = this->getOperation(); - op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); + op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); } template diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 45b9c490fd21..70cd55dbbb13 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -315,7 +315,7 @@ public: attrs = newAttrs; } void setAttrs(ArrayRef newAttrs) { - setAttrs(DictionaryAttr::get(newAttrs, getContext())); + setAttrs(DictionaryAttr::get(getContext(), newAttrs)); } /// Return the specified attribute if present, null otherwise. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index c5f252e45a20..a7b1fd8cfe64 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -44,7 +44,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { /*defaultImplementation=*/[{ this->getOperation()->setAttr( mlir::SymbolTable::getSymbolAttrName(), - StringAttr::get(name, this->getOperation()->getContext())); + StringAttr::get(this->getOperation()->getContext(), name)); }] >, InterfaceMethod<"Gets the visibility of this symbol.", diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 90ed9cb0ad02..9e61e3a9d6e0 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -42,9 +42,9 @@ bool mlirAttributeIsAArray(MlirAttribute attr) { MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements) { SmallVector attrs; - return wrap(ArrayAttr::get( - unwrapList(static_cast(numElements), elements, attrs), - unwrap(ctx))); + return wrap( + ArrayAttr::get(unwrap(ctx), unwrapList(static_cast(numElements), + elements, attrs))); } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { @@ -71,7 +71,7 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, attributes.emplace_back( Identifier::get(unwrap(elements[i].name), unwrap(ctx)), unwrap(elements[i].attribute)); - return wrap(DictionaryAttr::get(attributes, unwrap(ctx))); + return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { @@ -137,7 +137,7 @@ bool mlirAttributeIsABool(MlirAttribute attr) { } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { - return wrap(BoolAttr::get(value, unwrap(ctx))); + return wrap(BoolAttr::get(unwrap(ctx), value)); } bool mlirBoolAttrGetValue(MlirAttribute attr) { @@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) { MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { - return wrap( - OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), - StringRef(data, dataLength), unwrap(type), unwrap(ctx))); + return wrap(OpaqueAttr::get( + unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), + StringRef(data, dataLength), unwrap(type))); } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { @@ -185,7 +185,7 @@ bool mlirAttributeIsAString(MlirAttribute attr) { } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(str), unwrap(ctx))); + return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { @@ -211,7 +211,7 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); - return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx))); + return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { @@ -241,7 +241,7 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { - return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx))); + return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol))); } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp index 447b00567776..1b9e36180114 100644 --- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp @@ -148,7 +148,7 @@ StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation( auto blob = convertModuleToBlob(llvmModule, loc, name); if (!blob) return {}; - return StringAttr::get({blob->data(), blob->size()}, loc->getContext()); + return StringAttr::get(loc->getContext(), {blob->data(), blob->size()}); } std::unique_ptr> diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp index 887d3e798af7..5b62ca455dea 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -177,12 +177,12 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc( // Set SPIR-V binary shader data as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVBlobAttrName, - StringAttr::get({binary.data(), binary.size()}, loc->getContext())); + StringAttr::get(loc->getContext(), {binary.data(), binary.size()})); // Set entry point name as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVEntryPointAttrName, - StringAttr::get(launchOp.getKernelName(), loc->getContext())); + StringAttr::get(loc->getContext(), launchOp.getKernelName())); launchOp.erase(); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 87026e4483e6..29cf42205a56 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -687,8 +687,8 @@ public: rewriter.create(loc, llvmI32Type, executionModeAttr); structValue = rewriter.create( loc, structType, structValue, executionMode, - ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, - context)); + ArrayAttr::get(context, + {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)})); // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { @@ -696,9 +696,9 @@ public: Value entry = rewriter.create(loc, llvmI32Type, attr); structValue = rewriter.create( loc, structType, structValue, entry, - ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1), - rewriter.getIntegerAttr(rewriter.getI32Type(), i)}, - context)); + ArrayAttr::get(context, + {rewriter.getIntegerAttr(rewriter.getI32Type(), 1), + rewriter.getIntegerAttr(rewriter.getI32Type(), i)})); } rewriter.create(loc, ArrayRef({structValue})); rewriter.eraseOp(op); @@ -1297,17 +1297,17 @@ public: switch (funcOp.function_control()) { #define DISPATCH(functionControl, llvmAttr) \ case functionControl: \ - newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ + newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ break; DISPATCH(spirv::FunctionControl::Inline, - StringAttr::get("alwaysinline", context)); + StringAttr::get(context, "alwaysinline")); DISPATCH(spirv::FunctionControl::DontInline, - StringAttr::get("noinline", context)); + StringAttr::get(context, "noinline")); DISPATCH(spirv::FunctionControl::Pure, - StringAttr::get("readonly", context)); + StringAttr::get(context, "readonly")); DISPATCH(spirv::FunctionControl::Const, - StringAttr::get("readnone", context)); + StringAttr::get(context, "readnone")); #undef DISPATCH diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 794f4a5d6c1e..ea0a4259637c 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -4016,7 +4016,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase { if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), - StringAttr::get(this->dataLayout, m.getContext())); + StringAttr::get(m.getContext(), this->dataLayout)); } }; } // end namespace diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9e88250e2cab..683de815a54e 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -762,7 +762,7 @@ public: if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); extracted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); @@ -871,7 +871,7 @@ public: if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); extracted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); @@ -887,7 +887,7 @@ public: // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); inserted = rewriter.create(loc, llvmResultType, adaptor.dest(), inserted, nMinusOnePositionAttrs); diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp index c1d0820e1cc7..6ccb59aff35a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp @@ -53,7 +53,7 @@ LogicalResult setMappingAttr(scf::ParallelOp ploopOp, } ArrayRef mappingAsAttrs(mapping.data(), mapping.size()); ploopOp->setAttr(getMappingAttrName(), - ArrayAttr::get(mappingAsAttrs, ploopOp.getContext())); + ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs)); return success(); } } // namespace gpu diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a3960ae94b27..e96668779401 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -225,7 +225,7 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { if (genericAttrNamesSet.count(attr.first.strref()) > 0) genericAttrs.push_back(attr); if (!genericAttrs.empty()) { - auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext()); + auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); p << genericDictAttr; } @@ -833,7 +833,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, // Handle the corner case of the result being a rank 0 shaped type. Return an // emtpy ArrayAttr. if (mapsConsumer.empty() && !mapsProducer.empty()) - return ArrayAttr::get(ArrayRef(), context); + return ArrayAttr::get(context, ArrayRef()); if (mapsProducer.empty() || mapsConsumer.empty() || mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || mapsProducer.size() != mapsConsumer[0].getNumDims()) @@ -854,7 +854,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, numLhsDims, /*numSymbols =*/0, reassociations, context))); reassociations.clear(); } - return ArrayAttr::get(reassociationMaps, context); + return ArrayAttr::get(context, reassociationMaps); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 8db4824cbbd2..c7b76404b2f8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -137,11 +137,11 @@ static ArrayAttr replaceUnitDims(DenseSet &unitDims, // wrong, so abort. if (!inversePermutation(concatAffineMaps(newIndexingMaps))) return nullptr; - return ArrayAttr::get( - llvm::to_vector<4>(llvm::map_range( - newIndexingMaps, - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })), - context); + return ArrayAttr::get(context, + llvm::to_vector<4>(llvm::map_range( + newIndexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }))); } /// Modify the region of indexed generic op to drop arguments corresponding to @@ -220,7 +220,7 @@ struct FoldUnitDimLoops : public OpRewritePattern { rewriter.startRootUpdate(op); op.indexing_mapsAttr(newIndexingMapAttr); - op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); + op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); rewriter.finalizeRootUpdate(op); return success(); @@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, RankedTensorType::get(newShape, type.getElementType()), AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), newIndexExprs, context), - ArrayAttr::get(reassociationMaps, context)}; + ArrayAttr::get(context, reassociationMaps)}; return info; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index cac0ae0d081c..b893f2ba6721 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -77,9 +77,9 @@ LinalgOp mlir::linalg::interchange(LinalgOp op, applyPermutationToVector(itTypesVector, interchangeVector); op->setAttr(getIndexingMapsAttrName(), - ArrayAttr::get(newIndexingMaps, context)); + ArrayAttr::get(context, newIndexingMaps)); op->setAttr(getIteratorTypesAttrName(), - ArrayAttr::get(itTypesVector, context)); + ArrayAttr::get(context, itTypesVector)); return op; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 9b62b4289c77..4ce29b4a8397 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp, }); for (auto &var : interfaceVarSet) { interfaceVars.push_back(SymbolRefAttr::get( - cast(var).sym_name(), funcOp.getContext())); + funcOp.getContext(), cast(var).sym_name())); } return success(); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 0902b297ddd3..65ebc54aeeb3 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -338,7 +338,7 @@ OpFoldResult AssumingAllOp::fold(ArrayRef operands) { return a; } // If this is reached, all inputs were statically known passing. - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); } static LogicalResult verify(AssumingAllOp op) { @@ -482,10 +482,10 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { // Both operands are not needed if one is a scalar. if (operands[0] && operands[0].cast().getNumElements() == 0) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); if (operands[1] && operands[1].cast().getNumElements() == 0) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); if (operands[0] && operands[1]) { auto lhsShape = llvm::to_vector<6>( @@ -494,7 +494,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { operands[1].cast().getValues()); SmallVector resultShape; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); } // Lastly, see if folding can be completed based on what constraints are known @@ -506,7 +506,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { return nullptr; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion // failure, we do not replace it with a constant witness. @@ -526,7 +526,7 @@ void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, OpFoldResult CstrEqOp::fold(ArrayRef operands) { if (llvm::all_of(operands, [&](Attribute a) { return a && a == operands[0]; })) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion // failure, we do not try to replace it with a constant witness. Similarly, we @@ -573,14 +573,14 @@ OpFoldResult CstrRequireOp::fold(ArrayRef operands) { OpFoldResult ShapeEqOp::fold(ArrayRef operands) { if (lhs() == rhs()) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); auto lhs = operands[0].dyn_cast_or_null(); if (lhs == nullptr) return {}; auto rhs = operands[1].dyn_cast_or_null(); if (rhs == nullptr) return {}; - return BoolAttr::get(lhs == rhs, getContext()); + return BoolAttr::get(getContext(), lhs == rhs); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index c085c1cd33a7..ca2e2731df03 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -844,7 +844,7 @@ OpFoldResult CmpIOp::fold(ArrayRef operands) { if (lhs() == rhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return BoolAttr::get(val, getContext()); + return BoolAttr::get(getContext(), val); } auto lhs = operands.front().dyn_cast_or_null(); @@ -853,7 +853,7 @@ OpFoldResult CmpIOp::fold(ArrayRef operands) { return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return BoolAttr::get(val, getContext()); + return BoolAttr::get(getContext(), val); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index f20b713e8e77..9fe8cf23c162 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -247,7 +247,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) { if (traitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); - auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); + auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", "; p << op.rhs() << ", " << op.acc(); if (op.masks().size() == 2) @@ -1445,7 +1445,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef values, auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); }); - return ArrayAttr::get(llvm::to_vector<8>(attrs), context); + return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); } static LogicalResult verify(InsertStridedSliceOp op) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8a5206eb0b1c..bafeccbd53ea 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -92,11 +92,11 @@ NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) { UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } BoolAttr Builder::getBoolAttr(bool value) { - return BoolAttr::get(value, context); + return BoolAttr::get(context, value); } DictionaryAttr Builder::getDictionaryAttr(ArrayRef value) { - return DictionaryAttr::get(value, context); + return DictionaryAttr::get(context, value); } IntegerAttr Builder::getIndexAttr(int64_t value) { @@ -200,11 +200,11 @@ FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) { } StringAttr Builder::getStringAttr(StringRef bytes) { - return StringAttr::get(bytes, context); + return StringAttr::get(context, bytes); } ArrayAttr Builder::getArrayAttr(ArrayRef value) { - return ArrayAttr::get(value, context); + return ArrayAttr::get(context, value); } FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { @@ -214,12 +214,12 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { return getSymbolRefAttr(symName.getValue()); } FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { - return SymbolRefAttr::get(value, getContext()); + return SymbolRefAttr::get(getContext(), value); } SymbolRefAttr Builder::getSymbolRefAttr(StringRef value, ArrayRef nestedReferences) { - return SymbolRefAttr::get(value, nestedReferences, getContext()); + return SymbolRefAttr::get(getContext(), value, nestedReferences); } ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 162bed96e3f4..58a5b3370364 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -35,7 +35,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } // ArrayAttr //===----------------------------------------------------------------------===// -ArrayAttr ArrayAttr::get(ArrayRef value, MLIRContext *context) { +ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef value) { return Base::get(context, value); } @@ -134,8 +134,8 @@ DictionaryAttr::findDuplicate(SmallVectorImpl &array, return findDuplicateElement(array); } -DictionaryAttr DictionaryAttr::get(ArrayRef value, - MLIRContext *context) { +DictionaryAttr DictionaryAttr::get(MLIRContext *context, + ArrayRef value) { if (value.empty()) return DictionaryAttr::getEmpty(context); assert(llvm::all_of(value, @@ -267,13 +267,12 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, // SymbolRefAttr //===----------------------------------------------------------------------===// -FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { +FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { return Base::get(ctx, value, llvm::None).cast(); } -SymbolRefAttr SymbolRefAttr::get(StringRef value, - ArrayRef nestedReferences, - MLIRContext *ctx) { +SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, + ArrayRef nestedReferences) { return Base::get(ctx, value, nestedReferences); } @@ -294,7 +293,7 @@ ArrayRef SymbolRefAttr::getNestedReferences() const { IntegerAttr IntegerAttr::get(Type type, const APInt &value) { if (type.isSignlessInteger(1)) - return BoolAttr::get(value.getBoolValue(), type.getContext()); + return BoolAttr::get(type.getContext(), value.getBoolValue()); return Base::get(type.getContext(), type, value); } @@ -377,8 +376,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } // OpaqueAttr //===----------------------------------------------------------------------===// -OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context) { +OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect, + StringRef attrData, Type type) { return Base::get(context, dialect, attrData, type); } @@ -409,7 +408,7 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, // StringAttr //===----------------------------------------------------------------------===// -StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { +StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) { return get(bytes, NoneType::get(context)); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp index 469aa310140c..db383c691c7c 100644 --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -166,7 +166,7 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { newAttrs.insert(attr); for (auto &attr : getAttrs()) newAttrs.insert(attr); - dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext())); + dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector())); // Clone the body. getBody().cloneInto(&dest.getBody(), mapper); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index dbfa1bdf6f7e..8d13a9c4af32 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -872,7 +872,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, storage->setType(NoneType::get(ctx)); } -BoolAttr BoolAttr::get(bool value, MLIRContext *context) { +BoolAttr BoolAttr::get(MLIRContext *context, bool value) { return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index b4fe9f854dda..be312689cebb 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -76,7 +76,7 @@ Operation *Operation::create(Location location, OperationName name, ArrayRef attributes, BlockRange successors, unsigned numRegions) { return create(location, name, resultTypes, operands, - DictionaryAttr::get(attributes, location.getContext()), + DictionaryAttr::get(location.getContext(), attributes), successors, numRegions); } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index b198600e9242..70133d22482f 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -46,7 +46,7 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName, assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); MLIRContext *ctx = symbol->getContext(); - auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx); + auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName); results.push_back(leafRef); // Early exit for when 'within' is the parent of 'symbol'. @@ -67,13 +67,13 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName, getNameIfSymbol(symbolTableOp, symbolNameId); if (!symbolTableName) return failure(); - results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx)); + results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs)); symbolTableOp = symbolTableOp->getParentOp(); if (symbolTableOp == within) break; nestedRefs.insert(nestedRefs.begin(), - FlatSymbolRefAttr::get(*symbolTableName, ctx)); + FlatSymbolRefAttr::get(ctx, *symbolTableName)); } while (true); return success(); } @@ -203,7 +203,7 @@ StringRef SymbolTable::getSymbolName(Operation *symbol) { /// Sets the name of the given symbol operation. void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { symbol->setAttr(getSymbolAttrName(), - StringAttr::get(name, symbol->getContext())); + StringAttr::get(symbol->getContext(), name)); } /// Returns the visibility of the given symbol operation. @@ -235,7 +235,7 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { "unknown symbol visibility kind"); StringRef visName = vis == Visibility::Private ? "private" : "nested"; - symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); + symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName)); } /// Returns the nearest symbol table from a given operation `from`. Returns @@ -603,7 +603,7 @@ static SmallVector collectSymbolScopes(Operation *symbol, // doesn't support parent references. if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) - return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}}; + return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}}; return {}; } @@ -659,7 +659,7 @@ static SmallVector collectSymbolScopes(Operation *symbol, template static SmallVector collectSymbolScopes(StringRef symbol, IRUnit *limit) { - return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}}; + return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}}; } /// Returns true if the given reference 'SubRef' is a sub reference of the @@ -825,11 +825,11 @@ static Attribute rebuildAttrAfterRAUW( if (auto dictAttr = container.dyn_cast()) { auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); updateAttrs(make_second_range(newAttrs)); - return DictionaryAttr::get(newAttrs, dictAttr.getContext()); + return DictionaryAttr::get(dictAttr.getContext(), newAttrs); } auto newAttrs = llvm::to_vector<4>(container.cast().getValue()); updateAttrs(newAttrs); - return ArrayAttr::get(newAttrs, container.getContext()); + return ArrayAttr::get(container.getContext(), newAttrs); } /// Generates a new symbol reference attribute with a new leaf reference. @@ -839,8 +839,8 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, return newLeafAttr; auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); nestedRefs.back() = newLeafAttr; - return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs, - oldAttr.getContext()); + return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(), + nestedRefs); } /// The implementation of SymbolTable::replaceAllSymbolUses below. @@ -867,7 +867,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) { // Generate a new attribute to replace the given attribute. MLIRContext *ctx = limit->getContext(); - FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); + FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol); for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); auto walkFn = [&](SymbolTable::SymbolUse symbolUse, @@ -883,13 +883,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) { if (useRef != scope.symbol) { if (scope.symbol.isa()) { replacementRef = - SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx); + SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences()); } else { auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); nestedRefs[scope.symbol.getNestedReferences().size() - 1] = newLeafAttr; replacementRef = - SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx); + SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs); } } diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp index 859e8e279917..98f74174e5a3 100644 --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -148,7 +148,7 @@ Attribute Parser::parseAttribute(Type type) { return Attribute(); return type ? StringAttr::get(val, type) - : StringAttr::get(val, getContext()); + : StringAttr::get(getContext(), val); } // Parse a symbol reference attribute. @@ -176,7 +176,7 @@ Attribute Parser::parseAttribute(Type type) { std::string nameStr = getToken().getSymbolReference(); consumeToken(Token::at_identifier); - nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); + nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); } return builder.getSymbolRefAttr(nameStr, nestedRefs); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 2f0b3379d152..52ce37eb79ab 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -742,7 +742,8 @@ void OpEmitter::genAttrGetters() { body << " ::mlir::MLIRContext* ctx = getContext();\n"; body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; - body << " return ::mlir::DictionaryAttr::get({\n"; + body << " return ::mlir::DictionaryAttr::get("; + body << " ctx, {\n"; interleave( derivedAttrs, body, [&](const NamedAttribute &namedAttr) { @@ -755,7 +756,7 @@ void OpEmitter::genAttrGetters() { << "}"; }, ",\n"); - body << "\n }, ctx);"; + body << "});"; } } } diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp index 5595986e4016..52f522387017 100644 --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -150,7 +150,7 @@ static void emitFactoryDef(llvm::StringRef structName, } const char *getEndInfo = R"( - ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context); + ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields); return dict.dyn_cast<{0}>(); } )"; diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp index 0dd9ef9de3e6..ef0bdd81ee3a 100644 --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -67,7 +67,7 @@ TEST(StructsGenTest, ClassofExtraFalse) { newValues.push_back(wrongAttr); // Make a new DictionaryAttr and validate. - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -88,7 +88,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) { auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); newValues.push_back(wrongAttr); - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -113,7 +113,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) { auto wrongAttr = mlir::NamedAttribute(id, elementsAttr); newValues.push_back(wrongAttr); - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -130,7 +130,7 @@ TEST(StructsGenTest, ClassofMissingFalse) { expectedValues.begin() + 1, expectedValues.end()); // Make a new DictionaryAttr and validate it is not a validate TestStruct. - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } -- 2.34.1