From 5550c821897ab77e664977121a0e90ad5be1ff59 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 8 May 2023 16:33:54 +0200 Subject: [PATCH] [mlir] Move casting calls from methods to function calls MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Caveats include: - This clang-tidy script probably has more problems. - This only touches C++ code, so nothing that is being generated. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This first patch was created with the following steps. The intention is to only do automated changes at first, so I waste less time if it's reverted, and so the first mass change is more clear as an example to other teams that will need to follow similar steps. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. 4. Some changes have been deleted for the following reasons: - Some files had a variable also named cast - Some files had not included a header file that defines the cast functions - Some files are definitions of the classes that have the casting methods, so the code still refers to the method instead of the function without adding a prefix or removing the method declaration at the same time. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\ mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\ mlir/lib/**/IR/\ mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\ mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\ mlir/test/lib/Dialect/Test/TestTypes.cpp\ mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\ mlir/test/lib/Dialect/Test/TestAttributes.cpp\ mlir/unittests/TableGen/EnumsGenTest.cpp\ mlir/test/python/lib/PythonTestCAPI.cpp\ mlir/include/mlir/IR/ ``` Differential Revision: https://reviews.llvm.org/D150123 --- .../include/mlir/Bytecode/BytecodeImplementation.h | 4 +- .../Conversion/ArithCommon/AttrToLLVMConverter.h | 5 +- .../Conversion/MemRefToLLVM/AllocLikeConversion.h | 2 +- .../FileLineColLocBreakpointManager.h | 2 +- mlir/include/mlir/Dialect/Affine/IR/AffineOps.h | 26 +-- mlir/include/mlir/Dialect/Arith/IR/Arith.h | 6 +- mlir/include/mlir/Dialect/Async/IR/Async.h | 2 +- mlir/include/mlir/Dialect/CommonFolders.h | 46 ++--- .../Func/Transforms/DecomposeCallGraphTypes.h | 2 +- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 2 +- mlir/include/mlir/Dialect/Quant/UniformSupport.h | 4 +- .../mlir/Dialect/SparseTensor/IR/SparseTensor.h | 4 +- .../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 2 +- mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h | 2 +- .../mlir/Dialect/Transform/IR/MatchInterfaces.h | 4 +- .../Dialect/Transform/IR/TransformInterfaces.h | 4 +- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 5 +- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h | 4 +- .../include/mlir/Interfaces/InferTypeOpInterface.h | 4 +- .../mlir/Interfaces/ValueBoundsOpInterface.h | 2 +- mlir/include/mlir/Transforms/DialectConversion.h | 8 +- .../Analysis/AliasAnalysis/LocalAliasAnalysis.cpp | 18 +- .../lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp | 6 +- mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 4 +- mlir/lib/Analysis/Liveness.cpp | 6 +- mlir/lib/Analysis/SliceAnalysis.cpp | 4 +- mlir/lib/AsmParser/AsmParserState.cpp | 4 +- mlir/lib/AsmParser/AttributeParser.cpp | 20 +- mlir/lib/AsmParser/DialectSymbolParser.cpp | 2 +- mlir/lib/AsmParser/Parser.cpp | 4 +- mlir/lib/AsmParser/Parser.h | 2 +- mlir/lib/AsmParser/TypeParser.cpp | 8 +- mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 2 +- mlir/lib/Bytecode/Writer/IRNumbering.cpp | 4 +- mlir/lib/CAPI/Dialect/PDL.cpp | 14 +- mlir/lib/CAPI/Dialect/Quant.cpp | 69 +++---- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 15 +- mlir/lib/CAPI/Dialect/Transform.cpp | 6 +- .../lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 20 +- .../AffineToStandard/AffineToStandard.cpp | 4 +- mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp | 24 +-- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 48 ++--- .../Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp | 4 +- mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 12 +- .../BufferizationToMemRef.cpp | 4 +- .../lib/Conversion/ComplexToLibm/ComplexToLibm.cpp | 8 +- .../ComplexToStandard/ComplexToStandard.cpp | 67 +++--- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 31 ++- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 34 ++-- .../Conversion/GPUCommon/GPUToLLVMConversion.cpp | 14 +- .../Conversion/GPUCommon/OpToFuncCallLowering.h | 10 +- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 29 ++- .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 7 +- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 4 +- mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 16 +- .../ConvertGPULaunchFuncToVulkanLaunchFunc.cpp | 4 +- .../GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp | 15 +- mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp | 10 +- mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 4 +- mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp | 24 +-- mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp | 10 +- .../LinalgToStandard/LinalgToStandard.cpp | 4 +- mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp | 20 +- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 28 +-- mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 6 +- mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp | 26 +-- .../MemRefToLLVM/AllocLikeConversion.cpp | 6 +- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 99 +++++---- .../MemRefToSPIRV/MapMemRefStorageClassPass.cpp | 14 +- .../lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 46 ++--- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 24 +-- mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 6 +- .../Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp | 18 +- .../Conversion/PDLToPDLInterp/PredicateTree.cpp | 28 +-- mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp | 2 +- mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 11 +- .../SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp | 7 +- mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 97 +++++---- .../Conversion/ShapeToStandard/ShapeToStandard.cpp | 28 +-- .../lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp | 2 +- mlir/lib/Conversion/TosaToArith/TosaToArith.cpp | 4 +- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 224 ++++++++++----------- .../Conversion/TosaToLinalg/TosaToLinalgNamed.cpp | 70 ++++--- mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp | 26 +-- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | 48 ++--- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 54 ++--- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 25 ++- .../lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 33 ++- .../Dialect/AMDGPU/Transforms/EmulateAtomics.cpp | 4 +- .../lib/Dialect/Affine/Analysis/AffineAnalysis.cpp | 2 +- mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp | 4 +- mlir/lib/Dialect/Affine/Analysis/Utils.cpp | 24 +-- mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp | 6 +- .../Affine/Transforms/PipelineDataTransfer.cpp | 2 +- .../Affine/Transforms/SimplifyAffineStructures.cpp | 4 +- .../Dialect/Affine/Transforms/SuperVectorize.cpp | 8 +- mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 8 +- mlir/lib/Dialect/Affine/Utils/Utils.cpp | 20 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 12 +- .../Dialect/Arith/Transforms/EmulateWideInt.cpp | 39 ++-- .../Dialect/Arith/Transforms/ReifyValueBounds.cpp | 6 +- mlir/lib/Dialect/Arith/Utils/Utils.cpp | 28 +-- .../Async/Transforms/AsyncRuntimeRefCounting.cpp | 6 +- .../Async/Transforms/AsyncToAsyncRuntime.cpp | 6 +- .../TransformOps/BufferizationTransformOps.cpp | 2 +- .../Transforms/BufferDeallocation.cpp | 6 +- .../Transforms/BufferOptimizations.cpp | 4 +- .../Transforms/BufferResultsToOutParams.cpp | 10 +- .../Bufferization/Transforms/BufferUtils.cpp | 8 +- .../Dialect/Bufferization/Transforms/Bufferize.cpp | 14 +- .../Transforms/EmptyTensorElimination.cpp | 6 +- .../Transforms/FuncBufferizableOpInterfaceImpl.cpp | 18 +- .../Bufferization/Transforms/OneShotAnalysis.cpp | 42 ++-- .../Transforms/OneShotModuleBufferize.cpp | 14 +- .../Transforms/TensorCopyInsertion.cpp | 6 +- .../Dialect/GPU/TransformOps/GPUTransformOps.cpp | 10 +- .../Dialect/GPU/Transforms/AllReduceLowering.cpp | 2 +- .../Dialect/GPU/Transforms/AsyncRegionRewriter.cpp | 8 +- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 4 +- .../lib/Dialect/GPU/Transforms/MemoryPromotion.cpp | 8 +- mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp | 6 +- .../LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp | 6 +- .../Dialect/Linalg/TransformOps/LinalgMatchOps.cpp | 4 +- .../Linalg/TransformOps/LinalgTransformOps.cpp | 105 +++++----- .../lib/Dialect/Linalg/Transforms/ConstantFold.cpp | 14 +- .../Linalg/Transforms/ConvertConv2DToImg2Col.cpp | 26 +-- .../Transforms/ConvertToDestinationStyle.cpp | 26 +-- .../Linalg/Transforms/DataLayoutPropagation.cpp | 2 +- .../Linalg/Transforms/DecomposeLinalgOps.cpp | 10 +- mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp | 6 +- .../lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 17 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 35 ++-- .../Linalg/Transforms/ElementwiseToLinalg.cpp | 8 +- .../Transforms/EraseUnusedOperandsAndResults.cpp | 2 +- .../Transforms/FusePadOpWithLinalgProducer.cpp | 6 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 8 +- .../Dialect/Linalg/Transforms/FusionOnTensors.cpp | 8 +- .../lib/Dialect/Linalg/Transforms/HoistPadding.cpp | 9 +- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 4 +- .../Linalg/Transforms/NamedOpConversions.cpp | 6 +- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 4 +- mlir/lib/Dialect/Linalg/Transforms/Split.cpp | 2 +- .../Dialect/Linalg/Transforms/SplitReduction.cpp | 6 +- .../Dialect/Linalg/Transforms/SubsetHoisting.cpp | 8 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 10 +- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 3 +- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 38 ++-- .../Dialect/Linalg/Transforms/Vectorization.cpp | 57 +++--- mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp | 10 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 18 +- .../Math/Transforms/AlgebraicSimplification.cpp | 4 +- .../Math/Transforms/PolynomialApproximation.cpp | 16 +- .../MemRef/TransformOps/MemRefTransformOps.cpp | 2 +- .../Dialect/MemRef/Transforms/ComposeSubView.cpp | 8 +- .../Dialect/MemRef/Transforms/EmulateWideInt.cpp | 2 +- mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp | 8 +- .../MemRef/Transforms/ExpandStridedMetadata.cpp | 38 ++-- .../Transforms/ExtractAddressComputations.cpp | 4 +- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 2 +- .../MemRef/Transforms/IndependenceTransforms.cpp | 8 +- mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp | 10 +- .../Dialect/MemRef/Transforms/NormalizeMemRefs.cpp | 18 +- .../Transforms/ResolveShapedTypeResultDims.cpp | 8 +- .../MemRef/Transforms/RuntimeOpVerification.cpp | 8 +- .../NVGPU/Transforms/MmaSyncTF32Transform.cpp | 2 +- .../NVGPU/Transforms/OptimizeSharedMemory.cpp | 2 +- mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp | 2 +- mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp | 14 +- .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 6 +- .../SCF/Transforms/BufferizableOpInterfaceImpl.cpp | 80 ++++---- .../SCF/Transforms/LoopCanonicalization.cpp | 6 +- mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 5 +- .../Dialect/SCF/Transforms/TileUsingInterface.cpp | 4 +- .../Transforms/DecorateCompositeTypeLayoutPass.cpp | 4 +- .../SPIRV/Transforms/LowerABIAttributesPass.cpp | 14 +- .../SPIRV/Transforms/RewriteInsertsPass.cpp | 12 +- .../Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 59 +++--- .../SPIRV/Transforms/SPIRVWebGPUTransforms.cpp | 4 +- .../SPIRV/Transforms/UnifyAliasedResourcePass.cpp | 36 ++-- .../lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp | 4 +- mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp | 18 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 4 +- .../Shape/Transforms/OutlineShapeComputation.cpp | 2 +- .../Pipelines/SparseTensorPipelines.cpp | 2 +- .../SparseTensor/Transforms/CodegenUtils.cpp | 36 ++-- .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 6 +- .../SparseTensor/Transforms/LoopEmitter.cpp | 4 +- .../Dialect/SparseTensor/Transforms/LoopEmitter.h | 4 +- .../SparseTensor/Transforms/SparseGPUCodegen.cpp | 8 +- .../Transforms/SparseStorageSpecifierToLLVM.cpp | 6 +- .../Transforms/SparseTensorCodegen.cpp | 12 +- .../Transforms/SparseTensorConversion.cpp | 10 +- .../Transforms/SparseTensorRewriting.cpp | 8 +- .../Transforms/SparseTensorStorageLayout.h | 4 +- .../SparseTensor/Transforms/Sparsification.cpp | 12 +- mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp | 22 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 42 ++-- .../Transforms/ExtractSliceFromReshapeUtils.cpp | 4 +- mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 6 +- .../Tosa/Transforms/TosaDecomposeConv2D.cpp | 10 +- .../Tosa/Transforms/TosaDecomposeDepthwise.cpp | 16 +- .../Tosa/Transforms/TosaDecomposeTransposeConv.cpp | 24 +-- .../Tosa/Transforms/TosaFoldConstantTranspose.cpp | 6 +- .../Dialect/Tosa/Transforms/TosaInferShapes.cpp | 11 +- .../Tosa/Transforms/TosaMakeBroadcastable.cpp | 24 +-- .../lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 2 +- mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp | 34 ++-- mlir/lib/Dialect/Traits.cpp | 16 +- .../lib/Dialect/Transform/Transforms/CheckUses.cpp | 6 +- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 4 +- mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 12 +- mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp | 18 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 16 +- .../Vector/Transforms/LowerVectorBroadcast.cpp | 2 +- .../Vector/Transforms/LowerVectorContract.cpp | 40 ++-- .../Dialect/Vector/Transforms/LowerVectorMask.cpp | 10 +- .../Dialect/Vector/Transforms/LowerVectorScan.cpp | 2 +- .../Vector/Transforms/LowerVectorTransfer.cpp | 20 +- .../Vector/Transforms/LowerVectorTranspose.cpp | 6 +- .../Dialect/Vector/Transforms/VectorDistribute.cpp | 54 ++--- .../Vector/Transforms/VectorDropLeadUnitDim.cpp | 14 +- ...torInsertExtractStridedSliceRewritePatterns.cpp | 31 ++- .../Transforms/VectorTransferOpTransforms.cpp | 42 ++-- .../VectorTransferSplitRewritePatterns.cpp | 6 +- .../Dialect/Vector/Transforms/VectorTransforms.cpp | 26 +-- .../lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 8 +- .../X86Vector/Transforms/LegalizeForLLVMExport.cpp | 4 +- mlir/lib/ExecutionEngine/JitRunner.cpp | 24 +-- mlir/lib/Interfaces/DataLayoutInterfaces.cpp | 52 ++--- .../lib/Interfaces/DestinationStyleOpInterface.cpp | 4 +- mlir/lib/Interfaces/InferIntRangeInterface.cpp | 2 +- mlir/lib/Interfaces/InferTypeOpInterface.cpp | 26 +-- mlir/lib/Interfaces/ValueBoundsOpInterface.cpp | 8 +- mlir/lib/Rewrite/ByteCode.cpp | 34 ++-- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 50 +++-- mlir/lib/Target/LLVMIR/DebugTranslation.cpp | 24 +-- .../Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 4 +- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 8 +- .../Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp | 2 +- .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 +- mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp | 4 +- .../Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp | 6 +- .../Target/LLVMIR/LoopAnnotationTranslation.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 65 +++--- .../Target/SPIRV/Deserialization/Deserializer.cpp | 20 +- .../Target/SPIRV/Deserialization/Deserializer.h | 2 +- .../Target/SPIRV/Serialization/SerializeOps.cpp | 20 +- mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | 71 ++++--- mlir/lib/Target/SPIRV/Serialization/Serializer.h | 2 +- mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp | 4 +- mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp | 4 +- mlir/lib/Transforms/CSE.cpp | 2 +- mlir/lib/Transforms/Inliner.cpp | 2 +- mlir/lib/Transforms/Mem2Reg.cpp | 2 +- mlir/lib/Transforms/Utils/DialectConversion.cpp | 8 +- mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp | 2 +- mlir/lib/Transforms/Utils/RegionUtils.cpp | 12 +- mlir/lib/Transforms/ViewOpGraph.cpp | 6 +- mlir/test/lib/Analysis/TestAliasAnalysis.cpp | 8 +- .../lib/Analysis/TestMemRefStrideCalculation.cpp | 2 +- .../TestOneToNTypeConversionPass.cpp | 8 +- .../lib/Dialect/Affine/TestReifyValueBounds.cpp | 6 +- .../lib/Dialect/Affine/TestVectorizationUtils.cpp | 6 +- .../Dialect/Func/TestDecomposeCallGraphTypes.cpp | 4 +- .../Dialect/Linalg/TestLinalgFusionTransforms.cpp | 4 +- mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp | 4 +- .../lib/Dialect/Tensor/TestTensorTransforms.cpp | 2 +- mlir/test/lib/Dialect/Test/TestDialect.cpp | 36 ++-- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 12 +- mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp | 18 +- .../lib/Dialect/Vector/TestVectorTransforms.cpp | 10 +- .../test/lib/IR/TestBuiltinAttributeInterfaces.cpp | 2 +- mlir/test/lib/IR/TestDiagnostics.cpp | 2 +- mlir/test/lib/IR/TestFunc.cpp | 14 +- mlir/test/lib/IR/TestInterfaces.cpp | 4 +- mlir/test/lib/IR/TestOpaqueLoc.cpp | 2 +- mlir/test/lib/IR/TestPrintDefUse.cpp | 2 +- mlir/test/lib/Transforms/TestTopologicalSort.cpp | 2 +- .../mlir-linalg-ods-yaml-gen.cpp | 6 +- mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp | 8 +- mlir/unittests/IR/AttributeTest.cpp | 14 +- mlir/unittests/IR/InterfaceAttachmentTest.cpp | 49 +++-- .../Interfaces/DataLayoutInterfacesTest.cpp | 10 +- mlir/unittests/Pass/PassManagerTest.cpp | 8 +- 287 files changed, 2114 insertions(+), 2189 deletions(-) diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h index 60f5475..027df35 100644 --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -83,7 +83,7 @@ public: Attribute baseResult; if (failed(readAttribute(baseResult))) return failure(); - if ((result = baseResult.dyn_cast())) + if ((result = dyn_cast(baseResult))) return success(); return emitError() << "expected " << llvm::getTypeName() << ", but got: " << baseResult; @@ -100,7 +100,7 @@ public: Type baseResult; if (failed(readType(baseResult))) return failure(); - if ((result = baseResult.dyn_cast())) + if ((result = dyn_cast(baseResult))) return success(); return emitError() << "expected " << llvm::getTypeName() << ", but got: " << baseResult; diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h index 95b12b6..eea16b4 100644 --- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -38,9 +38,8 @@ public: // Get the name of the arith fastmath attribute. llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); // Remove the source fastmath attribute. - auto arithFMFAttr = - convertedAttr.erase(arithFMFAttrName) - .template dyn_cast_or_null(); + auto arithFMFAttr = dyn_cast_or_null( + convertedAttr.erase(arithFMFAttrName)); if (arithFMFAttr) { llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); convertedAttr.set(targetAttrName, diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h index 8684d35..a063623 100644 --- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h @@ -31,7 +31,7 @@ protected: Value input, Value alignment); static MemRefType getMemRefResultType(Operation *op) { - return op->getResult(0).getType().cast(); + return cast(op->getResult(0).getType()); } /// Computes the alignment for the given memory allocation op. diff --git a/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h index d4f9a6e..7def9e2 100644 --- a/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h +++ b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h @@ -95,7 +95,7 @@ private: std::optional matchFromLocation(Location initialLoc) const { std::optional match = std::nullopt; initialLoc->walk([&](Location loc) { - auto fileLoc = loc.dyn_cast(); + auto fileLoc = dyn_cast(loc); if (!fileLoc) return WalkResult::advance(); StringRef file = fileLoc.getFilename(); diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h index 008d398..1409d52 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -106,7 +106,7 @@ public: /// Returns the source MemRefType for this DMA operation. Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); } MemRefType getSrcMemRefType() { - return getSrcMemRef().getType().cast(); + return cast(getSrcMemRef().getType()); } /// Returns the rank (number of indices) of the source MemRefType. @@ -115,7 +115,7 @@ public: /// Returns the affine map used to access the source memref. AffineMap getSrcMap() { return getSrcMapAttr().getValue(); } AffineMapAttr getSrcMapAttr() { - return (*this)->getAttr(getSrcMapAttrStrName()).cast(); + return cast((*this)->getAttr(getSrcMapAttrStrName())); } /// Returns the source memref affine map indices for this DMA operation. @@ -127,7 +127,7 @@ public: /// Returns the memory space of the source memref. unsigned getSrcMemorySpace() { - return getSrcMemRef().getType().cast().getMemorySpaceAsInt(); + return cast(getSrcMemRef().getType()).getMemorySpaceAsInt(); } /// Returns the operand index of the destination memref. @@ -138,23 +138,23 @@ public: /// Returns the destination MemRefType for this DMA operation. Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); } MemRefType getDstMemRefType() { - return getDstMemRef().getType().cast(); + return cast(getDstMemRef().getType()); } /// Returns the rank (number of indices) of the destination MemRefType. unsigned getDstMemRefRank() { - return getDstMemRef().getType().cast().getRank(); + return cast(getDstMemRef().getType()).getRank(); } /// Returns the memory space of the source memref. unsigned getDstMemorySpace() { - return getDstMemRef().getType().cast().getMemorySpaceAsInt(); + return cast(getDstMemRef().getType()).getMemorySpaceAsInt(); } /// Returns the affine map used to access the destination memref. AffineMap getDstMap() { return getDstMapAttr().getValue(); } AffineMapAttr getDstMapAttr() { - return (*this)->getAttr(getDstMapAttrStrName()).cast(); + return cast((*this)->getAttr(getDstMapAttrStrName())); } /// Returns the destination memref indices for this DMA operation. @@ -172,18 +172,18 @@ public: /// Returns the Tag MemRef for this DMA operation. Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); } MemRefType getTagMemRefType() { - return getTagMemRef().getType().cast(); + return cast(getTagMemRef().getType()); } /// Returns the rank (number of indices) of the tag MemRefType. unsigned getTagMemRefRank() { - return getTagMemRef().getType().cast().getRank(); + return cast(getTagMemRef().getType()).getRank(); } /// Returns the affine map used to access the tag memref. AffineMap getTagMap() { return getTagMapAttr().getValue(); } AffineMapAttr getTagMapAttr() { - return (*this)->getAttr(getTagMapAttrStrName()).cast(); + return cast((*this)->getAttr(getTagMapAttrStrName())); } /// Returns the tag memref indices for this DMA operation. @@ -299,13 +299,13 @@ public: /// Returns the Tag MemRef associated with the DMA operation being waited on. Value getTagMemRef() { return getOperand(0); } MemRefType getTagMemRefType() { - return getTagMemRef().getType().cast(); + return cast(getTagMemRef().getType()); } /// Returns the affine map used to access the tag memref. AffineMap getTagMap() { return getTagMapAttr().getValue(); } AffineMapAttr getTagMapAttr() { - return (*this)->getAttr(getTagMapAttrStrName()).cast(); + return cast((*this)->getAttr(getTagMapAttrStrName())); } /// Returns the tag memref index for this DMA operation. @@ -316,7 +316,7 @@ public: /// Returns the rank (number of indices) of the tag memref. unsigned getTagMemRefRank() { - return getTagMemRef().getType().cast().getRank(); + return cast(getTagMemRef().getType()).getRank(); } /// Impelements the AffineMapAccessInterface. Returns the AffineMapAttr diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index f285262..1b516ff 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -63,7 +63,7 @@ public: Type type); inline int64_t value() { - return arith::ConstantOp::getValue().cast().getInt(); + return cast(arith::ConstantOp::getValue()).getInt(); } static bool classof(Operation *op); @@ -79,7 +79,7 @@ public: const APFloat &value, FloatType type); inline APFloat value() { - return arith::ConstantOp::getValue().cast().getValue(); + return cast(arith::ConstantOp::getValue()).getValue(); } static bool classof(Operation *op); @@ -94,7 +94,7 @@ public: static void build(OpBuilder &builder, OperationState &result, int64_t value); inline int64_t value() { - return arith::ConstantOp::getValue().cast().getInt(); + return cast(arith::ConstantOp::getValue()).getInt(); } static bool classof(Operation *op); diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h index 585a231..9265e2f 100644 --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -49,7 +49,7 @@ namespace async { /// Returns true if the type is reference counted at runtime. inline bool isRefCounted(Type type) { - return type.isa(); + return isa(type); } } // namespace async diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h index 8007027..d3fbc72 100644 --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -36,9 +36,9 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, if (!resultType || !operands[0] || !operands[1]) return {}; - if (operands[0].isa() && operands[1].isa()) { - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); + if (isa(operands[0]) && isa(operands[1])) { + auto lhs = cast(operands[0]); + auto rhs = cast(operands[1]); if (lhs.getType() != rhs.getType()) return {}; @@ -50,12 +50,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, return AttrElementT::get(resultType, *calRes); } - if (operands[0].isa() && - operands[1].isa()) { + if (isa(operands[0]) && + isa(operands[1])) { // Both operands are splats so we can avoid expanding the values out and // just fold based on the splat value. - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); + auto lhs = cast(operands[0]); + auto rhs = cast(operands[1]); if (lhs.getType() != rhs.getType()) return {}; @@ -67,11 +67,11 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, return DenseElementsAttr::get(cast(resultType), *elementResult); } - if (operands[0].isa() && operands[1].isa()) { + if (isa(operands[0]) && isa(operands[1])) { // Operands are ElementsAttr-derived; perform an element-wise fold by // expanding the values. - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); + auto lhs = cast(operands[0]); + auto rhs = cast(operands[1]); if (lhs.getType() != rhs.getType()) return {}; @@ -103,7 +103,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); auto getResultType = [](Attribute attr) -> Type { - if (auto typed = attr.dyn_cast_or_null()) + if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); return {}; }; @@ -158,27 +158,27 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, if (!operands[0]) return {}; - if (operands[0].isa()) { - auto op = operands[0].cast(); + if (isa(operands[0])) { + auto op = cast(operands[0]); auto res = calculate(op.getValue()); if (!res) return {}; return AttrElementT::get(op.getType(), *res); } - if (operands[0].isa()) { + if (isa(operands[0])) { // Both operands are splats so we can avoid expanding the values out and // just fold based on the splat value. - auto op = operands[0].cast(); + auto op = cast(operands[0]); auto elementResult = calculate(op.getSplatValue()); if (!elementResult) return {}; return DenseElementsAttr::get(op.getType(), *elementResult); - } else if (operands[0].isa()) { + } else if (isa(operands[0])) { // Operands are ElementsAttr-derived; perform an element-wise fold by // expanding the values. - auto op = operands[0].cast(); + auto op = cast(operands[0]); auto opIt = op.value_begin(); SmallVector elementResults; @@ -216,18 +216,18 @@ Attribute constFoldCastOp(ArrayRef operands, Type resType, if (!operands[0]) return {}; - if (operands[0].isa()) { - auto op = operands[0].cast(); + if (isa(operands[0])) { + auto op = cast(operands[0]); bool castStatus = true; auto res = calculate(op.getValue(), castStatus); if (!castStatus) return {}; return TargetAttrElementT::get(resType, res); } - if (operands[0].isa()) { + if (isa(operands[0])) { // The operand is a splat so we can avoid expanding the values out and // just fold based on the splat value. - auto op = operands[0].cast(); + auto op = cast(operands[0]); bool castStatus = true; auto elementResult = calculate(op.getSplatValue(), castStatus); @@ -235,10 +235,10 @@ Attribute constFoldCastOp(ArrayRef operands, Type resType, return {}; return DenseElementsAttr::get(cast(resType), elementResult); } - if (operands[0].isa()) { + if (isa(operands[0])) { // Operand is ElementsAttr-derived; perform an element-wise fold by // expanding the value. - auto op = operands[0].cast(); + auto op = cast(operands[0]); bool castStatus = true; auto opIt = op.value_begin(); SmallVector elementResults; diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h index 29bab1d..4af4c65 100644 --- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h @@ -73,7 +73,7 @@ private: [callback = std::forward(callback)]( OpBuilder &builder, Location loc, Type type, Value value, SmallVectorImpl &newValues) -> std::optional { - if (T derivedType = type.dyn_cast()) + if (T derivedType = dyn_cast(type)) return callback(builder, loc, derivedType, value, newValues); return std::nullopt; }; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 728f952..3725cdd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -222,7 +222,7 @@ SmallVector convertArrayToIndices(ArrayRef attrs) { SmallVector indices; indices.reserve(attrs.size()); for (Attribute attr : attrs) - indices.push_back(attr.cast().getInt()); + indices.push_back(cast(attr).getInt()); return indices; } diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/UniformSupport.h index 2a26aa8..9ea6d17 100644 --- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h +++ b/mlir/include/mlir/Dialect/Quant/UniformSupport.h @@ -67,7 +67,7 @@ public: static_cast(uniformType.getStorageTypeMin()), static_cast(uniformType.getStorageTypeMax()), uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) { - assert(uniformType.getExpressedType().isa()); + assert(isa(uniformType.getExpressedType())); assert(uniformType.getStorageType().isSignlessInteger()); } @@ -184,7 +184,7 @@ public: storageBitWidth(uniformType.getStorageTypeIntegralWidth()), isSigned(uniformType.isSigned()), quantizationDim(uniformType.getQuantizedDimension()) { - assert(uniformType.getExpressedType().isa()); + assert(isa(uniformType.getExpressedType())); assert(uniformType.getStorageType().isSignlessInteger()); assert(scales.size() == zeroPoints.size()); } diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index 481c2e6..72c1da8 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -95,7 +95,7 @@ template inline RankedTensorType getRankedTensorType(T &&t) { assert(static_cast(std::forward(t)) && "getRankedTensorType got null argument"); - return std::forward(t).getType().template cast(); + return cast(std::forward(t).getType()); } /// Convenience method to abbreviate casting `getType()`. @@ -103,7 +103,7 @@ template inline MemRefType getMemRefType(T &&t) { assert(static_cast(std::forward(t)) && "getMemRefType got null argument"); - return std::forward(t).getType().template cast(); + return cast(std::forward(t).getType()); } /// Convenience method to get a sparse encoding attribute from a type. diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index b36e40b..f425d37 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -51,7 +51,7 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, SmallVector dynTypes; SmallVector dynamicDims; for (const Value ¶m : params) { - auto paramTy = param.getType().cast(); + auto paramTy = cast(param.getType()); if (!paramTy.hasStaticShape()) dynTypes.push_back(paramTy); } diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h index cc846f2..bdacb98 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h @@ -44,7 +44,7 @@ struct ValueKnowledge { // Get the static knowledge intrinsic to `type`. static ValueKnowledge getKnowledgeFromType(Type type) { ValueKnowledge result = getPessimisticValueState(); - if (auto shapedType = type.dyn_cast()) { + if (auto shapedType = dyn_cast(type)) { if (shapedType.hasRank()) { result.hasRank = true; result.sizes.reserve(shapedType.getRank()); diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h index d362524..8f8adff 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h @@ -42,7 +42,7 @@ public: "SingleOpMatchOpTrait is only available on operations with " "MatchOpInterface"); Value operandHandle = cast(op).getOperandHandle(); - if (!operandHandle.getType().isa()) { + if (!isa(operandHandle.getType())) { return op->emitError() << "SingleOpMatchOpTrait requires the op handle " "to be of TransformHandleTypeInterface"; } @@ -82,7 +82,7 @@ public: "MatchOpInterface"); Value operandHandle = cast(op).getOperandHandle(); - if (!operandHandle.getType().isa()) { + if (!isa(operandHandle.getType())) { return op->emitError() << "SingleValueMatchOpTrait requires an operand " "of TransformValueHandleTypeInterface"; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index e2e2354..39a86d3 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -1144,9 +1144,9 @@ mlir::transform::TransformEachOpTrait::apply( SmallVector emptyPayload; SmallVector emptyParams; for (OpResult r : this->getOperation()->getResults()) { - if (r.getType().isa()) + if (isa(r.getType())) transformResults.setParams(r, emptyParams); - else if (r.getType().isa()) + else if (isa(r.getType())) transformResults.setValues(r, ValueRange()); else transformResults.set(r, emptyPayload); diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index b777611..61c929d 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -92,9 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) return reshapeSrcOp.getSrc(); // Reshape of a constant can be replaced with a new constant. - if (auto elements = operands.front().dyn_cast_or_null()) { - return elements.reshape( - reshapeOp.getResult().getType().template cast()); + if (auto elements = dyn_cast_or_null(operands.front())) { + return elements.reshape(cast(reshapeOp.getResult().getType())); } return nullptr; } diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index ed25021..2edb239 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -169,12 +169,12 @@ Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, /// Returns true if `attr` has "parallel" iterator type semantics. inline bool isParallelIterator(Attribute attr) { - return attr.cast().getValue() == IteratorType::parallel; + return cast(attr).getValue() == IteratorType::parallel; } /// Returns true if `attr` has "reduction" iterator type semantics. inline bool isReductionIterator(Attribute attr) { - return attr.cast().getValue() == IteratorType::reduction; + return cast(attr).getValue() == IteratorType::reduction; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h index acfe404..4ead5ec 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -39,11 +39,11 @@ reifyResultShapes(OpBuilder &b, Operation *op, class ShapeAdaptor { public: ShapeAdaptor(Type t) { - if (auto st = t.dyn_cast()) + if (auto st = dyn_cast(t)) val = st; } ShapeAdaptor(Attribute t) { - if (auto da = t.dyn_cast()) + if (auto da = dyn_cast(t)) val = da; } ShapeAdaptor(ShapedTypeComponents *components) : val(components) {} diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index ac71b73..29e1cf5 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -244,7 +244,7 @@ struct DstValueBoundsOpInterfaceExternalModel auto dstOp = cast(op); assert(value.getDefiningOp() == dstOp); - Value tiedOperand = dstOp.getTiedOpOperand(value.cast())->get(); + Value tiedOperand = dstOp.getTiedOpOperand(cast(value))->get(); cstr.bound(value)[dim] == cstr.getExpr(tiedOperand, dim); } }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index c31ac29..020c8ce9 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -358,7 +358,7 @@ private: return [callback = std::forward(callback)]( Type type, SmallVectorImpl &results, ArrayRef callStack) -> std::optional { - T derivedType = type.dyn_cast(); + T derivedType = dyn_cast(type); if (!derivedType) return std::nullopt; return callback(derivedType, results, callStack); @@ -380,7 +380,7 @@ private: return [callback = std::forward(callback)]( OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> std::optional { - if (T derivedType = resultType.dyn_cast()) + if (T derivedType = dyn_cast(resultType)) return callback(builder, derivedType, inputs, loc); return std::nullopt; }; @@ -395,8 +395,8 @@ private: wrapTypeAttributeConversion(FnT &&callback) { return [callback = std::forward(callback)]( Type type, Attribute attr) -> AttributeConversionResult { - if (T derivedType = type.dyn_cast()) { - if (A derivedAttr = attr.dyn_cast_or_null()) + if (T derivedType = dyn_cast(type)) { + if (A derivedAttr = dyn_cast_or_null(attr)) return callback(derivedType, derivedAttr); } return AttributeConversionResult::na(); diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 73ddd81..f205fab 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -59,11 +59,11 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch, } unsigned firstInputIndex, lastInputIndex; if (region) { - firstInputIndex = inputs[0].cast().getArgNumber(); - lastInputIndex = inputs.back().cast().getArgNumber(); + firstInputIndex = cast(inputs[0]).getArgNumber(); + lastInputIndex = cast(inputs.back()).getArgNumber(); } else { - firstInputIndex = inputs[0].cast().getResultNumber(); - lastInputIndex = inputs.back().cast().getResultNumber(); + firstInputIndex = cast(inputs[0]).getResultNumber(); + lastInputIndex = cast(inputs.back()).getResultNumber(); } if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { output.push_back(inputValue); @@ -186,9 +186,9 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, } --maxDepth; - if (BlockArgument arg = value.dyn_cast()) + if (BlockArgument arg = dyn_cast(value)) return collectUnderlyingAddressValues(arg, maxDepth, visited, output); - collectUnderlyingAddressValues(value.cast(), maxDepth, visited, + collectUnderlyingAddressValues(cast(value), maxDepth, visited, output); } @@ -216,10 +216,10 @@ getAllocEffectFor(Value value, Operation *&allocScopeOp) { // Try to get a memory effect interface for the parent operation. Operation *op; - if (BlockArgument arg = value.dyn_cast()) + if (BlockArgument arg = dyn_cast(value)) op = arg.getOwner()->getParentOp(); else - op = value.cast().getOwner(); + op = cast(value).getOwner(); MemoryEffectOpInterface interface = dyn_cast(op); if (!interface) return failure(); @@ -305,7 +305,7 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { if (rhsParentOp->isProperAncestor(lhsAllocScope)) return AliasResult::NoAlias; if (rhsParentOp == lhsAllocScope) { - BlockArgument rhsArg = rhs.dyn_cast(); + BlockArgument rhsArg = dyn_cast(rhs); if (rhsArg && rhs.getParentBlock()->isEntryBlock()) return AliasResult::NoAlias; } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index aa079cf..c866fc6 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -94,7 +94,7 @@ void IntegerRangeAnalysis::visitOperation( })); auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { - auto result = v.dyn_cast(); + auto result = dyn_cast(v); if (!result) return; assert(llvm::is_contained(op->getResults(), result)); @@ -139,7 +139,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( })); auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) { - auto arg = v.dyn_cast(); + auto arg = dyn_cast(v); if (!arg) return; if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) @@ -179,7 +179,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( if (loopBound.has_value()) { if (loopBound->is()) { if (auto bound = - loopBound->get().dyn_cast_or_null()) + dyn_cast_or_null(loopBound->get())) return bound.getValue(); } else if (auto value = loopBound->dyn_cast()) { const IntegerValueRangeLattice *lattice = diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index d3bc806..629c482 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -240,7 +240,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( if (inputs.size() != lattices.size()) { if (point.dyn_cast()) { if (!inputs.empty()) - firstIndex = inputs.front().cast().getResultNumber(); + firstIndex = cast(inputs.front()).getResultNumber(); visitNonControlFlowArgumentsImpl( branch, RegionSuccessor( @@ -248,7 +248,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors( lattices, firstIndex); } else { if (!inputs.empty()) - firstIndex = inputs.front().cast().getArgNumber(); + firstIndex = cast(inputs.front()).getArgNumber(); Region *region = point.get()->getParent(); visitNonControlFlowArgumentsImpl( branch, diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp index 7c04bb4..a8e0dae 100644 --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -184,7 +184,7 @@ Liveness::OperationListT Liveness::resolveLiveness(Value value) const { if (Operation *defOp = value.getDefiningOp()) currentBlock = defOp->getBlock(); else - currentBlock = value.cast().getOwner(); + currentBlock = cast(value).getOwner(); toProcess.push_back(currentBlock); visited.insert(currentBlock); @@ -280,7 +280,7 @@ void Liveness::print(raw_ostream &os) const { if (value.getDefiningOp()) os << "val_" << valueIds[value]; else { - auto blockArg = value.cast(); + auto blockArg = cast(value); os << "arg" << blockArg.getArgNumber() << "@" << blockIds[blockArg.getOwner()]; } @@ -404,7 +404,7 @@ LivenessBlockInfo::currentlyLiveValues(Operation *op) const { Operation *endOfLiveRange = nullptr; // If it's a live in or a block argument, then the start is the beginning // of the block. - if (isLiveIn(value) || value.isa()) + if (isLiveIn(value) || isa(value)) startOfLiveRange = &block->front(); else startOfLiveRange = block->findAncestorOpInBlock(*startOfLiveRange); diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index bcb23af..7af6a65 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -95,7 +95,7 @@ static void getBackwardSliceImpl(Operation *op, if (auto *definingOp = operand.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) getBackwardSliceImpl(definingOp, backwardSlice, filter); - } else if (auto blockArg = operand.dyn_cast()) { + } else if (auto blockArg = dyn_cast(operand)) { Block *block = blockArg.getOwner(); Operation *parentOp = block->getParentOp(); // TODO: determine whether we want to recurse backward into the other @@ -132,7 +132,7 @@ void mlir::getBackwardSlice(Value root, SetVector *backwardSlice, getBackwardSlice(definingOp, backwardSlice, filter, inclusive); return; } - Operation *bbAargOwner = root.cast().getOwner()->getParentOp(); + Operation *bbAargOwner = cast(root).getOwner()->getParentOp(); getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive); } diff --git a/mlir/lib/AsmParser/AsmParserState.cpp b/mlir/lib/AsmParser/AsmParserState.cpp index 29e1f40..e61aba5 100644 --- a/mlir/lib/AsmParser/AsmParserState.cpp +++ b/mlir/lib/AsmParser/AsmParserState.cpp @@ -73,7 +73,7 @@ void AsmParserState::Impl::resolveSymbolUses() { for (auto &it : *opAndUseMapIt.second) { symbolOps.clear(); if (failed(symbolTable.lookupSymbolIn( - opAndUseMapIt.first, it.first.cast(), symbolOps))) + opAndUseMapIt.first, cast(it.first), symbolOps))) continue; for (ArrayRef useRange : it.second) { @@ -301,7 +301,7 @@ void AsmParserState::addUses(Value value, ArrayRef locations) { } // Otherwise, this is a block argument. - BlockArgument arg = value.cast(); + BlockArgument arg = cast(value); auto existingIt = impl->blocksToIdx.find(arg.getOwner()); assert(existingIt != impl->blocksToIdx.end() && "expected valid block definition for block argument"); diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 1a491bac..9501794 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -348,7 +348,7 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) { else if (!(type = parseType())) return nullptr; } - if (!type.isa()) + if (!isa(type)) return (emitError("floating point value not valid for specified type"), nullptr); return FloatAttr::get(type, isNegative ? -*val : *val); @@ -416,7 +416,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { return nullptr; } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = dyn_cast(type)) { std::optional result; if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative, floatType.getFloatSemantics(), @@ -425,7 +425,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { return FloatAttr::get(floatType, *result); } - if (!type.isa()) + if (!isa(type)) return emitError(loc, "integer literal not valid for specified type"), nullptr; @@ -543,7 +543,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { // Check to see if we parse the literal from a hex string. if (hexStorage && - (eltType.isIntOrIndexOrFloat() || eltType.isa())) + (eltType.isIntOrIndexOrFloat() || isa(eltType))) return getHexAttr(loc, type); // Check that the parsed storage size has the same number of elements to the @@ -563,7 +563,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { // Handle complex types in the specific element type cases below. bool isComplex = false; - if (ComplexType complexTy = eltType.dyn_cast()) { + if (ComplexType complexTy = dyn_cast(eltType)) { eltType = complexTy.getElementType(); isComplex = true; } @@ -583,7 +583,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { return DenseElementsAttr::get(type, intValues); } // Handle floating point types. - if (FloatType floatTy = eltType.dyn_cast()) { + if (FloatType floatTy = dyn_cast(eltType)) { std::vector floatValues; if (failed(getFloatAttrElements(loc, floatTy, floatValues))) return nullptr; @@ -711,7 +711,7 @@ DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, /// Build a Dense attribute with hex data for the given type. DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) { Type elementType = type.getElementType(); - if (!elementType.isIntOrIndexOrFloat() && !elementType.isa()) { + if (!elementType.isIntOrIndexOrFloat() && !isa(elementType)) { p.emitError(loc) << "expected floating-point, integer, or complex element type, got " << elementType; @@ -904,7 +904,7 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { Token token = p.getToken(); std::optional result; - auto floatType = type.cast(); + auto floatType = cast(type); if (p.consumeIf(Token::integer)) { // Parse an integer literal as a float. if (p.parseFloatFromIntegerLiteral(result, token, isNegative, @@ -1025,7 +1025,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { return nullptr; } - ShapedType shapedType = attrType.dyn_cast(); + ShapedType shapedType = dyn_cast(attrType); if (!shapedType) { emitError(typeLoc, "`dense_resource` expected a shaped type"); return nullptr; @@ -1048,7 +1048,7 @@ ShapedType Parser::parseElementsLiteralType(Type type) { return nullptr; } - auto sType = type.dyn_cast(); + auto sType = dyn_cast(type); if (!sType) { emitError("elements literal must be a shaped type"); return nullptr; diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index c98b368..2798145 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -260,7 +260,7 @@ Attribute Parser::parseExtendedAttr(Type type) { }); // Ensure that the attribute has the same type as requested. - auto typedAttr = attr.dyn_cast_or_null(); + auto typedAttr = dyn_cast_or_null(attr); if (type && typedAttr && typedAttr.getType() != type) { emitError("attribute type different than expected: expected ") << type << ", but got " << typedAttr.getType(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index ade2465..69116ef 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -1333,7 +1333,7 @@ ParseResult OperationParser::parseGenericOperationAfterOpName( auto type = parseType(); if (!type) return failure(); - auto fnType = type.dyn_cast(); + auto fnType = dyn_cast(type); if (!fnType) return mlir::emitError(typeLoc, "expected function type"); @@ -2352,7 +2352,7 @@ ParseResult OperationParser::codeCompleteSSAUse() { if (!forwardRefPlaceholders.count(result)) detailOS << result.getOwner()->getName() << ": "; } else { - detailOS << "arg #" << frontValue.cast().getArgNumber() + detailOS << "arg #" << cast(frontValue).getArgNumber() << ": "; } diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index c5e3297..749b82c 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -241,7 +241,7 @@ public: return std::nullopt; if (Attribute parsedAttr = parseAttribute(type)) { - attr = parsedAttr.cast(); + attr = cast(parsedAttr); return success(); } return failure(); diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 737767c..2110492 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -129,7 +129,7 @@ Type Parser::parseComplexType() { if (!elementType || parseToken(Token::greater, "expected '>' in complex type")) return nullptr; - if (!elementType.isa() && !elementType.isa()) + if (!isa(elementType) && !isa(elementType)) return emitError(elementTypeLoc, "invalid element type for complex"), nullptr; @@ -207,8 +207,8 @@ Type Parser::parseMemRefType() { if (!attr) return failure(); - if (attr.isa()) { - layout = attr.cast(); + if (isa(attr)) { + layout = cast(attr); } else if (memorySpace) { return emitError("multiple memory spaces specified in memref type"); } else { @@ -383,7 +383,7 @@ Type Parser::parseTensorType() { Attribute encoding; if (consumeIf(Token::comma)) { encoding = parseAttribute(); - if (auto v = encoding.dyn_cast_or_null()) { + if (auto v = dyn_cast_or_null(encoding)) { if (failed(v.verifyEncoding(dimensions, elementType, [&] { return emitError(); }))) return nullptr; diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index e39c568..9344ec9 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -785,7 +785,7 @@ public: Attribute baseResult; if (failed(parseAttribute(reader, baseResult))) return failure(); - if ((result = baseResult.dyn_cast())) + if ((result = dyn_cast(baseResult))) return success(); return reader.emitError("expected attribute of type: ", llvm::getTypeName(), ", but got: ", baseResult); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index 7f56e9a..f3a1531 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -180,7 +180,7 @@ void IRNumberingState::number(Attribute attr) { // have a registered dialect when it got created. We don't want to encode this // as the builtin OpaqueAttr, we want to encode it as if the dialect was // actually loaded. - if (OpaqueAttr opaqueAttr = attr.dyn_cast()) { + if (OpaqueAttr opaqueAttr = dyn_cast(attr)) { numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); return; } @@ -310,7 +310,7 @@ void IRNumberingState::number(Type type) { // registered dialect when it got created. We don't want to encode this as the // builtin OpaqueType, we want to encode it as if the dialect was actually // loaded. - if (OpaqueType opaqueType = type.dyn_cast()) { + if (OpaqueType opaqueType = dyn_cast(type)) { numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); return; } diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp index 497b2cb..bd8b13c 100644 --- a/mlir/lib/CAPI/Dialect/PDL.cpp +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -21,7 +21,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PDL, pdl, pdl::PDLDialect) //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } //===---------------------------------------------------------------------===// @@ -29,7 +29,7 @@ bool mlirTypeIsAPDLType(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLAttributeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { @@ -41,7 +41,7 @@ MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLOperationType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLOperationTypeGet(MlirContext ctx) { @@ -53,7 +53,7 @@ MlirType mlirPDLOperationTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLRangeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLRangeTypeGet(MlirType elementType) { @@ -61,7 +61,7 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) { } MlirType mlirPDLRangeTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(cast(unwrap(type)).getElementType()); } //===---------------------------------------------------------------------===// @@ -69,7 +69,7 @@ MlirType mlirPDLRangeTypeGetElementType(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLTypeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLTypeTypeGet(MlirContext ctx) { @@ -81,7 +81,7 @@ MlirType mlirPDLTypeTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLValueType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLValueTypeGet(MlirContext ctx) { diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 065ab3e..0a7181d 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -20,7 +20,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) //===---------------------------------------------------------------------===// bool mlirTypeIsAQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } unsigned mlirQuantizedTypeGetSignedFlag() { @@ -40,39 +40,37 @@ int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, } MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { - return wrap(unwrap(type).cast().getExpressedType()); + return wrap(cast(unwrap(type)).getExpressedType()); } unsigned mlirQuantizedTypeGetFlags(MlirType type) { - return unwrap(type).cast().getFlags(); + return cast(unwrap(type)).getFlags(); } bool mlirQuantizedTypeIsSigned(MlirType type) { - return unwrap(type).cast().isSigned(); + return cast(unwrap(type)).isSigned(); } MlirType mlirQuantizedTypeGetStorageType(MlirType type) { - return wrap(unwrap(type).cast().getStorageType()); + return wrap(cast(unwrap(type)).getStorageType()); } int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { - return unwrap(type).cast().getStorageTypeMin(); + return cast(unwrap(type)).getStorageTypeMin(); } int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { - return unwrap(type).cast().getStorageTypeMax(); + return cast(unwrap(type)).getStorageTypeMax(); } unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { - return unwrap(type) - .cast() - .getStorageTypeIntegralWidth(); + return cast(unwrap(type)).getStorageTypeIntegralWidth(); } bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, MlirType candidate) { - return unwrap(type).cast().isCompatibleExpressedType( - unwrap(candidate)); + return cast(unwrap(type)) + .isCompatibleExpressedType(unwrap(candidate)); } MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { @@ -81,19 +79,19 @@ MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, MlirType candidate) { - return wrap(unwrap(type).cast().castFromStorageType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castFromStorageType(unwrap(candidate))); } MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { return wrap(quant::QuantizedType::castToStorageType( - unwrap(type).cast())); + cast(unwrap(type)))); } MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, MlirType candidate) { - return wrap(unwrap(type).cast().castFromExpressedType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castFromExpressedType(unwrap(candidate))); } MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { @@ -102,9 +100,8 @@ MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate) { - return wrap( - unwrap(type).cast().castExpressedToStorageType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castExpressedToStorageType(unwrap(candidate))); } //===---------------------------------------------------------------------===// @@ -112,7 +109,7 @@ MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, //===---------------------------------------------------------------------===// bool mlirTypeIsAAnyQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, @@ -128,7 +125,7 @@ MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, @@ -141,15 +138,15 @@ MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, } double mlirUniformQuantizedTypeGetScale(MlirType type) { - return unwrap(type).cast().getScale(); + return cast(unwrap(type)).getScale(); } int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { - return unwrap(type).cast().getZeroPoint(); + return cast(unwrap(type)).getZeroPoint(); } bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { - return unwrap(type).cast().isFixedPoint(); + return cast(unwrap(type)).isFixedPoint(); } //===---------------------------------------------------------------------===// @@ -157,7 +154,7 @@ bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirUniformQuantizedPerAxisTypeGet( @@ -172,33 +169,29 @@ MlirType mlirUniformQuantizedPerAxisTypeGet( } intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getScales() .size(); } double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getScales()[pos]; } int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, intptr_t pos) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getZeroPoints()[pos]; } int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getQuantizedDimension(); } bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { - return unwrap(type).cast().isFixedPoint(); + return cast(unwrap(type)).isFixedPoint(); } //===---------------------------------------------------------------------===// @@ -206,7 +199,7 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsACalibratedQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, @@ -216,9 +209,9 @@ MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, } double mlirCalibratedQuantizedTypeGetMin(MlirType type) { - return unwrap(type).cast().getMin(); + return cast(unwrap(type)).getMin(); } double mlirCalibratedQuantizedTypeGetMax(MlirType type) { - return unwrap(type).cast().getMax(); + return cast(unwrap(type)).getMax(); } diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 1aa6d32..795ce51 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -42,7 +42,7 @@ static_assert( "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return isa(unwrap(attr)); } MlirAttribute mlirSparseTensorEncodingAttrGet( @@ -60,29 +60,28 @@ MlirAttribute mlirSparseTensorEncodingAttrGet( } MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getDimOrdering()); + return wrap(cast(unwrap(attr)).getDimOrdering()); } MlirAffineMap mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { - return wrap( - unwrap(attr).cast().getHigherOrdering()); + return wrap(cast(unwrap(attr)).getHigherOrdering()); } intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { - return unwrap(attr).cast().getLvlRank(); + return cast(unwrap(attr)).getLvlRank(); } MlirSparseTensorDimLevelType mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { return static_cast( - unwrap(attr).cast().getLvlType(lvl)); + cast(unwrap(attr)).getLvlType(lvl)); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { - return unwrap(attr).cast().getPosWidth(); + return cast(unwrap(attr)).getPosWidth(); } int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { - return unwrap(attr).cast().getCrdWidth(); + return cast(unwrap(attr)).getCrdWidth(); } diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp index 606b301..90594b6 100644 --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -22,7 +22,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform, //===---------------------------------------------------------------------===// bool mlirTypeIsATransformAnyOpType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { @@ -34,7 +34,7 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsATransformOperationType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirTransformOperationTypeGet(MlirContext ctx, @@ -44,5 +44,5 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx, } MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { - return wrap(unwrap(type).cast().getOperationName()); + return wrap(cast(unwrap(type)).getOperationName()); } diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index eeac499..1769b1fa 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -48,7 +48,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { Location loc = gpuOp.getLoc(); Value memref = adaptor.getMemref(); Value unconvertedMemref = gpuOp.getMemref(); - MemRefType memrefType = unconvertedMemref.getType().cast(); + MemRefType memrefType = cast(unconvertedMemref.getType()); if (chipset.majorVersion < 9) return gpuOp.emitOpError("Raw buffer ops require GCN or higher"); @@ -85,13 +85,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { // so bitcast any floats to integers. Type llvmBufferValType = llvmWantedDataType; if (atomicCmpData) { - if (wantedDataType.isa()) + if (isa(wantedDataType)) return gpuOp.emitOpError("vector compare-and-swap does not exist"); - if (auto floatType = wantedDataType.dyn_cast()) + if (auto floatType = dyn_cast(wantedDataType)) llvmBufferValType = this->getTypeConverter()->convertType( rewriter.getIntegerType(floatType.getWidth())); } - if (auto dataVector = wantedDataType.dyn_cast()) { + if (auto dataVector = dyn_cast(wantedDataType)) { uint32_t elemBits = dataVector.getElementTypeBitWidth(); uint32_t totalBits = elemBits * dataVector.getNumElements(); if (totalBits > maxVectorOpWidth) @@ -312,7 +312,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input) { Type inputType = input.getType(); - if (auto vectorType = inputType.dyn_cast()) { + if (auto vectorType = dyn_cast(inputType)) { if (!vectorType.getElementType().isInteger(8)) return input; int64_t numBytes = vectorType.getNumElements(); @@ -342,10 +342,10 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), b = mfma.getBlocks(); Type sourceElem = mfma.getSourceA().getType(); - if (auto sourceType = sourceElem.dyn_cast()) + if (auto sourceType = dyn_cast(sourceElem)) sourceElem = sourceType.getElementType(); Type destElem = mfma.getDestC().getType(); - if (auto destType = destElem.dyn_cast()) + if (auto destType = dyn_cast(destElem)) destElem = destType.getElementType(); if (sourceElem.isF32() && destElem.isF32()) { @@ -406,7 +406,7 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); } - if (sourceElem.isa() && destElem.isInteger(32)) { + if (isa(sourceElem) && destElem.isInteger(32)) { if (m == 32 && n == 32 && k == 4 && b == 2) return ROCDL::mfma_i32_32x32x4i8::getOperationName(); if (m == 16 && n == 16 && k == 4 && b == 4) @@ -435,7 +435,7 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, // Known to be correct because there are no scalar f8 instructions and // because a length mismatch will have been caught by the verifier. Type sourceBElem = - mfma.getSourceB().getType().cast().getElementType(); + cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); @@ -453,7 +453,7 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset.minorVersion >= 0x40) { Type sourceBElem = - mfma.getSourceB().getType().cast().getElementType(); + cast(mfma.getSourceB().getType()).getElementType(); if (m == 16 && n == 16 && k == 32 && b == 1) { if (sourceBElem.isFloat8E5M2FNUZ()) return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index debb7e8..783745a 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -226,7 +226,7 @@ public: Type resultType = std::get<1>(pair); std::optional reductionOp = arith::symbolizeAtomicRMWKind( - static_cast(reduction.cast().getInt())); + static_cast(cast(reduction).getInt())); assert(reductionOp && "Reduction operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; identityVals.push_back( @@ -246,7 +246,7 @@ public: // For each of the reduction operations get the respective mlir::Value. std::optional reductionOp = arith::symbolizeAtomicRMWKind( - reductions[i].cast().getInt()); + cast(reductions[i]).getInt()); assert(reductionOp && "Reduction Operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; rewriter.setInsertionPoint(&parOp.getBody()->back()); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 4651c29..3b4b645 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -210,7 +210,7 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( // Handle the scalar and 1D vector cases. Type operandType = adaptor.getIn().getType(); - if (!operandType.isa()) { + if (!isa(operandType)) { Type targetType = this->typeConverter->convertType(resultType); if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, @@ -220,7 +220,7 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( @@ -255,7 +255,7 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite( Location loc = op.getLoc(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!isa(operandType)) { Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); @@ -269,7 +269,7 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite( return success(); } - if (!sumResultType.isa()) + if (!isa(sumResultType)) return rewriter.notifyMatchFailure(loc, "expected vector result types"); return rewriter.notifyMatchFailure(loc, @@ -295,16 +295,16 @@ LogicalResult MulIExtendedOpLowering::matchAndRewrite( // matching extended multiplication intrinsic, perform regular multiplication // on operands zero-extended to i(2*N) bits, and truncate the results back to // iN types. - if (!resultType.isa()) { + if (!isa(resultType)) { // Shift amount necessary to extract the high bits from widened result. TypedAttr shiftValAttr; - if (auto intTy = resultType.dyn_cast()) { + if (auto intTy = dyn_cast(resultType)) { unsigned resultBitwidth = intTy.getWidth(); auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); } else { - auto vecTy = resultType.cast(); + auto vecTy = cast(resultType); unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); auto attrTy = VectorType::get( vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); @@ -330,7 +330,7 @@ LogicalResult MulIExtendedOpLowering::matchAndRewrite( return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return rewriter.notifyMatchFailure(op, @@ -355,7 +355,7 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), @@ -363,7 +363,7 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( @@ -389,7 +389,7 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. - if (!operandType.isa()) { + if (!isa(operandType)) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), @@ -397,7 +397,7 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, return success(); } - if (!resultType.isa()) + if (!isa(resultType)) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index b6ed244..5d2c1f3 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -261,9 +261,9 @@ public: /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { - if (auto boolAttr = srcAttr.dyn_cast()) + if (auto boolAttr = dyn_cast(srcAttr)) return boolAttr; - if (auto intAttr = srcAttr.dyn_cast()) + if (auto intAttr = dyn_cast(srcAttr)) return builder.getBoolAttr(intAttr.getValue().getBoolValue()); return {}; } @@ -324,7 +324,7 @@ static bool isBoolScalarOrVector(Type type) { if (type.isInteger(1)) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getElementType().isInteger(1); return false; @@ -337,7 +337,7 @@ static bool hasSameBitwidth(Type a, Type b) { unsigned bw = 0; if (type.isIntOrFloat()) bw = type.getIntOrFloatBitWidth(); - else if (auto vecType = type.dyn_cast()) + else if (auto vecType = dyn_cast(type)) bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); return bw; }; @@ -369,18 +369,18 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { LogicalResult ConstantCompositeOpPattern::matchAndRewrite( arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto srcType = constOp.getType().dyn_cast(); + auto srcType = dyn_cast(constOp.getType()); if (!srcType || srcType.getNumElements() == 1) return failure(); // arith.constant should only have vector or tenor types. - assert((srcType.isa())); + assert((isa(srcType))); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); - auto dstElementsAttr = constOp.getValue().dyn_cast(); + auto dstElementsAttr = dyn_cast(constOp.getValue()); if (!dstElementsAttr) return failure(); @@ -388,7 +388,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( // If the composite type has more than one dimensions, perform linearization. if (srcType.getRank() > 1) { - if (srcType.isa()) { + if (isa(srcType)) { dstAttrType = RankedTensorType::get(srcType.getNumElements(), srcType.getElementType()); dstElementsAttr = dstElementsAttr.reshape(dstAttrType); @@ -402,19 +402,19 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( Type dstElemType; // Tensor types are converted to SPIR-V array types; vector types are // converted to SPIR-V vector/array types. - if (auto arrayType = dstType.dyn_cast()) + if (auto arrayType = dyn_cast(dstType)) dstElemType = arrayType.getElementType(); else - dstElemType = dstType.cast().getElementType(); + dstElemType = cast(dstType).getElementType(); // If the source and destination element types are different, perform // attribute conversion. if (srcElemType != dstElemType) { SmallVector elements; - if (srcElemType.isa()) { + if (isa(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues()) { FloatAttr dstAttr = - convertFloatAttr(srcAttr, dstElemType.cast(), rewriter); + convertFloatAttr(srcAttr, cast(dstElemType), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -424,7 +424,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( } else { for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { IntegerAttr dstAttr = convertIntegerAttr( - srcAttr, dstElemType.cast(), rewriter); + srcAttr, cast(dstElemType), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -435,7 +435,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite( // attributes; element attributes only works with builtin types. So we need // to prepare another converted builtin types for the destination elements // attribute. - if (dstAttrType.isa()) + if (isa(dstAttrType)) dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); else dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); @@ -456,7 +456,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite( arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); - if (auto shapedType = srcType.dyn_cast()) { + if (auto shapedType = dyn_cast(srcType)) { if (shapedType.getNumElements() != 1) return failure(); srcType = shapedType.getElementType(); @@ -465,7 +465,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite( return failure(); Attribute cstAttr = constOp.getValue(); - if (auto elementsAttr = cstAttr.dyn_cast()) + if (auto elementsAttr = dyn_cast(cstAttr)) cstAttr = elementsAttr.getSplatValue(); Type dstType = getTypeConverter()->convertType(srcType); @@ -473,14 +473,14 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite( return failure(); // Floating-point types. - if (srcType.isa()) { - auto srcAttr = cstAttr.cast(); + if (isa(srcType)) { + auto srcAttr = cast(cstAttr); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. if (srcType != dstType) { - dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); + dstAttr = convertFloatAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); } @@ -502,9 +502,9 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite( // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. - auto srcAttr = cstAttr.cast(); + auto srcAttr = cast(cstAttr); IntegerAttr dstAttr = - convertIntegerAttr(srcAttr, dstType.cast(), rewriter); + convertIntegerAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); @@ -678,12 +678,12 @@ ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, return getTypeConversionFailure(rewriter, op); Value allOnes; - if (auto intTy = dstType.dyn_cast()) { + if (auto intTy = dyn_cast(dstType)) { unsigned componentBitwidth = intTy.getWidth(); allOnes = rewriter.create( loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); - } else if (auto vectorTy = dstType.dyn_cast()) { + } else if (auto vectorTy = dyn_cast(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); allOnes = rewriter.create( loc, vectorTy, @@ -810,7 +810,7 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite( // There are no direct corresponding instructions in SPIR-V for such cases. // Extend them to 32-bit and do comparision then. Type type = rewriter.getI32Type(); - if (auto vectorType = dstType.dyn_cast()) + if (auto vectorType = dyn_cast(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = rewriter.create(op.getLoc(), type, adaptor.getLhs()); diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 691cd23..bdbf276 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -33,8 +33,8 @@ public: /// arm.neon.intr.sdot LogicalResult matchAndRewrite(Sdot2dOp op, PatternRewriter &rewriter) const override { - Type elemType = op.getB().getType().cast().getElementType(); - int length = op.getB().getType().cast().getShape()[0] * + Type elemType = cast(op.getB().getType()).getElementType(); + int length = cast(op.getB().getType()).getShape()[0] * Sdot2dOp::kReductionSize; VectorType flattenedVectorType = VectorType::get({length}, elemType); Value b2d = op.getB(); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 38041bd..d1998cf 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -366,12 +366,12 @@ public: static std::optional convertAsyncTypes(Type type, bool useOpaquePointers) { - if (type.isa()) + if (isa(type)) return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); - if (type.isa()) + if (isa(type)) return AsyncAPI::tokenType(type.getContext()); - if (type.isa()) + if (isa(type)) return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers); return std::nullopt; @@ -656,14 +656,14 @@ public: Type resultType = op->getResultTypes()[0]; // Tokens creation maps to a simple function call. - if (resultType.isa()) { + if (isa(resultType)) { rewriter.replaceOpWithNewOp( op, kCreateToken, converter->convertType(resultType)); return success(); } // To create a value we need to compute the storage requirement. - if (auto value = resultType.dyn_cast()) { + if (auto value = dyn_cast(resultType)) { // Returns the size requirements for the async value storage. auto sizeOf = [&](ValueType valueType) -> Value { auto loc = op->getLoc(); @@ -994,7 +994,7 @@ public: matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. - if (!op.getOperand().getType().isa()) + if (!isa(op.getOperand().getType())) return rewriter.notifyMatchFailure(op, "only token type is supported"); // Replace with a runtime API function call. diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index 5f64981..f498d2c 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -41,11 +41,11 @@ struct CloneOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // Check for unranked memref types which are currently not supported. Type type = op.getType(); - if (type.isa()) { + if (isa(type)) { return rewriter.notifyMatchFailure( op, "UnrankedMemRefType is not supported."); } - MemRefType memrefType = type.cast(); + MemRefType memrefType = cast(type); MemRefLayoutAttrInterface layout; auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp index d35165fe..3b83386 100644 --- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp +++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp @@ -26,9 +26,9 @@ namespace { // result type. struct ComplexTypeResolver { std::optional operator()(Type type) const { - auto complexType = type.cast(); + auto complexType = cast(type); auto elementType = complexType.getElementType(); - if (!elementType.isa()) + if (!isa(elementType)) return {}; return elementType.getIntOrFloatBitWidth() == 64; @@ -39,8 +39,8 @@ struct ComplexTypeResolver { // type. struct FloatTypeResolver { std::optional operator()(Type type) const { - auto elementType = type.cast(); - if (!elementType.isa()) + auto elementType = cast(type); + if (!isa(elementType)) return {}; return elementType.getIntOrFloatBitWidth() == 64; diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 5364976..9c05cad 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -57,7 +57,7 @@ struct Atan2OpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = op.getType().cast(); + auto type = cast(op.getType()); Type elementType = type.getElementType(); Value lhs = adaptor.getLhs(); @@ -102,10 +102,7 @@ struct ComparisonOpConversion : public OpConversionPattern { matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getLhs() - .getType() - .template cast() - .getElementType(); + auto type = cast(adaptor.getLhs().getType()).getElementType(); Value realLhs = rewriter.create(loc, type, adaptor.getLhs()); Value imagLhs = rewriter.create(loc, type, adaptor.getLhs()); @@ -132,8 +129,8 @@ struct BinaryComplexOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getLhs().getType().template cast(); - auto elementType = type.getElementType().template cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value realLhs = b.create(elementType, adaptor.getLhs()); @@ -160,8 +157,8 @@ struct TrigonometricOpConversion : public OpConversionPattern { matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().template cast(); - auto elementType = type.getElementType().template cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); @@ -222,8 +219,8 @@ struct DivOpConversion : public OpConversionPattern { matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getLhs().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); Value lhsReal = rewriter.create(loc, elementType, adaptor.getLhs()); @@ -441,8 +438,8 @@ struct ExpOpConversion : public OpConversionPattern { matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); @@ -466,8 +463,8 @@ struct Expm1OpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value exp = b.create(adaptor.getComplex()); @@ -490,8 +487,8 @@ struct LogOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value abs = b.create(elementType, adaptor.getComplex()); @@ -511,8 +508,8 @@ struct Log1pOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.getComplex()); @@ -550,8 +547,8 @@ struct MulOpConversion : public OpConversionPattern { matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = adaptor.getLhs().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); Value lhsReal = b.create(elementType, adaptor.getLhs()); Value lhsRealAbs = b.create(lhsReal); @@ -727,8 +724,8 @@ struct NegOpConversion : public OpConversionPattern { matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); @@ -773,7 +770,7 @@ struct SqrtOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = op.getType().cast(); + auto type = cast(op.getType()); Type elementType = type.getElementType(); Value arg = adaptor.getComplex(); @@ -837,8 +834,8 @@ struct SignOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value real = b.create(elementType, adaptor.getComplex()); @@ -881,8 +878,8 @@ struct TanhOpConversion : public OpConversionPattern { matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); // The hyperbolic tangent for complex number can be calculated as follows. // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) @@ -913,8 +910,8 @@ struct ConjOpConversion : public OpConversionPattern { matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value real = rewriter.create(loc, elementType, adaptor.getComplex()); Value imag = @@ -933,7 +930,7 @@ struct ConjOpConversion : public OpConversionPattern { static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, ComplexType type, Value a, Value b, Value c, Value d) { - auto elementType = type.getElementType().cast(); + auto elementType = cast(type.getElementType()); // Compute (a*a+b*b)^(0.5c). Value aaPbb = builder.create( @@ -995,8 +992,8 @@ struct PowOpConversion : public OpConversionPattern { matchAndRewrite(complex::PowOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - auto type = adaptor.getLhs().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getLhs().getType()); + auto elementType = cast(type.getElementType()); Value a = builder.create(elementType, adaptor.getLhs()); Value b = builder.create(elementType, adaptor.getLhs()); @@ -1015,8 +1012,8 @@ struct RsqrtOpConversion : public OpConversionPattern { matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - auto type = adaptor.getComplex().getType().cast(); - auto elementType = type.getElementType().cast(); + auto type = cast(adaptor.getComplex().getType()); + auto elementType = cast(type.getElementType()); Value a = builder.create(elementType, adaptor.getComplex()); Value b = builder.create(elementType, adaptor.getComplex()); diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 4fd5b9c..5867d9f 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -144,13 +144,13 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, size_t argOffset = resultStructType ? 1 : 0; for (auto [index, argType] : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(index + argOffset); - if (auto memrefType = argType.dyn_cast()) { + if (auto memrefType = dyn_cast(argType)) { Value loaded = rewriter.create( loc, typeConverter.convertType(memrefType), arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } - if (argType.isa()) { + if (isa(argType)) { Value loaded = rewriter.create( loc, typeConverter.convertType(argType), arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); @@ -218,8 +218,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, if (resultStructType) { // Allocate the struct on the stack and pass the pointer. - Type resultType = - wrapperType.cast().getParamType(0); + Type resultType = cast(wrapperType).getParamType(0); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); @@ -233,8 +232,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, for (Type input : type.getInputs()) { Value arg; int numToDrop = 1; - auto memRefType = input.dyn_cast(); - auto unrankedMemRefType = input.dyn_cast(); + auto memRefType = dyn_cast(input); + auto unrankedMemRefType = dyn_cast(input); if (memRefType || unrankedMemRefType) { numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) @@ -301,9 +300,9 @@ static void modifyFuncOpToUseBarePtrCallingConv( // Unranked memrefs are not supported in the bare pointer calling // convention. We should have bailed out before in the presence of // unranked memrefs. - assert(!argTy.isa() && + assert(!isa(argTy) && "Unranked memref is not supported"); - auto memrefTy = argTy.dyn_cast(); + auto memrefTy = dyn_cast(argTy); if (!memrefTy) continue; @@ -360,18 +359,18 @@ protected: } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( - llvmType.cast().getNumParams()); + cast(llvmType).getNumParams()); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { // Some LLVM IR attribute have a type attached to them. During FuncOp -> // LLVMFuncOp conversion these types may have changed. Account for that // change by converting attributes' types as well. SmallVector convertedAttrs; - auto attrsDict = argAttrDicts[i].cast(); + auto attrsDict = cast(argAttrDicts[i]); convertedAttrs.reserve(attrsDict.size()); for (const NamedAttribute &attr : attrsDict) { const auto convert = [&](const NamedAttribute &attr) { return TypeAttr::get(getTypeConverter()->convertType( - attr.getValue().cast().getValue())); + cast(attr.getValue()).getValue())); }; if (attr.getName().getValue() == LLVM::LLVMDialect::getByValAttrName()) { @@ -418,7 +417,7 @@ protected: LLVM::Linkage linkage = LLVM::Linkage::External; if (funcOp->hasAttr(linkageAttrName)) { auto attr = - funcOp->getAttr(linkageAttrName).dyn_cast(); + dyn_cast(funcOp->getAttr(linkageAttrName)); if (!attr) { funcOp->emitError() << "Contains " << linkageAttrName << " attribute not of type LLVM::LinkageAttr"; @@ -545,7 +544,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { if (useBarePtrCallConv) { for (auto it : callOp->getOperands()) { Type operandType = it.getType(); - if (operandType.isa()) { + if (isa(operandType)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); @@ -669,11 +668,11 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); - if (oldTy.isa() && getTypeConverter()->canConvertToBarePtr( - oldTy.cast())) { + if (isa(oldTy) && getTypeConverter()->canConvertToBarePtr( + cast(oldTy))) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.allocatedPtr(rewriter, loc); - } else if (oldTy.isa()) { + } else if (isa(oldTy)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index a13acc6..664d077 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -26,22 +26,20 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { BlockArgument attribution = en.value(); - auto type = attribution.getType().dyn_cast(); + auto type = dyn_cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); uint64_t numElements = type.getNumElements(); auto elementType = - typeConverter->convertType(type.getElementType()).template cast(); + cast(typeConverter->convertType(type.getElementType())); auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); uint64_t alignment = 0; if (auto alignAttr = - gpuFuncOp - .getWorkgroupAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()) - .dyn_cast_or_null()) + dyn_cast_or_null(gpuFuncOp.getWorkgroupAttributionAttr( + en.index(), LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); auto globalOp = rewriter.create( gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, @@ -100,7 +98,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, global.getAddrSpace()), global.getSymNameAttr()); auto elementType = - global.getType().cast().getElementType(); + cast(global.getType()).getElementType(); Value memory = rewriter.create( loc, getTypeConverter()->getPointerType(elementType, @@ -112,7 +110,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; - auto type = attribution.getType().cast(); + auto type = cast(attribution.getType()); auto descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, memory); signatureConversion.remapInput(numProperArguments + en.index(), descr); @@ -123,7 +121,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, auto int64Ty = IntegerType::get(rewriter.getContext(), 64); for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { Value attribution = en.value(); - auto type = attribution.getType().cast(); + auto type = cast(attribution.getType()); assert(type && type.hasStaticShape() && "unexpected type in attribution"); // Explicitly drop memory space when lowering private memory @@ -136,10 +134,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = - gpuFuncOp - .getPrivateAttributionAttr( - en.index(), LLVM::LLVMDialect::getAlignAttrName()) - .dyn_cast_or_null()) + dyn_cast_or_null(gpuFuncOp.getPrivateAttributionAttr( + en.index(), LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); Value allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); @@ -164,7 +160,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front()); for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) { - auto memrefTy = en.value().dyn_cast(); + auto memrefTy = dyn_cast(en.value()); if (!memrefTy) continue; assert(memrefTy.hasStaticShape() && @@ -302,7 +298,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( rewriter.create(loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; - if (auto floatType = arg.getType().dyn_cast()) { + if (auto floatType = dyn_cast(arg.getType())) { if (!floatType.isF64()) arg = rewriter.create( loc, typeConverter->convertType(rewriter.getF64Type()), arg); @@ -428,7 +424,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( Type type = arg.getType(); Value promotedArg = arg; assert(type.isIntOrFloat()); - if (type.isa()) { + if (isa(type)) { type = rewriter.getF64Type(); promotedArg = rewriter.create(loc, type, arg); } @@ -462,14 +458,14 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, LLVMTypeConverter &converter) { TypeRange operandTypes(operands); if (llvm::none_of(operandTypes, - [](Type type) { return type.isa(); })) { + [](Type type) { return isa(type); })) { return rewriter.notifyMatchFailure(op, "expected vector operand"); } if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0) return rewriter.notifyMatchFailure(op, "expected no region/successor"); if (op->getNumResults() != 1) return rewriter.notifyMatchFailure(op, "expected single result"); - VectorType vectorType = op->getResult(0).getType().dyn_cast(); + VectorType vectorType = dyn_cast(op->getResult(0).getType()); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result"); @@ -482,7 +478,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { Value index = rewriter.create(loc, indexType, i); auto extractElement = [&](Value operand) -> Value { - if (!operand.getType().isa()) + if (!isa(operand.getType())) return operand; return rewriter.create(loc, operand, index); }; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 3687bd6..43dff49 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -454,7 +454,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( Location loc = op->getLoc(); auto memRefType = hostRegisterOp.getValue().getType(); - auto elementType = memRefType.cast().getElementType(); + auto elementType = cast(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); auto arguments = getTypeConverter()->promoteOperands( @@ -476,7 +476,7 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( Location loc = op->getLoc(); auto memRefType = hostUnregisterOp.getValue().getType(); - auto elementType = memRefType.cast().getElementType(); + auto elementType = cast(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); auto arguments = getTypeConverter()->promoteOperands( @@ -555,7 +555,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( } static bool isGpuAsyncTokenType(Value value) { - return value.getType().isa(); + return isa(value.getType()); } // Converts !gpu.async.token operands of `async.yield` to runtime calls. The @@ -591,7 +591,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. static bool isDefinedByCallTo(Value value, StringRef functionName) { - assert(value.getType().isa()); + assert(isa(value.getType())); if (auto defOp = value.getDefiningOp()) return defOp.getCallee()->equals(functionName); return false; @@ -862,7 +862,7 @@ static Value bitAndAddrspaceCast(Location loc, LLVM::LLVMPointerType destinationType, Value sourcePtr, LLVMTypeConverter &typeConverter) { - auto sourceTy = sourcePtr.getType().cast(); + auto sourceTy = cast(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) sourcePtr = rewriter.create( loc, @@ -879,7 +879,7 @@ static Value bitAndAddrspaceCast(Location loc, LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memRefType = memcpyOp.getSrc().getType().cast(); + auto memRefType = cast(memcpyOp.getSrc().getType()); if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || @@ -919,7 +919,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::MemsetOp memsetOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memRefType = memsetOp.getDst().getType().cast(); + auto memRefType = cast(memsetOp.getDst().getType()); if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || !isConvertibleAndHasIdentityMaps(memRefType) || diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 6858569..ebce2d7 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -54,8 +54,8 @@ public: Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); - StringRef funcName = getFunctionName( - funcType.cast().getReturnType()); + StringRef funcName = + getFunctionName(cast(funcType).getReturnType()); if (funcName.empty()) return failure(); @@ -78,7 +78,7 @@ public: private: Value maybeCast(Value operand, PatternRewriter &rewriter) const { Type type = operand.getType(); - if (!type.isa()) + if (!isa(type)) return operand; return rewriter.create( @@ -91,9 +91,9 @@ private: } StringRef getFunctionName(Type type) const { - if (type.isa()) + if (isa(type)) return f32Func; - if (type.isa()) + if (isa(type)) return f64Func; return ""; } diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index bf5be54..775dd1e 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -91,7 +91,7 @@ struct WmmaLoadOpToNVVMLowering ? NVVM::MMALayout::col : NVVM::MMALayout::row; gpu::MMAMatrixType retType = - subgroupMmaLoadMatrixOp.getRes().getType().cast(); + cast(subgroupMmaLoadMatrixOp.getRes().getType()); ArrayRef retTypeShape = retType.getShape(); int64_t m = 0; int64_t n = 0; @@ -122,8 +122,7 @@ struct WmmaLoadOpToNVVMLowering // Create nvvm.mma_load op according to the operand types. Value dataPtr = getStridedElementPtr( - loc, - subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(), + loc, cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()), adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); Value leadingDim = rewriter.create( @@ -158,7 +157,7 @@ struct WmmaStoreOpToNVVMLowering // Get the shape of the MMAMatrix type being stored. The shape will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType srcType = - subgroupMmaStoreMatrixOp.getSrc().getType().cast(); + cast(subgroupMmaStoreMatrixOp.getSrc().getType()); ArrayRef srcTypeShape = srcType.getShape(); NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose() ? NVVM::MMALayout::col @@ -170,7 +169,7 @@ struct WmmaStoreOpToNVVMLowering if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0) return rewriter.notifyMatchFailure(op, kInvalidCaseStr); - auto matrixType = adaptor.getSrc().getType().cast(); + auto matrixType = cast(adaptor.getSrc().getType()); for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create(loc, adaptor.getSrc(), i); @@ -179,7 +178,7 @@ struct WmmaStoreOpToNVVMLowering Value dataPtr = getStridedElementPtr( loc, - subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(), + cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()), adaptor.getDstMemref(), adaptor.getIndices(), rewriter); Value leadingDim = rewriter.create( loc, rewriter.getI32Type(), @@ -214,7 +213,7 @@ struct WmmaMmaOpToNVVMLowering SmallVector unpackedOps; auto unpackOp = [&](Value operand) { - auto structType = operand.getType().cast(); + auto structType = cast(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create(loc, operand, i); unpackedOps.push_back(toUse); @@ -224,10 +223,10 @@ struct WmmaMmaOpToNVVMLowering // Get the shapes of the MMAMatrix type being used. The shapes will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType aType = - subgroupMmaComputeOp.getOpA().getType().cast(); + cast(subgroupMmaComputeOp.getOpA().getType()); ArrayRef aTypeShape = aType.getShape(); gpu::MMAMatrixType cType = - subgroupMmaComputeOp.getOpC().getType().cast(); + cast(subgroupMmaComputeOp.getOpC().getType()); ArrayRef cTypeShape = cType.getShape(); int64_t m = cTypeShape[0]; int64_t n = cTypeShape[1]; @@ -245,7 +244,7 @@ struct WmmaMmaOpToNVVMLowering return rewriter.notifyMatchFailure(op, kInvalidCaseStr); NVVM::MMATypes bElementType = getElementType( - subgroupMmaComputeOp.getOpB().getType().cast()); + cast(subgroupMmaComputeOp.getOpB().getType())); if (bElementType != sourceType) return rewriter.notifyMatchFailure( op, "WMMA compute op input matrix element types must match."); @@ -277,9 +276,9 @@ struct WmmaConstantOpToNVVMLowering Location loc = subgroupMmaConstantOp.getLoc(); Value cst = adaptor.getOperands()[0]; LLVM::LLVMStructType type = convertMMAToLLVMType( - subgroupMmaConstantOp.getType().cast()); + cast(subgroupMmaConstantOp.getType())); // If the element type is a vector create a vector from the operand. - if (auto vecType = type.getBody()[0].dyn_cast()) { + if (auto vecType = dyn_cast(type.getBody()[0])) { Value vecCst = rewriter.create(loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = rewriter.create( @@ -301,9 +300,9 @@ struct WmmaConstantOpToNVVMLowering static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + auto floatType = cast(getElementTypeOrSelf(lhs.getType())); Type i1Type = builder.getI1Type(); - if (auto vecType = lhs.getType().dyn_cast()) + if (auto vecType = dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); Value cmp = builder.create( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, @@ -355,7 +354,7 @@ struct WmmaElementwiseOpToNVVMLowering Location loc = subgroupMmaElementwiseOp.getLoc(); size_t numOperands = adaptor.getOperands().size(); LLVM::LLVMStructType destType = convertMMAToLLVMType( - subgroupMmaElementwiseOp.getType().cast()); + cast(subgroupMmaElementwiseOp.getType())); Value matrixStruct = rewriter.create(loc, destType); for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { SmallVector extractedOperands; diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 98f90a3..1ac4e8e 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -54,7 +54,7 @@ using namespace mlir; static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { bool canBeBare = true; for (Type type : func.getArgumentTypes()) - if (auto memrefTy = type.dyn_cast()) + if (auto memrefTy = dyn_cast(type)) canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy); return canBeBare; } @@ -166,9 +166,8 @@ struct LowerGpuOpsToROCDLOpsPass // Manually rewrite known block size attributes so the LLVMIR translation // infrastructure can pick them up. m.walk([ctx](LLVM::LLVMFuncOp op) { - if (auto blockSizes = - op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()) - .dyn_cast_or_null()) { + if (auto blockSizes = dyn_cast_or_null( + op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) { op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(), blockSizes); // Also set up the rocdl.flat_work_group_size attribute to prevent diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index becb28e..feea1e3 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -495,9 +495,9 @@ static std::optional createGroupReduceOp(OpBuilder &builder, Type type = arg.getType(); using MembptrT = FuncT OpHandler::*; MembptrT handlerPtr; - if (type.isa()) { + if (isa(type)) { handlerPtr = &OpHandler::floatFunc; - } else if (type.isa()) { + } else if (isa(type)) { handlerPtr = &OpHandler::intFunc; } else { return std::nullopt; diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index f7e1356..d64fa6a 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -81,9 +81,9 @@ struct WmmaLoadOpToSPIRVLowering ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaLoadMatrixOp->getLoc(); gpu::MMAMatrixType retType = - subgroupMmaLoadMatrixOp.getRes().getType().cast(); + cast(subgroupMmaLoadMatrixOp.getRes().getType()); auto memrefType = - subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(); + cast(subgroupMmaLoadMatrixOp.getSrcMemref().getType()); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); @@ -114,7 +114,7 @@ struct WmmaStoreOpToSPIRVLowering ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaStoreMatrixOp->getLoc(); auto memrefType = - subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(); + cast(subgroupMmaStoreMatrixOp.getDstMemref().getType()); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); @@ -161,7 +161,7 @@ struct WmmaConstantOpToSPIRVLowering ConversionPatternRewriter &rewriter) const override { Value cst = adaptor.getOperands()[0]; auto coopType = convertMMAToSPIRVType( - subgroupMmaConstantMatrixOp.getType().cast()); + cast(subgroupMmaConstantMatrixOp.getType())); rewriter.replaceOpWithNewOp( subgroupMmaConstantMatrixOp, coopType, cst); return success(); @@ -180,11 +180,11 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering ConversionPatternRewriter &rewriter) const override { // All operands should be of cooperative matrix types. for (Value operand : adaptor.getOperands()) { - if (!operand.getType().isa()) + if (!isa(operand.getType())) return failure(); } auto coopType = convertMMAToSPIRVType( - elementwiseOp.getType().cast()); + cast(elementwiseOp.getType())); return success(createElementwiseOp(rewriter, elementwiseOp, coopType, adaptor.getOperands())); } @@ -204,7 +204,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering return failure(); // All operands should be of cooperative matrix types. for (Value operand : adaptor.getOperands()) { - if (!operand.getType().isa()) + if (!isa(operand.getType())) return failure(); } @@ -236,7 +236,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering scalar = cc.getConstituents().front(); auto coopType = convertMMAToSPIRVType( - elementwiseOp.getType().cast()); + cast(elementwiseOp.getType())); rewriter.replaceOpWithNewOp( elementwiseOp, coopType, ValueRange{matrix, scalar}); return success(); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp index e4ac642..2d22516 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -61,7 +61,7 @@ private: /// Checks where the given type is supported by Vulkan runtime. bool isSupportedType(Type type) { - if (auto memRefType = type.dyn_cast_or_null()) { + if (auto memRefType = dyn_cast_or_null(type)) { auto elementType = memRefType.getElementType(); return memRefType.hasRank() && (memRefType.getRank() >= 1 && memRefType.getRank() <= 3) && @@ -197,7 +197,7 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc( // The below cast always succeeds as it has already been verified in // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element // types. - elementTypes.push_back(type.cast().getElementType()); + elementTypes.push_back(cast(type).getElementType()); } vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName, builder.getTypeArrayAttr(elementTypes)); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp index 78d1f67..036eb02 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -142,11 +142,11 @@ private: /// Returns a string representation from the given `type`. StringRef stringifyType(Type type) { - if (type.isa()) + if (isa(type)) return "Float"; - if (type.isa()) + if (isa(type)) return "Half"; - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getWidth() == 32) return "Int32"; if (intType.getWidth() == 16) @@ -282,7 +282,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); // Special case for fp16 type. Since it is not a supported type in C we use // int16_t and bitcast the descriptor. - if (!useOpaquePointers && type.isa()) { + if (!useOpaquePointers && isa(type)) { auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16)); ptrToMemRefDescriptor = builder.create( loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); @@ -328,9 +328,8 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg, rank = 0; return success(); } - rank = llvmDescriptorTy.getBody()[3] - .cast() - .getNumElements(); + rank = + cast(llvmDescriptorTy.getBody()[3]).getNumElements(); return success(); } @@ -375,7 +374,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { for (auto type : types) { std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); - if (type.isa()) + if (isa(type)) type = IntegerType::get(&getContext(), 16); if (!module.lookupSymbol(fnName)) { auto fnType = LLVM::LLVMFunctionType::get( diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index 2373765..df9dafc 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -24,8 +24,7 @@ using namespace mlir; MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); - indexType = value.getType() - .cast() + indexType = cast(value.getType()) .getBody()[kOffsetPosInMemRefDescriptor]; } @@ -193,10 +192,9 @@ void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, } LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { - return value.getType() - .cast() - .getBody()[kAlignedPtrPosInMemRefDescriptor] - .cast(); + return cast( + cast(value.getType()) + .getBody()[kAlignedPtrPosInMemRefDescriptor]); } Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 67a2898..c55a62e 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -235,7 +235,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( SmallVector unrankedMemrefs; SmallVector unrankedAddressSpaces; for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (auto memRefType = origTypes[i].dyn_cast()) { + if (auto memRefType = dyn_cast(origTypes[i])) { unrankedMemrefs.emplace_back(operands[i]); FailureOr addressSpace = getTypeConverter()->getMemRefAddressSpace(memRefType); @@ -276,7 +276,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { Type type = origTypes[i]; - if (!type.isa()) + if (!isa(type)) continue; Value allocationSize = sizes[unrankedMemrefPos++]; UnrankedMemRefDescriptor desc(operands[i]); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 88d7eaf..cf0d506 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -260,7 +260,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { if (!resultType) return {}; - auto structType = resultType.dyn_cast(); + auto structType = dyn_cast(resultType); if (structType) { // Struct types cannot be safely returned via C interface. Make this a // pointer argument, instead. @@ -272,7 +272,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { auto converted = convertType(t); if (!converted || !LLVM::isCompatibleType(converted)) return {}; - if (t.isa()) + if (isa(t)) converted = getPointerType(converted); inputs.push_back(converted); } @@ -412,13 +412,13 @@ LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) { // Check if a memref type can be converted to a bare pointer. bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { - if (type.isa()) + if (isa(type)) // Unranked memref is not supported in the bare pointer calling convention. return false; // Check that the memref has static shape, strides and offset. Otherwise, it // cannot be lowered to a bare pointer. - auto memrefTy = type.cast(); + auto memrefTy = cast(type); if (!memrefTy.hasStaticShape()) return false; @@ -476,7 +476,7 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) { Type LLVMTypeConverter::convertCallingConventionType(Type type, bool useBarePtrCallConv) { if (useBarePtrCallConv) - if (auto memrefTy = type.dyn_cast()) + if (auto memrefTy = dyn_cast(type)) return convertMemRefToBarePtr(memrefTy); return convertType(type); @@ -491,7 +491,7 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors( assert(stdTypes.size() == values.size() && "The number of types and values doesn't match"); for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = stdTypes[i].dyn_cast()) + if (auto memrefTy = dyn_cast(stdTypes[i])) values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, memrefTy, values[i]); } @@ -569,19 +569,19 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = dyn_cast(operand.getType())) { MemRefDescriptor desc(llvmOperand); llvmOperand = desc.alignedPtr(builder, loc); - } else if (operand.getType().isa()) { + } else if (isa(operand.getType())) { llvm_unreachable("Unranked memrefs are not supported"); } } else { - if (operand.getType().isa()) { + if (isa(operand.getType())) { UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, promotedOperands); continue; } - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = dyn_cast(operand.getType())) { MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, promotedOperands); continue; @@ -600,7 +600,7 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { - if (auto memref = type.dyn_cast()) { + if (auto memref = dyn_cast(type)) { // In signatures, Memref descriptors are expanded into lists of // non-aggregate values. auto converted = @@ -610,7 +610,7 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, result.append(converted.begin(), converted.end()); return success(); } - if (type.isa()) { + if (isa(type)) { auto converted = converter.getUnrankedMemRefDescriptorFields(); if (converted.empty()) return failure(); diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e95c702..732f6c5 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -27,10 +27,10 @@ LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, } info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmNDVectorTy; - while (llvmTy.isa()) { + while (isa(llvmTy)) { info.arraySizes.push_back( - llvmTy.cast().getNumElements()); - llvmTy = llvmTy.cast().getElementType(); + cast(llvmTy).getNumElements()); + llvmTy = cast(llvmTy).getElementType(); } if (!LLVM::isCompatibleVectorType(llvmTy)) return info; @@ -81,7 +81,7 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { - auto resultNDVectorType = op->getResult(0).getType().cast(); + auto resultNDVectorType = cast(op->getResult(0).getType()); auto resultTypeInfo = extractNDVectorTypeInfo(resultNDVectorType, typeConverter); auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; @@ -114,7 +114,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( return failure(); auto llvmNDVectorTy = operands[0].getType(); - if (!llvmNDVectorTy.isa()) + if (!isa(llvmNDVectorTy)) return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, rewriter); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp index f94da68..4d1f35c7 100644 --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -42,7 +42,7 @@ static SmallVector extractOperandTypes(Operation *op) { // The underlying descriptor type (e.g. LLVM) does not have layout // information. Canonicalizing the type at the level of std when going into // a library call avoids needing to introduce DialectCastOp. - if (auto memrefType = type.dyn_cast()) + if (auto memrefType = dyn_cast(type)) result.push_back(makeStridedLayoutDynamic(memrefType)); else result.push_back(type); @@ -96,7 +96,7 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, SmallVector res; res.reserve(operands.size()); for (auto op : operands) { - auto memrefType = op.getType().dyn_cast(); + auto memrefType = dyn_cast(op.getType()); if (!memrefType) { res.push_back(op); continue; diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 10832a1..3a56764 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -106,7 +106,7 @@ LogicalResult VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { Type opType = op.getType(); Location loc = op.getLoc(); - auto vecType = opType.template dyn_cast(); + auto vecType = dyn_cast(opType); if (!vecType) return rewriter.notifyMatchFailure(op, "not a vector operation"); @@ -117,7 +117,7 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { Type resultElementType = vecType.getElementType(); Attribute initValueAttr; - if (resultElementType.isa()) + if (isa(resultElementType)) initValueAttr = FloatAttr::get(resultElementType, 0.0); else initValueAttr = IntegerAttr::get(resultElementType, 0); @@ -183,7 +183,7 @@ static FunctionType getElementalFuncTypeForOp(Operation *op) { /// } /// } static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { - assert(elementType.isa() && + assert(isa(elementType) && "non-integer element type for IPowIOp"); ImplicitLocOpBuilder builder = @@ -361,7 +361,7 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { LogicalResult IPowIOpLowering::matchAndRewrite(math::IPowIOp op, PatternRewriter &rewriter) const { - auto baseType = op.getOperands()[0].getType().dyn_cast(); + auto baseType = dyn_cast(op.getOperands()[0].getType()); if (!baseType) return rewriter.notifyMatchFailure(op, "non-integer base operand"); @@ -411,8 +411,8 @@ IPowIOpLowering::matchAndRewrite(math::IPowIOp op, /// } static func::FuncOp createElementFPowIFunc(ModuleOp *module, FunctionType funcType) { - auto baseType = funcType.getInput(0).cast(); - auto powType = funcType.getInput(1).cast(); + auto baseType = cast(funcType.getInput(0)); + auto powType = cast(funcType.getInput(1)); ImplicitLocOpBuilder builder = ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); @@ -586,7 +586,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module, LogicalResult FPowIOpLowering::matchAndRewrite(math::FPowIOp op, PatternRewriter &rewriter) const { - if (op.getType().template dyn_cast()) + if (dyn_cast(op.getType())) return rewriter.notifyMatchFailure(op, "non-scalar operation"); FunctionType funcType = getElementalFuncTypeForOp(op); @@ -649,7 +649,7 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op, /// return %out: i32 /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { - if (!elementType.isa()) { + if (!isa(elementType)) { LLVM_DEBUG({ DBGS() << "non-integer element type for CtlzFunc; type was: "; elementType.print(llvm::dbgs()); @@ -751,7 +751,7 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { /// operation. LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op, PatternRewriter &rewriter) const { - if (op.getType().template dyn_cast()) + if (dyn_cast(op.getType())) return rewriter.notifyMatchFailure(op, "non-scalar operation"); Type type = getElementTypeOrSelf(op.getResult().getType()); @@ -794,7 +794,7 @@ private: bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) { auto expTy = - getElementTypeOrSelf(op.getRhs().getType()).dyn_cast(); + dyn_cast(getElementTypeOrSelf(op.getRhs().getType())); return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent); } diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index c331f4f..6dc5c41 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -79,14 +79,14 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern { auto resultType = op.getResult().getType(); auto boolZero = rewriter.getBoolAttr(false); - if (!operandType.template isa()) { + if (!isa(operandType)) { LLVM::ConstantOp zero = rewriter.create(loc, boolZero); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand(), zero); return success(); } - auto vectorType = resultType.template dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return failure(); @@ -122,17 +122,17 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath expAttrs(op); ConvertFastMath subAttrs(op); - if (!operandType.isa()) { + if (!isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); + SplatElementsAttr::get(cast(resultType), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } @@ -143,7 +143,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern { return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); @@ -180,17 +180,17 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath addAttrs(op); ConvertFastMath logAttrs(op); - if (!operandType.isa()) { + if (!isa(operandType)) { LLVM::ConstantOp one = LLVM::isCompatibleVectorType(operandType) ? rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), + SplatElementsAttr::get(cast(resultType), floatOne)) : rewriter.create(loc, operandType, floatOne); @@ -202,7 +202,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern { return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); @@ -240,17 +240,17 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto resultType = op.getResult().getType(); - auto floatType = getElementTypeOrSelf(resultType).cast(); + auto floatType = cast(getElementTypeOrSelf(resultType)); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); ConvertFastMath sqrtAttrs(op); ConvertFastMath divAttrs(op); - if (!operandType.isa()) { + if (!isa(operandType)) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, - SplatElementsAttr::get(resultType.cast(), floatOne)); + SplatElementsAttr::get(cast(resultType), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } @@ -261,7 +261,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern { return success(); } - auto vectorType = resultType.dyn_cast(); + auto vectorType = dyn_cast(resultType); if (!vectorType) return failure(); diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index d683444..7fd9411 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -75,7 +75,7 @@ LogicalResult VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); auto loc = op.getLoc(); - auto vecType = opType.template dyn_cast(); + auto vecType = dyn_cast(opType); if (!vecType) return failure(); @@ -107,7 +107,7 @@ template LogicalResult PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto opType = op.getType(); - if (!opType.template isa()) + if (!isa(opType)) return failure(); auto loc = op.getLoc(); @@ -127,7 +127,7 @@ ScalarOpToLibmCall::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); - if (!type.template isa()) + if (!isa(type)) return failure(); auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc; diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 412f99ce..6630aaf 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -34,7 +34,7 @@ using namespace mlir; /// given type is not a 32-bit scalar/vector type. static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc) { - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { if (!vectorType.getElementType().isInteger(32)) return nullptr; SmallVector values(vectorType.getNumElements(), value); @@ -55,7 +55,7 @@ static bool isSupportedSourceType(Type originalType) { if (originalType.isIntOrIndexOrFloat()) return true; - if (auto vecTy = originalType.dyn_cast()) { + if (auto vecTy = dyn_cast(originalType)) { if (!vecTy.getElementType().isIntOrIndexOrFloat()) return false; if (vecTy.isScalable()) @@ -133,10 +133,10 @@ struct CopySignPattern final : public OpConversionPattern { return failure(); FloatType floatType; - if (auto scalarType = copySignOp.getType().dyn_cast()) { + if (auto scalarType = dyn_cast(copySignOp.getType())) { floatType = scalarType; - } else if (auto vectorType = copySignOp.getType().dyn_cast()) { - floatType = vectorType.getElementType().cast(); + } else if (auto vectorType = dyn_cast(copySignOp.getType())) { + floatType = cast(vectorType.getElementType()); } else { return failure(); } @@ -151,7 +151,7 @@ struct CopySignPattern final : public OpConversionPattern { Value valueMask = rewriter.create( loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { assert(vectorType.getRank() == 1); int count = vectorType.getNumElements(); intType = VectorType::get(count, intType); @@ -203,9 +203,9 @@ struct CountLeadingZerosPattern final // We can only support 32-bit integer types for now. unsigned bitwidth = 0; - if (type.isa()) + if (isa(type)) bitwidth = type.getIntOrFloatBitWidth(); - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = dyn_cast(type)) bitwidth = vectorType.getElementTypeBitWidth(); if (bitwidth != 32) return failure(); @@ -307,10 +307,10 @@ struct PowFOpPattern final : public OpConversionPattern { // Get the scalar float type. FloatType scalarFloatType; - if (auto scalarType = powfOp.getType().dyn_cast()) { + if (auto scalarType = dyn_cast(powfOp.getType())) { scalarFloatType = scalarType; - } else if (auto vectorType = powfOp.getType().dyn_cast()) { - scalarFloatType = vectorType.getElementType().cast(); + } else if (auto vectorType = dyn_cast(powfOp.getType())) { + scalarFloatType = cast(vectorType.getElementType()); } else { return failure(); } @@ -318,7 +318,7 @@ struct PowFOpPattern final : public OpConversionPattern { // Get int type of the same shape as the float type. Type scalarIntType = rewriter.getIntegerType(32); Type intType = scalarIntType; - if (auto vectorType = adaptor.getRhs().getType().dyn_cast()) { + if (auto vectorType = dyn_cast(adaptor.getRhs().getType())) { auto shape = vectorType.getShape(); intType = VectorType::get(shape, scalarIntType); } @@ -374,7 +374,7 @@ struct RoundOpPattern final : public OpConversionPattern { auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); Value half; - if (VectorType vty = ty.dyn_cast()) { + if (VectorType vty = dyn_cast(ty)) { half = rewriter.create( loc, vty, DenseElementsAttr::get(vty, diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index 37aa6cf..2fa4315 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -58,7 +58,7 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, LLVMTypeConverter &typeConverter) { - auto allocatedPtrTy = allocatedPtr.getType().cast(); + auto allocatedPtrTy = cast(allocatedPtr.getType()); unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType); if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) allocatedPtr = rewriter.create( @@ -114,10 +114,10 @@ unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes( layout = &analysis->getAbove(op); } Type elementType = memRefType.getElementType(); - if (auto memRefElementType = elementType.dyn_cast()) + if (auto memRefElementType = dyn_cast(elementType)) return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, *layout); - if (auto memRefElementType = elementType.dyn_cast()) + if (auto memRefElementType = dyn_cast(elementType)) return getTypeConverter()->getUnrankedMemRefDescriptorSize( memRefElementType, *layout); return layout->getTypeSize(elementType); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index e9fbad3..1a6e5a4 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -184,10 +184,10 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { rewriter.setInsertionPointToEnd(currentBlock); Value src = op.getSource(); - auto srcType = src.getType().dyn_cast(); + auto srcType = dyn_cast(src.getType()); Value srcNumElements = computeNumElements( srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); }); - auto dstType = op.getType().cast(); + auto dstType = cast(op.getType()); Value dstNumElements = computeNumElements( dstType, [&]() -> Value { return op.getDynamicResultSize(); }); Value cond = rewriter.create( @@ -342,7 +342,7 @@ struct AssumeAlignmentOpLowering unsigned alignment = op.getAlignment(); auto loc = op.getLoc(); - auto srcMemRefType = op.getMemref().getType().cast(); + auto srcMemRefType = cast(op.getMemref().getType()); Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, rewriter); @@ -417,7 +417,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.getSource().getType(); - if (operandType.isa()) { + if (isa(operandType)) { FailureOr extractedSize = extractSizeOfUnrankedMemRef( operandType, dimOp, adaptor.getOperands(), rewriter); if (failed(extractedSize)) @@ -425,7 +425,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern { rewriter.replaceOp(dimOp, {*extractedSize}); return success(); } - if (operandType.isa()) { + if (isa(operandType)) { rewriter.replaceOp( dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, adaptor.getOperands(), rewriter)}); @@ -441,7 +441,7 @@ private: ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); - auto unrankedMemRefType = operandType.cast(); + auto unrankedMemRefType = cast(operandType); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); FailureOr maybeAddressSpace = @@ -492,10 +492,7 @@ private: return idx; if (auto constantOp = dimOp.getIndex().getDefiningOp()) - return constantOp.getValue() - .cast() - .getValue() - .getSExtValue(); + return cast(constantOp.getValue()).getValue().getSExtValue(); return std::nullopt; } @@ -506,7 +503,7 @@ private: Location loc = dimOp.getLoc(); // Take advantage if index is constant. - MemRefType memRefType = operandType.cast(); + MemRefType memRefType = cast(operandType); if (std::optional index = getConstantDimIndex(dimOp)) { int64_t i = *index; if (i >= 0 && i < memRefType.getRank()) { @@ -589,7 +586,7 @@ struct GenericAtomicRMWOpLowering // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.getMemref().getType().cast(); + auto memRefType = cast(atomicOp.getMemref().getType()); auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), adaptor.getIndices(), rewriter); Value init = rewriter.create( @@ -712,7 +709,7 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { Location loc, Value sizeBytes, Operation *op) const override { auto getGlobalOp = cast(op); - MemRefType type = getGlobalOp.getResult().getType().cast(); + MemRefType type = cast(getGlobalOp.getResult().getType()); // This is called after a type conversion, which would have failed if this // call fails. @@ -823,12 +820,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.getMemref().getType(); - if (auto unrankedMemRefType = operandType.dyn_cast()) { + if (auto unrankedMemRefType = dyn_cast(operandType)) { UnrankedMemRefDescriptor desc(adaptor.getMemref()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } - if (auto rankedMemRefType = operandType.dyn_cast()) { + if (auto rankedMemRefType = dyn_cast(operandType)) { rewriter.replaceOp( op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); return success(); @@ -849,17 +846,17 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { // and require source and result type to have the same rank. Therefore, // perform a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. - if (srcType.isa() && dstType.isa()) + if (isa(srcType) && isa(dstType)) return success(typeConverter->convertType(srcType) == typeConverter->convertType(dstType)); // At least one of the operands is unranked type - assert(srcType.isa() || - dstType.isa()); + assert(isa(srcType) || + isa(dstType)); // Unranked to unranked cast is disallowed - return !(srcType.isa() && - dstType.isa()) + return !(isa(srcType) && + isa(dstType)) ? success() : failure(); } @@ -872,15 +869,15 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. - if (srcType.isa() && dstType.isa()) + if (isa(srcType) && isa(dstType)) return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); - if (srcType.isa() && dstType.isa()) { + if (isa(srcType) && isa(dstType)) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space - auto srcMemRefType = srcType.cast(); + auto srcMemRefType = cast(srcType); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( @@ -905,7 +902,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); - } else if (srcType.isa() && dstType.isa()) { + } else if (isa(srcType) && isa(dstType)) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. @@ -942,7 +939,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - auto srcType = op.getSource().getType().dyn_cast(); + auto srcType = dyn_cast(op.getSource().getType()); MemRefDescriptor srcDesc(adaptor.getSource()); @@ -984,8 +981,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); - auto srcType = op.getSource().getType().cast(); - auto targetType = op.getTarget().getType().cast(); + auto srcType = cast(op.getSource().getType()); + auto targetType = cast(op.getTarget().getType()); // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, MemRefType type) { @@ -1012,11 +1009,11 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { auto stackSaveOp = rewriter.create(loc, getVoidPtrType()); - auto srcMemRefType = srcType.dyn_cast(); + auto srcMemRefType = dyn_cast(srcType); Value unrankedSource = srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) : adaptor.getSource(); - auto targetMemRefType = targetType.dyn_cast(); + auto targetMemRefType = dyn_cast(targetType); Value unrankedTarget = targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) : adaptor.getTarget(); @@ -1055,8 +1052,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = op.getSource().getType().cast(); - auto targetType = op.getTarget().getType().cast(); + auto srcType = cast(op.getSource().getType()); + auto targetType = cast(op.getTarget().getType()); auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) { if (!type.hasStaticShape()) @@ -1077,7 +1074,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { }; auto isContiguousMemrefType = [&](BaseMemRefType type) { - auto memrefType = type.dyn_cast(); + auto memrefType = dyn_cast(type); // We can use memcpy for memrefs if they have an identity layout or are // contiguous with an arbitrary offset. Ignore empty memrefs, which is a // special case handled by memrefCopy. @@ -1105,9 +1102,9 @@ struct MemorySpaceCastOpLowering Location loc = op.getLoc(); Type resultType = op.getDest().getType(); - if (auto resultTypeR = resultType.dyn_cast()) { + if (auto resultTypeR = dyn_cast(resultType)) { auto resultDescType = - typeConverter->convertType(resultTypeR).cast(); + cast(typeConverter->convertType(resultTypeR)); Type newPtrType = resultDescType.getBody()[0]; SmallVector descVals; @@ -1122,10 +1119,10 @@ struct MemorySpaceCastOpLowering rewriter.replaceOp(op, result); return success(); } - if (auto resultTypeU = resultType.dyn_cast()) { + if (auto resultTypeU = dyn_cast(resultType)) { // Since the type converter won't be doing this for us, get the address // space. - auto sourceType = op.getSource().getType().cast(); + auto sourceType = cast(op.getSource().getType()); FailureOr maybeSourceAddrSpace = getTypeConverter()->getMemRefAddressSpace(sourceType); if (failed(maybeSourceAddrSpace)) @@ -1217,7 +1214,7 @@ static void extractPointersAndOffset(Location loc, Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); - if (operandType.isa()) { + if (isa(operandType)) { MemRefDescriptor desc(convertedOperand); *allocatedPtr = desc.allocatedPtr(rewriter, loc); *alignedPtr = desc.alignedPtr(rewriter, loc); @@ -1228,8 +1225,8 @@ static void extractPointersAndOffset(Location loc, // These will all cause assert()s on unconvertible types. unsigned memorySpace = *typeConverter.getMemRefAddressSpace( - operandType.cast()); - Type elementType = operandType.cast().getElementType(); + cast(operandType)); + Type elementType = cast(operandType).getElementType(); Type llvmElementType = typeConverter.convertType(elementType); LLVM::LLVMPointerType elementPtrType = typeConverter.getPointerType(llvmElementType, memorySpace); @@ -1273,9 +1270,9 @@ private: memref::ReinterpretCastOp castOp, memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = - castOp.getResult().getType().cast(); - auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + cast(castOp.getResult().getType()); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1339,13 +1336,12 @@ private: Type srcType, memref::ReshapeOp reshapeOp, memref::ReshapeOp::Adaptor adaptor, Value *descriptor) const { - auto shapeMemRefType = reshapeOp.getShape().getType().cast(); + auto shapeMemRefType = cast(reshapeOp.getShape().getType()); if (shapeMemRefType.hasStaticShape()) { MemRefType targetMemRefType = - reshapeOp.getResult().getType().cast(); - auto llvmTargetDescriptorTy = - typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + cast(reshapeOp.getResult().getType()); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1426,8 +1422,7 @@ private: Value resultRank = shapeDesc.size(rewriter, loc, 0); // Extract address space and element type. - auto targetType = - reshapeOp.getResult().getType().cast(); + auto targetType = cast(reshapeOp.getResult().getType()); unsigned addressSpace = *getTypeConverter()->getMemRefAddressSpace(targetType); Type elementType = targetType.getElementType(); @@ -1695,7 +1690,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); - auto srcMemRefType = viewOp.getSource().getType().cast(); + auto srcMemRefType = cast(viewOp.getSource().getType()); unsigned sourceMemorySpace = *getTypeConverter()->getMemRefAddressSpace(srcMemRefType); Value bitcastPtr; @@ -1848,7 +1843,7 @@ public: Location loc = extractStridedMetadataOp.getLoc(); Value source = extractStridedMetadataOp.getSource(); - auto sourceMemRefType = source.getType().cast(); + auto sourceMemRefType = cast(source.getType()); int64_t rank = sourceMemRefType.getRank(); SmallVector results; results.reserve(2 + rank * 2); @@ -1858,7 +1853,7 @@ public: Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), - extractStridedMetadataOp.getBaseBuffer().getType().cast(), + cast(extractStridedMetadataOp.getBaseBuffer().getType()), baseBuffer, alignedBuffer); results.push_back((Value)dstMemRef); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp index 3f92c6f..55c23d7 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -64,7 +64,7 @@ spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) { // Unknown dialect custom attributes are not supported by default. // Downstream callers should plug in more specialized ones. - auto intAttr = memorySpaceAttr.dyn_cast(); + auto intAttr = dyn_cast(memorySpaceAttr); if (!intAttr) return std::nullopt; unsigned memorySpace = intAttr.getInt(); @@ -118,7 +118,7 @@ spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) { // Unknown dialect custom attributes are not supported by default. // Downstream callers should plug in more specialized ones. - auto intAttr = memorySpaceAttr.dyn_cast(); + auto intAttr = dyn_cast(memorySpaceAttr); if (!intAttr) return std::nullopt; unsigned memorySpace = intAttr.getInt(); @@ -177,7 +177,7 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter( auto storageAttr = spirv::StorageClassAttr::get(memRefType.getContext(), *storage); - if (auto rankedType = memRefType.dyn_cast()) { + if (auto rankedType = dyn_cast(memRefType)) { return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), rankedType.getLayout(), storageAttr); } @@ -203,9 +203,9 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter( /// Returns true if the given `type` is considered as legal for SPIR-V /// conversion. static bool isLegalType(Type type) { - if (auto memRefType = type.dyn_cast()) { + if (auto memRefType = dyn_cast(type)) { Attribute spaceAttr = memRefType.getMemorySpace(); - return spaceAttr && spaceAttr.isa(); + return spaceAttr && isa(spaceAttr); } return true; } @@ -213,7 +213,7 @@ static bool isLegalType(Type type) { /// Returns true if the given `attr` is considered as legal for SPIR-V /// conversion. static bool isLegalAttr(Attribute attr) { - if (auto typeAttr = attr.dyn_cast()) + if (auto typeAttr = dyn_cast(attr)) return isLegalType(typeAttr.getValue()); return true; } @@ -266,7 +266,7 @@ LogicalResult MapMemRefStoragePattern::matchAndRewrite( llvm::SmallVector newAttrs; newAttrs.reserve(op->getAttrs().size()); for (auto attr : op->getAttrs()) { - if (auto typeAttr = attr.getValue().dyn_cast()) { + if (auto typeAttr = dyn_cast(attr.getValue())) { auto newAttr = getTypeConverter()->convertType(typeAttr.getValue()); newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); } else { diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 9c74feb..efd541b 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -93,11 +93,11 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask, /// can be lowered to SPIR-V. static bool isAllocationSupported(Operation *allocOp, MemRefType type) { if (isa(allocOp)) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) return false; } else if (isa(allocOp)) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Function) return false; } else { @@ -110,7 +110,7 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) { return false; Type elementType = type.getElementType(); - if (auto vecType = elementType.dyn_cast()) + if (auto vecType = dyn_cast(elementType)) elementType = vecType.getElementType(); return elementType.isIntOrFloat(); } @@ -119,7 +119,7 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) { /// operations of unsupported integer bitwidths, based on the memref /// type. Returns std::nullopt on failure. static std::optional getAtomicOpScope(MemRefType type) { - auto sc = type.getMemorySpace().dyn_cast_or_null(); + auto sc = dyn_cast_or_null(type.getMemorySpace()); switch (sc.getValue()) { case spirv::StorageClass::StorageBuffer: return spirv::Scope::Device; @@ -324,11 +324,11 @@ LogicalResult AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (atomicOp.getType().isa()) + if (isa(atomicOp.getType())) return rewriter.notifyMatchFailure(atomicOp, "unimplemented floating-point case"); - auto memrefType = atomicOp.getMemref().getType().cast(); + auto memrefType = cast(atomicOp.getMemref().getType()); std::optional scope = getAtomicOpScope(memrefType); if (!scope) return rewriter.notifyMatchFailure(atomicOp, @@ -380,7 +380,7 @@ LogicalResult DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - MemRefType deallocType = operation.getMemref().getType().cast(); + MemRefType deallocType = cast(operation.getMemref().getType()); if (!isAllocationSupported(operation, deallocType)) return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); rewriter.eraseOp(operation); @@ -395,7 +395,7 @@ LogicalResult IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = loadOp.getLoc(); - auto memrefType = loadOp.getMemref().getType().cast(); + auto memrefType = cast(loadOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -419,18 +419,18 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = pointeeType.dyn_cast()) + if (auto arrayType = dyn_cast(pointeeType)) dstType = arrayType.getElementType(); else dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = - pointeeType.cast().getElementType(0); - if (auto arrayType = structElemType.dyn_cast()) + cast(pointeeType).getElementType(0); + if (auto arrayType = dyn_cast(structElemType)) dstType = arrayType.getElementType(); else - dstType = structElemType.cast().getElementType(); + dstType = cast(structElemType).getElementType(); } int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); @@ -509,7 +509,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, LogicalResult LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = loadOp.getMemref().getType().cast(); + auto memrefType = cast(loadOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr( @@ -526,7 +526,7 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = storeOp.getMemref().getType().cast(); + auto memrefType = cast(storeOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -553,18 +553,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { - if (auto arrayType = pointeeType.dyn_cast()) + if (auto arrayType = dyn_cast(pointeeType)) dstType = arrayType.getElementType(); else dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = - pointeeType.cast().getElementType(0); - if (auto arrayType = structElemType.dyn_cast()) + cast(pointeeType).getElementType(0); + if (auto arrayType = dyn_cast(structElemType)) dstType = arrayType.getElementType(); else - dstType = structElemType.cast().getElementType(); + dstType = cast(structElemType).getElementType(); } int dstBits = dstType.getIntOrFloatBitWidth(); @@ -651,21 +651,21 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( return rewriter.notifyMatchFailure( loc, "address space casts require kernel capability"); - auto sourceType = addrCastOp.getSource().getType().dyn_cast(); + auto sourceType = dyn_cast(addrCastOp.getSource().getType()); if (!sourceType) return rewriter.notifyMatchFailure( loc, "SPIR-V lowering requires ranked memref types"); - auto resultType = addrCastOp.getResult().getType().cast(); + auto resultType = cast(addrCastOp.getResult().getType()); auto sourceStorageClassAttr = - sourceType.getMemorySpace().dyn_cast_or_null(); + dyn_cast_or_null(sourceType.getMemorySpace()); if (!sourceStorageClassAttr) return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { diag << "source address space " << sourceType.getMemorySpace() << " must be a SPIR-V storage class"; }); auto resultStorageClassAttr = - resultType.getMemorySpace().dyn_cast_or_null(); + dyn_cast_or_null(resultType.getMemorySpace()); if (!resultStorageClassAttr) return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { diag << "result address space " << resultType.getMemorySpace() @@ -709,7 +709,7 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto memrefType = storeOp.getMemref().getType().cast(); + auto memrefType = cast(storeOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto storePtr = spirv::getElementPtr( diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 4a923fa..3d898e5 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -28,7 +28,7 @@ using namespace mlir; /// `gpu.mma.sync` operation. static Type inferIntrinsicResultType(Type vectorResultType) { MLIRContext *ctx = vectorResultType.getContext(); - auto a = vectorResultType.cast(); + auto a = cast(vectorResultType); auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); auto i32Ty = IntegerType::get(ctx, 32); auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); @@ -69,8 +69,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, Type resultType, Value intrinsicResult, RewriterBase &rewriter) { MLIRContext *ctx = rewriter.getContext(); - auto structType = intrinsicResultType.dyn_cast(); - auto arrayType = resultType.dyn_cast(); + auto structType = dyn_cast(intrinsicResultType); + auto arrayType = dyn_cast(resultType); Type i32Ty = rewriter.getI32Type(); Type f32Ty = rewriter.getF32Type(); Type f64Ty = rewriter.getF64Type(); @@ -153,7 +153,7 @@ static SmallVector unpackOperandVector(RewriterBase &rewriter, Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8); Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1); - auto arrayTy = operand.getType().cast(); + auto arrayTy = cast(operand.getType()); for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { Value toUse = rewriter.create(loc, operand, i); @@ -172,7 +172,7 @@ static SmallVector unpackOperandVector(RewriterBase &rewriter, // For some element types (i32, f32, f64), we need to unpack the inner // vector/array type as well because the intrinsic expects individual // scalars to be provided. - VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); + VectorType innerArrayTy = dyn_cast(arrayTy.getElementType()); if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || innerArrayTy.getElementType() == f64Ty || innerArrayTy.getElementType() == f32Ty)) { @@ -207,7 +207,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { // of shape (NumRegisters, VectorRegister) where VectorRegister is the // vector type of the result and always 32 bits long. We bitcast the result // of the NVVM::LdMatrix to this vector type. - auto vectorResultType = op->getResultTypes()[0].dyn_cast(); + auto vectorResultType = dyn_cast(op->getResultTypes()[0]); if (!vectorResultType) { return failure(); } @@ -224,7 +224,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { ldMatrixResultType = rewriter.getI32Type(); } - auto srcMemrefType = op.getSrcMemref().getType().cast(); + auto srcMemrefType = cast(op.getSrcMemref().getType()); Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices(), rewriter); @@ -307,7 +307,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); - if (aType.getElementType().isa()) + if (isa(aType.getElementType())) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = @@ -388,7 +388,7 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, // constant. auto dstByteConstOp = dyn_cast(dstBytes.getDefiningOp()); - auto dstByteAttr = dstByteConstOp.getValue().dyn_cast(); + auto dstByteAttr = dyn_cast(dstByteConstOp.getValue()); int64_t dstByteVal = dstByteAttr.getValue().getSExtValue(); assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) && @@ -537,7 +537,7 @@ struct NVGPUMmaSparseSyncLowering // TODO: add an attribute to the op to customize this behavior. std::optional overflow(std::nullopt); - if (aType.getElementType().isa()) + if (isa(aType.getElementType())) overflow = NVVM::MMAIntOverflow::satfinite; SmallVector matA = @@ -585,7 +585,7 @@ struct NVGPUAsyncCopyLowering matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto dstMemrefType = op.getDst().getType().cast(); + auto dstMemrefType = cast(op.getDst().getType()); Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(), adaptor.getDstIndices(), rewriter); auto i8Ty = IntegerType::get(op.getContext(), 8); @@ -599,7 +599,7 @@ struct NVGPUAsyncCopyLowering if (!getTypeConverter()->useOpaquePointers()) dstPtr = rewriter.create(loc, dstPointerType, dstPtr); - auto srcMemrefType = op.getSrc().getType().cast(); + auto srcMemrefType = cast(op.getSrc().getType()); FailureOr srcAddressSpace = getTypeConverter()->getMemRefAddressSpace(srcMemrefType); if (failed(srcAddressSpace)) diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 4f6763b..34ec00f 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -70,7 +70,7 @@ struct RegionLessOpWithVarOperandsConversion Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); - if (originalVariableOperand.getType().isa()) { + if (isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); @@ -101,7 +101,7 @@ struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern { Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) return failure(); - if (originalVariableOperand.getType().isa()) { + if (isa(originalVariableOperand.getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); @@ -143,7 +143,7 @@ struct ReductionOpConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (curOp.getAccumulator().getType().isa()) { + if (isa(curOp.getAccumulator().getType())) { // TODO: Support memref type in variable operands return rewriter.notifyMatchFailure(curOp, "memref is not supported yet"); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index fc0c845..8cd180d 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -219,7 +219,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { // If this value corresponds to an operation, record that we are going to use // its location as part of a fused location. - bool isOperationValue = val && val.getType().isa(); + bool isOperationValue = val && isa(val.getType()); if (isOperationValue) locOps.insert(val); @@ -280,7 +280,7 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { // The first operation retrieves the representative value of a range. // This applies only when the parent is a range of values and we were // requested to use a representative value (e.g., upward traversal). - if (parentVal.getType().isa() && + if (isa(parentVal.getType()) && usersPos->useRepresentative()) value = builder.create(loc, parentVal, 0); else @@ -327,7 +327,7 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { break; } case Predicates::TypePos: { - if (parentVal.getType().isa()) + if (isa(parentVal.getType())) value = builder.create(loc, parentVal); else value = builder.create(loc, parentVal); @@ -357,11 +357,11 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { case Predicates::TypeLiteralPos: { auto *typePos = cast(pos); Attribute rawTypeAttr = typePos->getValue(); - if (TypeAttr typeAttr = rawTypeAttr.dyn_cast()) + if (TypeAttr typeAttr = dyn_cast(rawTypeAttr)) value = builder.create(loc, typeAttr); else value = builder.create( - loc, rawTypeAttr.cast()); + loc, cast(rawTypeAttr)); break; } default: @@ -410,7 +410,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, } case Predicates::TypeQuestion: { auto *ans = cast(answer); - if (val.getType().isa()) + if (isa(val.getType())) builder.create( loc, val, ans->getValue().cast(), success, failure); else @@ -554,7 +554,7 @@ void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, OperationNameAnswer>(val, defaultDest, builder, children); case Predicates::TypeQuestion: - if (val.getType().isa()) { + if (isa(val.getType())) { return createSwitchOp( val, defaultDest, builder, children); } @@ -745,7 +745,7 @@ void PatternLowering::generateRewriter( // Handle the case where there is a single range representing all of the // result types. OperandRange resultTys = operationOp.getTypeValues(); - if (resultTys.size() == 1 && resultTys[0].getType().isa()) { + if (resultTys.size() == 1 && isa(resultTys[0].getType())) { Value &type = rewriteValues[resultTys[0]]; if (!type) { auto results = builder.create(loc, createdOp); @@ -762,7 +762,7 @@ void PatternLowering::generateRewriter( Value &type = rewriteValues[it.value()]; if (type) continue; - bool isVariadic = it.value().getType().isa(); + bool isVariadic = isa(it.value().getType()); seenVariableLength |= isVariadic; // After a variable length result has been seen, we need to use result diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 0342914..7078e238 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -41,14 +41,14 @@ static bool comparePosDepth(Position *lhs, Position *rhs) { /// Returns the number of non-range elements within `values`. static unsigned getNumNonRangeValues(ValueRange values) { return llvm::count_if(values.getTypes(), - [](Type type) { return !type.isa(); }); + [](Type type) { return !isa(type); }); } static void getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, AttributePosition *pos) { - assert(val.getType().isa() && "expected attribute type"); + assert(isa(val.getType()) && "expected attribute type"); pdl::AttributeOp attr = cast(val.getDefiningOp()); predList.emplace_back(pos, builder.getIsNotNull()); @@ -65,7 +65,7 @@ static void getOperandTreePredicates(std::vector &predList, DenseMap &inputs, Position *pos) { Type valueType = val.getType(); - bool isVariadic = valueType.isa(); + bool isVariadic = isa(valueType); // If this is a typed operand, add a type constraint. TypeSwitch(val.getDefiningOp()) @@ -111,7 +111,7 @@ getTreePredicates(std::vector &predList, Value val, PredicateBuilder &builder, DenseMap &inputs, OperationPosition *pos, std::optional ignoreOperand = std::nullopt) { - assert(val.getType().isa() && "expected operation"); + assert(isa(val.getType()) && "expected operation"); pdl::OperationOp op = cast(val.getDefiningOp()); OperationPosition *opPos = cast(pos); @@ -148,7 +148,7 @@ getTreePredicates(std::vector &predList, Value val, llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) { getTreePredicates( predList, attr, builder, inputs, - builder.getAttribute(opPos, attrName.cast().getValue())); + builder.getAttribute(opPos, cast(attrName).getValue())); } // Process the operands and results of the operation. For all values up to @@ -157,7 +157,7 @@ getTreePredicates(std::vector &predList, Value val, // concrete indices until runtime. If there is only one variadic operand // group, we treat it as all of the operands/results of the operation. /// Operands. - if (operands.size() == 1 && operands[0].getType().isa()) { + if (operands.size() == 1 && isa(operands[0].getType())) { // Ignore the operands if we are performing an upward traversal (in that // case, they have already been visited). if (opPos->isRoot() || opPos->isOperandDefiningOp()) @@ -166,7 +166,7 @@ getTreePredicates(std::vector &predList, Value val, } else { bool foundVariableLength = false; for (const auto &operandIt : llvm::enumerate(operands)) { - bool isVariadic = operandIt.value().getType().isa(); + bool isVariadic = isa(operandIt.value().getType()); foundVariableLength |= isVariadic; // Ignore the specified operand, usually because this position was @@ -182,7 +182,7 @@ getTreePredicates(std::vector &predList, Value val, } } /// Results. - if (types.size() == 1 && types[0].getType().isa()) { + if (types.size() == 1 && isa(types[0].getType())) { getTreePredicates(predList, types.front(), builder, inputs, builder.getType(builder.getAllResults(opPos))); return; @@ -190,7 +190,7 @@ getTreePredicates(std::vector &predList, Value val, bool foundVariableLength = false; for (auto [idx, typeValue] : llvm::enumerate(types)) { - bool isVariadic = typeValue.getType().isa(); + bool isVariadic = isa(typeValue.getType()); foundVariableLength |= isVariadic; auto *resultPos = foundVariableLength @@ -301,7 +301,7 @@ static void getResultPredicates(pdl::ResultsOp op, // Ensure that the result isn't null if the result has an index. auto *parentPos = cast(inputs.lookup(op.getParent())); - bool isVariadic = op.getType().isa(); + bool isVariadic = isa(op.getType()); std::optional index = op.getIndex(); resultPos = builder.getResultGroup(parentPos, index, isVariadic); if (index) @@ -458,7 +458,7 @@ static void buildCostGraph(ArrayRef roots, RootOrderingGraph &graph, // Special case when we pass all the operands in one range. // For those, the index is empty. if (operands.size() == 1 && - operands[0].getType().isa()) { + isa(operands[0].getType())) { toVisit.emplace(operands[0], entry.value, std::nullopt, entry.depth + 1); return; @@ -514,7 +514,7 @@ static bool useOperandGroup(pdl::OperationOp op, unsigned index) { OperandRange operands = op.getOperandValues(); assert(index < operands.size() && "operand index out of range"); for (unsigned i = 0; i <= index; ++i) - if (operands[i].getType().isa()) + if (isa(operands[i].getType())) return true; return false; } @@ -542,7 +542,7 @@ static void visitUpward(std::vector &predList, } else if (useOperandGroup(operationOp, *opIndex.index)) { // We are querying an operand group. Type type = operationOp.getOperandValues()[*opIndex.index].getType(); - bool variadic = type.isa(); + bool variadic = isa(type); operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic); } else { // We are querying an individual operand. @@ -578,7 +578,7 @@ static void visitUpward(std::vector &predList, // Traverse up a group of results. auto *opPos = dyn_cast(pos); assert(opPos && "operations and results must be interleaved"); - bool isVariadic = value.getType().isa(); + bool isVariadic = isa(value.getType()); if (opIndex.index) pos = builder.getResultGroup(opPos, opIndex.index, isVariadic); else diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index c1a57d3..f9245ad 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -441,7 +441,7 @@ static LogicalResult processParallelLoop( Value iv, lowerBound, upperBound, step; std::tie(mappingAttribute, iv, lowerBound, upperBound, step) = config; auto annotation = - mappingAttribute.dyn_cast(); + dyn_cast(mappingAttribute); if (!annotation) return parallelOp.emitOpError() << "expected mapping attribute for lowering to GPU"; diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index e91cd0c..5008775 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -51,8 +51,7 @@ static bool matchSimpleReduction(Block &block) { Value reducedVal = matchReduction({block.getArguments()[1]}, /*redPos=*/0, combinerOps); - if (!reducedVal || !reducedVal.isa() || - combinerOps.size() != 1) + if (!reducedVal || !isa(reducedVal) || combinerOps.size() != 1) return false; return isa(combinerOps[0]) && @@ -155,7 +154,7 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { /// Returns an attribute with the minimum (if `min` is set) or the maximum value /// (otherwise) for the given float type. static Attribute minMaxValueForFloat(Type type, bool min) { - auto fltType = type.cast(); + auto fltType = cast(type); return FloatAttr::get( type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); } @@ -164,7 +163,7 @@ static Attribute minMaxValueForFloat(Type type, bool min) { /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForSignedInt(Type type, bool min) { - auto intType = type.cast(); + auto intType = cast(type); unsigned bitwidth = intType.getWidth(); return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) : llvm::APInt::getSignedMaxValue(bitwidth)); @@ -174,7 +173,7 @@ static Attribute minMaxValueForSignedInt(Type type, bool min) { /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForUnsignedInt(Type type, bool min) { - auto intType = type.cast(); + auto intType = cast(type); unsigned bitwidth = intType.getWidth(); return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth)); @@ -388,7 +387,7 @@ struct ParallelOpLowering : public OpRewritePattern { reductionVariables.reserve(parallelOp.getNumReductions()); for (Value init : parallelOp.getInitVals()) { assert((LLVM::isCompatibleType(init.getType()) || - init.getType().isa()) && + isa(init.getType())) && "cannot create a reduction variable if the type is not an LLVM " "pointer element"); Value storage = rewriter.create( diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index 08da805..4e17c96 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -220,9 +220,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands); for (const auto &operand : llvm::enumerate(kernelOperands)) { // Check if the kernel's operand is a ranked memref. - auto memRefType = launchOp.getKernelOperand(operand.index()) - .getType() - .dyn_cast(); + auto memRefType = dyn_cast( + launchOp.getKernelOperand(operand.index()).getType()); if (!memRefType) return failure(); @@ -241,7 +240,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern { // LLVM dialect global variable. spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; auto pointeeType = - spirvGlobal.getType().cast().getPointeeType(); + cast(spirvGlobal.getType()).getPointeeType(); auto dstGlobalType = typeConverter->convertType(pointeeType); if (!dstGlobalType) return failure(); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index b938947..8b43808 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -37,7 +37,7 @@ using namespace mlir; static bool isSignedIntegerOrVector(Type type) { if (type.isSignedInteger()) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getElementType().isSignedInteger(); return false; } @@ -46,18 +46,18 @@ static bool isSignedIntegerOrVector(Type type) { static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) return true; - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getElementType().isUnsignedInteger(); return false; } /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { - assert((type.isIntOrFloat() || type.isa()) && + assert((type.isIntOrFloat() || isa(type)) && "bitwidth is not supported for this type"); if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); - auto vecType = type.dyn_cast(); + auto vecType = dyn_cast(type); auto elementType = vecType.getElementType(); assert(elementType.isIntOrFloat() && "only integers and floats have a bitwidth"); @@ -66,29 +66,29 @@ static unsigned getBitWidth(Type type) { /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(Type type) { - return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type) - : type) - .cast() + return cast((LLVM::isCompatibleVectorType(type) + ? LLVM::getVectorElementType(type) + : type)) .getWidth(); } /// Creates `IntegerAttribute` with all bits set for given type static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { - if (auto vecType = type.dyn_cast()) { - auto integerType = vecType.getElementType().cast(); + if (auto vecType = dyn_cast(type)) { + auto integerType = cast(vecType.getElementType()); return builder.getIntegerAttr(integerType, -1); } - auto integerType = type.cast(); + auto integerType = cast(type); return builder.getIntegerAttr(integerType, -1); } /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { - if (srcType.isa()) { + if (isa(srcType)) { return rewriter.create( loc, dstType, - SplatElementsAttr::get(srcType.cast(), + SplatElementsAttr::get(cast(srcType), minusOneIntegerAttribute(srcType, rewriter))); } return rewriter.create( @@ -98,14 +98,14 @@ static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { - if (auto vecType = srcType.dyn_cast()) { - auto floatType = vecType.getElementType().cast(); + if (auto vecType = dyn_cast(srcType)) { + auto floatType = cast(vecType.getElementType()); return rewriter.create( loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } - auto floatType = srcType.cast(); + auto floatType = cast(srcType); return rewriter.create( loc, dstType, rewriter.getFloatAttr(floatType, value)); } @@ -157,7 +157,7 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - if (auto vectorType = srcType.dyn_cast()) { + if (auto vectorType = dyn_cast(srcType)) { unsigned numElements = vectorType.getNumElements(); return broadcast(loc, value, numElements, typeConverter, rewriter); } @@ -251,7 +251,7 @@ static std::optional convertArrayType(spirv::ArrayType type, TypeConverter &converter) { unsigned stride = type.getArrayStride(); Type elementType = type.getElementType(); - auto sizeInBytes = elementType.cast().getSizeInBytes(); + auto sizeInBytes = cast(elementType).getSizeInBytes(); if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) return std::nullopt; @@ -319,10 +319,9 @@ public: indices.insert(indices.begin(), zero); rewriter.replaceOpWithNewOp( op, dstType, - typeConverter.convertType(op.getBasePtr() - .getType() - .cast() - .getPointeeType()), + typeConverter.convertType( + cast(op.getBasePtr().getType()) + .getPointeeType()), adaptor.getBasePtr(), indices); return success(); } @@ -397,7 +396,7 @@ public: matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); - if (!srcType.isa() && !srcType.isIntOrFloat()) + if (!isa(srcType) && !srcType.isIntOrFloat()) return failure(); auto dstType = typeConverter.convertType(srcType); @@ -413,15 +412,15 @@ public: isUnsignedIntegerOrVector(srcType)) { auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); - if (srcType.isa()) { - auto dstElementsAttr = constOp.getValue().cast(); + if (isa(srcType)) { + auto dstElementsAttr = cast(constOp.getValue()); rewriter.replaceOpWithNewOp( constOp, dstType, dstElementsAttr.mapValues( signlessType, [&](const APInt &value) { return value; })); return success(); } - auto srcAttr = constOp.getValue().cast(); + auto srcAttr = cast(constOp.getValue()); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); @@ -454,17 +453,17 @@ public: // Create a constant that holds the size of the `Base`. IntegerType integerType; - if (auto vecType = srcType.dyn_cast()) - integerType = vecType.getElementType().cast(); + if (auto vecType = dyn_cast(srcType)) + integerType = cast(vecType.getElementType()); else - integerType = srcType.cast(); + integerType = cast(srcType); auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = - srcType.isa() + isa(srcType) ? rewriter.create( loc, dstType, - SplatElementsAttr::get(srcType.cast(), baseSize)) + SplatElementsAttr::get(cast(srcType), baseSize)) : rewriter.create(loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit @@ -573,9 +572,9 @@ public: return failure(); Type containerType = op.getComposite().getType(); - if (containerType.isa()) { + if (isa(containerType)) { Location loc = op.getLoc(); - IntegerAttr value = op.getIndices()[0].cast(); + IntegerAttr value = cast(op.getIndices()[0]); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getComposite(), index); @@ -605,9 +604,9 @@ public: return failure(); Type containerType = op.getComposite().getType(); - if (containerType.isa()) { + if (isa(containerType)) { Location loc = op.getLoc(); - IntegerAttr value = op.getIndices()[0].cast(); + IntegerAttr value = cast(op.getIndices()[0]); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( op, dstType, adaptor.getComposite(), adaptor.getObject(), index); @@ -732,7 +731,7 @@ public: if (op.getInitializer()) return failure(); - auto srcType = op.getType().cast(); + auto srcType = cast(op.getType()); auto dstType = typeConverter.convertType(srcType.getPointeeType()); if (!dstType) return failure(); @@ -946,12 +945,12 @@ public: Location loc = notOp.getLoc(); IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); - auto mask = srcType.template isa() - ? rewriter.create( - loc, dstType, - SplatElementsAttr::get( - srcType.template cast(), minusOne)) - : rewriter.create(loc, dstType, minusOne); + auto mask = + isa(srcType) + ? rewriter.create( + loc, dstType, + SplatElementsAttr::get(cast(srcType), minusOne)) + : rewriter.create(loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.getOperand(), mask); return success(); @@ -1262,9 +1261,9 @@ public: ConversionPatternRewriter &rewriter) const override { auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. - auto pointerTo = srcType.cast().getPointeeType(); + auto pointerTo = cast(srcType).getPointeeType(); auto init = varOp.getInitializer(); - if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa()) + if (init && !pointerTo.isIntOrFloat() && !isa(pointerTo)) return failure(); auto dstType = typeConverter.convertType(srcType); @@ -1303,7 +1302,7 @@ public: return failure(); if (typeConverter.useOpaquePointers() && - dstType.isa()) { + isa(dstType)) { rewriter.replaceOp(bitcastOp, adaptor.getOperand()); return success(); } @@ -1416,8 +1415,8 @@ public: auto components = adaptor.getComponents(); auto vector1 = adaptor.getVector1(); auto vector2 = adaptor.getVector2(); - int vector1Size = vector1.getType().cast().getNumElements(); - int vector2Size = vector2.getType().cast().getNumElements(); + int vector1Size = cast(vector1.getType()).getNumElements(); + int vector2Size = cast(vector2.getType()).getNumElements(); if (vector1Size == vector2Size) { rewriter.replaceOpWithNewOp( op, vector1, vector2, @@ -1426,16 +1425,16 @@ public: } auto dstType = typeConverter.convertType(op.getType()); - auto scalarType = dstType.cast().getElementType(); + auto scalarType = cast(dstType).getElementType(); auto componentsArray = components.getValue(); auto *context = rewriter.getContext(); auto llvmI32Type = IntegerType::get(context, 32); Value targetOp = rewriter.create(loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { - if (!componentsArray[i].isa()) + if (!isa(componentsArray[i])) return op.emitError("unable to support non-constant component"); - int indexVal = componentsArray[i].cast().getInt(); + int indexVal = cast(componentsArray[i]).getInt(); if (indexVal == -1) continue; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 16cbfca..a3e51ae 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -59,7 +59,7 @@ public: matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // For now, only error-free types are supported by this lowering. - if (op.getType().template isa()) + if (isa(op.getType())) return failure(); rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), @@ -127,7 +127,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); auto loc = op.getLoc(); @@ -189,7 +189,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite( // For now, this lowering supports only extent tensors, not `shape.shape` // types. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); auto loc = op.getLoc(); @@ -242,7 +242,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite( // For now, this lowering is only defined on `tensor` operands, not // on shapes. if (!llvm::all_of(op.getShapes(), - [](Value v) { return !v.getType().isa(); })) + [](Value v) { return !isa(v.getType()); })) return failure(); auto loc = op.getLoc(); @@ -363,13 +363,13 @@ LogicalResult GetExtentOpConverter::matchAndRewrite( GetExtentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, only error-free types are supported by this lowering. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); // Derive shape extent directly from shape origin if possible. This // circumvents the necessity to materialize the shape in memory. if (auto shapeOfOp = op.getShape().getDefiningOp()) { - if (shapeOfOp.getArg().getType().isa()) { + if (isa(shapeOfOp.getArg().getType())) { rewriter.replaceOpWithNewOp(op, shapeOfOp.getArg(), adaptor.getDim()); return success(); @@ -397,7 +397,7 @@ LogicalResult RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only error-free types. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); rewriter.replaceOpWithNewOp(op, adaptor.getShape(), 0); @@ -420,7 +420,7 @@ LogicalResult ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. - if (op.getShape().getType().isa()) + if (isa(op.getShape().getType())) return failure(); auto loc = op.getLoc(); @@ -499,7 +499,7 @@ LogicalResult ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!llvm::all_of(op.getShapes(), - [](Value v) { return !v.getType().isa(); })) + [](Value v) { return !isa(v.getType()); })) return failure(); Type i1Ty = rewriter.getI1Type(); @@ -570,18 +570,18 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // For now, only error-free types are supported by this lowering. - if (op.getType().isa()) + if (isa(op.getType())) return failure(); // For ranked tensor arguments, lower to `tensor.from_elements`. auto loc = op.getLoc(); Value tensor = adaptor.getArg(); Type tensorTy = tensor.getType(); - if (tensorTy.isa()) { + if (isa(tensorTy)) { // Build values for individual extents. SmallVector extentValues; - RankedTensorType rankedTensorTy = tensorTy.cast(); + RankedTensorType rankedTensorTy = cast(tensorTy); int64_t rank = rankedTensorTy.getRank(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { @@ -634,7 +634,7 @@ LogicalResult SplitAtOpConversion::matchAndRewrite( // Error conditions are not implemented, only lower if all operands and // results are extent tensors. if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()}, - [](Value v) { return v.getType().isa(); })) + [](Value v) { return isa(v.getType()); })) return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -667,7 +667,7 @@ public: LogicalResult matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!adaptor.getInput().getType().isa()) + if (!isa(adaptor.getInput().getType())) return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); rewriter.replaceOpWithNewOp(op, op.getType(), diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp index ed13ab3..373952c 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -44,7 +44,7 @@ public: LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto tensorType = extractOp.getTensor().getType().cast(); + auto tensorType = cast(extractOp.getTensor().getType()); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 1790e3d..c025fb9 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -34,14 +34,14 @@ public: }; Type matchContainerType(Type element, Type container) { - if (auto shapedTy = container.dyn_cast()) + if (auto shapedTy = dyn_cast(container)) return shapedTy.clone(element); return element; } TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { - if (auto shapedTy = type.dyn_cast()) { + if (auto shapedTy = dyn_cast(type)) { Type eTy = shapedTy.getElementType(); APInt valueInt(eTy.getIntOrFloatBitWidth(), value); return DenseIntElementsAttr::get(shapedTy, valueInt); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 3f970be..6aa0751 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -36,7 +36,7 @@ static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( - op->getAttr(attrName).cast().getValue().getSExtValue()); + cast(op->getAttr(attrName)).getValue().getSExtValue()); return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } @@ -47,13 +47,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, PatternRewriter &rewriter) { Location loc = op->getLoc(); auto elementTy = - op->getOperand(0).getType().cast().getElementType(); + cast(op->getOperand(0).getType()).getElementType(); // tosa::AbsOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto zero = rewriter.create( loc, rewriter.getZeroAttr(elementTy)); auto cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, @@ -63,21 +63,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::AddOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::SubOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::MulOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { if (dyn_cast(op).getShift() != 0) { (void)rewriter.notifyMatchFailure(op, "Cannot have shift value for float"); @@ -87,21 +87,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::DivOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ReciprocalOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); return rewriter.create(loc, resultTypes, one, args[0]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { Value a = args[0]; Value b = args[1]; auto shift = - op->getAttr("shift").cast().getValue().getSExtValue(); + cast(op->getAttr("shift")).getValue().getSExtValue(); if (shift > 0) { auto shiftConst = rewriter.create(loc, shift, /*bitwidth=*/8); @@ -134,17 +134,17 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::NegateOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); - if (isa(op) && elementTy.isa() && + if (isa(op) && isa(elementTy) && !cast(op).getQuantizationInfo()) { auto constant = rewriter.create(loc, IntegerAttr::get(elementTy, 0)); return rewriter.create(loc, resultTypes, constant, args[0]); } - if (isa(op) && elementTy.isa() && + if (isa(op) && isa(elementTy) && cast(op).getQuantizationInfo()) { auto quantizationInfo = cast(op).getQuantizationInfo(); int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); @@ -190,15 +190,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::BitwiseAndOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseOrOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseNotOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); auto allOnes = rewriter.create(loc, allOnesAttr); @@ -206,21 +206,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::BitwiseXOrOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalLeftShiftOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalRightShiftOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ArithmeticRightShiftOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto result = rewriter.create(loc, resultTypes, args); - auto round = op->getAttr("round").cast().getValue(); + auto round = cast(op->getAttr("round")).getValue(); if (!round) { return result; } @@ -256,7 +256,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::ClzOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, elementTy, args[0]); } @@ -280,27 +280,27 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, return rewriter.create(loc, resultTypes, args); // tosa::PowOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::RsqrtOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ExpOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::TanhOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::GreaterOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OGT, args[0], args[1]); @@ -309,7 +309,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, args[0], args[1]); // tosa::GreaterEqualOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OGE, args[0], args[1]); @@ -318,7 +318,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, args[0], args[1]); // tosa::EqualOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OEQ, args[0], args[1]); @@ -328,13 +328,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, // tosa::SelectOp if (isa(op)) { - elementTy = op->getOperand(1).getType().cast().getElementType(); - if (elementTy.isa() || elementTy.isa()) + elementTy = cast(op->getOperand(1).getType()).getElementType(); + if (isa(elementTy) || isa(elementTy)) return rewriter.create(loc, args[0], args[1], args[2]); } // tosa::MaximumOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } @@ -345,7 +345,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::MinimumOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } @@ -356,21 +356,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::CeilOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::FloorOp - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ClampOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { bool losesInfo = false; - APFloat minApf = op->getAttr("min_fp").cast().getValue(); - APFloat maxApf = op->getAttr("max_fp").cast().getValue(); - minApf.convert(elementTy.cast().getFloatSemantics(), + APFloat minApf = cast(op->getAttr("min_fp")).getValue(); + APFloat maxApf = cast(op->getAttr("max_fp")).getValue(); + minApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); - maxApf.convert(elementTy.cast().getFloatSemantics(), + maxApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); auto min = rewriter.create( loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); @@ -379,12 +379,12 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, return clampFloatHelper(loc, args[0], min, max, rewriter); } - if (isa(op) && elementTy.isa()) { - auto intTy = elementTy.cast(); + if (isa(op) && isa(elementTy)) { + auto intTy = cast(elementTy); int32_t min = static_cast( - op->getAttr("min_int").cast().getValue().getSExtValue()); + cast(op->getAttr("min_int")).getValue().getSExtValue()); int32_t max = static_cast( - op->getAttr("max_int").cast().getValue().getSExtValue()); + cast(op->getAttr("max_int")).getValue().getSExtValue()); if (intTy.isUnsignedInteger()) { min = std::max(min, 0); @@ -408,7 +408,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, } // tosa::SigmoidOp - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); auto negate = rewriter.create(loc, resultTypes, args[0]); @@ -427,11 +427,11 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, if (srcTy == dstTy) return args.front(); - if (srcTy.isa() && dstTy.isa() && bitExtend) + if (isa(srcTy) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); - if (srcTy.isa() && dstTy.isa() && !bitExtend) + if (isa(srcTy) && isa(dstTy) && !bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); @@ -440,13 +440,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, return rewriter.create(loc, resultTypes, args, std::nullopt); - if (srcTy.isInteger(1) && dstTy.isa() && bitExtend) + if (srcTy.isInteger(1) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); // Unsigned integers need an unrealized cast so that they can be passed // to UIToFP. - if (srcTy.isUnsignedInteger() && dstTy.isa()) { + if (srcTy.isUnsignedInteger() && isa(dstTy)) { auto unrealizedCast = rewriter .create( @@ -463,7 +463,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, std::nullopt); // Casting to boolean, floats need to only be checked as not-equal to zero. - if (srcTy.isa() && dstTy.isInteger(1)) { + if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = rewriter.create( loc, rewriter.getFloatAttr(srcTy, 0.0)); return rewriter.create(loc, arith::CmpFPredicate::UNE, @@ -490,18 +490,18 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, // Casting to boolean, integers need to only be checked as not-equal to // zero. - if (srcTy.isa() && dstTy.isInteger(1)) { + if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = rewriter.create( loc, 0, srcTy.getIntOrFloatBitWidth()); return rewriter.create(loc, arith::CmpIPredicate::ne, args.front(), zero); } - if (srcTy.isa() && dstTy.isa() && bitExtend) + if (isa(srcTy) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); - if (srcTy.isa() && dstTy.isa() && !bitExtend) { + if (isa(srcTy) && isa(dstTy) && !bitExtend) { return rewriter.create(loc, dstTy, args[0]); } } @@ -520,7 +520,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, "All TOSA elementwise ops should only return a single result."); auto results = operation->getResults(); - auto resultTy = operation->getResult(0).getType().dyn_cast(); + auto resultTy = dyn_cast(operation->getResult(0).getType()); if (!resultTy) return rewriter.notifyMatchFailure(operation, @@ -538,10 +538,10 @@ elementwiseMatchAndRewriteHelper(Operation *operation, SmallVector emptyTensors; SmallVector dynDims; - dynDims.resize(results.front().getType().cast().getRank()); + dynDims.resize(cast(results.front().getType()).getRank()); for (auto arg : operation->getOperands()) { - auto operandTy = arg.getType().cast(); + auto operandTy = cast(arg.getType()); for (int i = 0; i < operandTy.getRank(); i++) { if (operandTy.isDynamicDim(i) && !dynDims[i]) dynDims[i] = rewriter.create(loc, arg, i); @@ -551,7 +551,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, SmallVector filteredDims = condenseValues(dynDims); for (auto result : results) { - auto resultTy = result.getType().template cast(); + auto resultTy = cast(result.getType()); emptyTensors.push_back(rewriter.create( loc, resultTy.getShape(), resultTy.getElementType(), filteredDims)); opResultTypes.push_back(result.getType()); @@ -566,7 +566,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, // Input indexing maps may be broadcasted. for (Value operand : operation->getOperands()) { - ShapedType type = operand.getType().cast(); + ShapedType type = cast(operand.getType()); if (type.getShape() == resultTy.getShape()) { operands.push_back(operand); @@ -627,33 +627,33 @@ elementwiseMatchAndRewriteHelper(Operation *operation, // attribute type varies depending on the element type required. static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 0.0); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 0); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 1.0); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 1); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( - elementTy.cast().getFloatSemantics(), false)); + cast(elementTy).getFloatSemantics(), false)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( - elementTy.cast().getFloatSemantics(), true)); + cast(elementTy).getFloatSemantics(), true)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); @@ -663,12 +663,12 @@ static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getZero(1)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( - elementTy.cast().getFloatSemantics(), true)); + cast(elementTy).getFloatSemantics(), true)); - if (isa(op) && elementTy.isa()) + if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); @@ -682,37 +682,37 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { Location loc = op->getLoc(); - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } - if (isa(op) && elementTy.isa()) { + if (isa(op) && isa(elementTy)) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); @@ -733,8 +733,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); - auto inputTy = op->getOperand(0).getType().template cast(); - auto resultTy = op->getResult(0).getType().template cast(); + auto inputTy = cast(op->getOperand(0).getType()); + auto resultTy = cast(op->getResult(0).getType()); auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); @@ -799,7 +799,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, SmallVector reassociationMap; uint64_t expandInputRank = - linalgOp.getResults()[0].getType().cast().getRank(); + cast(linalgOp.getResults()[0].getType()).getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { @@ -848,14 +848,14 @@ public: auto loc = op.getLoc(); auto input = op->getOperand(0); - auto resultTy = op.getType().cast(); + auto resultTy = cast(op.getType()); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize(cast(op->getResult(0).getType()).getRank()); SmallVector inputExprs; inputExprs.resize(resultTy.getRank()); - auto operandTy = input.getType().cast(); + auto operandTy = cast(input.getType()); for (const auto &permutation : llvm::enumerate(perms.getValues())) { auto index = permutation.index(); auto value = permutation.value().getZExtValue(); @@ -893,8 +893,8 @@ public: PatternRewriter &rewriter) const final { auto loc = op.getLoc(); auto input = op.getInput(); - auto inputTy = op.getInput().getType().cast(); - auto outputTy = op.getOutput().getType().cast(); + auto inputTy = cast(op.getInput().getType()); + auto outputTy = cast(op.getOutput().getType()); unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error @@ -1036,7 +1036,7 @@ public: // Saturate to the output size. IntegerType outIntType = - blockArgs.back().getType().cast(); + cast(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); @@ -1089,8 +1089,8 @@ public: Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(op.getType()); const bool isBilinear = op.getMode() == "BILINEAR"; auto inputH = inputTy.getDimSize(1); @@ -1186,8 +1186,8 @@ public: Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().dyn_cast(); - auto resultTy = op.getType().dyn_cast(); + auto inputTy = dyn_cast(input.getType()); + auto resultTy = dyn_cast(op.getType()); if (!inputTy || !resultTy) return rewriter.notifyMatchFailure(op, @@ -1282,8 +1282,8 @@ public: Location loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(op.getType()); auto resultETy = resultTy.getElementType(); auto imageH = inputTy.getShape()[1]; @@ -1573,8 +1573,8 @@ public: PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput(); - auto inputTy = input.getType().template cast(); - auto resultTy = op.getType().template cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(op.getType()); auto axis = op.getAxis(); SmallVector dynDims; @@ -1635,9 +1635,9 @@ struct TileConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.getInput1(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto inputShape = inputTy.getShape(); - auto resultTy = op.getType().cast(); + auto resultTy = cast(op.getType()); auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -1710,14 +1710,14 @@ public: PatternRewriter &rewriter) const final { auto loc = argmaxOp.getLoc(); Value input = argmaxOp.getInput(); - auto inputTy = input.getType().cast(); - auto resultTy = argmaxOp.getOutput().getType().cast(); + auto inputTy = cast(input.getType()); + auto resultTy = cast(argmaxOp.getOutput().getType()); auto inElementTy = inputTy.getElementType(); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); - if (!outElementTy.isa()) + if (!isa(outElementTy)) return rewriter.notifyMatchFailure( argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); @@ -1792,10 +1792,10 @@ public: rewriter.create(loc, axis)); Value predicate; - if (inElementTy.isa()) { + if (isa(inElementTy)) { predicate = rewriter.create( nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); - } else if (inElementTy.isa()) { + } else if (isa(inElementTy)) { predicate = rewriter.create( nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { @@ -1830,8 +1830,8 @@ public: auto indices = adaptor.getOperands()[1]; auto valuesTy = - op.getValues().getType().dyn_cast_or_null(); - auto resultTy = op.getType().cast(); + dyn_cast_or_null(op.getValues().getType()); + auto resultTy = cast(op.getType()); if (!valuesTy) return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); @@ -1904,9 +1904,9 @@ public: auto loc = op.getLoc(); Value input = op.getInput(); Value table = op.getTable(); - auto inputTy = input.getType().cast(); - auto tableTy = table.getType().cast(); - auto resultTy = op.getType().cast(); + auto inputTy = cast(input.getType()); + auto tableTy = cast(table.getType()); + auto resultTy = cast(op.getType()); auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 61413b2..c55a548 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -36,7 +36,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef pad, if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) return input; - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type inputETy = inputTy.getElementType(); auto inputShape = inputTy.getShape(); @@ -67,7 +67,7 @@ static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef indexingMaps) { - ShapedType resultTy = conv.getType().cast(); + ShapedType resultTy = cast(conv.getType()); return rewriter .create( loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, @@ -125,7 +125,7 @@ static SmallVector inferDynamicDimsForConv( ArrayRef padAttr, ArrayRef strideAttr, ArrayRef dilationAttr, ArrayRef inputSizeDims, ArrayRef kernelSizeDims, OpBuilder &rewriter) { - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type inputETy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); @@ -187,11 +187,10 @@ public: Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().template cast(); - ShapedType weightTy = weight.getType().template cast(); - ShapedType biasTy = bias.getType().template cast(); - ShapedType resultTy = - op->getResult(0).getType().template cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -353,18 +352,18 @@ public: Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); int64_t resultRank = resultTy.getRank(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); - auto padAttr = op->getAttr("pad").cast(); - auto strideTosaAttr = op->getAttr("stride").cast(); - auto dilationTosaAttr = op->getAttr("dilation").cast(); + auto padAttr = cast(op->getAttr("pad")); + auto strideTosaAttr = cast(op->getAttr("stride")); + auto dilationTosaAttr = cast(op->getAttr("dilation")); if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -382,7 +381,7 @@ public: IntegerAttr kZp; if (isQuantized) { auto quantizationInfo = - op->getAttr("quantization_info").cast(); + cast(op->getAttr("quantization_info")); iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); } @@ -394,7 +393,7 @@ public: TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); if (isQuantized) { auto quantizationInfo = - op->getAttr("quantization_info").cast(); + cast(op->getAttr("quantization_info")); int64_t iZp = quantizationInfo.getInputZp(); int64_t intMin = @@ -505,14 +504,14 @@ public: ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = cast(op.getType()); auto outputElementTy = outputTy.getElementType(); - auto firstOperandTy = op->getOperand(0).getType().cast(); - auto secondOperandTy = op->getOperand(1).getType().cast(); + auto firstOperandTy = cast(op->getOperand(0).getType()); + auto secondOperandTy = cast(op->getOperand(1).getType()); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize(cast(op->getResult(0).getType()).getRank()); if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, op->getOperand(0), 0); @@ -564,20 +563,20 @@ public: matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto outputTy = op.getType().cast(); + auto outputTy = cast(op.getType()); auto input = op.getInput(); - auto inputTy = input.getType().cast(); + auto inputTy = cast(input.getType()); auto bias = op.getBias(); auto weight = op.getWeight(); - auto weightTy = weight.getType().cast(); + auto weightTy = cast(weight.getType()); auto weightShape = weightTy.getShape(); auto outputETy = outputTy.getElementType(); SmallVector dynDims; - dynDims.resize(op->getResult(0).getType().cast().getRank()); + dynDims.resize(cast(op->getResult(0).getType()).getRank()); if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { dynDims[0] = rewriter.create(loc, input, 0); @@ -676,9 +675,9 @@ public: PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.getInput(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); - ShapedType resultTy = op.getType().template cast(); + ShapedType resultTy = cast(op.getType()); Type resultETy = inputTy.getElementType(); auto dynamicDimsOr = @@ -691,11 +690,10 @@ public: TypedAttr initialAttr; if (resultETy.isF32()) initialAttr = rewriter.getFloatAttr( - resultETy, - APFloat::getLargest(resultETy.cast().getFloatSemantics(), - true)); + resultETy, APFloat::getLargest( + cast(resultETy).getFloatSemantics(), true)); - if (resultETy.isa()) + if (isa(resultETy)) initialAttr = rewriter.getIntegerAttr( resultETy, APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth())); @@ -747,14 +745,14 @@ public: PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.getInput(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type inElementTy = inputTy.getElementType(); - ShapedType resultTy = op.getType().template cast(); - Type resultETy = op.getType().cast().getElementType(); + ShapedType resultTy = cast(op.getType()); + Type resultETy = cast(op.getType()).getElementType(); Type accETy = - inElementTy.isa() ? rewriter.getI32Type() : inElementTy; + isa(inElementTy) ? rewriter.getI32Type() : inElementTy; ShapedType accTy = resultTy.clone(accETy); auto dynamicDimsOr = @@ -872,7 +870,7 @@ public: // a div however for quantized values input normalization had // to be applied. Value poolVal = args[0]; - if (accETy.isa()) { + if (isa(accETy)) { auto countF = rewriter.create(loc, accETy, count); poolVal = rewriter.create(loc, poolVal, countF) ->getResult(0); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 5e46fab0..5e6e971 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -134,8 +134,8 @@ public: LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + ShapedType operandTy = cast(adaptor.getInput1().getType()); + ShapedType resultTy = cast(reshape.getType()); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && resultTy.getRank() != 1) { @@ -172,8 +172,8 @@ public: LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + ShapedType operandTy = cast(adaptor.getInput1().getType()); + ShapedType resultTy = cast(reshape.getType()); bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && operandTy.getRank() != 1) { @@ -211,8 +211,8 @@ public: LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - ShapedType operandTy = adaptor.getInput1().getType().cast(); - ShapedType resultTy = reshape.getType().template cast(); + ShapedType operandTy = cast(adaptor.getInput1().getType()); + ShapedType resultTy = cast(reshape.getType()); bool isDynamic = !operandTy.hasStaticShape(); SmallVector intermediateShape; @@ -247,7 +247,7 @@ public: Value input = adaptor.getInput(); SmallVector strides, sizes; ArrayRef starts = sliceOp.getStart(); - strides.resize(sliceOp.getType().template cast().getRank(), 1); + strides.resize(cast(sliceOp.getType()).getRank(), 1); SmallVector dynSizes; for (const auto &i : llvm::enumerate(sliceOp.getSize())) { @@ -284,7 +284,7 @@ public: auto input = padOp.getInput1(); auto padding = padOp.getPadding(); - ShapedType inputTy = input.getType().cast(); + ShapedType inputTy = cast(input.getType()); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -297,11 +297,11 @@ public: loc, padOp.getPadConst(), ValueRange({})); } else { TypedAttr constantAttr; - if (elementTy.isa()) { + if (isa(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (elementTy.isa() && !padOp.getQuantizationInfo()) { + } else if (isa(elementTy) && !padOp.getQuantizationInfo()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (elementTy.isa() && padOp.getQuantizationInfo()) { + } else if (isa(elementTy) && padOp.getQuantizationInfo()) { int64_t value = padOp.getQuantizationInfo()->getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } @@ -355,8 +355,8 @@ struct ConcatConverter : public OpConversionPattern { LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputType = op.getOperand(0).getType().template cast(); - auto resultType = op.getType().dyn_cast(); + auto inputType = cast(op.getOperand(0).getType()); + auto resultType = dyn_cast(op.getType()); Location loc = op.getLoc(); int axis = op.getAxis(); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 5de3bef..a78402e 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -123,7 +123,7 @@ static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { // constant stride. static std::optional getMemrefConstantHorizontalStride(ShapedType type) { - auto memrefType = type.dyn_cast(); + auto memrefType = dyn_cast(type); if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. @@ -193,10 +193,10 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { - auto vecType = constantOp.getType().dyn_cast(); + auto vecType = dyn_cast(constantOp.getType()); if (!vecType || vecType.getRank() != 2) return false; - return constantOp.getValue().isa(); + return isa(constantOp.getValue()); } /// Return true if this is a broadcast from scalar to a 2D vector. @@ -268,11 +268,11 @@ extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { // matrixB and matrixC operands. vector.extract_strided_slice op // is not supported on registers containing matrixA operands. if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) - return (op->getResult(0).getType().cast() == - (*contractOp).getRhs().getType().cast()); + return (cast(op->getResult(0).getType()) == + cast((*contractOp).getRhs().getType())); if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) - return (op->getResult(0).getType().cast() == - (*contractOp).getAcc().getType().cast()); + return (cast(op->getResult(0).getType()) == + cast((*contractOp).getAcc().getType())); return false; } @@ -344,11 +344,11 @@ static SetVector getOpToConvert(mlir::Operation *op, bool useNvGpu) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), - [](Type t) { return t.isa(); }); + [](Type t) { return isa(t); }); }; auto hasVectorSrc = [](Operation *op) { return llvm::any_of(op->getOperandTypes(), - [](Type t) { return t.isa(); }); + [](Type t) { return isa(t); }); }; SetVector opToConvert; op->walk([&](vector::ContractionOp contract) { @@ -448,8 +448,8 @@ struct CombineTransferReadOpTranspose final (extOp = source.getDefiningOp())) { source = extOp->getOperand(0); resultType = - VectorType::get(resultType.cast().getShape(), - source.getType().cast().getElementType()); + VectorType::get(cast(resultType).getShape(), + cast(source.getType()).getElementType()); } auto transferReadOp = source.getDefiningOp(); @@ -553,7 +553,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, bool isSignedExtend = isa(user); if (isSignedExtend || isa(user)) { elType = IntegerType::get( - op.getContext(), elType.cast().getWidth(), + op.getContext(), cast(elType).getWidth(), isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); mappingResult = user->getResult(0); fragType = inferFragType(user); @@ -610,7 +610,7 @@ getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { SmallVector shape{regInfo.numRegistersPerFragment, regInfo.elementsPerRegister}; Type elType = regInfo.registerLLVMType; - if (auto vecType = elType.dyn_cast()) + if (auto vecType = dyn_cast(elType)) elType = vecType.getElementType(); return VectorType::get(shape, elType); } @@ -637,7 +637,7 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - auto dense = op.getValue().dyn_cast(); + auto dense = dyn_cast(op.getValue()); if (!dense) { LLVM_DEBUG(DBGS() << "not a splat\n"); return rewriter.notifyMatchFailure(op, "not a splat"); @@ -782,7 +782,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, // If we are not transposing, then we can use vectorized loads. Otherwise, we // must load each element individually. if (!isTransposeLoad) { - if (!loadedElType.isa()) { + if (!isa(loadedElType)) { loadedElType = VectorType::get({1}, loadedElType); } @@ -805,7 +805,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, rewriter.getI64ArrayAttr(i)); } } else { - if (auto vecType = loadedElType.dyn_cast()) { + if (auto vecType = dyn_cast(loadedElType)) { loadedElType = vecType.getElementType(); } for (int i = 0; i < vectorType.getShape()[0]; i++) { @@ -838,7 +838,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, /// Return true if this is a shared memory memref type. static bool isSharedMemory(MemRefType type) { auto addressSpace = - type.getMemorySpace().dyn_cast_or_null(); + dyn_cast_or_null(type.getMemorySpace()); if (addressSpace && addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace()) return true; @@ -860,7 +860,7 @@ convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); bool isLdMatrixCompatible = - isSharedMemory(op.getSource().getType().cast()) && + isSharedMemory(cast(op.getSource().getType())) && nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; VectorType vecTy = op.getVectorType(); @@ -929,7 +929,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl &results) { for (auto attr : arrayAttr) - results.push_back(attr.cast().getInt()); + results.push_back(cast(attr).getInt()); } static LogicalResult @@ -1041,9 +1041,9 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; - int64_t m = op.getLhs().getType().cast().getShape()[0]; - int64_t n = op.getRhs().getType().cast().getShape()[0]; - int64_t k = op.getLhs().getType().cast().getShape()[1]; + int64_t m = cast(op.getLhs().getType()).getShape()[0]; + int64_t n = cast(op.getRhs().getType()).getShape()[0]; + int64_t k = cast(op.getLhs().getType()).getShape()[1]; Value matmul = rewriter.create( op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; @@ -1060,11 +1060,11 @@ convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, assert(constantSupportsMMAMatrixType(op)); auto splat = - op.getValue().cast().getSplatValue(); + cast(op.getValue()).getSplatValue(); auto scalarConstant = rewriter.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); - auto vecType = op.getType().cast(); + auto vecType = cast(op.getType()); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = rewriter.create( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 05def0f..4175f8f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -256,8 +256,8 @@ public: return failure(); // Resolve address. - auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) - .template cast(); + auto vtype = cast( + this->typeConverter->convertType(loadOrStoreOp.getVectorType())); Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), adaptor.getIndices(), rewriter); Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype, @@ -277,7 +277,7 @@ public: LogicalResult matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MemRefType memRefType = gather.getBaseType().dyn_cast(); + MemRefType memRefType = dyn_cast(gather.getBaseType()); assert(memRefType && "The base should be bufferized"); if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter()))) @@ -296,7 +296,7 @@ public: auto llvmNDVectorTy = adaptor.getIndexVec().getType(); // Handle the simple case of 1-D vector. - if (!llvmNDVectorTy.isa()) { + if (!isa(llvmNDVectorTy)) { auto vType = gather.getVectorType(); // Resolve address. Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), @@ -501,7 +501,7 @@ static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - auto floatType = llvmType.cast(); + auto floatType = cast(llvmType); return rewriter.create( loc, llvmType, rewriter.getFloatAttr( @@ -513,7 +513,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, ConversionPatternRewriter &rewriter, Location loc, Type llvmType) { - auto floatType = llvmType.cast(); + auto floatType = cast(llvmType); return rewriter.create( loc, llvmType, rewriter.getFloatAttr( @@ -585,9 +585,9 @@ static Value createIntegerReductionComparisonOpLowering( /// with vector types. static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + auto floatType = cast(getElementTypeOrSelf(lhs.getType())); Type i1Type = builder.getI1Type(); - if (auto vecType = lhs.getType().dyn_cast()) + if (auto vecType = dyn_cast(lhs.getType())) i1Type = VectorType::get(vecType.getShape(), i1Type); Value cmp = builder.create( loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, @@ -768,7 +768,7 @@ public: return success(); } - if (!eltType.isa()) + if (!isa(eltType)) return failure(); // Floating-point reductions: add/mul/min/max @@ -966,14 +966,14 @@ public: // For all other cases, insert the individual values individually. int64_t v1Dim = v1Type.getDimSize(0); Type eltType; - if (auto arrayType = llvmType.dyn_cast()) + if (auto arrayType = dyn_cast(llvmType)) eltType = arrayType.getElementType(); else - eltType = llvmType.cast().getElementType(); + eltType = cast(llvmType).getElementType(); Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (const auto &en : llvm::enumerate(maskArrayAttr)) { - int64_t extPos = en.value().cast().getInt(); + int64_t extPos = cast(en.value()).getInt(); Value value = adaptor.getV1(); if (extPos >= v1Dim) { extPos -= v1Dim; @@ -1046,7 +1046,7 @@ public: } // One-shot extraction of vector from array (only requires extractvalue). - if (resultType.isa()) { + if (isa(resultType)) { SmallVector indices; for (auto idx : positionArrayAttr.getAsRange()) indices.push_back(idx.getInt()); @@ -1062,13 +1062,13 @@ public: if (positionAttrs.size() > 1) { SmallVector nMinusOnePosition; for (auto idx : positionAttrs.drop_back()) - nMinusOnePosition.push_back(idx.cast().getInt()); + nMinusOnePosition.push_back(cast(idx).getInt()); extracted = rewriter.create(loc, extracted, nMinusOnePosition); } // Remaining extraction of element from 1-D LLVM vector - auto position = positionAttrs.back().cast(); + auto position = cast(positionAttrs.back()); auto i64Type = IntegerType::get(rewriter.getContext(), 64); auto constant = rewriter.create(loc, i64Type, position); extracted = @@ -1169,7 +1169,7 @@ public: } // One-shot insertion of a vector into an array (only requires insertvalue). - if (sourceType.isa()) { + if (isa(sourceType)) { Value inserted = rewriter.create( loc, adaptor.getDest(), adaptor.getSource(), LLVM::convertArrayToIndices(positionArrayAttr)); @@ -1180,7 +1180,7 @@ public: // Potential extraction of 1-D vector from array. Value extracted = adaptor.getDest(); auto positionAttrs = positionArrayAttr.getValue(); - auto position = positionAttrs.back().cast(); + auto position = cast(positionAttrs.back()); auto oneDVectorType = destVectorType; if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); @@ -1333,7 +1333,7 @@ public: ConversionPatternRewriter &rewriter) const override { auto loc = castOp->getLoc(); MemRefType sourceMemRefType = - castOp.getOperand().getType().cast(); + cast(castOp.getOperand().getType()); MemRefType targetMemRefType = castOp.getType(); // Only static shape casts supported atm. @@ -1342,13 +1342,13 @@ public: return failure(); auto llvmSourceDescriptorTy = - adaptor.getOperands()[0].getType().dyn_cast(); + dyn_cast(adaptor.getOperands()[0].getType()); if (!llvmSourceDescriptorTy) return failure(); MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); - auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); + auto llvmTargetDescriptorTy = dyn_cast_or_null( + typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); @@ -1418,7 +1418,7 @@ public: LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); - if (dstType.getRank() != 1 || !dstType.cast().isScalable()) + if (dstType.getRank() != 1 || !cast(dstType).isScalable()) return failure(); IntegerType idxType = force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); @@ -1465,7 +1465,7 @@ public: // Make sure element type has runtime support. PrintConversion conversion = PrintConversion::None; - VectorType vectorType = printType.dyn_cast(); + VectorType vectorType = dyn_cast(printType); Type eltType = vectorType ? vectorType.getElementType() : printType; auto parent = printOp->getParentOfType(); Operation *printer; @@ -1481,7 +1481,7 @@ public: printer = LLVM::lookupOrCreatePrintBF16Fn(parent); } else if (eltType.isIndex()) { printer = LLVM::lookupOrCreatePrintU64Fn(parent); - } else if (auto intTy = eltType.dyn_cast()) { + } else if (auto intTy = dyn_cast(eltType)) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or // unsigned print method. Up to 64-bit is supported. @@ -1536,7 +1536,7 @@ private: void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, Value value, Type type, Operation *printer, int64_t rank, PrintConversion conversion) const { - VectorType vectorType = type.dyn_cast(); + VectorType vectorType = dyn_cast(type); Location loc = op->getLoc(); if (!vectorType) { assert(rank == 0 && "The scalar case expects rank == 0"); @@ -1610,7 +1610,7 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType resultType = splatOp.getType().cast(); + VectorType resultType = cast(splatOp.getType()); if (resultType.getRank() > 1) return failure(); @@ -1633,7 +1633,7 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { auto v = rewriter.create( splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); - int64_t width = splatOp.getType().cast().getDimSize(0); + int64_t width = cast(splatOp.getType()).getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 1a47dd1..9456b89 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -258,7 +258,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, /// Return true if this transfer op operates on a source tensor. template static bool isTensorOp(OpTy xferOp) { - if (xferOp.getShapedType().template isa()) { + if (isa(xferOp.getShapedType())) { if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { // TransferWriteOps on tensors have a result. assert(xferOp->getNumResults() > 0); @@ -314,7 +314,7 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { /// /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> static MemRefType unpackOneDim(MemRefType type) { - auto vectorType = type.getElementType().dyn_cast(); + auto vectorType = dyn_cast(type.getElementType()); auto memrefShape = type.getShape(); SmallVector newMemrefShape; newMemrefShape.append(memrefShape.begin(), memrefShape.end()); @@ -408,8 +408,8 @@ struct Strategy { getXferIndices(b, xferOp, iv, xferIndices); Location loc = xferOp.getLoc(); - auto bufferType = buffer.getType().dyn_cast(); - auto vecType = bufferType.getElementType().dyn_cast(); + auto bufferType = dyn_cast(buffer.getType()); + auto vecType = dyn_cast(bufferType.getElementType()); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( loc, vecType, xferOp.getSource(), xferIndices, @@ -432,8 +432,8 @@ struct Strategy { storeIndices.push_back(iv); Location loc = xferOp.getLoc(); - auto bufferType = buffer.getType().dyn_cast(); - auto vecType = bufferType.getElementType().dyn_cast(); + auto bufferType = dyn_cast(buffer.getType()); + auto vecType = dyn_cast(bufferType.getElementType()); auto vec = b.create(loc, vecType, xferOp.getPadding()); b.create(loc, vec, buffer, storeIndices); @@ -698,7 +698,7 @@ struct TransferOpConversion : public VectorToSCFPattern { // Find and cast data buffer. How the buffer can be found depends on OpTy. ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); auto dataBuffer = Strategy::getBuffer(xferOp); - auto dataBufferType = dataBuffer.getType().template dyn_cast(); + auto dataBufferType = dyn_cast(dataBuffer.getType()); auto castedDataType = unpackOneDim(dataBufferType); auto castedDataBuffer = locB.create(castedDataType, dataBuffer); @@ -707,8 +707,7 @@ struct TransferOpConversion : public VectorToSCFPattern { Value castedMaskBuffer; if (xferOp.getMask()) { auto maskBuffer = getMaskBuffer(xferOp); - auto maskBufferType = - maskBuffer.getType().template dyn_cast(); + auto maskBufferType = dyn_cast(maskBuffer.getType()); if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { // Do not unpack a dimension of the mask, if: // * To-be-unpacked transfer op dimension is a broadcast. @@ -889,7 +888,7 @@ struct UnrollTransferReadConversion SmallVector &indices) const { if (auto insertOp = getInsertOp(xferOp)) { for (Attribute attr : insertOp.getPosition()) - indices.push_back(attr.dyn_cast().getInt()); + indices.push_back(dyn_cast(attr).getInt()); } } @@ -908,7 +907,7 @@ struct UnrollTransferReadConversion auto insertOp = getInsertOp(xferOp); auto vec = getResultVector(xferOp, rewriter); - auto vecType = vec.getType().dyn_cast(); + auto vecType = dyn_cast(vec.getType()); auto xferVecType = xferOp.getVectorType(); auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), xferVecType.getElementType()); @@ -1016,7 +1015,7 @@ struct UnrollTransferWriteConversion SmallVector &indices) const { if (auto extractOp = getExtractOp(xferOp)) { for (Attribute attr : extractOp.getPosition()) - indices.push_back(attr.dyn_cast().getInt()); + indices.push_back(dyn_cast(attr).getInt()); } } @@ -1235,7 +1234,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern { if (xferOp.getTransferRank() == 0) return failure(); auto map = xferOp.getPermutationMap(); - auto memRefType = xferOp.getShapedType().template dyn_cast(); + auto memRefType = dyn_cast(xferOp.getShapedType()); if (!memRefType) return failure(); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 50017b7..35171b3 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -43,7 +43,7 @@ static int getNumBits(Type type) { // TODO: This does not take into account any memory layout or widening // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even // though in practice it will likely be stored as in a 4xi64 vector register. - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = dyn_cast(type)) return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); return type.getIntOrFloatBitWidth(); } @@ -95,7 +95,7 @@ struct VectorBroadcastConvert final if (!resultType) return failure(); - if (resultType.isa()) { + if (isa(resultType)) { rewriter.replaceOp(castOp, adaptor.getSource()); return success(); } @@ -116,7 +116,7 @@ struct VectorExtractOpConvert final matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only support extracting a scalar value now. - VectorType resultVectorType = extractOp.getType().dyn_cast(); + VectorType resultVectorType = dyn_cast(extractOp.getType()); if (resultVectorType && resultVectorType.getNumElements() > 1) return failure(); @@ -124,7 +124,7 @@ struct VectorExtractOpConvert final if (!dstType) return failure(); - if (adaptor.getVector().getType().isa()) { + if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } @@ -156,7 +156,7 @@ struct VectorExtractStridedSliceOpConvert final Value srcVector = adaptor.getOperands().front(); // Extract vector<1xT> case. - if (dstType.isa()) { + if (isa(dstType)) { rewriter.replaceOpWithNewOp(extractOp, srcVector, offset); return success(); @@ -203,7 +203,7 @@ struct VectorInsertOpConvert final return success(); } - if (insertOp.getSourceType().isa() || + if (isa(insertOp.getSourceType()) || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); int32_t id = getFirstIntValue(insertOp.getPosition()); @@ -224,7 +224,7 @@ struct VectorExtractElementOpConvert final if (!resultType) return failure(); - if (adaptor.getVector().getType().isa()) { + if (isa(adaptor.getVector().getType())) { rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } @@ -252,7 +252,7 @@ struct VectorInsertElementOpConvert final if (!vectorType) return failure(); - if (vectorType.isa()) { + if (isa(vectorType)) { rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } @@ -285,18 +285,17 @@ struct VectorInsertStridedSliceOpConvert final return failure(); uint64_t offset = getFirstIntValue(insertOp.getOffsets()); - if (srcVector.getType().isa()) { - assert(!dstVector.getType().isa()); + if (isa(srcVector.getType())) { + assert(!isa(dstVector.getType())); rewriter.replaceOpWithNewOp( insertOp, dstVector.getType(), srcVector, dstVector, rewriter.getI32ArrayAttr(offset)); return success(); } - uint64_t totalSize = - dstVector.getType().cast().getNumElements(); + uint64_t totalSize = cast(dstVector.getType()).getNumElements(); uint64_t insertSize = - srcVector.getType().cast().getNumElements(); + cast(srcVector.getType()).getNumElements(); SmallVector indices(totalSize); std::iota(indices.begin(), indices.end(), 0); @@ -324,7 +323,7 @@ struct VectorReductionPattern final if (!resultType) return failure(); - auto srcVectorType = adaptor.getVector().getType().dyn_cast(); + auto srcVectorType = dyn_cast(adaptor.getVector().getType()); if (!srcVectorType || srcVectorType.getRank() != 1) return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source"); @@ -393,10 +392,10 @@ public: Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) return failure(); - if (dstType.isa()) { + if (isa(dstType)) { rewriter.replaceOp(op, adaptor.getInput()); } else { - auto dstVecType = dstType.cast(); + auto dstVecType = cast(dstType); SmallVector source(dstVecType.getNumElements(), adaptor.getInput()); rewriter.replaceOpWithNewOp(op, dstType, @@ -422,7 +421,7 @@ struct VectorShuffleOpConvert final if (oldSourceType.getNumElements() > 1) { SmallVector components = llvm::to_vector<4>( llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t { - return attr.cast().getValue().getZExtValue(); + return cast(attr).getValue().getZExtValue(); })); rewriter.replaceOpWithNewOp( shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index 0c69cdc..d07d651 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -65,7 +65,7 @@ static void patchOperandSegmentSizes(ArrayRef attrs, newAttrs.push_back(attr); continue; } - auto segmentAttr = attr.getValue().cast(); + auto segmentAttr = cast(attr.getValue()); MLIRContext *context = segmentAttr.getContext(); DenseI32ArrayAttr newSegments; switch (action) { @@ -128,7 +128,7 @@ LogicalResult RawBufferAtomicByCasPattern::matchAndRewrite( Value prevLoadForCompare = prevLoad; Value atomicResForCompare = atomicRes; - if (auto floatDataTy = dataType.dyn_cast()) { + if (auto floatDataTy = dyn_cast(dataType)) { Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth()); prevLoadForCompare = rewriter.create(loc, equivInt, prevLoad); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp index 9b0ad3c..4b3730a 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -136,7 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) { bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) { // Any memref-typed iteration arguments are treated as serializing. if (llvm::any_of(forOp.getResultTypes(), - [](Type type) { return type.isa(); })) + [](Type type) { return isa(type); })) return false; // Collect all load and store ops in loop nest rooted at 'forOp'. diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp index 9db1e99..c97e99c 100644 --- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp @@ -162,7 +162,7 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) { /// conservative. static bool isAccessIndexInvariant(Value iv, Value index) { assert(isAffineForInductionVar(iv) && "iv must be a AffineForOp"); - assert(index.getType().isa() && "index must be of IndexType"); + assert(isa(index.getType()) && "index must be of IndexType"); SmallVector affineApplyOps; getReachableAffineApplyOps({index}, affineApplyOps); @@ -262,7 +262,7 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp, template static bool isVectorElement(LoadOrStoreOp memoryOp) { auto memRefType = memoryOp.getMemRefType(); - return memRefType.getElementType().template isa(); + return isa(memRefType.getElementType()); } using VectorizableOpFun = std::function; diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 438296f..4433d94 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -190,7 +190,7 @@ void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId, if (!hasEdge(srcId, dstId, value)) { outEdges[srcId].push_back({dstId, value}); inEdges[dstId].push_back({srcId, value}); - if (value.getType().isa()) + if (isa(value.getType())) memrefEdgeCount[value]++; } } @@ -200,7 +200,7 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId, Value value) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); - if (value.getType().isa()) { + if (isa(value.getType())) { assert(memrefEdgeCount.count(value) > 0); memrefEdgeCount[value]--; } @@ -289,7 +289,7 @@ void MemRefDependenceGraph::gatherDefiningNodes( // By definition of edge, if the edge value is a non-memref value, // then the dependence is between a graph node which defines an SSA value // and another graph node which uses the SSA value. - if (!edge.value.getType().isa()) + if (!isa(edge.value.getType())) definingNodes.insert(edge.id); } @@ -473,7 +473,7 @@ void MemRefDependenceGraph::forEachMemRefEdge( ArrayRef edges, const std::function &callback) { for (const auto &edge : edges) { // Skip if 'edge' is not a memref dependence edge. - if (!edge.value.getType().isa()) + if (!isa(edge.value.getType())) continue; assert(nodes.count(edge.id) > 0); // Skip if 'edge.id' is not a loop nest. @@ -808,13 +808,13 @@ std::optional ComputationSliceState::isMaximal() const { } unsigned MemRefRegion::getRank() const { - return memref.getType().cast().getRank(); + return cast(memref.getType()).getRank(); } std::optional MemRefRegion::getConstantBoundingSizeAndShape( SmallVectorImpl *shape, std::vector> *lbs, SmallVectorImpl *lbDivisors) const { - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); unsigned rank = memRefType.getRank(); if (shape) shape->reserve(rank); @@ -875,7 +875,7 @@ std::optional MemRefRegion::getConstantBoundingSizeAndShape( void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const { assert(pos < cst.getNumDimVars() && "invalid position"); - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); unsigned rank = memRefType.getRank(); assert(rank == cst.getNumDimVars() && "inconsistent memref region"); @@ -1049,7 +1049,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, // to guard against potential over-approximation from projection. // TODO: Support dynamic memref dimensions. if (addMemRefDimBounds) { - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); for (unsigned r = 0; r < rank; r++) { cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0); if (memRefType.isDynamicDim(r)) @@ -1071,7 +1071,7 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) { unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); - } else if (auto vectorType = elementType.dyn_cast()) { + } else if (auto vectorType = dyn_cast(elementType)) { if (vectorType.getElementType().isIntOrFloat()) sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); @@ -1085,7 +1085,7 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) { // Returns the size of the region. std::optional MemRefRegion::getRegionSize() { - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); @@ -1119,7 +1119,7 @@ mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) { if (!memRefType.hasStaticShape()) return std::nullopt; auto elementType = memRefType.getElementType(); - if (!elementType.isIntOrFloat() && !elementType.isa()) + if (!elementType.isIntOrFloat() && !isa(elementType)) return std::nullopt; auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType); @@ -1708,7 +1708,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { } unsigned MemRefAccess::getRank() const { - return memref.getType().cast().getRank(); + return cast(memref.getType()).getRank(); } bool MemRefAccess::isStore() const { diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 89f0a9e..2a9416f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -289,7 +289,7 @@ bool MemRefDependenceGraph::init() { // memref type. Call Op that returns one or more memref type results // is already taken care of, by the previous conditions. if (llvm::any_of(op.getOperandTypes(), - [&](Type t) { return t.isa(); })) { + [&](Type t) { return isa(t); })) { Node node(nextNodeId++, &op); nodes.insert({node.id, node}); } @@ -379,7 +379,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, OpBuilder top(forInst->getParentRegion()); // Create new memref type based on slice bounds. auto oldMemRef = cast(srcStoreOpInst).getMemRef(); - auto oldMemRefType = oldMemRef.getType().cast(); + auto oldMemRefType = cast(oldMemRef.getType()); unsigned rank = oldMemRefType.getRank(); // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. @@ -516,7 +516,7 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId, return WalkResult::advance(); for (Value v : op->getOperands()) // Collect memref values only. - if (v.getType().isa()) + if (isa(v.getType())) memRefValues.insert(v); return WalkResult::advance(); }); diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index 7d815f7..7029251 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -88,7 +88,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({}); }; - auto oldMemRefType = oldMemRef.getType().cast(); + auto oldMemRefType = cast(oldMemRef.getType()); auto newMemRefType = doubleShape(oldMemRefType); // The double buffer is allocated right before 'forOp'. diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index 8987a82..4961807 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -100,9 +100,9 @@ void SimplifyAffineStructures::runOnOperation() { SmallVector opsToSimplify; func.walk([&](Operation *op) { for (auto attr : op->getAttrs()) { - if (auto mapAttr = attr.getValue().dyn_cast()) + if (auto mapAttr = dyn_cast(attr.getValue())) simplifyAndUpdateAttribute(op, attr.getName(), mapAttr); - else if (auto setAttr = attr.getValue().dyn_cast()) + else if (auto setAttr = dyn_cast(attr.getValue())) simplifyAndUpdateAttribute(op, attr.getName(), setAttr); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 1d34732..b23a2cc 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -838,7 +838,7 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced, Value replacement) { assert(!valueVectorReplacement.contains(replaced) && "Vector replacement already registered"); - assert(replacement.getType().isa() && + assert(isa(replacement.getType()) && "Expected vector type in vector replacement"); valueVectorReplacement.map(replaced, replacement); } @@ -883,7 +883,7 @@ void VectorizationState::registerValueScalarReplacementImpl(Value replaced, Value replacement) { assert(!valueScalarReplacement.contains(replaced) && "Scalar value replacement already registered"); - assert(!replacement.getType().isa() && + assert(!isa(replacement.getType()) && "Expected scalar type in scalar replacement"); valueScalarReplacement.map(replaced, replacement); } @@ -946,7 +946,7 @@ isVectorizableLoopPtrFactory(const DenseSet ¶llelLoops, /// strategy on the scalar type. static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy) { - assert(!scalarTy.isa() && "Expected scalar type"); + assert(!isa(scalarTy) && "Expected scalar type"); return VectorType::get(strategy->vectorSizes, scalarTy); } @@ -1137,7 +1137,7 @@ static Value vectorizeOperand(Value operand, VectorizationState &state) { // An vector operand that is not in the replacement map should never reach // this point. Reaching this point could mean that the code was already // vectorized and we shouldn't try to vectorize already vectorized code. - assert(!operand.getType().isa() && + assert(!isa(operand.getType()) && "Vector op not found in replacement map"); // Vectorize constant. diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 94203ec..01c7c77 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1852,7 +1852,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion, int64_t numEltPerStride = 1; int64_t stride = 1; for (int d = bufferShape.size() - 1; d >= 1; d--) { - int64_t dimSize = region.memref.getType().cast().getDimSize(d); + int64_t dimSize = cast(region.memref.getType()).getDimSize(d); stride *= dimSize; numEltPerStride *= bufferShape[d]; // A stride is needed only if the region has a shorter extent than the @@ -1891,7 +1891,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, return ubMap.getNumInputs() == ubOperands.size(); })); - unsigned rank = memref.getType().cast().getRank(); + unsigned rank = cast(memref.getType()).getRank(); assert(lbMaps.size() == rank && "wrong number of lb maps"); assert(ubMaps.size() == rank && "wrong number of ub maps"); @@ -2003,7 +2003,7 @@ static LogicalResult generateCopy( auto loc = region.loc; auto memref = region.memref; - auto memRefType = memref.getType().cast(); + auto memRefType = cast(memref.getType()); if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); @@ -2276,7 +2276,7 @@ static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs, assert(false && "expected load or store op"); return false; } - auto memRefType = region->memref.getType().cast(); + auto memRefType = cast(region->memref.getType()); if (!memRefType.hasStaticShape()) return false; diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index e454567..4e02b61 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1119,9 +1119,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, ArrayRef symbolOperands, bool allowNonDereferencingOps) { - unsigned newMemRefRank = newMemRef.getType().cast().getRank(); + unsigned newMemRefRank = cast(newMemRef.getType()).getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); + unsigned oldMemRefRank = cast(oldMemRef.getType()).getRank(); (void)oldMemRefRank; // unused in opt mode if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && @@ -1134,8 +1134,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( } // Assert same elemental type. - assert(oldMemRef.getType().cast().getElementType() == - newMemRef.getType().cast().getElementType()); + assert(cast(oldMemRef.getType()).getElementType() == + cast(newMemRef.getType()).getElementType()); SmallVector usePositions; for (const auto &opEntry : llvm::enumerate(op->getOperands())) { @@ -1172,7 +1172,7 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); - AffineMap oldMap = oldMapAttrPair.getValue().cast().getValue(); + AffineMap oldMap = cast(oldMapAttrPair.getValue()).getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( op->operand_begin() + memRefOperandPos + 1, @@ -1294,9 +1294,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( ArrayRef symbolOperands, Operation *domOpFilter, Operation *postDomOpFilter, bool allowNonDereferencingOps, bool replaceInDeallocOp) { - unsigned newMemRefRank = newMemRef.getType().cast().getRank(); + unsigned newMemRefRank = cast(newMemRef.getType()).getRank(); (void)newMemRefRank; // unused in opt mode - unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); + unsigned oldMemRefRank = cast(oldMemRef.getType()).getRank(); (void)oldMemRefRank; if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && @@ -1309,8 +1309,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( } // Assert same elemental type. - assert(oldMemRef.getType().cast().getElementType() == - newMemRef.getType().cast().getElementType()); + assert(cast(oldMemRef.getType()).getElementType() == + cast(newMemRef.getType()).getElementType()); std::unique_ptr domInfo; std::unique_ptr postDomInfo; @@ -1734,7 +1734,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { SmallVector> tileSizePos; (void)getTileSizePos(layoutMap, tileSizePos); if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) { - MemRefType oldMemRefType = oldMemRef.getType().cast(); + MemRefType oldMemRefType = cast(oldMemRef.getType()); SmallVector newDynamicSizes; createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, newDynamicSizes); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index 9602d53..85e0725 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -34,7 +34,7 @@ struct ConstantOpInterface return constantOp->emitError("could not infer memory space"); // Only ranked tensors are supported. - if (!constantOp.getType().isa()) + if (!isa(constantOp.getType())) return failure(); // Only constants inside a module are supported. @@ -58,7 +58,7 @@ struct ConstantOpInterface bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. - assert(value.isa()); + assert(isa(value)); return false; } }; @@ -84,21 +84,21 @@ struct IndexCastOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto castOp = cast(op); - auto resultTensorType = castOp.getType().cast(); + auto resultTensorType = cast(castOp.getType()); FailureOr source = getBuffer(rewriter, castOp.getIn(), options); if (failed(source)) return failure(); - auto sourceType = source->getType().cast(); + auto sourceType = cast(source->getType()); // Result type should have same layout and address space as the source type. BaseMemRefType resultType; - if (auto rankedMemRefType = sourceType.dyn_cast()) { + if (auto rankedMemRefType = dyn_cast(sourceType)) { resultType = MemRefType::get( rankedMemRefType.getShape(), resultTensorType.getElementType(), rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); } else { - auto unrankedMemrefType = sourceType.cast(); + auto unrankedMemrefType = cast(sourceType); resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), unrankedMemrefType.getMemorySpace()); } diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index 22ec425..1a50b4a 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -63,10 +63,10 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, const APInt &value) { TypedAttr attr; - if (auto intTy = type.dyn_cast()) { + if (auto intTy = dyn_cast(type)) { attr = rewriter.getIntegerAttr(type, value); } else { - auto vecTy = type.cast(); + auto vecTy = cast(type); attr = SplatElementsAttr::get(vecTy, value); } @@ -78,10 +78,10 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, int64_t value) { unsigned elementBitWidth = 0; - if (auto intTy = type.dyn_cast()) + if (auto intTy = dyn_cast(type)) elementBitWidth = intTy.getWidth(); else - elementBitWidth = type.cast().getElementTypeBitWidth(); + elementBitWidth = cast(type).getElementTypeBitWidth(); return createScalarOrSplatConstant(rewriter, loc, type, APInt(elementBitWidth, value)); @@ -95,7 +95,7 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset) { - ArrayRef shape = input.getType().cast().getShape(); + ArrayRef shape = cast(input.getType()).getShape(); assert(lastOffset < shape.back() && "Offset out of bounds"); // Scalarize the result in case of 1D vectors. @@ -125,7 +125,7 @@ extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, // `input` is a scalar, this is a noop. static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input) { - auto vecTy = input.getType().dyn_cast(); + auto vecTy = dyn_cast(input.getType()); if (!vecTy) return input; @@ -142,7 +142,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, /// `input` is a scalar, this is a noop. static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input) { - auto vecTy = input.getType().dyn_cast(); + auto vecTy = dyn_cast(input.getType()); if (!vecTy) return input; @@ -159,11 +159,11 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset) { - ArrayRef shape = dest.getType().cast().getShape(); + ArrayRef shape = cast(dest.getType()).getShape(); assert(lastOffset < shape.back() && "Offset out of bounds"); // Handle scalar source. - if (source.getType().isa()) + if (isa(source.getType())) return rewriter.create(loc, source, dest, lastOffset); SmallVector offsets(shape.size(), 0); @@ -215,14 +215,14 @@ struct ConvertConstant final : OpConversionPattern { unsigned newBitWidth = newType.getElementTypeBitWidth(); Attribute oldValue = op.getValueAttr(); - if (auto intAttr = oldValue.dyn_cast()) { + if (auto intAttr = dyn_cast(oldValue)) { auto [low, high] = getHalves(intAttr.getValue(), newBitWidth); auto newAttr = DenseElementsAttr::get(newType, {low, high}); rewriter.replaceOpWithNewOp(op, newAttr); return success(); } - if (auto splatAttr = oldValue.dyn_cast()) { + if (auto splatAttr = dyn_cast(oldValue)) { auto [low, high] = getHalves(splatAttr.getSplatValue(), newBitWidth); int64_t numSplatElems = splatAttr.getNumElements(); @@ -238,7 +238,7 @@ struct ConvertConstant final : OpConversionPattern { return success(); } - if (auto elemsAttr = oldValue.dyn_cast()) { + if (auto elemsAttr = dyn_cast(oldValue)) { int64_t numElems = elemsAttr.getNumElements(); SmallVector values; values.reserve(numElems * 2); @@ -527,9 +527,8 @@ struct ConvertMaxMin final : OpConversionPattern { Location loc = op->getLoc(); Type oldTy = op.getType(); - auto newTy = this->getTypeConverter() - ->convertType(oldTy) - .template dyn_cast_or_null(); + auto newTy = dyn_cast_or_null( + this->getTypeConverter()->convertType(oldTy)); if (!newTy) return rewriter.notifyMatchFailure( loc, llvm::formatv("unsupported type: {0}", op.getType())); @@ -549,11 +548,11 @@ struct ConvertMaxMin final : OpConversionPattern { /// Returns true iff the type is `index` or `vector<...index>`. static bool isIndexOrIndexVector(Type type) { - if (type.isa()) + if (isa(type)) return true; - if (auto vectorTy = type.dyn_cast()) - if (vectorTy.getElementType().isa()) + if (auto vectorTy = dyn_cast(type)) + if (isa(vectorTy.getElementType())) return true; return false; @@ -610,7 +609,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern { // Emit an index cast over the matching narrow type. Type narrowTy = rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth()); - if (auto vecTy = resultType.dyn_cast()) + if (auto vecTy = dyn_cast(resultType)) narrowTy = VectorType::get(vecTy.getShape(), narrowTy); // Sign or zero-extend the result. Let the matching conversion pattern @@ -1116,7 +1115,7 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter( // Vector case. addConversion([this](VectorType ty) -> std::optional { - auto intTy = ty.getElementType().dyn_cast(); + auto intTy = dyn_cast(ty.getElementType()); if (!intTy) return ty; diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index 787d498..8eddd81 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -86,12 +86,12 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, continue; } - assert(value.getType().cast().isDynamicDim(*dim) && + assert(cast(value.getType()).isDynamicDim(*dim) && "expected dynamic dim"); - if (value.getType().isa()) { + if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. operands.push_back(b.create(loc, value, *dim)); - } else if (value.getType().isa()) { + } else if (isa(value.getType())) { // A memref dimension is used: generate a memref.dim. operands.push_back(b.create(loc, value, *dim)); } else { diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index 45a4bf7..fb363c8 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -58,7 +58,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr) { if (auto value = ofr.dyn_cast()) return value; - auto attr = ofr.dyn_cast().dyn_cast(); + auto attr = dyn_cast(ofr.dyn_cast()); assert(attr && "expect the op fold result casts to an integer attribute"); return b.create(loc, attr.getValue().getSExtValue()); } @@ -73,8 +73,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, if (targetIsIndex ^ valueIsIndex) return b.create(loc, targetType, value); - auto targetIntegerType = targetType.dyn_cast(); - auto valueIntegerType = value.getType().dyn_cast(); + auto targetIntegerType = dyn_cast(targetType); + auto valueIntegerType = dyn_cast(value.getType()); assert(targetIntegerType && valueIntegerType && "unexpected cast between types other than integers and index"); assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); @@ -88,9 +88,9 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast) { if (operand.getType() == toType) return operand; - if (auto toIntType = toType.dyn_cast()) { + if (auto toIntType = dyn_cast(toType)) { // If operand is floating point, cast directly to the int type. - if (operand.getType().isa()) { + if (isa(operand.getType())) { if (isUnsignedCast) return b.create(loc, toType, operand); return b.create(loc, toType, operand); @@ -98,7 +98,7 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, // Cast index operands directly to the int type. if (operand.getType().isIndex()) return b.create(loc, toType, operand); - if (auto fromIntType = operand.getType().dyn_cast()) { + if (auto fromIntType = dyn_cast(operand.getType())) { // Either extend or truncate. if (toIntType.getWidth() > fromIntType.getWidth()) { if (isUnsignedCast) @@ -108,15 +108,15 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, if (toIntType.getWidth() < fromIntType.getWidth()) return b.create(loc, toType, operand); } - } else if (auto toFloatType = toType.dyn_cast()) { + } else if (auto toFloatType = dyn_cast(toType)) { // If operand is integer, cast directly to the float type. // Note that it is unclear how to cast from BF16<->FP16. - if (operand.getType().isa()) { + if (isa(operand.getType())) { if (isUnsignedCast) return b.create(loc, toFloatType, operand); return b.create(loc, toFloatType, operand); } - if (auto fromFloatType = operand.getType().dyn_cast()) { + if (auto fromFloatType = dyn_cast(operand.getType())) { if (toFloatType.getWidth() > fromFloatType.getWidth()) return b.create(loc, toFloatType, operand); if (toFloatType.getWidth() < fromFloatType.getWidth()) @@ -141,27 +141,27 @@ Value ArithBuilder::_and(Value lhs, Value rhs) { return b.create(loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::sub(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::mul(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::sgt(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { - if (lhs.getType().isa()) + if (isa(lhs.getType())) return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp index 7db078a..04f131e 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -528,9 +528,9 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { Operation *op = operand.getOwner(); Type type = operand.get().getType(); - bool isToken = type.isa(); - bool isGroup = type.isa(); - bool isValue = type.isa(); + bool isToken = isa(type); + bool isGroup = isa(type); + bool isValue = isa(type); // Drop reference after async token or group error check (coro await). if (auto await = dyn_cast(op)) diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp index 25cb618..db7550d 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -161,7 +161,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { // We treat TokenType as state update marker to represent side-effects of // async computations - bool isStateful = func.getCallableResults().front().isa(); + bool isStateful = isa(func.getCallableResults().front()); std::optional retToken; if (isStateful) @@ -535,7 +535,7 @@ public: ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // a `token` or a `value`, for `await_all` it must be a `group`). - if (!op.getOperand().getType().template isa()) + if (!isa(op.getOperand().getType())) return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); // Check if await operation is inside the coroutine function. @@ -646,7 +646,7 @@ public: getReplacementValue(AwaitOp op, Value operand, ConversionPatternRewriter &rewriter) const override { // Load from the async value storage. - auto valueType = operand.getType().cast().getValueType(); + auto valueType = cast(operand.getType()).getValueType(); return rewriter.create(op->getLoc(), valueType, operand); } }; diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index ed95a62..5e36b55 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -59,7 +59,7 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults, // This transform op is currently restricted to ModuleOps and function ops. // Such ops are modified in-place. - transformResults.set(getTransformed().cast(), payloadOps); + transformResults.set(cast(getTransformed()), payloadOps); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp index cf51aa5..b813b24 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -280,7 +280,7 @@ private: // defined in a non-dominated block or it is defined in the same block // but the current value is not dominated by the source value. if (!dominators.dominates(definingBlock, parentBlock) || - (definingBlock == parentBlock && value.isa())) { + (definingBlock == parentBlock && isa(value))) { toProcess.emplace_back(value, parentBlock); valuesToFree.insert(value); } else if (visitedValues.insert(std::make_tuple(value, definingBlock)) @@ -307,8 +307,8 @@ private: // Add new allocs and additional clone operations. for (Value value : valuesToFree) { - if (failed(value.isa() - ? introduceBlockArgCopy(value.cast()) + if (failed(isa(value) + ? introduceBlockArgCopy(cast(value)) : introduceValueCopyForRegionResult(value))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp index 83b2ef6..278664a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -43,7 +43,7 @@ static bool isKnownControlFlowInterface(Operation *op) { /// exceed the stack space. static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes, unsigned maxRankOfAllocatedMemRef) { - auto type = alloc.getType().dyn_cast(); + auto type = dyn_cast(alloc.getType()); if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { @@ -355,7 +355,7 @@ public: OpBuilder builder(startOperation); Operation *allocOp = alloc.getDefiningOp(); Operation *alloca = builder.create( - alloc.getLoc(), alloc.getType().cast(), + alloc.getLoc(), cast(alloc.getType()), allocOp->getOperands(), allocOp->getAttrs()); // Replace the original alloc by a newly created alloca. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 7b63335..dd359c2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -53,7 +53,7 @@ updateFuncOp(func::FuncOp func, SmallVector erasedResultTypes; BitVector erasedResultIndices(functionType.getNumResults()); for (const auto &resultType : llvm::enumerate(functionType.getResults())) { - if (auto memrefType = resultType.value().dyn_cast()) { + if (auto memrefType = dyn_cast(resultType.value())) { if (!hasStaticIdentityLayout(memrefType) && !hasFullyDynamicLayoutMap(memrefType)) { // Only buffers with static identity layout can be allocated. These can @@ -103,7 +103,7 @@ static void updateReturnOps(func::FuncOp func, SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; for (Value operand : op.getOperands()) { - if (operand.getType().isa()) + if (isa(operand.getType())) copyIntoOutParams.push_back(operand); else keepAsReturnOperands.push_back(operand); @@ -137,7 +137,7 @@ updateCalls(ModuleOp module, SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; for (OpResult result : op.getResults()) { - if (result.getType().isa()) + if (isa(result.getType())) replaceWithOutParams.push_back(result); else replaceWithNewCallResults.push_back(result); @@ -145,13 +145,13 @@ updateCalls(ModuleOp module, SmallVector outParams; OpBuilder builder(op); for (Value memref : replaceWithOutParams) { - if (!memref.getType().cast().hasStaticShape()) { + if (!cast(memref.getType()).hasStaticShape()) { op.emitError() << "cannot create out param for dynamically shaped result"; didFail = true; return; } - auto memrefType = memref.getType().cast(); + auto memrefType = cast(memref.getType()); auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index b9776e2..f8231ca 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -68,7 +68,7 @@ void BufferPlacementAllocs::build(Operation *op) { [=](MemoryEffects::EffectInstance &it) { Value value = it.getValue(); return isa(it.getEffect()) && value && - value.isa() && + isa(value) && it.getResource() != SideEffects::AutomaticAllocationScopeResource::get(); }); @@ -149,7 +149,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) { FailureOr bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, Attribute memorySpace) { - auto type = constantOp.getType().cast(); + auto type = cast(constantOp.getType()); auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) return failure(); @@ -185,14 +185,14 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, : IntegerAttr(); BufferizeTypeConverter typeConverter; - auto memrefType = typeConverter.convertType(type).cast(); + auto memrefType = cast(typeConverter.convertType(type)); if (memorySpace) memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/memrefType, - /*initial_value=*/constantOp.getValue().cast(), + /*initial_value=*/cast(constantOp.getValue()), /*constant=*/true, /*alignment=*/memrefAlignment); symbolTable.insert(global); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 4eabfcc..24aaff0 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -44,7 +44,7 @@ using namespace mlir::bufferization; static Value materializeToTensor(OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); + assert(isa(inputs[0].getType())); return builder.create(loc, type, inputs[0]); } @@ -66,11 +66,11 @@ BufferizeTypeConverter::BufferizeTypeConverter() { ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1 && "expected exactly one input"); - if (auto inputType = inputs[0].getType().dyn_cast()) { + if (auto inputType = dyn_cast(inputs[0].getType())) { // MemRef to MemRef cast. assert(inputType != type && "expected different types"); // Unranked to ranked and ranked to unranked casts must be explicit. - auto rankedDestType = type.dyn_cast(); + auto rankedDestType = dyn_cast(type); if (!rankedDestType) return nullptr; FailureOr replacement = @@ -80,7 +80,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() { return *replacement; } - if (inputs[0].getType().isa()) { + if (isa(inputs[0].getType())) { // Tensor to MemRef cast. return builder.create(loc, type, inputs[0]); } @@ -222,7 +222,7 @@ struct OneShotBufferizePass parseLayoutMapOption(unknownTypeConversion); opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, const BufferizationOptions &options) { - auto tensorType = value.getType().cast(); + auto tensorType = cast(value.getType()); if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) return bufferization::getMemRefTypeWithStaticIdentityLayout( tensorType, memorySpace); @@ -325,7 +325,7 @@ mlir::bufferization::createFinalizingBufferizePass() { // BufferizableOpInterface-based Bufferization //===----------------------------------------------------------------------===// -static bool isaTensor(Type t) { return t.isa(); } +static bool isaTensor(Type t) { return isa(t); } /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { @@ -549,7 +549,7 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() { options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( - value.getType().cast(), memorySpace); + cast(value.getType()), memorySpace); }; options.opFilter.allowDialect(); return options; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 5fc1257..58475d2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -33,12 +33,12 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, Operation *insertionPoint, const SmallVector &neededValues) { for (Value val : neededValues) { - if (auto bbArg = val.dyn_cast()) { + if (auto bbArg = dyn_cast(val)) { Block *owner = bbArg.getOwner(); if (!owner->findAncestorOpInBlock(*insertionPoint)) return false; } else { - auto opResult = val.cast(); + auto opResult = cast(val); if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) return false; } @@ -75,7 +75,7 @@ findValidInsertionPoint(Operation *emptyTensorOp, // * in case of an OpResult: There must be at least one op right after the // defining op (the anchor op or one of its // parents). - if (auto bbArg = val.dyn_cast()) { + if (auto bbArg = dyn_cast(val)) { insertionPointCandidates.push_back( &bbArg.getOwner()->getOperations().front()); } else { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index bf14e46..f73efc1 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -60,7 +60,7 @@ static BaseMemRefType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { auto tensorType = - funcOp.getFunctionType().getInput(index).dyn_cast(); + dyn_cast(funcOp.getFunctionType().getInput(index)); assert(tensorType && "expected TensorType"); BaseMemRefType memrefType = options.functionArgTypeConverterFn( @@ -71,7 +71,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, if (!layoutAttr) return memrefType; - auto rankedMemrefType = memrefType.dyn_cast(); + auto rankedMemrefType = dyn_cast(memrefType); assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); return MemRefType::get( rankedMemrefType.getShape(), rankedMemrefType.getElementType(), @@ -224,7 +224,7 @@ struct CallOpInterface for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { unsigned returnValIdx = it.index(); Type returnType = it.value(); - if (!returnType.isa()) { + if (!isa(returnType)) { // Non-tensor values are returned. retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); @@ -242,7 +242,7 @@ struct CallOpInterface Value tensorOperand = opOperand.get(); // Non-tensor operands are just copied. - if (!tensorOperand.getType().isa()) { + if (!isa(tensorOperand.getType())) { newOperands[idx] = tensorOperand; continue; } @@ -342,7 +342,7 @@ struct FuncOpInterface SmallVector argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (auto tensorType = argType.dyn_cast()) { + if (auto tensorType = dyn_cast(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -356,7 +356,7 @@ struct FuncOpInterface if (funcOp.getBody().empty()) { SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (resultType.isa()) + if (isa(resultType)) return funcOp->emitError() << "cannot bufferize bodiless function " << "that returns a tensor"; retTypes.push_back(resultType); @@ -373,7 +373,7 @@ struct FuncOpInterface // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. Block &frontBlock = funcOp.getBody().front(); for (BlockArgument &bbArg : frontBlock.getArguments()) { - auto tensorType = bbArg.getType().dyn_cast(); + auto tensorType = dyn_cast(bbArg.getType()); // Non-tensor types stay the same. if (!tensorType) continue; @@ -404,7 +404,7 @@ struct FuncOpInterface SmallVector returnValues; for (OpOperand &returnOperand : returnOp->getOpOperands()) { Value returnVal = returnOperand.get(); - auto tensorType = returnVal.getType().dyn_cast(); + auto tensorType = dyn_cast(returnVal.getType()); rewriter.setInsertionPoint(returnOp); // If not a tensor type just forward it. @@ -436,7 +436,7 @@ struct FuncOpInterface bool isWritable(Operation *op, Value value, const AnalysisState &state) const { auto funcOp = cast(op); - BlockArgument bbArg = value.dyn_cast(); + BlockArgument bbArg = dyn_cast(value); assert(bbArg && "expected BlockArgument"); // "bufferization.writable" overrides other writability decisions. This is diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index db7d453..6da5126 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -66,7 +66,7 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState) using namespace mlir; using namespace mlir::bufferization; -static bool isaTensor(Type t) { return t.isa(); } +static bool isaTensor(Type t) { return isa(t); } //===----------------------------------------------------------------------===// // Bufferization-specific attribute manipulation. @@ -85,11 +85,11 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { SmallVector inPlaceVector; if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) { inPlaceVector = SmallVector(llvm::to_vector<4>( - attr.cast().getAsValueRange())); + cast(attr).getAsValueRange())); } else { inPlaceVector = SmallVector(op->getNumOperands(), "none"); for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) inPlaceVector[opOperand.getOperandNumber()] = "false"; } inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; @@ -107,12 +107,12 @@ OneShotAnalysisState::OneShotAnalysisState( // Set up alias sets. op->walk([&](Operation *op) { for (Value v : op->getResults()) - if (v.getType().isa()) + if (isa(v.getType())) createAliasInfoEntry(v); for (Region &r : op->getRegions()) for (Block &b : r.getBlocks()) for (auto bbArg : b.getArguments()) - if (bbArg.getType().isa()) + if (isa(bbArg.getType())) createAliasInfoEntry(bbArg); }); @@ -121,7 +121,7 @@ OneShotAnalysisState::OneShotAnalysisState( if (!options.isOpAllowed(bufferizableOp)) return WalkResult::skip(); for (OpOperand &opOperand : bufferizableOp->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) bufferizeInPlace(opOperand); return WalkResult::advance(); @@ -187,13 +187,13 @@ void OneShotAnalysisState::gatherYieldedTensors(Operation *op) { for (OpOperand &returnValOperand : returnOp->getOpOperands()) { Value returnVal = returnValOperand.get(); // Skip non-tensor values. - if (!returnVal.getType().isa()) + if (!isa(returnVal.getType())) continue; // Add all aliases of the returned value. But only the ones that are in // the same block. applyOnAliases(returnVal, [&](Value v) { - if (auto bbArg = v.dyn_cast()) { + if (auto bbArg = dyn_cast(v)) { if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) yieldedTensors.insert(bbArg); return; @@ -217,7 +217,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { // Check all tensor OpResults. for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!isa(opResult.getType())) continue; // If there is no preceding definition, the tensor contents are @@ -259,7 +259,7 @@ bool OneShotAnalysisState::isWritable(Value value) const { return bufferizableOp.isWritable(value, *this); // Query BufferizableOpInterface to see if the BlockArgument is writable. - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = dyn_cast(value)) if (auto bufferizableOp = getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) return bufferizableOp.isWritable(bbArg, *this); @@ -431,12 +431,12 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; readingOp->setAttr(readAttr, b.getUnitAttr()); - if (auto opResult = definition.dyn_cast()) { + if (auto opResult = dyn_cast(definition)) { std::string defAttr = id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]"; opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr()); } else { - auto bbArg = definition.cast(); + auto bbArg = cast(definition); std::string defAttr = id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr()); @@ -581,7 +581,7 @@ hasReadAfterWriteInterference(const DenseSet &usesRead, continue; } } else { - auto bbArg = definition.cast(); + auto bbArg = cast(definition); Block *block = bbArg.getOwner(); if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " @@ -715,12 +715,12 @@ static void annotateNonWritableTensor(Value value) { static int64_t counter = 0; OpBuilder b(value.getContext()); std::string id = "W_" + std::to_string(counter++); - if (auto opResult = value.dyn_cast()) { + if (auto opResult = dyn_cast(value)) { std::string attr = id + "[NOT-WRITABLE: result " + std::to_string(opResult.getResultNumber()) + "]"; opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr()); } else { - auto bbArg = value.cast(); + auto bbArg = cast(value); std::string attr = id + "[NOT-WRITABLE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr()); @@ -812,7 +812,7 @@ LogicalResult OneShotAnalysisState::analyzeSingleOp(Operation *op, const DominanceInfo &domInfo) { for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo))) return failure(); return success(); @@ -831,7 +831,7 @@ static void equivalenceAnalysis(SmallVector &ops, for (Operation *op : ops) { if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!isa(opResult.getType())) continue; AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); if (aliases.getNumAliases() == 0) @@ -958,7 +958,7 @@ static LogicalResult checkAliasInfoConsistency(Operation *op, } for (OpOperand &opOperand : op->getOpOperands()) { - if (opOperand.get().getType().isa()) { + if (isa(opOperand.get().getType())) { if (wouldCreateReadAfterWriteInterference( opOperand, domInfo, state, /*checkConsistencyOnly=*/true)) { @@ -984,7 +984,7 @@ annotateOpsWithBufferizationMarkers(Operation *op, // Add __inplace_operands_attr__. op->walk([&](Operation *op) { for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) setInPlaceOpOperand(opOperand, state.isInPlace(opOperand)); }); } @@ -1031,12 +1031,12 @@ static LogicalResult assertNoAllocsReturned(Operation *op, for (OpOperand &returnValOperand : returnOp->getOpOperands()) { Value returnVal = returnValOperand.get(); // Skip non-tensor values. - if (!returnVal.getType().isa()) + if (!isa(returnVal.getType())) continue; bool foundEquivValue = false; state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { - if (auto bbArg = equivVal.dyn_cast()) { + if (auto bbArg = dyn_cast(equivVal)) { Operation *definingOp = bbArg.getOwner()->getParentOp(); if (definingOp->isProperAncestor(returnOp)) foundEquivValue = true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 27b560a..d0af1c2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -109,9 +109,9 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal, SmallVector equivBbArgs; if (op->hasAttr(kEquivalentArgsAttr)) { - auto attr = op->getAttr(kEquivalentArgsAttr).cast(); + auto attr = cast(op->getAttr(kEquivalentArgsAttr)); equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { - return a.cast().getValue().getSExtValue(); + return cast(a).getValue().getSExtValue(); })); } else { equivBbArgs.append(op->getNumOperands(), -1); @@ -132,10 +132,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, // return value may alias with any tensor bbArg. FunctionType type = funcOp.getFunctionType(); for (const auto &inputIt : llvm::enumerate(type.getInputs())) { - if (!inputIt.value().isa()) + if (!isa(inputIt.value())) continue; for (const auto &resultIt : llvm::enumerate(type.getResults())) { - if (!resultIt.value().isa()) + if (!isa(resultIt.value())) continue; int64_t returnIdx = resultIt.index(); int64_t bbArgIdx = inputIt.index(); @@ -150,9 +150,9 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, assert(returnOp && "expected func with single return op"); for (OpOperand &returnVal : returnOp->getOpOperands()) - if (returnVal.get().getType().isa()) + if (isa(returnVal.get().getType())) for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) { + if (isa(bbArg.getType())) { int64_t returnIdx = returnVal.getOperandNumber(); int64_t bbArgIdx = bbArg.getArgNumber(); if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { @@ -193,7 +193,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; ++idx) { // Skip non-tensor arguments. - if (!funcOp.getFunctionType().getInput(idx).isa()) + if (!isa(funcOp.getFunctionType().getInput(idx))) continue; bool isRead; bool isWritten; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index 4cd19b4..b12ea25 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -58,7 +58,7 @@ resolveUsesInRepetitiveRegions(Operation *op, for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { Value operand = opOperand.get(); // Skip non-tensor operands. - if (!operand.getType().isa()) + if (!isa(operand.getType())) continue; // Skip operands that do not bufferize to memory writes. if (!bufferizableOp.bufferizesToMemoryWrite(opOperand, state)) @@ -85,7 +85,7 @@ resolveUsesInRepetitiveRegions(Operation *op, // Insert a tensor copy and replace all uses inside of repetitive regions. rewriter.setInsertionPoint(bufferizableOp); auto tensorCopy = rewriter.create( - bufferizableOp->getLoc(), operand.getType().cast(), + bufferizableOp->getLoc(), cast(operand.getType()), /*dynamicSizes=*/ValueRange(), /*copy=*/operand, /*memory_space=*/IntegerAttr()); for (OpOperand *use : usesInsideRegion) @@ -137,7 +137,7 @@ mlir::bufferization::insertTensorCopies(Operation *op, SmallVector escapeAttrValue; bool foundTensorResult = false; for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa() || + if (!isa(opResult.getType()) || !bufferizableOp.bufferizesToAllocation(opResult)) { escapeAttrValue.push_back(false); continue; diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 39c9e5e..8cd2ccf 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -257,19 +257,19 @@ checkMappingAttributeTypes(std::optional transformOp, bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); bool hasLinearMapping = llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { - return attr.isa(); + return isa(attr); }); int64_t countMappingTypes = 0; countMappingTypes += hasBlockMapping ? 1 : 0; @@ -520,7 +520,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( ArrayRef{forallMappingAttrs}.take_front( forallOp.getInductionVars().size()))) { Value peIdOp = mappingIdOps[static_cast( - dim.cast().getMappingId())]; + cast(dim).getMappingId())]; bvm.map(iv, peIdOp); } diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp index 0a584a7..ca9f2ac 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -214,7 +214,7 @@ private: /// Returns an accumulator factory that creates an op specified by opName. AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { - bool isFloatingPoint = valueType.isa(); + bool isFloatingPoint = isa(valueType); switch (opName) { case gpu::AllReduceOperation::ADD: return isFloatingPoint ? getFactory() diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index 0890bf2..1fbe66f 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -158,9 +158,9 @@ async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, transform(executeOp.getResultTypes(), std::back_inserter(resultTypes), [](Type type) { // Extract value type from !async.value. - if (auto valueType = type.dyn_cast()) + if (auto valueType = dyn_cast(type)) return valueType.getValueType(); - assert(type.isa() && "expected token type"); + assert(isa(type) && "expected token type"); return type; }); transform(results, std::back_inserter(resultTypes), @@ -305,9 +305,9 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback { executeOp.getBodyResults(), [](OpResult result) { if (result.use_empty() || result.hasOneUse()) return false; - auto valueType = result.getType().dyn_cast(); + auto valueType = dyn_cast(result.getType()); return valueType && - valueType.getValueType().isa(); + isa(valueType.getValueType()); }); if (multiUseResults.empty()) return; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 91c1c76..b1e2f91 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -338,7 +338,7 @@ public: if (!resultAttr) return failure(); - dataLayoutSpec = resultAttr.dyn_cast(); + dataLayoutSpec = dyn_cast(resultAttr); if (!dataLayoutSpec) return failure(); } @@ -410,7 +410,7 @@ private: SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { StringRef symbolName = - symbolUse.getSymbolRef().cast().getValue(); + cast(symbolUse.getSymbolRef()).getValue(); if (symbolTable.lookup(symbolName)) continue; diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp index ea9c396..21de15e 100644 --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -30,7 +30,7 @@ using namespace mlir::gpu; /// single-iteration loops. Maps the innermost loops to thread dimensions, in /// reverse order to enable access coalescing in the innermost loop. static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { - auto memRefType = from.getType().cast(); + auto memRefType = cast(from.getType()); auto rank = memRefType.getRank(); SmallVector lbs, ubs, steps; @@ -121,8 +121,8 @@ static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) { /// pointed to by "from". In case a smaller block would be sufficient, the /// caller can create a subview of the memref and promote it instead. static void insertCopies(Region ®ion, Location loc, Value from, Value to) { - auto fromType = from.getType().cast(); - auto toType = to.getType().cast(); + auto fromType = cast(from.getType()); + auto toType = cast(to.getType()); (void)fromType; (void)toType; assert(fromType.getShape() == toType.getShape()); @@ -143,7 +143,7 @@ static void insertCopies(Region ®ion, Location loc, Value from, Value to) { /// copies will be inserted in the beginning and in the end of the function. void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) { Value value = op.getArgument(arg); - auto type = value.getType().dyn_cast(); + auto type = dyn_cast(value.getType()); assert(type && type.hasStaticShape() && "can only promote memrefs"); // Get the type of the buffer in the workgroup memory. diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp index 71d2776..8b09f44 100644 --- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp @@ -67,7 +67,7 @@ LogicalResult DynParametricAttrConstraint::verify( ConstraintVerifier &context) const { // Check that the base is the expected one. - auto dynAttr = attr.dyn_cast(); + auto dynAttr = dyn_cast(attr); if (!dynAttr || dynAttr.getAttrDef() != attrDef) { if (emitError) { StringRef dialectName = attrDef->getDialect()->getNamespace(); @@ -102,7 +102,7 @@ LogicalResult DynParametricTypeConstraint::verify( function_ref emitError, Attribute attr, ConstraintVerifier &context) const { // Check that the base is a TypeAttr. - auto typeAttr = attr.dyn_cast(); + auto typeAttr = dyn_cast(attr); if (!typeAttr) { if (emitError) return emitError() << "expected type, got attribute '" << attr; @@ -110,7 +110,7 @@ LogicalResult DynParametricTypeConstraint::verify( } // Check that the type base is the expected one. - auto dynType = typeAttr.getValue().dyn_cast(); + auto dynType = dyn_cast(typeAttr.getValue()); if (!dynType || dynType.getTypeDef() != typeDef) { if (emitError) { StringRef dialectName = typeDef->getDialect()->getNamespace(); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp index 9f38b0c..ecdadd3 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -25,11 +25,11 @@ using namespace mlir; /// Attempt to extract a filename for the given loc. static FileLineColLoc extractFileLoc(Location loc) { - if (auto fileLoc = loc.dyn_cast()) + if (auto fileLoc = dyn_cast(loc)) return fileLoc; - if (auto nameLoc = loc.dyn_cast()) + if (auto nameLoc = dyn_cast(loc)) return extractFileLoc(nameLoc.getChildLoc()); - if (auto opaqueLoc = loc.dyn_cast()) + if (auto opaqueLoc = dyn_cast(loc)) return extractFileLoc(opaqueLoc.getFallbackLocation()); return FileLineColLoc(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 1936a53..02909bb 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -607,7 +607,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation( return diag; Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position)); - if (getResult().getType().isa()) { + if (isa(getResult().getType())) { results.setValues(cast(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -648,7 +648,7 @@ transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op, LogicalResult transform::MatchStructuredResultOp::verify() { if ((getAny() || getSingle()) ^ - getResult().getType().isa()) { + isa(getResult().getType())) { return emitOpError() << "expects either the any/single keyword or the type " "value handle result type"; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 74eb3a2..ea8d285 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -87,7 +87,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations( SmallVector &result, ArrayRef ofrs) { for (OpFoldResult ofr : ofrs) { if (ofr.is()) { - if (!ofr.get().isa()) + if (!isa(ofr.get())) return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; result.push_back(ofr); continue; @@ -155,7 +155,7 @@ transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { return linalg::bufferizeToAllocation(rewriter, v, memorySpace); })); - results.setValues(getTransformed().cast(), transformed); + results.setValues(cast(getTransformed()), transformed); return DiagnosedSilenceableFailure::success(); } @@ -276,7 +276,7 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, if (!sizesAttr) return parser.emitError(opLoc) << "expected '" << sizesAttrName << "' attribute"; - auto sizesArrayAttr = sizesAttr.dyn_cast(); + auto sizesArrayAttr = dyn_cast(sizesAttr); if (!sizesArrayAttr) return parser.emitError(opLoc) << "'" << sizesAttrName << "' attribute must be an array"; @@ -389,7 +389,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, // Tile the producer. int64_t resultNumber = - sliceOpToTile.getSource().cast().getResultNumber(); + cast(sliceOpToTile.getSource()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); FailureOr tileAndFuseResult = @@ -411,10 +411,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], - sliceOpToTile->getResult(0) - .getType() - .cast() - .getShape()); + cast(sliceOpToTile->getResult(0).getType()).getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); return tileAndFuseResult->tiledOps; @@ -482,7 +479,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. - int64_t resultNumber = pUse->get().cast().getResultNumber(); + int64_t resultNumber = cast(pUse->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); // Gather destination tensors. @@ -516,10 +513,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( // Replace the extract op. auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0], - sliceOpToTile->getResult(0) - .getType() - .cast() - .getShape()); + cast(sliceOpToTile->getResult(0).getType()).getShape()); assert(succeeded(maybeRankReduced) && "unexpected shape"); rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); @@ -568,7 +562,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, // TODO: Generalize to other type of ops. assert(!isa(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); - unsigned resultNumber = use->get().cast().getResultNumber(); + unsigned resultNumber = cast(use->get()).getResultNumber(); LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); OpBuilder::InsertionGuard guard(rewriter); @@ -587,8 +581,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, ArrayRef producerOps = state.getPayloadOps(getProducerOp()); // If nothing to fuse, propagate success. if (producerOps.empty()) { - results.set(getFusedOp().cast(), - SmallVector{}); + results.set(cast(getFusedOp()), SmallVector{}); return DiagnosedSilenceableFailure::success(); } ArrayRef containingOps = state.getPayloadOps(getContainingOp()); @@ -671,7 +664,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - results.set(getFusedOp().cast(), fusedOps); + results.set(cast(getFusedOp()), fusedOps); return DiagnosedSilenceableFailure::success(); } @@ -865,7 +858,7 @@ transform::MatchOp::apply(transform::TransformResults &results, }; payloadOps.front()->walk(matchFun); - results.set(getResult().cast(), res); + results.set(cast(getResult()), res); return DiagnosedSilenceableFailure::success(); } @@ -901,7 +894,7 @@ static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( LinalgOp target, transform::ApplyToEachResultList &results, TransformState &state) { - if (getLowSize().getType().isa()) { + if (isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() << "cannot compute parametric tile sizes for dynamically " @@ -923,7 +916,7 @@ DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( spec->lowTileSize * spec->lowTripCount}), [&builder, this](int64_t value) { return builder.getIntegerAttr( - getLowSize().getType().cast().getType(), value); + cast(getLowSize().getType()).getType(), value); })); return DiagnosedSilenceableFailure::success(); } @@ -958,7 +951,7 @@ void transform::MultiTileSizesOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTarget(), effects); producesHandle(getResults(), effects); - if (getLowSize().getType().isa()) + if (isa(getLowSize().getType())) onlyReadsPayload(effects); else modifiesPayload(effects); @@ -1006,7 +999,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults, ArrayRef targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. if (targetOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); + transformResults.set(cast(getPackedOp()), {}); return DiagnosedSilenceableFailure::success(); } // Fail on multi-op handles. @@ -1036,7 +1029,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults, if (failed(maybeResult)) return emitDefiniteFailure("data tiling failed"); - transformResults.set(getPackedOp().cast(), + transformResults.set(cast(getPackedOp()), maybeResult->packedLinalgOp.getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -1242,7 +1235,7 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults, } results.push_back(linalgOp); } - transformResults.set(getPackedOp().cast(), results); + transformResults.set(cast(getPackedOp()), results); return DiagnosedSilenceableFailure::success(); } @@ -1322,9 +1315,9 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults, ArrayRef linalgOps = state.getPayloadOps(getTargetLinalgOp()); // Step 1. If nothing to pack, propagate success. if (packOrUnpackOps.empty()) { - transformResults.set(getPackedOp().cast(), {}); - transformResults.set(getPackOp().cast(), {}); - transformResults.set(getUnPackOp().cast(), {}); + transformResults.set(cast(getPackedOp()), {}); + transformResults.set(cast(getPackOp()), {}); + transformResults.set(cast(getUnPackOp()), {}); return DiagnosedSilenceableFailure::success(); } @@ -1366,7 +1359,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults, if (unPackOp) { assert(!packOp && "packOp must be null on entry when unPackOp is not null"); OpOperand *packUse = linalgOp.getDpsInitOperand( - unPackOp.getSource().cast().getResultNumber()); + cast(unPackOp.getSource()).getResultNumber()); packOp = dyn_cast_or_null(packUse->get().getDefiningOp()); if (!packOp || !packOp.getResult().hasOneUse()) return emitSilenceableError() << "could not find matching pack op"; @@ -1400,14 +1393,14 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults, assert(succeeded(res) && "unexpected packTranspose failure"); // Step 4. Return results. - transformResults.set(getPackOp().cast(), {res->transposedPackOp}); - transformResults.set(getPackedOp().cast(), + transformResults.set(cast(getPackOp()), {res->transposedPackOp}); + transformResults.set(cast(getPackedOp()), {res->transposedLinalgOp}); if (unPackOp) { - transformResults.set(getUnPackOp().cast(), + transformResults.set(cast(getUnPackOp()), {res->transposedUnPackOp}); } else { - transformResults.set(getUnPackOp().cast(), {}); + transformResults.set(cast(getUnPackOp()), {}); } return DiagnosedSilenceableFailure::success(); @@ -1430,14 +1423,14 @@ transform::PadOp::applyToOne(LinalgOp target, SmallVector paddingValues; for (auto const &it : llvm::zip(getPaddingValues(), target->getOperandTypes())) { - auto attr = std::get<0>(it).dyn_cast(); + auto attr = dyn_cast(std::get<0>(it)); if (!attr) { emitOpError("expects padding values to be typed attributes"); return DiagnosedSilenceableFailure::definiteFailure(); } Type elementType = getElementTypeOrSelf(std::get<1>(it)); // Try to parse string attributes to obtain an attribute of element type. - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = dyn_cast(attr)) { auto parsedAttr = dyn_cast_if_present( parseAttribute(stringAttr, getContext(), elementType, /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); @@ -1462,9 +1455,9 @@ transform::PadOp::applyToOne(LinalgOp target, // Extract the transpose vectors. SmallVector> transposePaddings; - for (Attribute transposeVector : getTransposePaddings().cast()) + for (Attribute transposeVector : cast(getTransposePaddings())) transposePaddings.push_back( - extractFromI64ArrayAttr(transposeVector.cast())); + extractFromI64ArrayAttr(cast(transposeVector))); TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); @@ -1549,13 +1542,13 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( return emitDefiniteFailure() << "could not build packing loop nest"; if (result->clonedLoopIvs.empty()) { - transformResults.set(getPackingLoop().cast(), + transformResults.set(cast(getPackingLoop()), result->hoistedPadOp.getOperation()); return DiagnosedSilenceableFailure::success(); } auto outerPackedLoop = scf::getForInductionVarOwner(result->clonedLoopIvs.front()); - transformResults.set(getPackingLoop().cast(), + transformResults.set(cast(getPackingLoop()), outerPackedLoop.getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -1643,7 +1636,7 @@ transform::PromoteOp::applyToOne(LinalgOp target, if (mapping.size() > 1) return emitDefaultDefiniteFailure(target); - auto addressSpace = mapping[0].cast(); + auto addressSpace = cast(mapping[0]); if (addressSpace.getAddressSpace() == gpu::GPUDialect::getWorkgroupAddressSpace()) { @@ -1711,7 +1704,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults, rewriter.replaceOp(target, replacement->getResults()); replacements.push_back(replacement); } - transformResults.set(getReplacement().cast(), replacements); + transformResults.set(cast(getReplacement()), replacements); return DiagnosedSilenceableFailure::success(); } @@ -1828,7 +1821,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { auto diag = DiagnosedSilenceableFailure::success(); - if (getDynamicSplitPoint().getType().isa()) { + if (isa(getDynamicSplitPoint().getType())) { splitPoints = llvm::to_vector(llvm::map_range( state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { if (op->getNumResults() != 1 || @@ -1909,8 +1902,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, return diag; } - results.set(getFirst().cast(), first); - results.set(getSecond().cast(), second); + results.set(cast(getFirst()), first); + results.set(cast(getSecond()), second); return DiagnosedSilenceableFailure::success(); } @@ -2212,12 +2205,12 @@ transform::TileOp::apply(TransformResults &transformResults, dynamicSizeProducers.reserve(getDynamicSizes().size()); paramSizes.reserve(getDynamicSizes().size()); for (Value transformValue : getDynamicSizes()) { - if (transformValue.getType().isa()) { + if (isa(transformValue.getType())) { dynamicSizeProducers.push_back({}); ArrayRef params = state.getParams(transformValue); paramSizes.push_back( llvm::to_vector(llvm::map_range(params, [](Attribute attr) { - return attr.cast().getValue().getSExtValue(); + return cast(attr).getValue().getSExtValue(); }))); if (paramSizes.back().size() != targets.size()) { @@ -2247,7 +2240,7 @@ transform::TileOp::apply(TransformResults &transformResults, for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) + isa(op->getResult(0).getType())) continue; DiagnosedSilenceableFailure diag = @@ -2283,7 +2276,7 @@ transform::TileOp::apply(TransformResults &transformResults, for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); + getLoc(), cast(attr).getInt())); continue; } ArrayRef dynamicSizes = dynamicSizeProducers[dynamicIdx]; @@ -2320,9 +2313,9 @@ transform::TileOp::apply(TransformResults &transformResults, loops[en2.index()].push_back(en2.value()); } - transformResults.set(getTiledLinalgOp().cast(), tiled); + transformResults.set(cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) - transformResults.set(getLoops()[en.index()].cast(), en.value()); + transformResults.set(cast(getLoops()[en.index()]), en.value()); return DiagnosedSilenceableFailure::success(); } @@ -2582,8 +2575,8 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults, tiledOps.push_back(tilingResult.tiledOp); } - transformResults.set(getForallOp().cast(), tileOps); - transformResults.set(getTiledOp().cast(), tiledOps); + transformResults.set(cast(getForallOp()), tileOps); + transformResults.set(cast(getTiledOp()), tiledOps); return DiagnosedSilenceableFailure::success(); } @@ -2678,7 +2671,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, for (Operation *op : dynamicSizeProducers.back()) { if (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) + isa(op->getResult(0).getType())) continue; DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected sizes to be produced by ops " @@ -2712,7 +2705,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, for (OpFoldResult ofr : getMixedSizes()) { if (auto attr = ofr.dyn_cast()) { sizes.push_back(b.create( - getLoc(), attr.cast().getInt())); + getLoc(), cast(attr).getInt())); } else { sizes.push_back( dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); @@ -2737,9 +2730,9 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, loops[en2.index()].push_back(en2.value()); } - transformResults.set(getTiledLinalgOp().cast(), tiled); + transformResults.set(cast(getTiledLinalgOp()), tiled); for (const auto &en : llvm::enumerate(loops)) - transformResults.set(getLoops()[en.index()].cast(), en.value()); + transformResults.set(cast(getLoops()[en.index()]), en.value()); return DiagnosedSilenceableFailure::success(); } @@ -2899,7 +2892,7 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( for (OpFoldResult sz : getMixedVectorSizes()) { if (sz.is()) { auto attr = sz.get(); - vectorSizes.push_back(attr.cast().getInt()); + vectorSizes.push_back(cast(attr).getInt()); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index 1a7d7a1..6b06c32 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -64,20 +64,20 @@ public: if (genericOp.getNumDpsInits() != 1) return failure(); - auto outputType = genericOp.getResultTypes().front().dyn_cast(); + auto outputType = dyn_cast(genericOp.getResultTypes().front()); // Require the output types to be static given that we are generating // constants. if (!outputType || !outputType.hasStaticShape()) return failure(); if (!llvm::all_of(genericOp.getInputs(), [](Value input) { - return input.getType().isa(); + return isa(input.getType()); })) return failure(); // Make sure all element types are the same. auto getOperandElementType = [](Value value) { - return value.getType().cast().getElementType(); + return cast(value.getType()).getElementType(); }; if (!llvm::all_equal( llvm::map_range(genericOp->getOperands(), getOperandElementType))) @@ -138,7 +138,7 @@ public: // unify the following cases but they have lifetime as the MLIRContext. SmallVector intOutputValues; SmallVector fpOutputValues; - if (elementType.template isa()) + if (isa(elementType)) fpOutputValues.resize(numElements, APFloat(0.f)); else intOutputValues.resize(numElements); @@ -174,7 +174,7 @@ public: auto inputShapes = llvm::to_vector<4>( llvm::map_range(genericOp.getInputs(), [](Value value) { - return value.getType().cast().getShape(); + return cast(value.getType()).getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op @@ -205,7 +205,7 @@ public: } }; - bool isFloat = elementType.isa(); + bool isFloat = isa(elementType); if (isFloat) { SmallVector> inFpRanges; for (int i = 0; i < numInputs; ++i) @@ -282,7 +282,7 @@ struct FoldConstantTranspose : public FoldConstantBase { // The yield op should return the block argument corresponds to the input. for (Value yieldVal : yieldOp.getValues()) { - auto yieldArg = yieldVal.dyn_cast(); + auto yieldArg = dyn_cast(yieldVal); if (!yieldArg || yieldArg.getOwner() != &body) return nullptr; if (yieldArg.getArgNumber() != 0) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index 5423cf8..48c2459 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -29,7 +29,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { } static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { - bool isInt = x.getType().isa(); + bool isInt = isa(x.getType()); if (isInt) return builder.create(loc, x, y); return builder.create(loc, x, y); @@ -42,7 +42,7 @@ static Value createMul(Location loc, Value x, Value y, Type accType, convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false); Value yConvert = convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false); - if (accType.isa()) + if (isa(accType)) return builder.create(loc, xConvert, yConvert); return builder.create(loc, xConvert, yConvert); } @@ -74,9 +74,9 @@ static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = cast(convOp.getInputs()[0].getType()); + auto filterType = cast(convOp.getInputs()[1].getType()); + auto outputType = cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -210,9 +210,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::DepthwiseConv2DNhwcHwcOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = cast(convOp.getInputs()[0].getType()); + auto filterType = cast(convOp.getInputs()[1].getType()); + auto outputType = cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -230,7 +230,7 @@ rewriteInIm2Col(RewriterBase &rewriter, Location loc = convOp.getLoc(); auto transposeOperand = [&](Value operand, ArrayRef indices) { - auto operandTensorType = operand.getType().cast(); + auto operandTensorType = cast(operand.getType()); auto nloops = indices.size(); ArrayRef inputShape = operandTensorType.getShape(); @@ -272,7 +272,7 @@ rewriteInIm2Col(RewriterBase &rewriter, Value inputT = transposeOperand(input, {0, 3, 1, 2}); Value filterT = transposeOperand(filter, {2, 0, 1}); ArrayRef filterTShape = - filterT.getType().cast().getShape(); + cast(filterT.getType()).getShape(); ArrayRef outputShape = outputType.getShape(); int n = outputShape[0]; @@ -360,9 +360,9 @@ rewriteInIm2Col(RewriterBase &rewriter, FailureOr> rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { - auto inputType = convOp.getInputs()[0].getType().cast(); - auto filterType = convOp.getInputs()[1].getType().cast(); - auto outputType = convOp.getOutputs()[0].getType().cast(); + auto inputType = cast(convOp.getInputs()[0].getType()); + auto filterType = cast(convOp.getInputs()[1].getType()); + auto outputType = cast(convOp.getOutputs()[0].getType()); if (!filterType.hasStaticShape()) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 3ec5094..a81a48d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -66,12 +66,12 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, Attribute constYieldedValue; // Is the yielded value a bbArg defined outside of the PadOp? bool outsideBbArg = - yieldedValue.isa() && - yieldedValue.cast().getOwner()->getParentOp() != + isa(yieldedValue) && + cast(yieldedValue).getOwner()->getParentOp() != padOp.getOperation(); // Is the yielded value an OpResult defined outside of the PadOp? bool outsideOpResult = - yieldedValue.isa() && + isa(yieldedValue) && yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation(); bool invariantYieldedValue = outsideBbArg || outsideOpResult; if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) { @@ -120,19 +120,19 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, static SmallVector reifyOrComputeDynamicSizes(OpBuilder &b, Value value) { - auto tensorType = value.getType().cast(); + auto tensorType = cast(value.getType()); if (tensorType.hasStaticShape()) return {}; // Try to reify dynamic sizes. ReifiedRankedShapedTypeDims reifiedShape; - if (value.isa() && + if (isa(value) && succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) { SmallVector dynSizes; for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) dynSizes.push_back( - reifiedShape[value.cast().getResultNumber()][i] + reifiedShape[cast(value).getResultNumber()][i] .get()); } return dynSizes; @@ -153,12 +153,12 @@ static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, Attribute memorySpace = {}) { OpBuilder::InsertionGuard g(rewriter); - auto tensorType = value.getType().cast(); + auto tensorType = cast(value.getType()); // Create buffer allocation. - auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout( - tensorType, memorySpace) - .cast(); + auto memrefType = + cast(bufferization::getMemRefTypeWithStaticIdentityLayout( + tensorType, memorySpace)); SmallVector dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value); Value alloc = rewriter.create(loc, memrefType, dynamicSizes); @@ -206,7 +206,7 @@ FailureOr mlir::linalg::rewriteInDestinationPassingStyle( RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { Location loc = fromElementsOp.getLoc(); RankedTensorType tensorType = - fromElementsOp.getType().cast(); + cast(fromElementsOp.getType()); auto shape = tensorType.getShape(); // Create tensor.empty. @@ -247,7 +247,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, return failure(); Location loc = generateOp.getLoc(); - RankedTensorType tensorType = generateOp.getType().cast(); + RankedTensorType tensorType = cast(generateOp.getType()); // Create tensor.empty. auto emptyOp = @@ -339,7 +339,7 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value, llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; })); OpBuilder::InsertionGuard g(rewriter); - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = dyn_cast(value)) { rewriter.setInsertionPointToStart(bbArg.getOwner()); } else { rewriter.setInsertionPointAfter(value.getDefiningOp()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index e5764cb..1ddd8b1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -640,7 +640,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) { auto loc = genericOp.getLoc(); Value unPackDest = producerUnPackOp.getDest(); auto genericOutType = - genericOp.getDpsInitOperand(0)->get().getType().cast(); + cast(genericOp.getDpsInitOperand(0)->get().getType()); if (producerUnPackOp.getDestType() != genericOutType || !genericOutType.hasStaticShape()) { unPackDest = tensor::UnPackOp::createDestinationTensor( diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp index e381b0a..42f87a1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -132,12 +132,12 @@ SmallVector permuteValues(ArrayRef values, static Value getZero(OpBuilder &b, Location loc, Type elementType) { assert(elementType.isIntOrIndexOrFloat() && "expected scalar type while computing zero value"); - if (elementType.isa()) + if (isa(elementType)) return b.create(loc, 0, elementType); if (elementType.isIndex()) return b.create(loc, 0); // Assume float. - auto floatType = elementType.cast(); + auto floatType = cast(elementType); return b.create( loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); } @@ -179,7 +179,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, if (resultNumber) { newInitValues.push_back( genericOp.getDpsInitOperand(*resultNumber)->get()); - OpResult result = genericOp.getResult(*resultNumber).cast(); + OpResult result = cast(genericOp.getResult(*resultNumber)); newResultTypes.push_back(result.getType()); peeledGenericOpIndexingMaps.push_back( genericOp.getIndexingMapMatchingResult(result)); @@ -231,7 +231,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, })); for (auto resultNum : llvm::seq(origNumResults, peeledGenericOpNumResults)) { - OpResult result = peeledGenericOp.getResult(resultNum).cast(); + OpResult result = cast(peeledGenericOp.getResult(resultNum)); indexingMaps.push_back( peeledGenericOp.getIndexingMapMatchingResult(result)); } @@ -348,7 +348,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, /// the peeled operation. SmallVector replacements; for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { - OpResult opr = yieldValue.value().dyn_cast(); + OpResult opr = dyn_cast(yieldValue.value()); if (!opr || opr.getOwner() != peeledScalarOperation) replacements.push_back(residualGenericOp.getResult(yieldValue.index())); else diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 5fd4885..bf91a70 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -32,7 +32,7 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); auto inputType = inputs[0].getType(); - if (inputType.isa()) + if (isa(inputType)) return nullptr; // A detensored value is converted back by creating a new tensor from its @@ -320,9 +320,9 @@ struct LinalgDetensorize // * Add the argument to blockArgsToDetensor. // * Walk the use-def chain backwards to add each predecessor's // terminator-operands corresponding to currentItem to workList. - if (currentItem.dyn_cast()) { + if (dyn_cast(currentItem)) { BlockArgument currentItemBlockArgument = - currentItem.cast(); + cast(currentItem); Block *ownerBlock = currentItemBlockArgument.getOwner(); // Function arguments are not detensored/converted. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 4a2c0a6..d8eccb9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -308,7 +308,7 @@ struct MoveInitOperandsToInput : public OpRewritePattern { for (OpOperand *op : candidates) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(op->get()); - auto elemType = op->get().getType().cast().getElementType(); + auto elemType = cast(op->get().getType()).getElementType(); auto empty = rewriter.create( loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); @@ -387,7 +387,7 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand, // Early return for memrefs with affine maps to represent that we will always // leave them unchanged. Type actualType = opOperand->get().getType(); - if (auto memref = actualType.dyn_cast()) { + if (auto memref = dyn_cast(actualType)) { if (!memref.getLayout().isIdentity()) return std::nullopt; } @@ -437,7 +437,7 @@ struct ReplaceUnitExtents : public OpRewritePattern { ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { // There are no results for memref outputs. - auto origResultType = origOutput.getType().cast(); + auto origResultType = cast(origOutput.getType()); if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { unsigned rank = origResultType.getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); @@ -459,7 +459,7 @@ struct ReplaceUnitExtents : public OpRewritePattern { Value collapseValue(Value operand, ArrayRef targetShape, ArrayRef reassociation, Location loc, PatternRewriter &rewriter) const { - if (auto memrefType = operand.getType().dyn_cast()) { + if (auto memrefType = dyn_cast(operand.getType())) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, @@ -478,7 +478,7 @@ struct ReplaceUnitExtents : public OpRewritePattern { return rewriter.create(loc, targetType, operand, reassociation); } - if (auto tensorType = operand.getType().dyn_cast()) { + if (auto tensorType = dyn_cast(operand.getType())) { if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { FailureOr rankReducingExtract = tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, @@ -502,7 +502,7 @@ struct ReplaceUnitExtents : public OpRewritePattern { PatternRewriter &rewriter) const override { // Skip the pattern if the op has any tensor with special encoding. if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { - auto tensorType = type.dyn_cast(); + auto tensorType = dyn_cast(type); return tensorType && tensorType.getEncoding() != nullptr; })) return failure(); @@ -607,11 +607,10 @@ struct RankReducedExtractSliceOp if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); - auto rankReducedType = + auto rankReducedType = cast( tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( reassociation->size(), sliceOp.getSourceType(), offsets, sizes, - strides) - .cast(); + strides)); Location loc = sliceOp.getLoc(); Value newSlice = rewriter.create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index bf728a6..33ff4a3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -87,7 +87,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { // type. Producer must have full tensor semantics to avoid potential // aliasing between producer and consumer memrefs. if (!producer.hasTensorSemantics() || - !fusedOperand->get().getType().isa()) + !isa(fusedOperand->get().getType())) return false; // Verify that @@ -232,14 +232,14 @@ static void generateFusedElementwiseOpRegion( // forward the yield operand. auto producerYieldOp = cast(producerBlock.getTerminator()); unsigned producerResultNumber = - fusedOperand->get().cast().getResultNumber(); + cast(fusedOperand->get()).getResultNumber(); Value replacement = mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber)); // Sanity checks, if replacement is not already in the mapper then it must be // produced outside. if (replacement == producerYieldOp.getOperand(producerResultNumber)) { - if (auto bb = replacement.dyn_cast()) + if (auto bb = dyn_cast(replacement)) assert(bb.getOwner() != &producerBlock && "yielded block argument must have been mapped"); else @@ -278,7 +278,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand) { assert(areElementwiseOpsFusable(fusedOperand) && "expected elementwise operation pre-conditions to pass"); - auto producerResult = fusedOperand->get().cast(); + auto producerResult = cast(fusedOperand->get()); auto producer = cast(producerResult.getOwner()); auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. @@ -357,7 +357,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, fusedOutputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); Type resultType = opOperand->get().getType(); - if (!resultType.isa()) + if (!isa(resultType)) fusedResultTypes.push_back(resultType); } @@ -512,7 +512,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, return genericOp.hasTensorSemantics() && llvm::all_of(genericOp.getIndexingMaps().getValue(), [](Attribute attr) { - return attr.cast() + return cast(attr) .getValue() .isProjectedPermutation(); }) && @@ -776,7 +776,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, continue; } if (auto opOperandType = - opOperand->get().getType().dyn_cast()) { + dyn_cast(opOperand->get().getType())) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); RankedTensorType expandedOperandType = getExpandedType(opOperandType, indexingMap, expansionInfo); @@ -805,7 +805,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, SmallVector outputs; for (OpOperand *opOperand : genericOp.getDpsInitOperands()) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - auto opOperandType = opOperand->get().getType().cast(); + auto opOperandType = cast(opOperand->get().getType()); RankedTensorType expandedOutputType = getExpandedType(opOperandType, indexingMap, expansionInfo); if (expandedOutputType != opOperand->get().getType()) { @@ -921,7 +921,7 @@ struct FoldReshapeWithGenericOpByExpansion LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. - auto producerResult = reshapeOp.getSrc().dyn_cast(); + auto producerResult = dyn_cast(reshapeOp.getSrc()); if (!producerResult) { return rewriter.notifyMatchFailure(reshapeOp, "source not produced by an operation"); @@ -959,8 +959,9 @@ struct FoldReshapeWithGenericOpByExpansion // same type as the returns of the original generic op, the consumer reshape // op can be replaced by the source of the collapse_shape op that defines // the replacement. - Value reshapeReplacement = (*replacementValues) - [reshapeOp.getSrc().cast().getResultNumber()]; + Value reshapeReplacement = + (*replacementValues)[cast(reshapeOp.getSrc()) + .getResultNumber()]; if (auto collapseOp = reshapeReplacement.getDefiningOp()) { reshapeReplacement = collapseOp.getSrc(); @@ -1447,7 +1448,7 @@ FailureOr> mlir::linalg::collapseGenericOpIterationDims( .createLoopRanges(rewriter, genericOp.getLoc()); auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { if (auto attr = ofr.dyn_cast()) - return attr.cast().getInt() == value; + return cast(attr).getInt() == value; llvm::APInt actual; return matchPattern(ofr.get(), m_ConstantInt(&actual)) && actual.getSExtValue() == value; @@ -1521,8 +1522,8 @@ FailureOr> mlir::linalg::collapseGenericOpIterationDims( Value collapsedOpResult = collapsedGenericOp->getResult(originalResult.index()); auto originalResultType = - originalResult.value().getType().cast(); - auto collapsedOpResultType = collapsedOpResult.getType().cast(); + cast(originalResult.value().getType()); + auto collapsedOpResultType = cast(collapsedOpResult.getType()); if (collapsedOpResultType.getRank() != originalResultType.getRank()) { AffineMap indexingMap = genericOp.getIndexingMapMatchingResult(originalResult.value()); @@ -1671,7 +1672,7 @@ public: return false; }; - auto resultValue = opOperand->get().dyn_cast(); + auto resultValue = dyn_cast(opOperand->get()); if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) continue; @@ -1756,7 +1757,7 @@ struct RemoveOutsDependency : public OpRewritePattern { for (OpOperand *opOperand : op.getDpsInitOperands()) { if (!op.payloadUsesValueFromOperand(opOperand)) { Value operandVal = opOperand->get(); - auto operandType = operandVal.getType().dyn_cast(); + auto operandType = dyn_cast(operandVal.getType()); if (!operandType) continue; @@ -1810,7 +1811,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern { fillFound = true; Value fillVal = fillOp.value(); auto resultType = - fillOp.result().getType().cast().getElementType(); + cast(fillOp.result().getType()).getElementType(); Value convertedVal = convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType, /*isUnsignedCast =*/false); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index 549764d..18026cc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -28,7 +28,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { // TODO: The conversion pattern can be made to work for `any_of` here, but // it's more complex as it requires tracking which operands are scalars. return llvm::all_of(op->getOperandTypes(), - [](Type type) { return type.isa(); }); + [](Type type) { return isa(type); }); } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over @@ -67,7 +67,7 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { // Extract static / dynamic shape mix from the first operand. Value firstOperand = operands.front(); - auto rankedTensorType = t.cast(); + auto rankedTensorType = cast(t); auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); auto dynamicShape = linalg::createDynamicDimensions(b, loc, firstOperand); @@ -87,7 +87,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); - auto rank = op->getResult(0).getType().cast().getRank(); + auto rank = cast(op->getResult(0).getType()).getRank(); SmallVector indexingMaps( op->getNumResults() + op->getNumOperands(), rewriter.getMultiDimIdentityMap(rank)); @@ -104,7 +104,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { auto resultTypes = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { - return type.cast().getElementType(); + return cast(type).getElementType(); })); auto *scalarOp = builder.create(loc, op->getName().getIdentifier(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp index defa027..c89fc5b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -89,7 +89,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults Location loc = genericOp.getLoc(); SmallVector newResultTypes; for (Value v : newOutputOperands) - if (v.getType().isa()) + if (isa(v.getType())) newResultTypes.push_back(v.getType()); auto newOp = rewriter.create( loc, newResultTypes, newInputOperands, newOutputOperands, diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp index b6e2ffc..703db837 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -86,12 +86,12 @@ struct FusePadOp : OpRewritePattern { // result of the generic op. The low pad values are the offsets, the size of // the source is the size of the slice. // TODO: This insert/extract could be potentially made a utility method. - unsigned resultNumber = source.cast().getResultNumber(); + unsigned resultNumber = cast(source).getResultNumber(); SmallVector offsets = padOp.getMixedLowPad(); SmallVector sizes; sizes.reserve(offsets.size()); - for (const auto &shape : llvm::enumerate( - source.getType().cast().getShape())) { + for (const auto &shape : + llvm::enumerate(cast(source.getType()).getShape())) { if (ShapedType::isDynamic(shape.value())) { sizes.push_back( rewriter.create(loc, source, shape.index()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 6f9b608..cf3fd4b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -151,7 +151,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); for (OpOperand *operand : producer.getDpsInitOperands()) { - auto tensorType = operand->get().getType().dyn_cast(); + auto tensorType = dyn_cast(operand->get().getType()); if (!tensorType) continue; unsigned rank = tensorType.getRank(); @@ -210,20 +210,20 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, // dependence tracking since the dependence tracking is similar to what is done // w.r.t to buffers. static void getProducerOfTensor(Value tensor, OpResult &opResult) { - if (!tensor.getType().isa()) + if (!isa(tensor.getType())) return; while (true) { LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); if (auto linalgOp = tensor.getDefiningOp()) { - opResult = tensor.cast(); + opResult = cast(tensor); return; } if (auto sliceOp = tensor.getDefiningOp()) { tensor = sliceOp.getSource(); continue; } - if (auto blockArg = tensor.dyn_cast()) { + if (auto blockArg = dyn_cast(tensor)) { if (auto forOp = blockArg.getDefiningOp()) { tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index d8ecc80..87aade3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -227,7 +227,7 @@ SmallVector TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { return {}; bbArgs.push_back(bbArg); OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); - bbArg = iterArg->get().dyn_cast(); + bbArg = dyn_cast(iterArg->get()); } // Reverse the block arguments to order them from outer to inner. @@ -358,13 +358,13 @@ FailureOr TileLoopNest::fuseProducer(OpBuilder &b, // Check if the producer is a LinalgOp possibly passed by iteration argument. OpOperand *iterArg = nullptr; - auto producerResult = sliceOp.getSource().dyn_cast(); - if (auto bbArg = sliceOp.getSource().dyn_cast()) { + auto producerResult = dyn_cast(sliceOp.getSource()); + if (auto bbArg = dyn_cast(sliceOp.getSource())) { iterArg = getTiedIterArg(bbArg); // Check the iteration argument may be used to pass in the producer output. if (!iterArg || hasOtherUses(bbArg, sliceOp)) return failure(); - producerResult = iterArg->get().dyn_cast(); + producerResult = dyn_cast(iterArg->get()); } if (!producerResult || !isa(producerResult.getOwner())) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 251f7d8..21d83d2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -549,7 +549,7 @@ static FailureOr buildPackingLoopNestImpl( int paddedRank = paddedTensorType.getRank(); // Step 0. Populate bvm with opToHoist.getSource if relevant. - BlockArgument bbArg = opToHoist.getSource().dyn_cast(); + BlockArgument bbArg = dyn_cast(opToHoist.getSource()); while (bbArg) { auto forOp = dyn_cast(bbArg.getOwner()->getParentOp()); if (!forOp) @@ -558,7 +558,7 @@ static FailureOr buildPackingLoopNestImpl( break; OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg); bvm.map(bbArg, operand.get()); - bbArg = operand.get().dyn_cast(); + bbArg = dyn_cast(operand.get()); } // Step 1. iteratively clone loops and push `hoistedPackedTensor`. @@ -754,9 +754,8 @@ static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp, if (!destOp) break; LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n"); - source = - destOp.getDpsInitOperand(source.cast().getResultNumber()) - ->get(); + source = destOp.getDpsInitOperand(cast(source).getResultNumber()) + ->get(); } LLVM_DEBUG(DBGS() << "--final source: " << source << "\n"); LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 13ec4d9..01b893a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -86,7 +86,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); func.walk([&](vector::TransferReadOp transferRead) { - if (!transferRead.getShapedType().isa()) + if (!isa(transferRead.getShapedType())) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate for hoisting: " diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 23c831f..d91d8c4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -162,7 +162,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, SmallVector, 8> indexing; SmallVector outputBuffers; for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) { - if (!outputOperand->get().getType().isa()) + if (!isa(outputOperand->get().getType())) continue; indexing.push_back(makeCanonicalAffineApplies( b, loc, linalgOp.getMatchingIndexingMap(outputOperand), @@ -242,7 +242,7 @@ static FailureOr linalgOpToLoopsImpl(RewriterBase &rewriter, return failure(); // The induction variable is a block argument of the entry block of the // loop operation. - BlockArgument ivVal = iv.dyn_cast(); + BlockArgument ivVal = dyn_cast(iv); if (!ivVal) return failure(); loopSet.insert(ivVal.getOwner()->getParentOp()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp index cabd342..93fa5ff 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -44,9 +44,9 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, auto result = operation->getResult(0); - auto kernelTy = kernel.getType().dyn_cast(); - auto initTy = init.getType().dyn_cast(); - auto resultTy = result.getType().template dyn_cast(); + auto kernelTy = dyn_cast(kernel.getType()); + auto initTy = dyn_cast(init.getType()); + auto resultTy = dyn_cast(result.getType()); if (!kernelTy || !initTy || !resultTy) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 4fcffea..d39cd0e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -292,9 +292,9 @@ promoteSubViews(ImplicitLocOpBuilder &b, }) .Case([&](ComplexType t) { Value tmp; - if (auto et = t.getElementType().dyn_cast()) + if (auto et = dyn_cast(t.getElementType())) tmp = b.create(FloatAttr::get(et, 0.0)); - else if (auto et = t.getElementType().cast()) + else if (auto et = cast(t.getElementType())) tmp = b.create(IntegerAttr::get(et, 0)); return b.create(t, tmp, tmp); }) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index 344b289..203ae43 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -93,7 +93,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, {iterationSpace[dimension].offset, iterationSpace[dimension].size, minSplitPoint}); if (auto attr = remainingSize.dyn_cast()) { - if (attr.cast().getValue().isZero()) + if (cast(attr).getValue().isZero()) return {op, TilingInterface()}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index b4d95b7..982b024 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -113,7 +113,7 @@ FailureOr mlir::linalg::splitReduction( } Type newType = RankedTensorType::get( newShape, - operand->get().getType().cast().getElementType()); + cast(operand->get().getType()).getElementType()); Value newInput = b.create( loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); @@ -309,7 +309,7 @@ FailureOr mlir::linalg::splitReductionByScaling( fillOps.reserve(op.getNumDpsInits()); for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) { Value rankedTensor = std::get<0>(it)->get(); - auto t = rankedTensor.getType().cast(); + auto t = cast(rankedTensor.getType()); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( reductionDimSize / splitFactor, insertSplitDimension); SmallVector dims = @@ -383,7 +383,7 @@ FailureOr mlir::linalg::splitReductionByScaling( combinerOps)) { Value reindexedOutput = std::get<0>(it); Value originalOutput = std::get<1>(it)->get(); - auto originalOutputType = originalOutput.getType().cast(); + auto originalOutputType = cast(originalOutput.getType()); Operation *combinerOp = std::get<2>(it); AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp index c0355a1..f455678 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp @@ -65,7 +65,7 @@ static FailureOr findHoistableMatchingExtractSlice(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp, BlockArgument srcTensor) { - assert(srcTensor.getType().isa() && "not a ranked tensor"); + assert(isa(srcTensor.getType()) && "not a ranked tensor"); auto forOp = cast(srcTensor.getOwner()->getParentOp()); @@ -92,7 +92,7 @@ findHoistableMatchingExtractSlice(RewriterBase &rewriter, // Skip insert_slice whose vector is defined within the loop: we need to // hoist that definition first otherwise dominance violations trigger. - if (!extractSliceOp.getSource().isa() && + if (!isa(extractSliceOp.getSource()) && !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) { LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n"); continue; @@ -119,7 +119,7 @@ static FailureOr findHoistableMatchingTransferRead(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, BlockArgument srcTensor) { - if (!srcTensor.getType().isa()) + if (!isa(srcTensor.getType())) return failure(); auto forOp = cast(srcTensor.getOwner()->getParentOp()); @@ -152,7 +152,7 @@ findHoistableMatchingTransferRead(RewriterBase &rewriter, // transfer_read may be of a vector that is defined within the loop: we // traverse it by virtue of bypassing disjoint subset operations rooted at // a bbArg and yielding a matching yield. - if (!read.getSource().isa() && + if (!isa(read.getSource()) && !forOp.isDefinedOutsideOfLoop(read.getSource())) { LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop " "dependent but will be tested for disjointness as " diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 1ff1166..57798fc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -49,7 +49,7 @@ static bool isZero(OpFoldResult v) { if (!v) return false; if (auto attr = v.dyn_cast()) { - IntegerAttr intAttr = attr.dyn_cast(); + IntegerAttr intAttr = dyn_cast(attr); return intAttr && intAttr.getValue().isZero(); } if (auto cst = v.get().getDefiningOp()) @@ -105,7 +105,7 @@ void mlir::linalg::transformIndexOps( static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, OpFoldResult value) { if (auto attr = value.dyn_cast()) { - assert(attr.cast().getValue().isStrictlyPositive() && + assert(cast(attr).getValue().isStrictlyPositive() && "expected strictly positive tile size and divisor"); return; } @@ -587,8 +587,8 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, SmallVector loops; loops.reserve(ivs.size()); for (auto iv : ivs) { - if (iv.isa()) { - loops.push_back(iv.cast().getOwner()->getParentOp()); + if (isa(iv)) { + loops.push_back(cast(iv).getOwner()->getParentOp()); assert(loops.back() && "no owner found for induction variable!"); } else { // TODO: Instead of doing this, try to recover the ops used instead of the @@ -712,7 +712,7 @@ FailureOr linalg::tileReductionUsingForall( outOffsets[reductionDim] = forallOp.getInductionVars().front(); // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( - loc, initOperand->get().getType().cast(), + loc, cast(initOperand->get().getType()), destBbArgs[destNum], outOffsets, sizes, strides)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 1c3745f..36f13fa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -365,8 +365,7 @@ struct LinalgOpPartialReductionInterface // Then create a new reduction that only reduce the newly added dimension // from the previous op. - int64_t intermRank = - partialReduce[0].getType().cast().getRank(); + int64_t intermRank = cast(partialReduce[0].getType()).getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector reductionIteratorTypes; SmallVector exprs; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index a9e8ac0..2300895 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -89,7 +89,7 @@ static FailureOr padOperandToSmallestStaticBoundingBox( // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; while (auto linalgOp = currOpOperand->get().getDefiningOp()) { - OpResult result = currOpOperand->get().cast(); + OpResult result = cast(currOpOperand->get()); currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber()); } @@ -133,7 +133,7 @@ static FailureOr padOperandToSmallestStaticBoundingBox( // If the size is an attribute add it directly to `paddedShape`. if (en.value().is()) { paddedShape[shapeIdx++] = - en.value().get().dyn_cast().getInt(); + dyn_cast(en.value().get()).getInt(); LLVM_DEBUG( DBGS() << "------dim is an attr, add it to padded shape, SKIP\n"); continue; @@ -232,7 +232,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); - int64_t rank = paddedResult.getType().cast().getRank(); + int64_t rank = cast(paddedResult.getType()).getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes; SmallVector strides(rank, rewriter.getIndexAttr(1)); @@ -476,7 +476,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, tensor::PackOp packOp) { // 1. Filter out NYI cases. auto packedTensorType = - packOp->getResultTypes().front().cast(); + cast(packOp->getResultTypes().front()); if (llvm::any_of(packOp.getStaticInnerTiles(), [](int64_t size) { return ShapedType::isDynamic(size); })) { return rewriter.notifyMatchFailure( @@ -639,7 +639,7 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, int64_t packedRank = packedTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); - auto destTensorType = unPackOp.getDest().getType().cast(); + auto destTensorType = cast(unPackOp.getDest().getType()); if (unPackOp.isLikeUnPad()) { // This unpack is just a plain unpad. // Just extract the slice from the higher ranked tensor. @@ -889,7 +889,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace( // Sanity check of the expected transposed tensor type. auto tensorType = permuteShape( - opOperand.get().getType().cast(), permutation); + cast(opOperand.get().getType()), permutation); (void)tensorType; assert(tensorType == transposedValue.getType() && "expected tensor type mismatch"); @@ -1050,8 +1050,8 @@ LogicalResult PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const { - auto inputShapedType = padOp.getSource().getType().cast(); - auto resultShapedType = padOp.getResult().getType().cast(); + auto inputShapedType = cast(padOp.getSource().getType()); + auto resultShapedType = cast(padOp.getResult().getType()); // Bail on non-static shapes. if (!inputShapedType.hasStaticShape()) @@ -1068,7 +1068,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp, Operation *definingOp = padValue.getDefiningOp(); if (definingOp && definingOp->getBlock() == &block) return failure(); - if (!definingOp && padValue.cast().getOwner() == &block) + if (!definingOp && cast(padValue).getOwner() == &block) return failure(); // Create tensor with the padded shape @@ -1134,7 +1134,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, return val; return rewriter .create( - padOp.getLoc(), ofr.get().cast().getInt()) + padOp.getLoc(), cast(ofr.get()).getInt()) .getResult(); }; @@ -1514,9 +1514,9 @@ FailureOr DownscaleSizeOneWindowed2DConvolution:: Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto kernelType = dyn_cast(kernel.getType()); + auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); @@ -1638,9 +1638,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto kernelType = dyn_cast(kernel.getType()); + auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); @@ -1706,9 +1706,9 @@ DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, Value kernel = convOp.getInputs().back(); Value output = convOp.getOutputs().front(); - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto kernelType = dyn_cast(kernel.getType()); + auto outputType = dyn_cast(output.getType()); auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 56b4516..2236d1b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -563,7 +563,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, loc, value, outputOperand->get(), indices, writeMap); } else { // 0-d case is still special: do not invert the reindexing writeMap. - if (!value.getType().isa()) + if (!isa(value.getType())) value = rewriter.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); write = rewriter.create( @@ -864,7 +864,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, targetShape.back() == 1) return VectorMemoryAccessKind::Gather; - auto inputShape = extractOp.getTensor().getType().cast(); + auto inputShape = cast(extractOp.getTensor().getType()); // 2. Assume that it's a gather load when reading _from_ a tensor for which // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. @@ -1024,8 +1024,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, const IRMapping &bvm) { Value reduceVec = bvm.lookup(reduceValue); Value outputVec = bvm.lookup(initialValue); - auto reduceType = reduceVec.getType().dyn_cast(); - auto outputType = outputVec.getType().dyn_cast(); + auto reduceType = dyn_cast(reduceVec.getType()); + auto outputType = dyn_cast(outputVec.getType()); // Reduce only if needed as the value may already have been reduce for // contraction vectorization. if (!reduceType || @@ -1082,7 +1082,7 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, // 4 . Check if the operation is a reduction. SmallVector> reductionOperands; for (Value operand : op->getOperands()) { - auto blockArg = operand.dyn_cast(); + auto blockArg = dyn_cast(operand); if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) continue; @@ -1107,7 +1107,7 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, // a. first get the first max ranked shape. SmallVector firstMaxRankedShape; for (Value operand : op->getOperands()) { - auto vt = bvm.lookup(operand).getType().dyn_cast(); + auto vt = dyn_cast(bvm.lookup(operand).getType()); if (vt && firstMaxRankedShape.size() < vt.getShape().size()) firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); } @@ -1230,7 +1230,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. - if (readValue.getType().cast().getRank() == 0) + if (cast(readValue.getType()).getRank() == 0) readValue = rewriter.create(loc, readValue); LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue @@ -1528,8 +1528,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, memref::CopyOp copyOp) { - auto srcType = copyOp.getSource().getType().cast(); - auto dstType = copyOp.getTarget().getType().cast(); + auto srcType = cast(copyOp.getSource().getType()); + auto dstType = cast(copyOp.getTarget().getType()); if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); @@ -1549,7 +1549,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, Value readValue = rewriter.create( loc, readType, copyOp.getSource(), indices, rewriter.getMultiDimIdentityMap(srcType.getRank())); - if (readValue.getType().cast().getRank() == 0) { + if (cast(readValue.getType()).getRank() == 0) { readValue = rewriter.create(loc, readValue); readValue = rewriter.create(loc, writeType, readValue); } @@ -1566,7 +1566,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, /// Helper function that retrieves the value of an IntegerAttr. static int64_t getIntFromAttr(Attribute attr) { - return attr.cast().getInt(); + return cast(attr).getInt(); } /// Given an ArrayRef of OpFoldResults, return a vector of Values. @@ -1836,8 +1836,8 @@ struct PadOpVectorizationWithTransferWritePattern if (hasSameTensorSize(castOp.getSource(), afterTrimming)) return true; - auto t1 = beforePadding.getType().dyn_cast(); - auto t2 = afterTrimming.getType().dyn_cast(); + auto t1 = dyn_cast(beforePadding.getType()); + auto t2 = dyn_cast(afterTrimming.getType()); // Only RankedTensorType supported. if (!t1 || !t2) return false; @@ -1946,7 +1946,7 @@ struct PadOpVectorizationWithInsertSlicePattern if (!padValue) return failure(); // Dynamic shapes not supported. - if (!padOp.getResult().getType().cast().hasStaticShape()) + if (!cast(padOp.getResult().getType()).hasStaticShape()) return failure(); // Pad result not used as destination. if (insertOp.getDest() == padOp.getResult()) @@ -2074,7 +2074,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( memref::CopyOp copyOp; for (auto &u : subView.getUses()) { if (auto newCopyOp = dyn_cast(u.getOwner())) { - assert(newCopyOp.getTarget().getType().isa()); + assert(isa(newCopyOp.getTarget().getType())); if (newCopyOp.getTarget() != subView) continue; if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) @@ -2091,7 +2091,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( FillOp maybeFillOp; for (auto &u : viewOrAlloc.getUses()) { if (auto newFillOp = dyn_cast(u.getOwner())) { - assert(newFillOp.output().getType().isa()); + assert(isa(newFillOp.output().getType())); if (newFillOp.output() != viewOrAlloc) continue; if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) @@ -2162,7 +2162,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( return rewriter.notifyMatchFailure(xferOp, "no copy found"); // `out` is the subview copied into that we replace. - assert(copyOp.getTarget().getType().isa()); + assert(isa(copyOp.getTarget().getType())); Value out = copyOp.getTarget(); // Forward vector.transfer into copy. @@ -2204,7 +2204,7 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { namespace { bool isCastOfBlockArgument(Operation *op) { return isa(op) && op->getNumOperands() == 1 && - op->getOperand(0).isa(); + isa(op->getOperand(0)); } bool isSupportedPoolKind(vector::CombiningKind kind) { @@ -2268,9 +2268,9 @@ struct Conv1DGenerator lhsShaped = linalgOp.getDpsInputOperand(0)->get(); rhsShaped = linalgOp.getDpsInputOperand(1)->get(); resShaped = linalgOp.getDpsInitOperand(0)->get(); - lhsShapedType = lhsShaped.getType().dyn_cast(); - rhsShapedType = rhsShaped.getType().dyn_cast(); - resShapedType = resShaped.getType().dyn_cast(); + lhsShapedType = dyn_cast(lhsShaped.getType()); + rhsShapedType = dyn_cast(rhsShaped.getType()); + resShapedType = dyn_cast(resShaped.getType()); if (!lhsShapedType || !rhsShapedType || !resShapedType) return; // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR @@ -2717,8 +2717,8 @@ struct Conv1DGenerator /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { - auto rhsTy = rhs.getType().cast(); - auto resTy = res.getType().cast(); + auto rhsTy = cast(rhs.getType()); + auto resTy = cast(res.getType()); // TODO(suderman): Change this to use a vector.ima intrinsic. lhs = promote(rewriter, loc, lhs, resTy); @@ -2730,7 +2730,7 @@ struct Conv1DGenerator if (!lhs || !rhs) return nullptr; - if (resTy.getElementType().isa()) + if (isa(resTy.getElementType())) return rewriter.create(loc, lhs, rhs, res); auto mul = rewriter.create(loc, lhs, rhs); @@ -2863,15 +2863,14 @@ private: // Otherwise, check for one or zero `ext` predecessor. The `ext` operands // must be block arguments or extension of block arguments. bool setOperKind(Operation *reduceOp) { - int numBlockArguments = - llvm::count_if(reduceOp->getOperands(), - [](Value v) { return v.isa(); }); + int numBlockArguments = llvm::count_if( + reduceOp->getOperands(), [](Value v) { return isa(v); }); switch (numBlockArguments) { case 1: { // Will be convolution if feeder is a MulOp. // Otherwise, if it can be pooling. auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) { - return !v.isa(); + return !isa(v); }); Operation *feedOp = (*feedValIt).getDefiningOp(); if (isCastOfBlockArgument(feedOp)) { @@ -2880,7 +2879,7 @@ private: poolExtOp = feedOp->getName().getIdentifier(); } else if (!(isa(feedOp) && llvm::all_of(feedOp->getOperands(), [](Value v) { - if (v.isa()) + if (isa(v)) return true; if (Operation *op = v.getDefiningOp()) return isCastOfBlockArgument(op); diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp index 12b55ef..f7376c0 100644 --- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp @@ -43,16 +43,16 @@ namespace mlir { namespace linalg { Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - if (val.getType().isa()) + if (isa(val.getType())) return b.createOrFold(loc, val, dim); - if (val.getType().isa()) + if (isa(val.getType())) return b.createOrFold(loc, val, dim); llvm_unreachable("Expected MemRefType or TensorType"); } OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - auto shapedType = val.getType().cast(); + auto shapedType = cast(val.getType()); if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) return createOrFoldDimOp(b, loc, val, dim); return b.getIndexAttr(shapedType.getDimSize(dim)); @@ -60,7 +60,7 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, SmallVector createDynamicDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); + auto shapedType = cast(val.getType()); assert(shapedType.hasRank() && "`val` must have a static rank"); SmallVector res; res.reserve(shapedType.getRank()); @@ -73,7 +73,7 @@ SmallVector createDynamicDimensions(OpBuilder &b, Location loc, SmallVector getMixedDimensions(OpBuilder &b, Location loc, Value val) { - auto shapedType = val.getType().cast(); + auto shapedType = cast(val.getType()); assert(shapedType.hasRank() && "`val` must have a static rank"); SmallVector dynamicDims = createDynamicDimensions(b, loc, val); return getMixedValues(shapedType.getShape(), dynamicDims, b); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 5e3413a..ef31668 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -281,7 +281,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, auto linalgOp = current.getDefiningOp(); if (!linalgOp) break; - OpResult opResult = current.cast(); + OpResult opResult = cast(current); current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); } auto padOp = current ? current.getDefiningOp() : nullptr; @@ -331,7 +331,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, Value outputTensor, ArrayRef transposeVector) { - auto resultTensorType = outputTensor.getType().cast(); + auto resultTensorType = cast(outputTensor.getType()); Type elementType = resultTensorType.getElementType(); assert(isPermutationVector(transposeVector) && @@ -366,9 +366,9 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, } GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { - auto memrefTypeTo = to.getType().cast(); + auto memrefTypeTo = cast(to.getType()); #ifndef NDEBUG - auto memrefTypeFrom = from.getType().cast(); + auto memrefTypeFrom = cast(from.getType()); assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() && "`from` and `to` memref must have the same rank"); #endif // NDEBUG @@ -650,7 +650,7 @@ void GenerateLoopNest::doit( static Value materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = dyn_cast(valueToTile.getType()); auto *sliceOp = TypeSwitch(shapedType) .Case([&](MemRefType) { return builder.create( @@ -685,7 +685,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef lbs, ArrayRef ubs, ArrayRef subShapeSizes, bool omitPartialTileCheck) { - auto shapedType = valueToTile.getType().dyn_cast(); + auto shapedType = dyn_cast(valueToTile.getType()); assert(shapedType && "only shaped types can be tiled"); ArrayRef shape = shapedType.getShape(); int64_t rank = shapedType.getRank(); @@ -889,7 +889,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, // subdomains explicit. Type operandType = opOperand.get().getType(); - if (!isTiled(map, tileSizes) && !(operandType.isa() && + if (!isTiled(map, tileSizes) && !(isa(operandType) && linalgOp.isDpsInit(&opOperand))) { allSliceParams.push_back(std::nullopt); LLVM_DEBUG(llvm::dbgs() @@ -971,7 +971,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes) { auto size = it.value(); curr.push_back(dim); auto attr = size.dyn_cast(); - if (attr && attr.cast().getInt() == 1) + if (attr && cast(attr).getInt() == 1) continue; reassociation.emplace_back(ReassociationIndices{}); std::swap(reassociation.back(), curr); @@ -989,7 +989,7 @@ std::optional getNeutralElement(Operation *op) { // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); - if (auto floatType = resultType.dyn_cast()) { + if (auto floatType = dyn_cast(resultType)) { const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index c5e008e..dcace48 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -64,7 +64,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op, // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&](Value value) -> Value { - if (auto vec = op.getType().dyn_cast()) + if (auto vec = dyn_cast(op.getType())) return rewriter.create(op.getLoc(), vec, value); return value; }; @@ -167,7 +167,7 @@ PowIStrengthReduction::matchAndRewrite( // Maybe broadcasts scalar value into vector type compatible with `op`. auto bcast = [&loc, &op, &rewriter](Value value) -> Value { - if (auto vec = op.getType().template dyn_cast()) + if (auto vec = dyn_cast(op.getType())) return rewriter.create(loc, vec, value); return value; }; diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 6d286a3..a3efc6e 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -40,7 +40,7 @@ using namespace mlir::vector; // Returns vector shape if the type is a vector. Returns an empty shape if it is // not a vector. static ArrayRef vectorShape(Type type) { - auto vectorType = type.dyn_cast(); + auto vectorType = dyn_cast(type); return vectorType ? vectorType.getShape() : ArrayRef(); } @@ -54,14 +54,14 @@ static ArrayRef vectorShape(Value value) { // Broadcasts scalar type into vector type (iff shape is non-scalar). static Type broadcast(Type type, ArrayRef shape) { - assert(!type.isa() && "must be scalar type"); + assert(!isa(type) && "must be scalar type"); return !shape.empty() ? VectorType::get(shape, type) : type; } // Broadcasts scalar value into vector (iff shape is non-scalar). static Value broadcast(ImplicitLocOpBuilder &builder, Value value, ArrayRef shape) { - assert(!value.getType().isa() && "must be scalar value"); + assert(!isa(value.getType()) && "must be scalar value"); auto type = broadcast(value.getType(), shape); return !shape.empty() ? builder.create(type, value) : value; } @@ -92,7 +92,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, assert(!operands.empty() && "operands must be not empty"); assert(vectorWidth > 0 && "vector width must be larger than 0"); - VectorType inputType = operands[0].getType().cast(); + VectorType inputType = cast(operands[0].getType()); ArrayRef inputShape = inputType.getShape(); // If input shape matches target vector width, we can just call the @@ -118,7 +118,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, for (unsigned i = 0; i < operands.size(); ++i) { auto operand = operands[i]; - auto eltType = operand.getType().cast().getElementType(); + auto eltType = cast(operand.getType()).getElementType(); auto expandedType = VectorType::get(expandedShape, eltType); expandedOperands[i] = builder.create(expandedType, operand); @@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, } // Stitch results together into one large vector. - Type resultEltType = results[0].getType().cast().getElementType(); + Type resultEltType = cast(results[0].getType()).getElementType(); Type resultExpandedType = VectorType::get(expandedShape, resultEltType); Value result = builder.create( resultExpandedType, builder.getZeroAttr(resultExpandedType)); @@ -318,9 +318,9 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) { // Create F32 equivalent type. Type newType; - if (auto shaped = origType.dyn_cast()) { + if (auto shaped = dyn_cast(origType)) { newType = shaped.clone(rewriter.getF32Type()); - } else if (origType.isa()) { + } else if (isa(origType)) { newType = rewriter.getF32Type(); } else { return rewriter.notifyMatchFailure(op, diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 7f702e1..ae2472d 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -69,7 +69,7 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( results.push_back(*newBuffer); } - transformResults.set(getResult().cast(), results); + transformResults.set(cast(getResult()), results); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index 369f225..9b1d85b2 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -57,7 +57,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern { // always 1. if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) { Attribute attr = valueOrAttr.dyn_cast(); - return attr && attr.cast().getInt() == 1; + return attr && cast(attr).getInt() == 1; })) { strides = SmallVector(sourceOp.getMixedStrides().size(), rewriter.getI64IntegerAttr(1)); @@ -93,8 +93,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern { // If both offsets are static we can simply calculate the combined // offset statically. offsets.push_back(rewriter.getI64IntegerAttr( - opOffsetAttr.cast().getInt() + - sourceOffsetAttr.cast().getInt())); + cast(opOffsetAttr).getInt() + + cast(sourceOffsetAttr).getInt())); } else { // When either offset is dynamic, we must emit an additional affine // transformation to add the two offsets together dynamically. @@ -102,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern { SmallVector affineApplyOperands; for (auto valueOrAttr : {opOffset, sourceOffset}) { if (auto attr = valueOrAttr.dyn_cast()) { - expr = expr + attr.cast().getInt(); + expr = expr + cast(attr).getInt(); } else { expr = expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index 6202b57..57f0141 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -149,7 +149,7 @@ void memref::populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter) { typeConverter.addConversion( [&typeConverter](MemRefType ty) -> std::optional { - auto intTy = ty.getElementType().dyn_cast(); + auto intTy = dyn_cast(ty.getElementType()); if (!intTy) return ty; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index 38fb113..8a276eb 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -89,11 +89,11 @@ public: LogicalResult matchAndRewrite(memref::ReshapeOp op, PatternRewriter &rewriter) const final { - auto shapeType = op.getShape().getType().cast(); + auto shapeType = cast(op.getShape().getType()); if (!shapeType.hasStaticShape()) return failure(); - int64_t rank = shapeType.cast().getDimSize(0); + int64_t rank = cast(shapeType).getDimSize(0); SmallVector sizes, strides; sizes.resize(rank); strides.resize(rank); @@ -106,7 +106,7 @@ public: if (op.getType().isDynamicDim(i)) { Value index = rewriter.create(loc, i); size = rewriter.create(loc, op.getShape(), index); - if (!size.getType().isa()) + if (!isa(size.getType())) size = rewriter.create( loc, rewriter.getIndexType(), size); sizes[i] = size; @@ -141,7 +141,7 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase { op.getKind() != arith::AtomicRMWKind::minf; }); target.addDynamicallyLegalOp([](memref::ReshapeOp op) { - return !op.getShape().getType().cast().hasStaticShape(); + return !cast(op.getShape().getType()).hasStaticShape(); }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index ea372bf..ff2c410 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -62,7 +62,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, // Build a plain extract_strided_metadata(memref) from subview(memref). Location origLoc = subview.getLoc(); Value source = subview.getSource(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -115,7 +115,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, // The final result is . // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all // the values. - auto subType = subview.getType().cast(); + auto subType = cast(subview.getType()); unsigned subRank = subType.getRank(); // The sizes of the final type are defined directly by the input sizes of @@ -338,7 +338,7 @@ SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) @@ -358,10 +358,9 @@ SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get() - .cast() - .getInt(); + int64_t baseExpandedStride = + cast(expandedStrides[doneStrideIdx].get()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, @@ -372,10 +371,9 @@ SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, // Now apply the origStride to the remaining dimensions. AffineExpr s0 = builder.getAffineSymbolExpr(0); for (; doneStrideIdx < groupSize; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get() - .cast() - .getInt(); + int64_t baseExpandedStride = + cast(expandedStrides[doneStrideIdx].get()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); } @@ -445,7 +443,7 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, // Build the affine expr of the product of the original sizes involved in that // group. Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; @@ -479,7 +477,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, "Reassociation group should have at least one dimension"); Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); @@ -562,7 +560,7 @@ public: // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = reshape.getLoc(); Value source = reshape.getSrc(); - auto sourceType = source.getType().cast(); + auto sourceType = cast(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -650,8 +648,7 @@ public: if (!allocLikeOp) return failure(); - auto memRefType = - allocLikeOp.getResult().getType().template cast(); + auto memRefType = cast(allocLikeOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) return rewriter.notifyMatchFailure( allocLikeOp, "alloc-like operations should have been normalized"); @@ -688,7 +685,7 @@ public: SmallVector results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast(); + auto baseBufferType = cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); @@ -737,7 +734,7 @@ public: if (!getGlobalOp) return failure(); - auto memRefType = getGlobalOp.getResult().getType().cast(); + auto memRefType = cast(getGlobalOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure( getGlobalOp, @@ -759,7 +756,7 @@ public: SmallVector results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast(); + auto baseBufferType = cast(op.getBaseBuffer().getType()); int64_t offset = 0; if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); @@ -838,8 +835,7 @@ class ExtractStridedMetadataOpReinterpretCastFolder return rewriter.notifyMatchFailure( reinterpretCastOp, "reinterpret_cast source's type is incompatible"); - auto memrefType = - reinterpretCastOp.getResult().getType().cast(); + auto memrefType = cast(reinterpretCastOp.getResult().getType()); unsigned rank = memrefType.getRank(); SmallVector results; results.resize_for_overwrite(rank * 2 + 2); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index 5141b5f..05ba6a3 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -120,7 +120,7 @@ template static FailureOr getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { Value src = transferLikeOp.getSource(); - if (src.getType().isa()) + if (isa(src.getType())) return src; return failure(); } @@ -240,7 +240,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern { return rewriter.notifyMatchFailure(loadStoreLikeOp, "source is not a memref"); Value srcMemRef = *failureOrSrcMemRef; - auto ldStTy = srcMemRef.getType().cast(); + auto ldStTy = cast(srcMemRef.getType()); unsigned loadStoreRank = ldStTy.getRank(); // Don't waste compile time if there is nothing to rewrite. if (loadStoreRank == 0) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 72675b0..2c30e98 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -148,7 +148,7 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, if (collapseShapeOp.getReassociationIndices().empty()) { auto zeroAffineMap = rewriter.getConstantAffineMap(0); int64_t srcRank = - collapseShapeOp.getViewSource().getType().cast().getRank(); + cast(collapseShapeOp.getViewSource().getType()).getRank(); for (int64_t i = 0; i < srcRank; i++) { OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, loc, zeroAffineMap, dynamicIndices); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index aa1d27d..68b72ef 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -71,11 +71,9 @@ propagateSubViewOp(RewriterBase &rewriter, UnrealizedConversionCastOp conversionOp, SubViewOp op) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); - auto newResultType = - SubViewOp::inferRankReducedResultType( - op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides()) - .cast(); + auto newResultType = cast(SubViewOp::inferRankReducedResultType( + op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides())); Value newSubview = rewriter.create( op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index ee1adcc..eb1df2a8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -61,11 +61,11 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), val.getType().cast(), + subviewUse.getType().getShape(), cast(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); Value newSubview = rewriter.create( - subviewUse->getLoc(), newType.cast(), val, + subviewUse->getLoc(), cast(newType), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); @@ -209,9 +209,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, for (int64_t i = 0, e = originalShape.size(); i != e; ++i) sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); // Strides is [1, 1 ... 1 ]. - auto dstMemref = memref::SubViewOp::inferRankReducedResultType( - originalShape, mbMemRefType, offsets, sizes, strides) - .cast(); + auto dstMemref = + cast(memref::SubViewOp::inferRankReducedResultType( + originalShape, mbMemRefType, offsets, sizes, strides)); Value subview = rewriter.create(loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index c252433..aa21497 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -180,7 +180,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { llvm::seq(0, callOp.getNumResults())) { Value oldMemRef = callOp.getResult(resIndex); if (auto oldMemRefType = - oldMemRef.getType().dyn_cast()) + dyn_cast(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); @@ -192,7 +192,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); - if (auto oldMemRefType = oldMemRef.getType().dyn_cast()) + if (auto oldMemRefType = dyn_cast(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return false; @@ -226,7 +226,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, funcOp.walk([&](func::ReturnOp returnOp) { for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { Type opType = operandEn.value().getType(); - MemRefType memrefType = opType.dyn_cast(); + MemRefType memrefType = dyn_cast(opType); // If type is not memref or if the memref type is same as that in // function's return signature then no update is required. if (!memrefType || memrefType == resultTypes[operandEn.index()]) @@ -284,7 +284,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, if (oldResult.getType() == newResult.getType()) continue; AffineMap layoutMap = - oldResult.getType().cast().getLayout().getAffineMap(); + cast(oldResult.getType()).getLayout().getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, /*extraIndices=*/{}, /*indexRemap=*/layoutMap, @@ -358,7 +358,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, for (unsigned argIndex : llvm::seq(0, functionType.getNumInputs())) { Type argType = functionType.getInput(argIndex); - MemRefType memrefType = argType.dyn_cast(); + MemRefType memrefType = dyn_cast(argType); // Check whether argument is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -422,11 +422,11 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, // Replace all uses of the old memrefs. Value oldMemRef = op->getResult(resIndex); Value newMemRef = newOp->getResult(resIndex); - MemRefType oldMemRefType = oldMemRef.getType().dyn_cast(); + MemRefType oldMemRefType = dyn_cast(oldMemRef.getType()); // Check whether the operation result is MemRef type. if (!oldMemRefType) continue; - MemRefType newMemRefType = newMemRef.getType().cast(); + MemRefType newMemRefType = cast(newMemRef.getType()); if (oldMemRefType == newMemRefType) continue; // TODO: Assume single layout map. Multiple maps not supported. @@ -466,7 +466,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, for (unsigned resIndex : llvm::seq(0, functionType.getNumResults())) { Type resType = functionType.getResult(resIndex); - MemRefType memrefType = resType.dyn_cast(); + MemRefType memrefType = dyn_cast(resType); // Check whether result is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -507,7 +507,7 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp, bool resultTypeNormalized = false; for (unsigned resIndex : llvm::seq(0, oldOp->getNumResults())) { auto resultType = oldOp->getResult(resIndex).getType(); - MemRefType memrefType = resultType.dyn_cast(); + MemRefType memrefType = dyn_cast(resultType); // Check whether the operation result is MemRef type. if (!memrefType) { resultTypes.push_back(resultType); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 8c544bb..526c1c6 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -40,7 +40,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast(); + OpResult dimValue = dyn_cast(dimOp.getSource()); if (!dimValue) return failure(); auto shapedTypeOp = @@ -61,8 +61,8 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern { return failure(); Value resultShape = reifiedResultShapes[dimValue.getResultNumber()]; - auto resultShapeType = resultShape.getType().dyn_cast(); - if (!resultShapeType || !resultShapeType.getElementType().isa()) + auto resultShapeType = dyn_cast(resultShape.getType()); + if (!resultShapeType || !isa(resultShapeType.getElementType())) return failure(); Location loc = dimOp->getLoc(); @@ -82,7 +82,7 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast(); + OpResult dimValue = dyn_cast(dimOp.getSource()); if (!dimValue) return failure(); std::optional dimIndex = dimOp.getConstantIndex(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 9ffb315..05a069d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -38,14 +38,14 @@ struct CastOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto castOp = cast(op); - auto srcType = castOp.getSource().getType().cast(); + auto srcType = cast(castOp.getSource().getType()); // Nothing to check if the result is an unranked memref. - auto resultType = castOp.getType().dyn_cast(); + auto resultType = dyn_cast(castOp.getType()); if (!resultType) return; - if (srcType.isa()) { + if (isa(srcType)) { // Check rank. Value srcRank = builder.create(loc, castOp.getSource()); Value resultRank = @@ -75,7 +75,7 @@ struct CastOpInterface // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. - if (auto rankedSrcType = srcType.dyn_cast()) + if (auto rankedSrcType = dyn_cast(srcType)) if (!rankedSrcType.isDynamicDim(it.index())) continue; diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp index 292738d..b9dd174 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -42,7 +42,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern { Location location = op->getLoc(); if (op->hasAttr(op.getTf32EnabledAttrName()) || - !op.getMatrixA().getType().cast().getElementType().isF32()) + !cast(op.getMatrixA().getType()).getElementType().isF32()) return failure(); if (precision == MmaSyncF32Lowering::Unkown) diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp index 07e9ae9..486c7868 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp @@ -180,7 +180,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, mlir::LogicalResult mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue) { - auto memRefType = memrefValue.getType().dyn_cast(); + auto memRefType = dyn_cast(memrefValue.getType()); if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) return failure(); diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index 7525f9f..5a0018c 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -63,7 +63,7 @@ FailureOr nvgpu::getWarpMatrixInfo(Operation *op) { info.vectorType = writeOp.getVectorType(); } else if (isa(op)) { - info.vectorType = op->getResult(0).getType().cast(); + info.vectorType = cast(op->getResult(0).getType()); } else { return op->emitError() << "unhandled operation type in nvgpu.mma.sync conversion path"; diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index ddd8ae0..5ee53ea 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -14,13 +14,13 @@ using namespace mlir; using namespace mlir::quant; static bool isQuantizablePrimitiveType(Type inputType) { - return inputType.isa(); + return isa(inputType); } ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType(Type inputType) { - if (inputType.isa()) { - Type elementType = inputType.cast().getElementType(); + if (isa(inputType)) { + Type elementType = cast(inputType).getElementType(); if (!isQuantizablePrimitiveType(elementType)) return ExpressedToQuantizedConverter{inputType, nullptr}; return ExpressedToQuantizedConverter{inputType, elementType}; @@ -34,11 +34,11 @@ ExpressedToQuantizedConverter::forInputType(Type inputType) { Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { assert(expressedType && "convert() on unsupported conversion"); - if (auto tensorType = inputType.dyn_cast()) + if (auto tensorType = dyn_cast(inputType)) return RankedTensorType::get(tensorType.getShape(), elementalType); - if (auto tensorType = inputType.dyn_cast()) + if (auto tensorType = dyn_cast(inputType)) return UnrankedTensorType::get(elementalType); - if (auto vectorType = inputType.dyn_cast()) + if (auto vectorType = dyn_cast(inputType)) return VectorType::get(vectorType.getShape(), elementalType); // If the expressed types match, just use the new elemental type. @@ -50,7 +50,7 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { ElementsAttr UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) { - if (auto attr = realValue.dyn_cast()) { + if (auto attr = dyn_cast(realValue)) { return convert(attr); } // TODO: handles sparse elements attribute diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 18425de..2da7473 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -49,7 +49,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results, } parents.insert(loop); } - results.set(getResult().cast(), parents.getArrayRef()); + results.set(cast(getResult()), parents.getArrayRef()); return DiagnosedSilenceableFailure::success(); } @@ -116,8 +116,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results, functions.push_back(*outlined); calls.push_back(call); } - results.set(getFunction().cast(), functions); - results.set(getCall().cast(), calls); + results.set(cast(getFunction()), functions); + results.set(cast(getCall()), calls); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 13f0d76..ad395a9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -30,8 +30,8 @@ namespace { /// Helper function for loop bufferization. Cast the given buffer to the given /// memref type. static Value castBuffer(OpBuilder &b, Value buffer, Type type) { - assert(type.isa() && "expected BaseMemRefType"); - assert(buffer.getType().isa() && "expected BaseMemRefType"); + assert(isa(type) && "expected BaseMemRefType"); + assert(isa(buffer.getType()) && "expected BaseMemRefType"); // If the buffer already has the correct type, no cast is needed. if (buffer.getType() == type) return buffer; @@ -78,7 +78,7 @@ struct ConditionOpInterface SmallVector newArgs; for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Value value = it.value(); - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); @@ -141,7 +141,7 @@ struct ExecuteRegionOpInterface rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { - if (it.value().isa()) { + if (isa(it.value())) { newResults.push_back(rewriter.create( executeRegionOp.getLoc(), newOp->getResult(it.index()))); } else { @@ -183,7 +183,7 @@ struct IfOpInterface // Compute bufferized result types. SmallVector newTypes; for (Value result : ifOp.getResults()) { - if (!result.getType().isa()) { + if (!isa(result.getType())) { newTypes.push_back(result.getType()); continue; } @@ -218,13 +218,13 @@ struct IfOpInterface assert(value.getDefiningOp() == op && "invalid valid"); // Determine buffer types of the true/false branches. - auto opResult = value.cast(); + auto opResult = cast(value); auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); BaseMemRefType thenBufferType, elseBufferType; - if (thenValue.getType().isa()) { + if (isa(thenValue.getType())) { // True branch was already bufferized. - thenBufferType = thenValue.getType().cast(); + thenBufferType = cast(thenValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(thenValue, options, fixedTypes); @@ -232,9 +232,9 @@ struct IfOpInterface return failure(); thenBufferType = *maybeBufferType; } - if (elseValue.getType().isa()) { + if (isa(elseValue.getType())) { // False branch was already bufferized. - elseBufferType = elseValue.getType().cast(); + elseBufferType = cast(elseValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(elseValue, options, fixedTypes); @@ -253,7 +253,7 @@ struct IfOpInterface // Layout maps are different: Promote to fully dynamic layout map. return getMemRefTypeWithFullyDynamicLayout( - opResult.getType().cast(), thenBufferType.getMemorySpace()); + cast(opResult.getType()), thenBufferType.getMemorySpace()); } }; @@ -262,7 +262,7 @@ struct IfOpInterface static DenseSet getTensorIndices(ValueRange values) { DenseSet result; for (const auto &it : llvm::enumerate(values)) - if (it.value().getType().isa()) + if (isa(it.value().getType())) result.insert(it.index()); return result; } @@ -275,8 +275,8 @@ DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; for (unsigned int i = 0; i < minSize; ++i) { - if (!bbArgs[i].getType().isa() || - !yieldedValues[i].getType().isa()) + if (!isa(bbArgs[i].getType()) || + !isa(yieldedValues[i].getType())) continue; if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) result.insert(i); @@ -291,7 +291,7 @@ getBuffers(RewriterBase &rewriter, MutableArrayRef operands, const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { - if (opOperand.get().getType().isa()) { + if (isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options); if (failed(resultBuffer)) @@ -361,9 +361,9 @@ static FailureOr computeLoopRegionIterArgBufferType( // Compute the buffer type of the yielded value. BaseMemRefType yieldedValueBufferType; - if (yieldedValue.getType().isa()) { + if (isa(yieldedValue.getType())) { // scf.yield was already bufferized. - yieldedValueBufferType = yieldedValue.getType().cast(); + yieldedValueBufferType = cast(yieldedValue.getType()); } else { auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, newFixedTypes); @@ -379,7 +379,7 @@ static FailureOr computeLoopRegionIterArgBufferType( // If there is a mismatch between the yielded buffer type and the iter_arg // buffer type, the buffer type must be promoted to a fully dynamic layout // map. - auto yieldedRanked = yieldedValueBufferType.cast(); + auto yieldedRanked = cast(yieldedValueBufferType); #ifndef NDEBUG auto iterRanked = initArgBufferType->cast(); assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) && @@ -388,7 +388,7 @@ static FailureOr computeLoopRegionIterArgBufferType( "expected same memory space"); #endif // NDEBUG return getMemRefTypeWithFullyDynamicLayout( - iterArg.getType().cast(), + cast(iterArg.getType()), yieldedRanked.getMemorySpace()); } @@ -516,16 +516,16 @@ struct ForOpInterface const DenseMap &fixedTypes) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(value.getType().isa() && "expected tensor type"); + assert(isa(value.getType()) && "expected tensor type"); // Get result/argument number. unsigned resultNum; - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = dyn_cast(value)) { resultNum = forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg)) .getResultNumber(); } else { - resultNum = value.cast().getResultNumber(); + resultNum = cast(value).getResultNumber(); } // Compute the bufferized type. @@ -560,7 +560,7 @@ struct ForOpInterface Value initArg = it.value(); Value result = forOp->getResult(it.index()); // If the type is not a tensor, bufferization doesn't need to touch it. - if (!result.getType().isa()) { + if (!isa(result.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -611,7 +611,7 @@ struct ForOpInterface auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); for (OpResult opResult : op->getOpResults()) { - if (!opResult.getType().isa()) + if (!isa(opResult.getType())) continue; // Note: This is overly strict. We should check for aliasing bufferized @@ -736,7 +736,7 @@ struct WhileOpInterface for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; - if (!value.getType().isa() || + if (!isa(value.getType()) || (equivalentYieldsAfter.contains(idx) && equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); @@ -786,7 +786,7 @@ struct WhileOpInterface Value initArg = it.value(); Value beforeArg = whileOp.getBeforeArguments()[it.index()]; // If the type is not a tensor, bufferization doesn't need to touch it. - if (!beforeArg.getType().isa()) { + if (!isa(beforeArg.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -799,7 +799,7 @@ struct WhileOpInterface // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - if (!bbArg.getType().isa()) + if (!isa(bbArg.getType())) return bbArg.getType(); // TODO: error handling return bufferization::getBufferType(bbArg, options)->cast(); @@ -848,10 +848,10 @@ struct WhileOpInterface const DenseMap &fixedTypes) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(value.getType().isa() && "expected tensor type"); + assert(isa(value.getType()) && "expected tensor type"); // Case 1: Block argument of the "before" region. - if (auto bbArg = value.dyn_cast()) { + if (auto bbArg = dyn_cast(value)) { if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; auto yieldOp = whileOp.getYieldOp(); @@ -865,18 +865,18 @@ struct WhileOpInterface // The bufferized "after" bbArg type can be directly computed from the // bufferized "before" bbArg type. unsigned resultNum; - if (auto opResult = value.dyn_cast()) { + if (auto opResult = dyn_cast(value)) { resultNum = opResult.getResultNumber(); - } else if (value.cast().getOwner()->getParent() == + } else if (cast(value).getOwner()->getParent() == &whileOp.getAfter()) { - resultNum = value.cast().getArgNumber(); + resultNum = cast(value).getArgNumber(); } else { llvm_unreachable("invalid value"); } Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; - if (!conditionYieldedVal.getType().isa()) { + if (!isa(conditionYieldedVal.getType())) { // scf.condition was already bufferized. - return conditionYieldedVal.getType().cast(); + return cast(conditionYieldedVal.getType()); } return bufferization::getBufferType(conditionYieldedVal, options, fixedTypes); @@ -902,7 +902,7 @@ struct WhileOpInterface auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { - if (!it.value().getType().isa()) + if (!isa(it.value().getType())) continue; if (!state.areEquivalentBufferizedValues( it.value(), conditionOp->getBlock()->getArgument(it.index()))) @@ -913,7 +913,7 @@ struct WhileOpInterface auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { - if (!it.value().getType().isa()) + if (!isa(it.value().getType())) continue; if (!state.areEquivalentBufferizedValues( it.value(), yieldOp->getBlock()->getArgument(it.index()))) @@ -971,7 +971,7 @@ struct YieldOpInterface SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); @@ -1110,7 +1110,7 @@ struct ForallOpInterface const DenseMap &fixedTypes) const { auto forallOp = cast(op); - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. return bufferization::getBufferType( @@ -1119,8 +1119,8 @@ struct ForallOpInterface // The bufferized result type is the same as the bufferized type of the // corresponding output operand. return bufferization::getBufferType( - forallOp.getOutputs()[value.cast().getResultNumber()], - options, fixedTypes); + forallOp.getOutputs()[cast(value).getResultNumber()], options, + fixedTypes); } bool isRepetitiveRegion(Operation *op, unsigned index) const { diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index 2450a0e..9959149 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -43,7 +43,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) { while (value) { if (value == forOp.getRegionIterArgs()[arg]) return true; - OpResult opResult = value.dyn_cast(); + OpResult opResult = dyn_cast(value); if (!opResult) return false; @@ -91,7 +91,7 @@ struct DimOfIterArgFolder : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - auto blockArg = dimOp.getSource().template dyn_cast(); + auto blockArg = dyn_cast(dimOp.getSource()); if (!blockArg) return failure(); auto forOp = dyn_cast(blockArg.getParentBlock()->getParentOp()); @@ -139,7 +139,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern { auto forOp = dimOp.getSource().template getDefiningOp(); if (!forOp) return failure(); - auto opResult = dimOp.getSource().template cast(); + auto opResult = cast(dimOp.getSource()); unsigned resultNumber = opResult.getResultNumber(); if (!isShapePreserving(forOp, resultNumber)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 6a9f725..a85985b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -164,8 +164,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op, clone->walk([&](Operation *nested) { for (OpOperand &operand : nested->getOpOperands()) { Operation *def = operand.get().getDefiningOp(); - if ((def && !clone->isAncestor(def)) || - operand.get().isa()) + if ((def && !clone->isAncestor(def)) || isa(operand.get())) callback(&operand); } }); @@ -346,7 +345,7 @@ void LoopPipelinerInternal::createKernel( rewriter.setInsertionPointAfter(newOp); continue; } - auto arg = operand->get().dyn_cast(); + auto arg = dyn_cast(operand->get()); if (arg && arg.getOwner() == forOp.getBody()) { // If the value is a loop carried value coming from stage N + 1 remap, // it will become a direct use. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 131e821..224bec3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -496,7 +496,7 @@ getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef loops) { std::optional destinationIterArg; auto loopIt = loops.rbegin(); - while (auto iterArg = source->get().dyn_cast()) { + while (auto iterArg = dyn_cast(source->get())) { scf::ForOp loop = *loopIt; if (iterArg.getOwner()->getParentOp() != loop) break; @@ -505,7 +505,7 @@ getUntiledProducerFromSliceSource(OpOperand *source, } if (loopIt == loops.rend()) destinationIterArg = source; - return {source->get().dyn_cast(), destinationIterArg}; + return {dyn_cast(source->get()), destinationIterArg}; } /// Implementation of fusing producer of a single slice by computing the diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp index f154840..c22cb67 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -42,8 +42,8 @@ public: PatternRewriter &rewriter) const override { SmallVector globalVarAttrs; - auto ptrType = op.getType().cast(); - auto pointeeType = ptrType.getPointeeType().cast(); + auto ptrType = cast(op.getType()); + auto pointeeType = cast(ptrType.getPointeeType()); spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType); if (!structType) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index c0ab215..9f2755da0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -51,19 +51,19 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, // info create a variable of type !spirv.ptr>. If // not it must already be a !spirv.ptr>. auto varType = funcOp.getFunctionType().getInput(argIndex); - if (varType.cast().isScalarOrVector()) { + if (cast(varType).isScalarOrVector()) { auto storageClass = abiInfo.getStorageClass(); if (!storageClass) return nullptr; varType = spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); } - auto varPtrType = varType.cast(); - auto varPointeeType = varPtrType.getPointeeType().cast(); + auto varPtrType = cast(varType); + auto varPointeeType = cast(varPtrType.getPointeeType()); // Set the offset information. varPointeeType = - VulkanLayoutUtils::decorateType(varPointeeType).cast(); + cast(VulkanLayoutUtils::decorateType(varPointeeType)); if (!varPointeeType) return nullptr; @@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp, // Starting with version 1.4, the interface’s storage classes are all // storage classes used in declaring all global variables referenced by the // entry point’s call tree." We should consider the target environment here. - switch (var.getType().cast().getStorageClass()) { + switch (cast(var.getType()).getStorageClass()) { case spirv::StorageClass::Input: case spirv::StorageClass::Output: interfaceVarSet.insert(var.getOperation()); @@ -247,7 +247,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( // at the start of the function. It is probably better to do the load just // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. - if (argType.value().cast().isScalarOrVector()) { + if (cast(argType.value()).isScalarOrVector()) { auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); auto loadPtr = rewriter.create( @@ -287,7 +287,7 @@ void LowerABIAttributesPass::runOnOperation() { typeConverter.addSourceMaterialization([](OpBuilder &builder, spirv::PointerType type, ValueRange inputs, Location loc) { - if (inputs.size() != 1 || !inputs[0].getType().isa()) + if (inputs.size() != 1 || !isa(inputs[0].getType())) return Value(); return builder.create(loc, type, inputs[0]).getResult(); }); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp index 51c36bd..f38282f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -84,15 +84,13 @@ void RewriteInsertsPass::runOnOperation() { LogicalResult RewriteInsertsPass::collectInsertionChain( spirv::CompositeInsertOp op, SmallVectorImpl &insertions) { - auto indicesArrayAttr = op.getIndices().cast(); + auto indicesArrayAttr = cast(op.getIndices()); // TODO: handle nested composite object. if (indicesArrayAttr.size() == 1) { - auto numElements = op.getComposite() - .getType() - .cast() + auto numElements = cast(op.getComposite().getType()) .getNumElements(); - auto index = indicesArrayAttr[0].cast().getInt(); + auto index = cast(indicesArrayAttr[0]).getInt(); // Need a last index to collect a sequential chain. if (index + 1 != numElements) return failure(); @@ -109,9 +107,9 @@ LogicalResult RewriteInsertsPass::collectInsertionChain( return failure(); --index; - indicesArrayAttr = op.getIndices().cast(); + indicesArrayAttr = cast(op.getIndices()); if ((indicesArrayAttr.size() != 1) || - (indicesArrayAttr[0].cast().getInt() != index)) + (cast(indicesArrayAttr[0]).getInt() != index)) return failure(); } } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 5a5cdfe..793b025 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -139,7 +139,7 @@ bool SPIRVTypeConverter::allows(spirv::Capability capability) { // SPIR-V dialect. Keeping it local till the use case arises. static std::optional getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { - if (type.isa()) { + if (isa(type)) { auto bitWidth = type.getIntOrFloatBitWidth(); // According to the SPIR-V spec: // "There is no physical size or bit pattern defined for values with boolean @@ -152,21 +152,21 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } - if (auto complexType = type.dyn_cast()) { + if (auto complexType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) return std::nullopt; return 2 * *elementSize; } - if (auto vecType = type.dyn_cast()) { + if (auto vecType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, vecType.getElementType()); if (!elementSize) return std::nullopt; return vecType.getNumElements() * *elementSize; } - if (auto memRefType = type.dyn_cast()) { + if (auto memRefType = dyn_cast(type)) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. int64_t offset; @@ -198,7 +198,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return (offset + memrefSize) * *elementSize; } - if (auto tensorType = type.dyn_cast()) { + if (auto tensorType = dyn_cast(type)) { if (!tensorType.hasStaticShape()) return std::nullopt; @@ -246,12 +246,12 @@ convertScalarType(const spirv::TargetEnv &targetEnv, return nullptr; } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = dyn_cast(type)) { LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return Builder(targetEnv.getContext()).getF32Type(); } - auto intType = type.cast(); + auto intType = cast(type); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(targetEnv.getContext(), /*width=*/32, intType.getSignedness()); @@ -319,8 +319,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv, // Get extension and capability requirements for the given type. SmallVector, 1> extensions; SmallVector, 2> capabilities; - type.cast().getExtensions(extensions, storageClass); - type.cast().getCapabilities(capabilities, storageClass); + cast(type).getExtensions(extensions, storageClass); + cast(type).getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && @@ -415,8 +415,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, << "using non-8-bit storage for bool types unimplemented"); return nullptr; } - auto elementType = IntegerType::get(type.getContext(), numBoolBits) - .dyn_cast(); + auto elementType = dyn_cast( + IntegerType::get(type.getContext(), numBoolBits)); if (!elementType) return nullptr; Type arrayElemType = @@ -487,7 +487,7 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type) { - auto attr = type.getMemorySpace().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(type.getMemorySpace()); if (!attr) { LLVM_DEBUG( llvm::dbgs() @@ -499,7 +499,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } spirv::StorageClass storageClass = attr.getValue(); - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { if (type.getElementTypeBitWidth() == 1) return convertBoolMemrefType(targetEnv, options, type, storageClass); if (type.getElementTypeBitWidth() < 8) @@ -508,17 +508,17 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, Type arrayElemType; Type elementType = type.getElementType(); - if (auto vecType = elementType.dyn_cast()) { + if (auto vecType = dyn_cast(elementType)) { arrayElemType = convertVectorType(targetEnv, options, vecType, storageClass); - } else if (auto complexType = elementType.dyn_cast()) { + } else if (auto complexType = dyn_cast(elementType)) { arrayElemType = convertComplexType(targetEnv, options, complexType, storageClass); - } else if (auto scalarType = elementType.dyn_cast()) { + } else if (auto scalarType = dyn_cast(elementType)) { arrayElemType = convertScalarType(targetEnv, options, scalarType, storageClass); - } else if (auto indexType = elementType.dyn_cast()) { - type = convertIndexElementType(type, options).cast(); + } else if (auto indexType = dyn_cast(elementType)) { + type = cast(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); } else { LLVM_DEBUG( @@ -583,7 +583,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); addConversion([this](IntegerType intType) -> std::optional { - if (auto scalarType = intType.dyn_cast()) + if (auto scalarType = dyn_cast(intType)) return convertScalarType(this->targetEnv, this->options, scalarType); if (intType.getWidth() < 8) return convertSubByteIntegerType(this->options, intType); @@ -591,7 +591,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, }); addConversion([this](FloatType floatType) -> std::optional { - if (auto scalarType = floatType.dyn_cast()) + if (auto scalarType = dyn_cast(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); return Type(); }); @@ -784,7 +784,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount, static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount) { for (auto varOp : body.getOps()) { - auto ptrType = varOp.getType().dyn_cast(); + auto ptrType = dyn_cast(varOp.getType()); if (!ptrType) continue; @@ -792,10 +792,9 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body, // block statically used per shader entry point." So we should always reuse // the existing one. if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { - auto numElements = ptrType.getPointeeType() - .cast() - .getElementType(0) - .cast() + auto numElements = cast( + cast(ptrType.getPointeeType()) + .getElementType(0)) .getNumElements(); if (numElements == elementCount) return varOp; @@ -926,8 +925,8 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, linearizeIndex(indices, strides, offset, indexType, loc, builder); } Type pointeeType = - basePtr.getType().cast().getPointeeType(); - if (pointeeType.isa()) { + cast(basePtr.getType()).getPointeeType(); + if (isa(pointeeType)) { linearizedIndices.push_back(linearIndex); return builder.create(loc, basePtr, linearizedIndices); @@ -1015,7 +1014,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { // Ensure that all types have been converted to SPIRV types. if (llvm::any_of(valueTypes, - [](Type t) { return !t.isa(); })) + [](Type t) { return !isa(t); })) return false; // Special treatment for global variables, whose type requirements are @@ -1029,13 +1028,13 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); - valueType.cast().getExtensions(typeExtensions); + cast(valueType).getExtensions(typeExtensions); if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, typeExtensions))) return false; typeCapabilities.clear(); - valueType.cast().getCapabilities(typeCapabilities); + cast(valueType).getCapabilities(typeCapabilities); if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, typeCapabilities))) return false; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index 3cd4937..44fea86 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -41,7 +41,7 @@ namespace { //===----------------------------------------------------------------------===// Attribute getScalarOrSplatAttr(Type type, int64_t value) { APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); - if (auto intTy = type.dyn_cast()) + if (auto intTy = dyn_cast(type)) return IntegerAttr::get(intTy, sizedValue); return SplatElementsAttr::get(cast(type), sizedValue); @@ -149,7 +149,7 @@ struct ExpandMulExtendedPattern final : OpRewritePattern { // Currently, WGSL only supports 32-bit integer types. Any other integer // types should already have been promoted/demoted to i32. - auto elemTy = getElementTypeOrSelf(lhs.getType()).cast(); + auto elemTy = cast(getElementTypeOrSelf(lhs.getType())); if (elemTy.getIntOrFloatBitWidth() != 32) return rewriter.notifyMatchFailure( loc, diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp index 97f16d1..ea856c7 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -65,16 +65,16 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) { /// `!spirv.ptr>>`. Returns null type /// otherwise. static Type getRuntimeArrayElementType(Type type) { - auto ptrType = type.dyn_cast(); + auto ptrType = dyn_cast(type); if (!ptrType) return {}; - auto structType = ptrType.getPointeeType().dyn_cast(); + auto structType = dyn_cast(ptrType.getPointeeType()); if (!structType || structType.getNumElements() != 1) return {}; auto rtArrayType = - structType.getElementType(0).dyn_cast(); + dyn_cast(structType.getElementType(0)); if (!rtArrayType) return {}; @@ -97,7 +97,7 @@ deduceCanonicalResource(ArrayRef types) { for (const auto &indexedTypes : llvm::enumerate(types)) { spirv::SPIRVType type = indexedTypes.value(); assert(type.isScalarOrVector()); - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { if (vectorType.getNumElements() % 2 != 0) return std::nullopt; // Odd-sized vector has special layout // requirements. @@ -277,7 +277,7 @@ void ResourceAliasAnalysis::recordIfUnifiable( if (!elementType) return; // Unexpected resource variable type. - auto type = elementType.cast(); + auto type = cast(elementType); if (!type.isScalarOrVector()) return; // Unexpected resource element type. @@ -370,7 +370,7 @@ struct ConvertAccessChain : public ConvertAliasResource { Location loc = acOp.getLoc(); - if (srcElemType.isIntOrFloat() && dstElemType.isa()) { + if (srcElemType.isIntOrFloat() && isa(dstElemType)) { // The source indices are for a buffer with scalar element types. Rewrite // them into a buffer with vector element types. We need to scale the last // index for the vector as a whole, then add one level of index for inside @@ -398,7 +398,7 @@ struct ConvertAccessChain : public ConvertAliasResource { } if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || - (srcElemType.isa() && dstElemType.isa())) { + (isa(srcElemType) && isa(dstElemType))) { // The source indices are for a buffer with larger bitwidth scalar/vector // element types. Rewrite them into a buffer with smaller bitwidth element // types. We only need to scale the last index. @@ -433,10 +433,10 @@ struct ConvertLoad : public ConvertAliasResource { LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcPtrType = loadOp.getPtr().getType().cast(); - auto srcElemType = srcPtrType.getPointeeType().cast(); - auto dstPtrType = adaptor.getPtr().getType().cast(); - auto dstElemType = dstPtrType.getPointeeType().cast(); + auto srcPtrType = cast(loadOp.getPtr().getType()); + auto srcElemType = cast(srcPtrType.getPointeeType()); + auto dstPtrType = cast(adaptor.getPtr().getType()); + auto dstElemType = cast(dstPtrType.getPointeeType()); Location loc = loadOp.getLoc(); auto newLoadOp = rewriter.create(loc, adaptor.getPtr()); @@ -454,7 +454,7 @@ struct ConvertLoad : public ConvertAliasResource { } if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || - (srcElemType.isa() && dstElemType.isa())) { + (isa(srcElemType) && isa(dstElemType))) { // The source and destination have scalar types of different bitwidths, or // vector types of different component counts. For such cases, we load // multiple smaller bitwidth values and construct a larger bitwidth one. @@ -495,13 +495,13 @@ struct ConvertLoad : public ConvertAliasResource { // type. Type vectorType = srcElemType; - if (!srcElemType.isa()) + if (!isa(srcElemType)) vectorType = VectorType::get({ratio}, dstElemType); // If both the source and destination are vector types, we need to make // sure the scalar type is the same for composite construction later. - if (auto srcElemVecType = srcElemType.dyn_cast()) - if (auto dstElemVecType = dstElemType.dyn_cast()) { + if (auto srcElemVecType = dyn_cast(srcElemType)) + if (auto dstElemVecType = dyn_cast(dstElemType)) { if (srcElemVecType.getElementType() != dstElemVecType.getElementType()) { int64_t count = @@ -515,7 +515,7 @@ struct ConvertLoad : public ConvertAliasResource { Value vectorValue = rewriter.create( loc, vectorType, components); - if (!srcElemType.isa()) + if (!isa(srcElemType)) vectorValue = rewriter.create(loc, srcElemType, vectorValue); rewriter.replaceOp(loadOp, vectorValue); @@ -534,9 +534,9 @@ struct ConvertStore : public ConvertAliasResource { matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcElemType = - storeOp.getPtr().getType().cast().getPointeeType(); + cast(storeOp.getPtr().getType()).getPointeeType(); auto dstElemType = - adaptor.getPtr().getType().cast().getPointeeType(); + cast(adaptor.getPtr().getType()).getPointeeType(); if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) return rewriter.notifyMatchFailure(storeOp, "not scalar type"); if (!areSameBitwidthScalarType(srcElemType, dstElemType)) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6e09a84..095db6b 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -159,13 +159,13 @@ void UpdateVCEPass::runOnOperation() { SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); - valueType.cast().getExtensions(typeExtensions); + cast(valueType).getExtensions(typeExtensions); if (failed(checkAndUpdateExtensionRequirements( op, targetEnv, typeExtensions, deducedExtensions))) return WalkResult::interrupt(); typeCapabilities.clear(); - valueType.cast().getCapabilities(typeCapabilities); + cast(valueType).getCapabilities(typeCapabilities); if (failed(checkAndUpdateCapabilityRequirements( op, targetEnv, typeCapabilities, deducedCapabilities))) return WalkResult::interrupt(); diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp index 67d61f8..b19495b 100644 --- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp @@ -53,7 +53,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType, // must be a runtime array. assert(memberSize != std::numeric_limits().max() || (i + 1 == e && - structType.getElementType(i).isa())); + isa(structType.getElementType(i)))); // According to the Vulkan spec: // "A structure has a base alignment equal to the largest base alignment of // any of its members." @@ -79,23 +79,23 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType, Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size, VulkanLayoutUtils::Size &alignment) { - if (type.isa()) { + if (isa(type)) { alignment = getScalarTypeAlignment(type); // Vulkan spec does not specify any padding for a scalar type. size = alignment; return type; } - if (auto structType = type.dyn_cast()) + if (auto structType = dyn_cast(type)) return decorateType(structType, size, alignment); - if (auto arrayType = type.dyn_cast()) + if (auto arrayType = dyn_cast(type)) return decorateType(arrayType, size, alignment); - if (auto vectorType = type.dyn_cast()) + if (auto vectorType = dyn_cast(type)) return decorateType(vectorType, size, alignment); - if (auto arrayType = type.dyn_cast()) { + if (auto arrayType = dyn_cast(type)) { size = std::numeric_limits().max(); return decorateType(arrayType, alignment); } - if (type.isa()) { + if (isa(type)) { // TODO: Add support for `PhysicalStorageBufferAddresses`. return nullptr; } @@ -161,13 +161,13 @@ VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) { } bool VulkanLayoutUtils::isLegalType(Type type) { - auto ptrType = type.dyn_cast(); + auto ptrType = dyn_cast(type); if (!ptrType) { return true; } auto storageClass = ptrType.getStorageClass(); - auto structType = ptrType.getPointeeType().dyn_cast(); + auto structType = dyn_cast(ptrType.getPointeeType()); if (!structType) { return true; } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index fc67fea..4a567f4 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -64,7 +64,7 @@ struct AssumingOpInterface rewriter.setInsertionPointAfter(newOp); SmallVector newResults; for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { - if (it.value().isa()) { + if (isa(it.value())) { newResults.push_back(rewriter.create( assumingOp.getLoc(), newOp->getResult(it.index()))); } else { @@ -116,7 +116,7 @@ struct AssumingYieldOpInterface auto yieldOp = cast(op); SmallVector newResults; for (Value value : yieldOp.getOperands()) { - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr buffer = getBuffer(rewriter, value, options); if (failed(buffer)) return failure(); diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp index f23a090..1a6f868 100644 --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -133,7 +133,7 @@ void constructShapeFunc( for (shape::WithOp withOp : allWithOps) { Value value = withOp.getOperand(); Value shape = withOp.getShape(); - RankedTensorType rankedType = value.getType().dyn_cast(); + RankedTensorType rankedType = dyn_cast(value.getType()); if (rankedType == nullptr) continue; diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index 99a619c..990f8f7 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -41,7 +41,7 @@ getBufferizationOptions(bool analysisOnly) { options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( - value.getType().cast(), memorySpace); + cast(value.getType()), memorySpace); }; if (analysisOnly) { options.testAnalysisOnly = true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 6fd55c7..ace8a88 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -28,7 +28,7 @@ using namespace mlir::sparse_tensor; static std::optional> genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) { if (auto constOp = tensor.getDefiningOp()) { - if (auto a = constOp.getValue().dyn_cast()) { + if (auto a = dyn_cast(constOp.getValue())) { auto coordinates = builder.create(loc, a.getIndices()); auto values = builder.create(loc, a.getValues()); return std::make_pair(coordinates, values); @@ -94,7 +94,7 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { if (tp.isIndex()) return OverheadType::kIndex; - if (auto intTp = tp.dyn_cast()) + if (auto intTp = dyn_cast(tp)) return overheadTypeEncoding(intTp.getWidth()); llvm_unreachable("Unknown overhead type"); } @@ -169,7 +169,7 @@ PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { return PrimaryType::kI16; if (elemTp.isInteger(8)) return PrimaryType::kI8; - if (auto complexTp = elemTp.dyn_cast()) { + if (auto complexTp = dyn_cast(elemTp)) { auto complexEltTp = complexTp.getElementType(); if (complexEltTp.isF64()) return PrimaryType::kC64; @@ -205,10 +205,10 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, return value; // int <=> index - if (srcTp.isa() || dstTp.isa()) + if (isa(srcTp) || isa(dstTp)) return builder.create(loc, dstTp, value); - const auto srcIntTp = srcTp.dyn_cast_or_null(); + const auto srcIntTp = dyn_cast_or_null(srcTp); const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); } @@ -216,7 +216,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s) { Value load = builder.create(loc, mem, s); - if (!load.getType().isa()) { + if (!isa(load.getType())) { if (load.getType().getIntOrFloatBitWidth() < 64) load = builder.create(loc, builder.getI64Type(), load); load = @@ -226,14 +226,14 @@ Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, } mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { - if (tp.isa()) + if (isa(tp)) return builder.getFloatAttr(tp, 1.0); - if (tp.isa()) + if (isa(tp)) return builder.getIndexAttr(1); - if (auto intTp = tp.dyn_cast()) + if (auto intTp = dyn_cast(tp)) return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); - if (tp.isa()) { - auto shapedTp = tp.cast(); + if (isa(tp)) { + auto shapedTp = cast(tp); if (auto one = getOneAttr(builder, shapedTp.getElementType())) return DenseElementsAttr::get(shapedTp, one); } @@ -244,13 +244,13 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, Value v) { Type tp = v.getType(); Value zero = constantZero(builder, loc, tp); - if (tp.isa()) + if (isa(tp)) return builder.create(loc, arith::CmpFPredicate::UNE, v, zero); if (tp.isIntOrIndex()) return builder.create(loc, arith::CmpIPredicate::ne, v, zero); - if (tp.dyn_cast()) + if (dyn_cast(tp)) return builder.create(loc, v, zero); llvm_unreachable("Non-numeric type"); } @@ -580,12 +580,12 @@ void sparse_tensor::foreachInSparseConstant( } // Remap value. Value val; - if (attr.getElementType().isa()) { - auto valAttr = elems[i].second.cast(); + if (isa(attr.getElementType())) { + auto valAttr = cast(elems[i].second); val = builder.create(loc, attr.getElementType(), valAttr); } else { - auto valAttr = elems[i].second.cast(); + auto valAttr = cast(elems[i].second); val = builder.create(loc, valAttr); } assert(val); @@ -597,7 +597,7 @@ SmallVector sparse_tensor::loadAll(OpBuilder &builder, Location loc, size_t size, Value mem, size_t offsetIdx, Value offsetVal) { #ifndef NDEBUG - const auto memTp = mem.getType().cast(); + const auto memTp = cast(mem.getType()); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(size)); @@ -619,7 +619,7 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs, size_t offsetIdx, Value offsetVal) { #ifndef NDEBUG const size_t vsize = vs.size(); - const auto memTp = mem.getType().cast(); + const auto memTp = cast(mem.getType()); assert(memTp.getRank() == 1); const DynSize memSh = memTp.getDimSize(0); assert(ShapedType::isDynamic(memSh) || memSh >= static_cast(vsize)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index e04475ea..9e76289 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -260,7 +260,7 @@ Value reshapeValuesToLevels(OpBuilder &builder, Location loc, /// `IntegerType`), this also works for `RankedTensorType` and `VectorType` /// (for which it generates a constant `DenseElementsAttr` of zeros). inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { - if (auto ctp = tp.dyn_cast()) { + if (auto ctp = dyn_cast(tp)) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto zeroa = builder.getArrayAttr({zeroe, zeroe}); return builder.create(loc, tp, zeroa); @@ -271,7 +271,7 @@ inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { /// Generates a 1-valued constant of the given type. This supports all /// the same types as `constantZero`. inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { - if (auto ctp = tp.dyn_cast()) { + if (auto ctp = dyn_cast(tp)) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); auto onee = getOneAttr(builder, ctp.getElementType()); auto zeroa = builder.getArrayAttr({onee, zeroe}); @@ -350,7 +350,7 @@ inline Value constantDimLevelTypeEncoding(OpBuilder &builder, Location loc, } inline bool isZeroRankedTensorOrScalar(Type type) { - auto rtp = type.dyn_cast(); + auto rtp = dyn_cast(type); return !rtp || rtp.getRank() == 0; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index 731a1a9..d61e545 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -350,7 +350,7 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, // on positions. for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) { const Value tensor = tensors[t]; - const auto rtp = tensor.getType().dyn_cast(); + const auto rtp = dyn_cast(tensor.getType()); if (!rtp) // Skips only scalar, zero ranked tensor still need to be bufferized and // (probably) filled with zeros by users. @@ -432,7 +432,7 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, Type indexType = builder.getIndexType(); Value c0 = constantZero(builder, loc, indexType); for (TensorId t = 0, e = tensors.size(); t < e; t++) { - auto rtp = tensors[t].getType().dyn_cast(); + auto rtp = dyn_cast(tensors[t].getType()); if (!rtp) continue; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h index 67a3f3d..0371578 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -415,11 +415,11 @@ private: // check `dstLvl < dstLvlRank` at the top; and only here need to // assert that `reassoc.size() == dstLvlRank`. assert(dstLvl < reassoc.size() && "Level is out-of-bounds"); - const auto srcLvls = reassoc[dstLvl].cast(); + const auto srcLvls = cast(reassoc[dstLvl]); return llvm::to_vector<2>( llvm::map_range(srcLvls, [&](Attribute srcLvl) -> Level { // TODO: replace this with the converter for `LevelAttr`. - return srcLvl.cast().getValue().getZExtValue(); + return cast(srcLvl).getValue().getZExtValue(); })); } return {dstLvl}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index c99c26b..bb52e08 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -100,7 +100,7 @@ static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, /// completion. Needs to cast the buffer to a unranked buffer. static Value genHostRegisterMemref(OpBuilder &builder, Location loc, Value mem) { - MemRefType memTp = mem.getType().cast(); + MemRefType memTp = cast(mem.getType()); UnrankedMemRefType resTp = UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0); Value cast = builder.create(loc, resTp, mem); @@ -133,7 +133,7 @@ static void genBlockingWait(OpBuilder &builder, Location loc, /// that feature does not seem to be fully supported yet. static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem, Value token) { - auto tp = mem.getType().cast(); + auto tp = cast(mem.getType()); auto elemTp = tp.getElementType(); auto shape = tp.getShape(); auto memTp = MemRefType::get(shape, elemTp); @@ -304,7 +304,7 @@ struct ForallRewriter : public OpRewritePattern { for (OpOperand &o : op->getOpOperands()) { Value val = o.get(); Block *block; - if (auto arg = val.dyn_cast()) + if (auto arg = dyn_cast(val)) block = arg.getOwner(); else block = val.getDefiningOp()->getBlock(); @@ -321,7 +321,7 @@ struct ForallRewriter : public OpRewritePattern { Type tp = val.getType(); if (val.getDefiningOp()) constants.push_back(val); - else if (tp.isa() || tp.isIntOrIndex()) + else if (isa(tp) || tp.isIntOrIndex()) scalars.push_back(val); else if (isa(tp)) buffers.push_back(val); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp index 0c68c4d..f34ed97 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -111,9 +111,9 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); if (!source) { - auto memSizeArrayType = structType.cast() - .getBody()[kMemSizePosInSpecifier] - .cast(); + auto memSizeArrayType = + cast(cast(structType) + .getBody()[kMemSizePosInSpecifier]); Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); // Fill memSizes array with zero. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index aebf054..88f79bf 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -80,7 +80,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx) { idx = genCast(builder, loc, idx, builder.getIndexType()); val = genCast(builder, loc, val, - mem.getType().cast().getElementType()); + cast(mem.getType()).getElementType()); builder.create(loc, val, mem, idx); } @@ -253,7 +253,7 @@ static void createAllocFields(OpBuilder &builder, Location loc, case SparseTensorFieldKind::CrdMemRef: case SparseTensorFieldKind::ValMemRef: field = createAllocation( - builder, loc, fType.cast(), + builder, loc, cast(fType), (fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic : (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic : valHeuristic, @@ -779,7 +779,7 @@ public: fields.reserve(desc.getNumFields()); // Memcpy on memref fields. for (auto field : desc.getMemRefFields()) { - auto memrefTp = field.getType().cast(); + auto memrefTp = cast(field.getType()); auto size = rewriter.create(loc, field, 0); auto copied = rewriter.create(loc, memrefTp, ValueRange{size}); @@ -1128,7 +1128,7 @@ public: auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource()); SmallVector fields; foreachFieldAndTypeInSparseTensor( - SparseTensorType(op.getResult().getType().cast()), + SparseTensorType(cast(op.getResult().getType())), [&rewriter, &fields, srcDesc, loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl, DimLevelType /*dlt*/) -> bool { @@ -1143,7 +1143,7 @@ public: // values. Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0); auto dstMem = rewriter.create( - loc, fTp.cast(), sz); + loc, cast(fTp), sz); if (fTp != srcMem.getType()) { // Converts elements type. scf::buildLoopNest( @@ -1397,7 +1397,7 @@ struct SparsePackOpConverter : public OpConversionPattern { } assert(field); - if (auto memrefTp = field.getType().dyn_cast(); + if (auto memrefTp = dyn_cast(field.getType()); memrefTp && memrefTp.getRank() > 1) { ReassociationIndices reassociation; for (int i = 0, e = memrefTp.getRank(); i < e; i++) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 8d0c854..906f700 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -399,7 +399,7 @@ static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType, /// (which can be either dim- or lvl-coords, depending on context). static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter, Value coords, Value elemPtr) { - Type elemTp = elemPtr.getType().cast().getElementType(); + Type elemTp = cast(elemPtr.getType()).getElementType(); SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)}; SmallVector params{iter, coords, elemPtr}; Type i1 = builder.getI1Type(); @@ -1045,7 +1045,7 @@ public: matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resTp = op.getType(); - Type posTp = resTp.cast().getElementType(); + Type posTp = cast(resTp).getElementType(); SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)}; Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel()); replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl}, @@ -1064,7 +1064,7 @@ public: ConversionPatternRewriter &rewriter) const override { // TODO: use `SparseTensorType::getCrdType` instead. Type resType = op.getType(); - const Type crdTp = resType.cast().getElementType(); + const Type crdTp = cast(resType).getElementType(); SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)}; Location loc = op->getLoc(); @@ -1096,7 +1096,7 @@ public: LogicalResult matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resType = op.getType().cast(); + auto resType = cast(op.getType()); rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType, adaptor.getOperands())); return success(); @@ -1113,7 +1113,7 @@ public: ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); // Query values array size for the actually stored values size. - Type eltType = op.getTensor().getType().cast().getElementType(); + Type eltType = cast(op.getTensor().getType()).getElementType(); auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType); Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands()); rewriter.replaceOpWithNewOp(op, values, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 2a4bbb0..ca27794 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -79,7 +79,7 @@ static bool isSampling(GenericOp op) { // Helper to detect chain of multiplications that do not involve x. static bool isMulChain(Value val, Value x) { - if (auto arg = val.dyn_cast()) + if (auto arg = dyn_cast(val)) return arg != x; if (auto *def = val.getDefiningOp()) { if (isa(def) || isa(def)) @@ -105,7 +105,7 @@ static bool isSumOfMul(GenericOp op) { // Helper to detect direct yield of a zero value. static bool isZeroYield(GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); - if (auto arg = yieldOp.getOperand(0).dyn_cast()) { + if (auto arg = dyn_cast(yieldOp.getOperand(0))) { if (arg.getOwner()->getParentOp() == op) { return isZeroValue(op->getOperand(arg.getArgNumber())); } @@ -719,7 +719,7 @@ private: bool fromSparseConst = false; if (auto constOp = op.getSource().getDefiningOp()) { - if (constOp.getValue().dyn_cast()) { + if (dyn_cast(constOp.getValue())) { fromSparseConst = true; } } @@ -972,7 +972,7 @@ public: // Special-case: for each over a sparse constant uses its own rewriting // rule. if (auto constOp = input.getDefiningOp()) { - if (auto attr = constOp.getValue().dyn_cast()) { + if (auto attr = dyn_cast(constOp.getValue())) { return genForeachOnSparseConstant(op, rewriter, attr); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h index 788ad28..a51fcc5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -450,7 +450,7 @@ inline Value genTuple(OpBuilder &builder, Location loc, inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { auto tuple = getTuple(tensor); - SparseTensorType stt(tuple.getResultTypes()[0].cast()); + SparseTensorType stt(cast(tuple.getResultTypes()[0])); return SparseTensorDescriptor(stt, tuple.getInputs()); } @@ -458,7 +458,7 @@ inline MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { auto tuple = getTuple(tensor); fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - SparseTensorType stt(tuple.getResultTypes()[0].cast()); + SparseTensorType stt(cast(tuple.getResultTypes()[0])); return MutSparseTensorDescriptor(stt, fields); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index afeabb3..681ba21 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -88,9 +88,9 @@ public: // Overrides method from AffineExprVisitor. void visitDimExpr(AffineDimExpr expr) { if (pickedDim == nullptr || - pickIterType == iterTypes[expr.getPosition()] - .cast() - .getValue()) { + pickIterType == + cast(iterTypes[expr.getPosition()]) + .getValue()) { pickedDim = expr; } } @@ -344,7 +344,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, // we can't use `getRankedTensorType`/`getSparseTensorType` here. // However, we don't need to handle `StorageSpecifierType`, so we // can use `SparseTensorType` once we guard against non-tensors. - const auto rtp = tensor.getType().dyn_cast(); + const auto rtp = dyn_cast(tensor.getType()); if (!rtp) return 0; const SparseTensorType stt(rtp); @@ -1243,7 +1243,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at, Location loc = op.getLoc(); if (atStart) { auto dynShape = {ShapedType::kDynamic}; - Type etp = tensor.getType().cast().getElementType(); + Type etp = cast(tensor.getType()).getElementType(); Type t1 = MemRefType::get(dynShape, etp); Type t2 = MemRefType::get(dynShape, builder.getI1Type()); Type t3 = MemRefType::get(dynShape, builder.getIndexType()); @@ -1833,7 +1833,7 @@ public: // required for sparse tensor slice rank reducing too. Level maxLvlRank = 0; for (auto operand : op.getOperands()) { - if (auto rtp = operand.getType().dyn_cast()) { + if (auto rtp = dyn_cast(operand.getType())) { maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank()); } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index e8558c1..ae31af0 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -1061,8 +1061,8 @@ bool Merger::maybeZero(ExprId e) const { if (expr.kind == TensorExp::Kind::kInvariant) { if (auto c = expr.val.getDefiningOp()) { ArrayAttr arrayAttr = c.getValue(); - return arrayAttr[0].cast().getValue().isZero() && - arrayAttr[1].cast().getValue().isZero(); + return cast(arrayAttr[0]).getValue().isZero() && + cast(arrayAttr[1]).getValue().isZero(); } if (auto c = expr.val.getDefiningOp()) return c.value() == 0; @@ -1077,7 +1077,7 @@ Type Merger::inferType(ExprId e, Value src) const { Type dtp = exp(e).val.getType(); // Inspect source type. For vector types, apply the same // vectorization to the destination type. - if (auto vtp = src.getType().dyn_cast()) + if (auto vtp = dyn_cast(src.getType())) return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); return dtp; } @@ -1085,7 +1085,7 @@ Type Merger::inferType(ExprId e, Value src) const { /// Ensures that sparse compiler can generate code for expression. static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { // Arguments are always admissible. - if (v.isa()) + if (isa(v)) return true; // Accept index anywhere. Operation *def = v.getDefiningOp(); @@ -1113,7 +1113,7 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) { } std::optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { - if (auto arg = v.dyn_cast()) { + if (auto arg = dyn_cast(v)) { const TensorId tid = makeTensorId(arg.getArgNumber()); // Any argument of the generic op that is not marked as a scalar // argument is considered a tensor, indexed by the implicit loop @@ -1346,8 +1346,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, case TensorExp::Kind::kAbsF: return rewriter.create(loc, v0); case TensorExp::Kind::kAbsC: { - auto type = v0.getType().cast(); - auto eltType = type.getElementType().cast(); + auto type = cast(v0.getType()); + auto eltType = cast(type.getElementType()); return rewriter.create(loc, eltType, v0); } case TensorExp::Kind::kAbsI: @@ -1407,13 +1407,13 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0, case TensorExp::Kind::kTruncI: return rewriter.create(loc, inferType(e, v0), v0); case TensorExp::Kind::kCIm: { - auto type = v0.getType().cast(); - auto eltType = type.getElementType().cast(); + auto type = cast(v0.getType()); + auto eltType = cast(type.getElementType()); return rewriter.create(loc, eltType, v0); } case TensorExp::Kind::kCRe: { - auto type = v0.getType().cast(); - auto eltType = type.getElementType().cast(); + auto type = cast(v0.getType()); + auto eltType = cast(type.getElementType()); return rewriter.create(loc, eltType, v0); } case TensorExp::Kind::kBitCast: diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 57e5df4..d93d886 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -60,20 +60,20 @@ struct CastOpInterface // type in case the input is an unranked tensor type. // Case 1: Casting an unranked tensor - if (castOp.getSource().getType().isa()) { + if (isa(castOp.getSource().getType())) { // When casting to a ranked tensor, we cannot infer any static offset or // strides from the source. Assume fully dynamic. return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 2: Casting to an unranked tensor type - if (castOp.getType().isa()) { + if (isa(castOp.getType())) { return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); } // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not // change. - auto rankedResultType = castOp.getType().cast(); + auto rankedResultType = cast(castOp.getType()); return MemRefType::get( rankedResultType.getShape(), rankedResultType.getElementType(), maybeSrcBufferType->cast().getLayout(), memorySpace); @@ -158,7 +158,7 @@ struct CollapseShapeOpInterface if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; - auto bufferType = buffer.getType().cast(); + auto bufferType = cast(buffer.getType()); if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. @@ -383,11 +383,9 @@ struct ExtractSliceOpInterface SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - return memref::SubViewOp::inferRankReducedResultType( - extractSliceOp.getType().getShape(), - srcMemrefType->cast(), mixedOffsets, mixedSizes, - mixedStrides) - .cast(); + return cast(memref::SubViewOp::inferRankReducedResultType( + extractSliceOp.getType().getShape(), srcMemrefType->cast(), + mixedOffsets, mixedSizes, mixedStrides)); } }; @@ -459,7 +457,7 @@ struct FromElementsOpInterface auto fromElementsOp = cast(op); // Should the buffer be deallocated? bool dealloc = shouldDeallocateOpResult( - fromElementsOp.getResult().cast(), options); + cast(fromElementsOp.getResult()), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) @@ -467,7 +465,7 @@ struct FromElementsOpInterface // Allocate a buffer for the result. Location loc = op->getLoc(); - auto tensorType = fromElementsOp.getType().cast(); + auto tensorType = cast(fromElementsOp.getType()); auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. FailureOr tensorAlloc = @@ -540,7 +538,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, ValueRange dynamicSizes, Region &generateBody) { assert(generateBody.hasOneBlock() && "expected body with single block"); - auto tensorType = tensorDestination.getType().cast(); + auto tensorType = cast(tensorDestination.getType()); assert(generateBody.getNumArguments() == tensorType.getRank() && "rank mismatch"); @@ -579,7 +577,7 @@ struct GenerateOpInterface auto generateOp = cast(op); // Should the buffer be deallocated? bool dealloc = shouldDeallocateOpResult( - generateOp.getResult().cast(), options); + cast(generateOp.getResult()), options); // TODO: Implement memory space for this op. if (options.defaultMemorySpace != Attribute()) @@ -800,12 +798,11 @@ struct InsertSliceOpInterface return failure(); // Take a subview of the destination buffer. - auto dstMemrefType = dstMemref->getType().cast(); + auto dstMemrefType = cast(dstMemref->getType()); auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( + cast(memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getShape(), dstMemrefType, - mixedOffsets, mixedSizes, mixedStrides) - .cast(); + mixedOffsets, mixedSizes, mixedStrides)); Value subView = rewriter.create( loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, mixedStrides); @@ -900,7 +897,7 @@ struct PadOpInterface // Should the buffer be deallocated? bool dealloc = - shouldDeallocateOpResult(padOp.getResult().cast(), options); + shouldDeallocateOpResult(cast(padOp.getResult()), options); // Allocate a buffer for the padded result. FailureOr tensorAlloc = allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), @@ -992,7 +989,7 @@ struct ReshapeOpInterface return failure(); auto resultMemRefType = getMemRefType( reshapeOp.getResult(), options, /*layout=*/{}, - srcBuffer->getType().cast().getMemorySpace()); + cast(srcBuffer->getType()).getMemorySpace()); replaceOpWithNewBufferizedOp( rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); return success(); @@ -1039,14 +1036,13 @@ struct ParallelInsertSliceOpInterface return failure(); // Take a subview of the destination buffer. - auto destBufferType = destBuffer->getType().cast(); + auto destBufferType = cast(destBuffer->getType()); auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( + cast(memref::SubViewOp::inferRankReducedResultType( parallelInsertSliceOp.getSourceType().getShape(), destBufferType, parallelInsertSliceOp.getMixedOffsets(), parallelInsertSliceOp.getMixedSizes(), - parallelInsertSliceOp.getMixedStrides()) - .cast(); + parallelInsertSliceOp.getMixedStrides())); Value subview = rewriter.create( parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer, parallelInsertSliceOp.getMixedOffsets(), diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp index b5e75e0..968d68e 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -29,7 +29,7 @@ using namespace mlir::tensor; /// Get the dimension size of a value of RankedTensor type at the static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor, int64_t dimIdx) { - RankedTensorType tensorType = rankedTensor.getType().cast(); + RankedTensorType tensorType = cast(rankedTensor.getType()); if (!tensorType.isDynamicDim(dimIdx)) { return b.getIndexAttr(tensorType.getDimSize(dimIdx)); } @@ -41,7 +41,7 @@ static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, static SmallVector getShapeDimSizes(OpBuilder &b, Location loc, Value rankedTensor) { SmallVector dimSizes; - RankedTensorType tensorType = rankedTensor.getType().cast(); + RankedTensorType tensorType = cast(rankedTensor.getType()); for (unsigned i = 0; i < tensorType.getRank(); i++) dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i)); return dimSizes; diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 71dddd1..4ecb800 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -44,7 +44,7 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source, SmallVector mlir::tensor::createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor) { - auto tensorTy = rankedTensor.getType().cast(); + auto tensorTy = cast(rankedTensor.getType()); SmallVector dynamicDims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { if (en.value() == ShapedType::kDynamic) @@ -57,7 +57,7 @@ SmallVector mlir::tensor::createDynamicDimValues(OpBuilder &b, FailureOr mlir::tensor::createDimValue(OpBuilder &b, Location loc, Value rankedTensor, int64_t dim) { - auto tensorTy = rankedTensor.getType().dyn_cast(); + auto tensorTy = dyn_cast(rankedTensor.getType()); if (!tensorTy) return failure(); auto shape = tensorTy.getShape(); @@ -70,7 +70,7 @@ FailureOr mlir::tensor::createDimValue(OpBuilder &b, Location loc, SmallVector mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) { - auto tensorTy = rankedTensor.getType().cast(); + auto tensorTy = cast(rankedTensor.getType()); SmallVector dims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { if (ShapedType::isDynamic(en.value())) { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp index 7b47338..44f64f7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -34,9 +34,9 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getType().cast(); + ShapedType inputType = cast(input.getType()); + ShapedType weightType = cast(weight.getType()); + ShapedType resultType = cast(op.getType()); auto numDynamic = llvm::count_if(inputType.getShape(), ShapedType::isDynamic); @@ -66,7 +66,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { auto quantizationInfo = op.getQuantizationInfo(); int64_t iZp = quantizationInfo->getInputZp(); - if (!validIntegerRange(inputETy.cast(), iZp)) + if (!validIntegerRange(cast(inputETy), iZp)) return rewriter.notifyMatchFailure( op, "tosa.conv op quantization has zp outside of input range"); @@ -116,7 +116,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { weightShape[3]}; auto revisedWeightShapeType = RankedTensorType::get( revisedWeightShape, - weight.getType().dyn_cast().getElementType()); + dyn_cast(weight.getType()).getElementType()); auto reshapedWeight = rewriter .create( op.getLoc(), revisedWeightShapeType, weight, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 81ec7fd..488e46d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -28,9 +28,9 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { PatternRewriter &rewriter) const override { Value input = op.getInput(); Value weight = op.getWeight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getOutput().getType().cast(); + ShapedType inputType = cast(input.getType()); + ShapedType weightType = cast(weight.getType()); + ShapedType resultType = cast(op.getOutput().getType()); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { @@ -52,7 +52,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; inputType = RankedTensorType::get( revisedInputShape, - input.getType().dyn_cast().getElementType()); + dyn_cast(input.getType()).getElementType()); input = rewriter .create( op.getLoc(), inputType, input, @@ -76,7 +76,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { auto applyZp = [&](Value val, int64_t zp) -> Value { if (zp == 0) return val; - auto ety = val.getType().cast().getElementType(); + auto ety = cast(val.getType()).getElementType(); auto zpTy = RankedTensorType::get({}, ety); auto zpAttr = DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp)); @@ -126,17 +126,17 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]}; auto mulShapeType = RankedTensorType::get( mulShape, - weight.getType().dyn_cast().getElementType()); + dyn_cast(weight.getType()).getElementType()); Value mulValue = rewriter .create(op.getLoc(), mulShapeType, input, weight, /*shift=*/0) .getResult(); // Reshape output to [N, H, W, C * M]. - auto outputShape = op.getOutput().getType().cast().getShape(); + auto outputShape = cast(op.getOutput().getType()).getShape(); auto outputShapeType = RankedTensorType::get( outputShape, - input.getType().dyn_cast().getElementType()); + dyn_cast(input.getType()).getElementType()); auto outputValue = rewriter.create( op.getLoc(), outputShapeType, mulValue, rewriter.getDenseI64ArrayAttr(outputShape)); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 74533de..87563c1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -56,7 +56,7 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy, // Compute the knowledge based on the inferred type. auto inferredKnowledge = mlir::tosa::ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = resultTy.cast().getElementType(); + inferredKnowledge.dtype = cast(resultTy).getElementType(); inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { @@ -83,10 +83,10 @@ public: Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); llvm::ArrayRef stride = op.getStride(); llvm::ArrayRef pad = op.getOutPad(); @@ -146,10 +146,10 @@ public: Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + ShapedType inputTy = cast(input.getType()); + ShapedType weightTy = cast(weight.getType()); + ShapedType biasTy = cast(bias.getType()); + ShapedType resultTy = cast(op->getResult(0).getType()); Type inputETy = inputTy.getElementType(); Type weightETy = weightTy.getElementType(); @@ -202,7 +202,7 @@ public: weight, weightPaddingVal); } - weightTy = weight.getType().cast(); + weightTy = cast(weight.getType()); weightHeight = weightTy.getDimSize(1); weightWidth = weightTy.getDimSize(2); @@ -231,7 +231,7 @@ public: weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter.getDenseI64ArrayAttr(weightReshapeDims1)); - ShapedType restridedWeightTy = weight.getType().cast(); + ShapedType restridedWeightTy = cast(weight.getType()); weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, @@ -297,7 +297,7 @@ public: } // Factor the resulting width / height. - ShapedType convTy = conv2d.getType().cast(); + ShapedType convTy = cast(conv2d.getType()); Type convETy = convTy.getElementType(); int64_t convHeight = convTy.getDimSize(1); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp index 9e2102e..302e279 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -72,7 +72,7 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, auto baseType = inputType.getElementType(); // Handle possible integer types - if (auto intType = baseType.dyn_cast()) { + if (auto intType = dyn_cast(baseType)) { switch (intType.getWidth()) { case 1: return transposeType(attr, inputType, outputType, permValues); @@ -102,7 +102,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override { - auto outputType = op.getType().cast(); + auto outputType = cast(op.getType()); // TOSA supports quantized types. if (!outputType.getElementType().isIntOrIndexOrFloat()) return failure(); @@ -122,7 +122,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern { permAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); - auto inputType = op.getInput1().getType().cast(); + auto inputType = cast(op.getInput1().getType()); auto resultAttr = transpose(inputValues, inputType, outputType, permValues); rewriter.replaceOpWithNewOp(op, outputType, resultAttr); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp index 0c03cec..3e2da9d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -54,7 +54,7 @@ void propagateShapesToTosaIf( for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { auto inferredTy = shapesStorage[op.getOperand(i)]; auto blockArg = frontBlock.getArgument(i - 1); - auto oldType = blockArg.getType().cast(); + auto oldType = cast(blockArg.getType()); if (inferredTy.hasRank()) { Type newType = oldType.clone(inferredTy.getDims()); @@ -89,7 +89,7 @@ void propagateShapesToTosaWhile( // loop body / condition for tosa.while. llvm::SmallVector argTypes; for (auto operand : op.getOperands()) { - auto operandTy = operand.getType().cast(); + auto operandTy = cast(operand.getType()); auto shapedTypeComponent = shapesStorage[operand]; if (shapedTypeComponent.hasRank()) { auto newTy = operandTy.clone(shapedTypeComponent.getDims()); @@ -188,7 +188,7 @@ void propagateShapesToTosaWhile( void propagateShapesInRegion(Region ®ion) { DenseMap shapesStorage; auto setShapes = [&](Value val, Type t) { - if (auto st = t.dyn_cast()) + if (auto st = dyn_cast(t)) shapesStorage[val] = st; else shapesStorage[val] = t; @@ -247,8 +247,7 @@ void propagateShapesInRegion(Region ®ion) { // Compute the knowledge based on the inferred type. auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = - resultTy.cast().getElementType(); + inferredKnowledge.dtype = cast(resultTy).getElementType(); inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { @@ -274,7 +273,7 @@ void propagateShapesInRegion(Region ®ion) { for (auto it : shapesStorage) { auto result = it.second; if (result.hasRank()) { - Type t = it.first.getType().cast().clone(result.getDims()); + Type t = cast(it.first.getType()).clone(result.getDims()); it.first.setType(t); } } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index b18e3b4..bcfcbbb 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -82,8 +82,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value &input1, Value &input2) { - auto input1Ty = input1.getType().dyn_cast(); - auto input2Ty = input2.getType().dyn_cast(); + auto input1Ty = dyn_cast(input1.getType()); + auto input2Ty = dyn_cast(input2.getType()); if (!input1Ty || !input2Ty) { return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); @@ -106,9 +106,9 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, } ArrayRef higherRankShape = - higherTensorValue.getType().cast().getShape(); + cast(higherTensorValue.getType()).getShape(); ArrayRef lowerRankShape = - lowerTensorValue.getType().cast().getShape(); + cast(lowerTensorValue.getType()).getShape(); SmallVector reshapeOutputShape; @@ -116,7 +116,7 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, .failed()) return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type"); - auto reshapeInputType = lowerTensorValue.getType().cast(); + auto reshapeInputType = cast(lowerTensorValue.getType()); auto reshapeOutputType = RankedTensorType::get( ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); @@ -155,7 +155,7 @@ struct ConvertTosaOp : public OpRewritePattern { Value input2 = tosaBinaryOp.getInput2(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); @@ -183,7 +183,7 @@ struct ConvertTosaOp : public OpRewritePattern { Value input2 = tosaBinaryOp.getInput2(); int32_t shift = tosaBinaryOp.getShift(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); @@ -214,7 +214,7 @@ struct ConvertTosaOp Value input2 = tosaBinaryOp.getInput2(); int32_t round = tosaBinaryOp.getRound(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); @@ -242,7 +242,7 @@ struct ConvertTosaOp : public OpRewritePattern { Value input3 = tosaOp.getOnFalse(); Value output = tosaOp.getResult(); - auto outputType = output.getType().dyn_cast(); + auto outputType = dyn_cast(output.getType()); if (!outputType) return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); @@ -265,9 +265,9 @@ struct ConvertTosaOp : public OpRewritePattern { tosaOp, "cannot rewrite as the rank of all operands is already aligned"); - int32_t result1Rank = input1.getType().cast().getRank(); - int32_t result2Rank = input2.getType().cast().getRank(); - int32_t result3Rank = input3.getType().cast().getRank(); + int32_t result1Rank = cast(input1.getType()).getRank(); + int32_t result2Rank = cast(input2.getType()).getRank(); + int32_t result3Rank = cast(input3.getType()).getRank(); if ((result1Rank != result2Rank) || (result2Rank != result3Rank)) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4cb727b..5605080 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -106,7 +106,7 @@ void TosaValidation::runOnOperation() { getOperation().walk([&](Operation *op) { for (Value operand : op->getOperands()) { if ((profileType == TosaProfileEnum::BaseInference) && - getElementTypeOrSelf(operand).isa()) { + isa(getElementTypeOrSelf(operand))) { return signalPassFailure(); } if (getElementTypeOrSelf(operand).isF64()) { diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 0b5fc45..1c4ae1f 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -116,16 +116,16 @@ ConvOpQuantizationAttr mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight) { - auto inputType = input.getType().dyn_cast(); - auto weightType = weight.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto weightType = dyn_cast(weight.getType()); if (!inputType || !weightType) return nullptr; auto inputQType = GET_UQTYPE(inputType); auto weightPerTensorQType = GET_UQTYPE(weightType); - auto weightPerAxisQType = weightType.getElementType() - .dyn_cast(); + auto weightPerAxisQType = + dyn_cast(weightType.getElementType()); // Weights must be either per-tensor quantized or per-axis quantized. assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && @@ -160,8 +160,8 @@ MatMulOpQuantizationAttr mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b) { - auto aType = a.getType().dyn_cast(); - auto bType = b.getType().dyn_cast(); + auto aType = dyn_cast(a.getType()); + auto bType = dyn_cast(b.getType()); if (!aType || !bType) return nullptr; @@ -189,8 +189,8 @@ UnaryOpQuantizationAttr mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType) { - auto inputType = input.getType().dyn_cast(); - auto outputType = outputRawType.dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto outputType = dyn_cast(outputRawType); if (!inputType || !outputType) return nullptr; @@ -215,7 +215,7 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, Value input) { - auto inputType = input.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); if (!inputType) return nullptr; @@ -235,8 +235,8 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight) { - auto inputType = input.getType().dyn_cast(); - auto weightType = weight.getType().dyn_cast(); + auto inputType = dyn_cast(input.getType()); + auto weightType = dyn_cast(weight.getType()); assert(inputType && weightType && "Could not extract input or weight tensors from Conv op"); @@ -250,7 +250,7 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); - auto outputShapedType = outputType.dyn_cast(); + auto outputShapedType = dyn_cast(outputType); assert(outputShapedType && "Could not extract output shape type from Conv op"); @@ -274,8 +274,8 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, auto convfunc = quant::ExpressedToQuantizedConverter::forInputType(inputDType); - auto minElems = minAttr.dyn_cast(); - auto maxElems = maxAttr.dyn_cast(); + auto minElems = dyn_cast(minAttr); + auto maxElems = dyn_cast(maxAttr); SmallVector min, max; @@ -291,12 +291,12 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, for (auto i : maxElems) max.push_back(FloatAttr::getValueAsDouble(i)); } else { // Just a single FP value. - auto minVal = minAttr.dyn_cast(); + auto minVal = dyn_cast(minAttr); if (minVal) min.push_back(minVal.getValueAsDouble()); else return {}; - auto maxVal = maxAttr.dyn_cast(); + auto maxVal = dyn_cast(maxAttr); if (maxVal) max.push_back(maxVal.getValueAsDouble()); else @@ -309,7 +309,7 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType, builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. - auto shape = inputDType.dyn_cast(); + auto shape = dyn_cast(inputDType); if (!shape) return {}; if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index a1a3032..2ae67dc 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -116,7 +116,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, /// Returns the shape of the given type. Scalars will be considered as having a /// shape with zero dimensions. static ArrayRef getShape(Type type) { - if (auto sType = type.dyn_cast()) + if (auto sType = dyn_cast(type)) return sType.getShape(); return {}; } @@ -142,8 +142,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2, // If one of the types is unranked tensor, then the other type shouldn't be // vector and the result should have unranked tensor type. - if (type1.isa() || type2.isa()) { - if (type1.isa() || type2.isa()) + if (isa(type1) || isa(type2)) { + if (isa(type1) || isa(type2)) return {}; return UnrankedTensorType::get(elementType); } @@ -151,7 +151,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2, // Returns the type kind if the given type is a vector or ranked tensor type. // Returns std::nullopt otherwise. auto getCompositeTypeKind = [](Type type) -> std::optional { - if (type.isa()) + if (isa(type)) return type.getTypeID(); return std::nullopt; }; @@ -189,8 +189,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2, template static std::tuple hasTensorOrVectorType(iterator_range types) { return std::make_tuple( - llvm::any_of(types, [](Type t) { return t.isa(); }), - llvm::any_of(types, [](Type t) { return t.isa(); })); + llvm::any_of(types, [](Type t) { return isa(t); }), + llvm::any_of(types, [](Type t) { return isa(t); })); } static bool isCompatibleInferredReturnShape(ArrayRef inferred, @@ -242,7 +242,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { return op->emitError("cannot broadcast vector with tensor"); auto rankedOperands = make_filter_range( - op->getOperandTypes(), [](Type t) { return t.isa(); }); + op->getOperandTypes(), [](Type t) { return isa(t); }); // If all operands are unranked, then all result shapes are possible. if (rankedOperands.empty()) @@ -261,7 +261,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { } auto rankedResults = make_filter_range( - op->getResultTypes(), [](Type t) { return t.isa(); }); + op->getResultTypes(), [](Type t) { return isa(t); }); // If all of the results are unranked then no further verification. if (rankedResults.empty()) diff --git a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp index 05fba01..45fa644 100644 --- a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp @@ -148,14 +148,14 @@ public: // TODO: when this ported to the dataflow analysis infra, we should have // proper support for region-based control flow. Operation *valueSource = - operand.get().isa() + isa(operand.get()) ? operand.get().getDefiningOp() : operand.get().getParentBlock()->getParentOp(); auto iface = cast(valueSource); SmallVector instances; iface.getEffectsOnResource(transform::TransformMappingResource::get(), instances); - assert((operand.get().isa() || + assert((isa(operand.get()) || hasEffect(instances, operand.get())) && "expected the op defining the value to have an allocation effect " "on it"); @@ -182,7 +182,7 @@ public: // value is defined in the middle of the block, i.e., is not a block // argument. bool isOutermost = ancestor == ancestors.front(); - bool isFromBlockPartial = isOutermost && operand.get().isa(); + bool isFromBlockPartial = isOutermost && isa(operand.get()); // Check if the value may be freed by operations between its definition // (allocation) point in its block and the terminator of the block or the diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 94fa2d3..8538892 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -162,7 +162,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute( SmallVector reassociationAttr = llvm::to_vector<4>(llvm::map_range( reassociation, [&](const ReassociationIndices &indices) -> Attribute { - return b.getI64ArrayAttr(indices).cast(); + return cast(b.getI64ArrayAttr(indices)); })); return b.getArrayAttr(reassociationAttr); } @@ -267,7 +267,7 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible( } bool mlir::hasNonIdentityLayout(Type type) { - if (auto memrefType = type.dyn_cast()) + if (auto memrefType = dyn_cast(type)) return !memrefType.getLayout().isIdentity(); return false; } diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 45edd5f..09137d3 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -19,7 +19,7 @@ bool isZeroIndex(OpFoldResult v) { if (!v) return false; if (auto attr = v.dyn_cast()) { - IntegerAttr intAttr = attr.dyn_cast(); + IntegerAttr intAttr = dyn_cast(attr); return intAttr && intAttr.getValue().isZero(); } if (auto cst = v.get().getDefiningOp()) @@ -53,7 +53,7 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr, SmallVectorImpl &staticVec) { auto v = ofr.dyn_cast(); if (!v) { - APInt apInt = ofr.get().cast().getValue(); + APInt apInt = cast(ofr.get()).getValue(); staticVec.push_back(apInt.getSExtValue()); return; } @@ -71,8 +71,8 @@ void dispatchIndexOpFoldResults(ArrayRef ofrs, /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. SmallVector extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); + llvm::map_range(cast(attr), [](Attribute a) -> int64_t { + return cast(a).getInt(); })); } @@ -124,7 +124,7 @@ std::optional getConstantIntValue(OpFoldResult ofr) { } // Case 2: Check for IntegerAttr. Attribute attr = ofr.dyn_cast(); - if (auto intAttr = attr.dyn_cast_or_null()) + if (auto intAttr = dyn_cast_or_null(attr)) return intAttr.getValue().getSExtValue(); return std::nullopt; } @@ -184,7 +184,7 @@ decomposeMixedValues(Builder &b, SmallVector dynamicValues; for (const auto &it : mixedValues) { if (it.is()) { - staticValues.push_back(it.get().cast().getInt()); + staticValues.push_back(cast(it.get()).getInt()); } else { staticValues.push_back(ShapedType::kDynamic); dynamicValues.push_back(it.get()); diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp index aed39f8..a2977901 100644 --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) { if (indexingMaps.size() != 3) return false; - auto map0 = indexingMaps[0].cast().getValue(); - auto map1 = indexingMaps[1].cast().getValue(); - auto map2 = indexingMaps[2].cast().getValue(); + auto map0 = cast(indexingMaps[0]).getValue(); + auto map1 = cast(indexingMaps[1]).getValue(); + auto map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || map2.getNumResults() != 2 || map0.getNumInputs() != 3 || @@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) { if (indexingMaps.size() != 3) return false; - auto map0 = indexingMaps[0].cast().getValue(); - auto map1 = indexingMaps[1].cast().getValue(); - auto map2 = indexingMaps[2].cast().getValue(); + auto map0 = cast(indexingMaps[0]).getValue(); + auto map1 = cast(indexingMaps[1]).getValue(); + auto map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 2 || map1.getNumResults() != 2 || map2.getNumResults() != 2 || map0.getNumInputs() != 3 || @@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) { if (indexingMaps.size() != 3) return false; - auto map0 = indexingMaps[0].cast().getValue(); - auto map1 = indexingMaps[1].cast().getValue(); - auto map2 = indexingMaps[2].cast().getValue(); + auto map0 = cast(indexingMaps[0]).getValue(); + auto map1 = cast(indexingMaps[1]).getValue(); + auto map2 = cast(indexingMaps[2]).getValue(); if (map0.getNumResults() != 3 || map1.getNumResults() != 3 || map2.getNumResults() != 3 || map0.getNumInputs() != 4 || diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index a643104..ad7e367 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -30,14 +30,14 @@ struct TransferReadOpInterface vector::TransferReadOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return false; } @@ -50,7 +50,7 @@ struct TransferReadOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto readOp = cast(op); - assert(readOp.getShapedType().isa() && + assert(isa(readOp.getShapedType()) && "only tensor types expected"); FailureOr buffer = getBuffer(rewriter, readOp.getSource(), options); if (failed(buffer)) @@ -74,7 +74,7 @@ struct TransferWriteOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto writeOp = cast(op); - assert(writeOp.getShapedType().isa() && + assert(isa(writeOp.getShapedType()) && "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. @@ -99,14 +99,14 @@ struct GatherOpInterface vector::GatherOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return true; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && + assert(isa(opOperand.get().getType()) && "only tensor types expected"); return false; } @@ -119,7 +119,7 @@ struct GatherOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto gatherOp = cast(op); - assert(gatherOp.getBaseType().isa() && + assert(isa(gatherOp.getBaseType()) && "only tensor types expected"); FailureOr buffer = getBuffer(rewriter, gatherOp.getBase(), options); if (failed(buffer)) @@ -266,7 +266,7 @@ struct YieldOpInterface // may get dropped during the bufferization of vector.mask. SmallVector newResults; for (Value value : yieldOp.getOperands()) { - if (value.getType().isa()) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options); if (failed(maybeBuffer)) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index ad538fe..7c606e0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -49,7 +49,7 @@ public: PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType dstType = op.getResultVectorType(); - VectorType srcType = op.getSourceType().dyn_cast(); + VectorType srcType = dyn_cast(op.getSourceType()); Type eltType = dstType.getElementType(); // Scalar to any vector can use splat. diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 16751f8..986c5f8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -96,9 +96,9 @@ static Value reshapeLoad(Location loc, Value val, VectorType type, return rewriter.create(loc, lowType, val, posAttr); } // Unroll leading dimensions. - VectorType vType = lowType.cast(); + VectorType vType = cast(lowType); Type resType = VectorType::Builder(type).dropDim(index); - auto resVectorType = resType.cast(); + auto resVectorType = cast(resType); Value result = rewriter.create( loc, resVectorType, rewriter.getZeroAttr(resVectorType)); for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { @@ -126,7 +126,7 @@ static Value reshapeStore(Location loc, Value val, Value result, } // Unroll leading dimensions. Type lowType = VectorType::Builder(type).dropDim(0); - VectorType vType = lowType.cast(); + VectorType vType = cast(lowType); Type insType = VectorType::Builder(vType).dropDim(0); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); @@ -160,7 +160,7 @@ createContractArithOp(Location loc, Value x, Value y, Value acc, // Only valid for integer types. return std::nullopt; // Special case for fused multiply-add. - if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { + if (acc && isa(acc.getType()) && kind == CombiningKind::ADD) { Value fma = rewriter.create(loc, x, y, acc); if (mask) // The fma op doesn't need explicit masking. However, fma ops used in @@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator Value promote(Value v, Type dstElementType) { Type elementType = v.getType(); - auto vecType = elementType.dyn_cast(); + auto vecType = dyn_cast(elementType); if (vecType) elementType = vecType.getElementType(); if (elementType == dstElementType) @@ -426,7 +426,7 @@ struct UnrolledOuterProductGenerator Type promotedType = dstElementType; if (vecType) promotedType = VectorType::get(vecType.getShape(), promotedType); - if (dstElementType.isa()) + if (isa(dstElementType)) return rewriter.create(loc, promotedType, v); return rewriter.create(loc, promotedType, v); } @@ -438,7 +438,7 @@ struct UnrolledOuterProductGenerator if (mask && !maybeMask.has_value()) return failure(); - Type resElementType = res.getType().cast().getElementType(); + Type resElementType = cast(res.getType()).getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { Value extractA = rewriter.create(loc, lhs, k); Value extractB = rewriter.create(loc, rhs, k); @@ -684,7 +684,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, return failure(); } - VectorType dstType = op.getResultType().cast(); + VectorType dstType = cast(op.getResultType()); assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && "Expected dst type of rank 1 or 2"); @@ -695,7 +695,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, // ExtractOp does not allow dynamic indexing, we must unroll explicitly. Value res = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); - bool isInt = dstType.getElementType().isa(); + bool isInt = isa(dstType.getElementType()); for (unsigned r = 0; r < dstRows; ++r) { Value a = rewriter.create(op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { @@ -789,7 +789,7 @@ struct ContractOpToElementwise } else { // If the parallel dimension doesn't exist we will have to broadcast it. lhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); + cast(contractOp.getResultType()).getDimSize(i)); lhsTranspose.push_back(lhsDims.size() - 1); } std::optional rhsDim = @@ -799,7 +799,7 @@ struct ContractOpToElementwise } else { // If the parallel dimension doesn't exist we will have to broadcast it. rhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); + cast(contractOp.getResultType()).getDimSize(i)); rhsTranspose.push_back(rhsDims.size() - 1); } } @@ -969,7 +969,7 @@ FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, Value mask) const { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - VectorType resType = op.getResultType().cast(); + VectorType resType = cast(op.getResultType()); // Find the iterator type index and result index. SmallVector iMap = op.getIndexingMapsArray(); int64_t iterIndex = -1; @@ -1044,10 +1044,10 @@ FailureOr ContractionOpLowering::lowerReduction( VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); - if (resType.isa()) + if (isa(resType)) return rewriter.notifyMatchFailure(op, "did not expect a VectorType result"); - bool isInt = resType.isa(); + bool isInt = isa(resType); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMapsArray(); @@ -1133,10 +1133,10 @@ public: auto loc = op.getLoc(); VectorType lhsType = op.getOperandVectorTypeLHS(); - VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); + VectorType rhsType = dyn_cast(op.getOperandTypeRHS()); VectorType resType = op.getResultVectorType(); Type eltType = resType.getElementType(); - bool isInt = eltType.isa(); + bool isInt = isa(eltType); Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; vector::CombiningKind kind = op.getKind(); @@ -1231,7 +1231,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, return failure(); Type dstElementType = op.getType(); - if (auto vecType = dstElementType.dyn_cast()) + if (auto vecType = dyn_cast(dstElementType)) dstElementType = vecType.getElementType(); if (elementType != dstElementType) return failure(); @@ -1259,8 +1259,8 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, return failure(); // At this point lhs and rhs are in row-major. - VectorType lhsType = lhs.getType().cast(); - VectorType rhsType = rhs.getType().cast(); + VectorType lhsType = cast(lhs.getType()); + VectorType rhsType = cast(rhs.getType()); int64_t lhsRows = lhsType.getDimSize(0); int64_t lhsColumns = lhsType.getDimSize(1); int64_t rhsColumns = rhsType.getDimSize(1); @@ -1289,7 +1289,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, llvm_unreachable("invalid contraction semantics"); Value res = - elementType.isa() + isa(elementType) ? static_cast(rew.create(loc, op.getAcc(), mul)) : static_cast( rew.create(loc, op.getAcc(), mul)); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 3f26558..a0ed056 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -52,7 +52,7 @@ public: LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { - auto dstType = op.getResult().getType().cast(); + auto dstType = cast(op.getResult().getType()); int64_t rank = dstType.getRank(); if (rank <= 1) return rewriter.notifyMatchFailure( @@ -112,7 +112,7 @@ public: if (rank == 0) { assert(dimSizes.size() == 1 && "Expected exactly one dim size for a 0-D vector"); - bool value = dimSizes[0].cast().getInt() == 1; + bool value = cast(dimSizes[0]).getInt() == 1; rewriter.replaceOpWithNewOp( op, dstType, DenseIntElementsAttr::get( @@ -122,14 +122,14 @@ public: } // Scalable constant masks can only be lowered for the "none set" case. - if (dstType.cast().isScalable()) { + if (cast(dstType).isScalable()) { rewriter.replaceOpWithNewOp( op, DenseElementsAttr::get(dstType, false)); return success(); } int64_t trueDim = std::min(dstType.getDimSize(0), - dimSizes[0].cast().getInt()); + cast(dimSizes[0]).getInt()); if (rank == 1) { // Express constant 1-D case in explicit vector form: @@ -146,7 +146,7 @@ public: VectorType::get(dstType.getShape().drop_front(), eltType); SmallVector newDimSizes; for (int64_t r = 1; r < rank; r++) - newDimSizes.push_back(dimSizes[r].cast().getInt()); + newDimSizes.push_back(cast(dimSizes[r]).getInt()); Value trueVal = rewriter.create( loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); Value result = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index eb2deba..463aab1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -48,7 +48,7 @@ static Value genOperator(Location loc, Value x, Value y, PatternRewriter &rewriter) { using vector::CombiningKind; - auto elType = x.getType().cast().getElementType(); + auto elType = cast(x.getType()).getElementType(); bool isInt = elType.isIntOrIndex(); Value combinedResult{nullptr}; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index f15d0c8..4f68526 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -29,7 +29,7 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, size_t index = 0; for (unsigned pos : permutation) newInBoundsValues[pos] = - attr.getValue()[index++].cast().getValue(); + cast(attr.getValue()[index++]).getValue(); return builder.getBoolArrayAttr(newInBoundsValues); } @@ -37,7 +37,7 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, /// dimensions. static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, int64_t addedRank) { - auto originalVecType = vec.getType().cast(); + auto originalVecType = cast(vec.getType()); SmallVector newShape(addedRank, 1); newShape.append(originalVecType.getShape().begin(), originalVecType.getShape().end()); @@ -257,7 +257,7 @@ struct TransferWriteNonPermutationLowering // All the new dimensions added are inbound. SmallVector newInBoundsValues(missingInnerDim.size(), true); for (Attribute attr : op.getInBounds().value().getValue()) { - newInBoundsValues.push_back(attr.cast().getValue()); + newInBoundsValues.push_back(cast(attr).getValue()); } newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); } @@ -315,7 +315,7 @@ struct TransferOpReduceRank : public OpRewritePattern { // In the meantime, lower these to a scalar load when they pop up. if (reducedShapeRank == 0) { Value newRead; - if (op.getShapedType().isa()) { + if (isa(op.getShapedType())) { newRead = rewriter.create( op.getLoc(), op.getSource(), op.getIndices()); } else { @@ -397,7 +397,7 @@ struct TransferReadToVectorLoadLowering &broadcastedDims)) return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); - auto memRefType = read.getShapedType().dyn_cast(); + auto memRefType = dyn_cast(read.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(read, "not a memref source"); @@ -418,11 +418,11 @@ struct TransferReadToVectorLoadLowering // `vector.load` supports vector types as memref's elements only when the // resulting vector type is the same as the element type. auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) + if (isa(memrefElTy) && memrefElTy != unbroadcastedVectorType) return rewriter.notifyMatchFailure(read, "incompatible element type"); // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && + if (!isa(memrefElTy) && memrefElTy != read.getVectorType().getElementType()) return rewriter.notifyMatchFailure(read, "non-matching element type"); @@ -543,7 +543,7 @@ struct TransferWriteToVectorStoreLowering diag << "permutation map is not minor identity: " << write; }); - auto memRefType = write.getShapedType().dyn_cast(); + auto memRefType = dyn_cast(write.getShapedType()); if (!memRefType) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "not a memref type: " << write; @@ -558,13 +558,13 @@ struct TransferWriteToVectorStoreLowering // `vector.store` supports vector types as memref's elements only when the // type of the vector value being written is the same as the element type. auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != write.getVectorType()) + if (isa(memrefElTy) && memrefElTy != write.getVectorType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; }); // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && + if (!isa(memrefElTy) && memrefElTy != write.getVectorType().getElementType()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "elemental type mismatch: " << write; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 42c1aa5..7d804dd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -156,7 +156,7 @@ static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, /// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { - assert(v1.getType().cast().getShape()[0] == 16 && + assert(cast(v1.getType()).getShape()[0] == 16 && "expected a vector with length=16"); SmallVector shuffleMask; auto appendToMask = [&](int64_t base, uint8_t control) { @@ -291,7 +291,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd); auto reshInputType = VectorType::get( - {m, n}, source.getType().cast().getElementType()); + {m, n}, cast(source.getType()).getElementType()); Value res = b.create(reshInputType, b.getZeroAttr(reshInputType)); for (int64_t i = 0; i < m; ++i) @@ -329,7 +329,7 @@ public: // Set up convenience transposition table. SmallVector transp; for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); + transp.push_back(cast(attr).getInt()); if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) && succeeded(isTranspose2DSlice(op))) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 2b5706a..e56aa62 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -62,8 +62,8 @@ struct DistributedLoadStoreHelper { Value laneId, Value zero) : sequentialVal(sequentialVal), distributedVal(distributedVal), laneId(laneId), zero(zero) { - sequentialVectorType = sequentialVal.getType().dyn_cast(); - distributedVectorType = distributedVal.getType().dyn_cast(); + sequentialVectorType = dyn_cast(sequentialVal.getType()); + distributedVectorType = dyn_cast(distributedVal.getType()); if (sequentialVectorType && distributedVectorType) distributionMap = calculateImplicitMap(sequentialVectorType, distributedVectorType); @@ -89,7 +89,7 @@ struct DistributedLoadStoreHelper { "Must store either the preregistered distributed or the " "preregistered sequential value."); // Scalar case can directly use memref.store. - if (!val.getType().isa()) + if (!isa(val.getType())) return b.create(loc, val, buffer, zero); // Vector case must use vector::TransferWriteOp which will later lower to @@ -131,7 +131,7 @@ struct DistributedLoadStoreHelper { Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { // Scalar case can directly use memref.store. - if (!type.isa()) + if (!isa(type)) return b.create(loc, buffer, zero); // Other cases must be vector atm. @@ -149,7 +149,7 @@ struct DistributedLoadStoreHelper { } SmallVector inBounds(indices.size(), true); return b.create( - loc, type.cast(), buffer, indices, + loc, cast(type), buffer, indices, ArrayRef(inBounds.begin(), inBounds.end())); } @@ -630,14 +630,14 @@ struct WarpOpElementwise : public OpRewritePattern { Location loc = warpOp.getLoc(); for (OpOperand &operand : elementWise->getOpOperands()) { Type targetType; - if (auto vecType = distributedVal.getType().dyn_cast()) { + if (auto vecType = dyn_cast(distributedVal.getType())) { // If the result type is a vector, the operands must also be vectors. - auto operandType = operand.get().getType().cast(); + auto operandType = cast(operand.get().getType()); targetType = VectorType::get(vecType.getShape(), operandType.getElementType()); } else { auto operandType = operand.get().getType(); - assert(!operandType.isa() && + assert(!isa(operandType) && "unexpected yield of vector from op with scalar result type"); targetType = operandType; } @@ -687,7 +687,7 @@ struct WarpOpConstant : public OpRewritePattern { if (!yieldOperand) return failure(); auto constantOp = yieldOperand->get().getDefiningOp(); - auto dense = constantOp.getValue().dyn_cast(); + auto dense = dyn_cast(constantOp.getValue()); if (!dense) return failure(); unsigned operandIndex = yieldOperand->getOperandNumber(); @@ -737,8 +737,8 @@ struct WarpOpTransferRead : public OpRewritePattern { SmallVector indices(read.getIndices().begin(), read.getIndices().end()); - auto sequentialType = read.getResult().getType().cast(); - auto distributedType = distributedVal.getType().cast(); + auto sequentialType = cast(read.getResult().getType()); + auto distributedType = cast(distributedVal.getType()); AffineMap map = calculateImplicitMap(sequentialType, distributedType); AffineMap indexMap = map.compose(read.getPermutationMap()); OpBuilder::InsertionGuard g(rewriter); @@ -752,7 +752,7 @@ struct WarpOpTransferRead : public OpRewritePattern { unsigned indexPos = indexExpr.getPosition(); unsigned vectorPos = std::get<1>(it).cast().getPosition(); int64_t scale = - distributedVal.getType().cast().getDimSize(vectorPos); + cast(distributedVal.getType()).getDimSize(vectorPos); indices[indexPos] = affine::makeComposedAffineApply( rewriter, read.getLoc(), d0 + scale * d1, {indices[indexPos], warpOp.getLaneid()}); @@ -845,7 +845,7 @@ struct WarpOpForwardOperand : public OpRewritePattern { resultIndex = operand.getOperandNumber(); break; } - auto arg = operand.get().dyn_cast(); + auto arg = dyn_cast(operand.get()); if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) continue; Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; @@ -874,7 +874,7 @@ struct WarpOpBroadcast : public OpRewritePattern { auto broadcastOp = operand->get().getDefiningOp(); Location loc = broadcastOp.getLoc(); auto destVecType = - warpOp->getResultTypes()[operandNumber].cast(); + cast(warpOp->getResultTypes()[operandNumber]); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, {broadcastOp.getSource()}, @@ -914,7 +914,7 @@ struct WarpOpExtract : public OpRewritePattern { // Rewrite vector.extract with 1d source to vector.extractelement. if (extractSrcType.getRank() == 1) { assert(extractOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = extractOp.getPosition()[0].cast().getInt(); + int64_t pos = cast(extractOp.getPosition()[0]).getInt(); rewriter.setInsertionPoint(extractOp); rewriter.replaceOpWithNewOp( extractOp, extractOp.getVector(), @@ -946,8 +946,8 @@ struct WarpOpExtract : public OpRewritePattern { // Find the distributed dimension. There should be exactly one. auto distributedType = - warpOp.getResult(operandNumber).getType().cast(); - auto yieldedType = operand->get().getType().cast(); + cast(warpOp.getResult(operandNumber).getType()); + auto yieldedType = cast(operand->get().getType()); int64_t distributedDim = -1; for (int64_t i = 0; i < yieldedType.getRank(); ++i) { if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { @@ -1083,7 +1083,7 @@ struct WarpOpInsertElement : public OpRewritePattern { auto insertOp = operand->get().getDefiningOp(); VectorType vecType = insertOp.getDestVectorType(); VectorType distrType = - warpOp.getResult(operandNumber).getType().cast(); + cast(warpOp.getResult(operandNumber).getType()); bool hasPos = static_cast(insertOp.getPosition()); // Yield destination vector, source scalar and position from warp op. @@ -1171,7 +1171,7 @@ struct WarpOpInsert : public OpRewritePattern { // Rewrite vector.insert with 1d dest to vector.insertelement. if (insertOp.getDestVectorType().getRank() == 1) { assert(insertOp.getPosition().size() == 1 && "expected 1 index"); - int64_t pos = insertOp.getPosition()[0].cast().getInt(); + int64_t pos = cast(insertOp.getPosition()[0]).getInt(); rewriter.setInsertionPoint(insertOp); rewriter.replaceOpWithNewOp( insertOp, insertOp.getSource(), insertOp.getDest(), @@ -1199,8 +1199,8 @@ struct WarpOpInsert : public OpRewritePattern { // Find the distributed dimension. There should be exactly one. auto distrDestType = - warpOp.getResult(operandNumber).getType().cast(); - auto yieldedType = operand->get().getType().cast(); + cast(warpOp.getResult(operandNumber).getType()); + auto yieldedType = cast(operand->get().getType()); int64_t distrDestDim = -1; for (int64_t i = 0; i < yieldedType.getRank(); ++i) { if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { @@ -1213,7 +1213,7 @@ struct WarpOpInsert : public OpRewritePattern { assert(distrDestDim != -1 && "could not find distributed dimension"); // Compute the distributed source vector type. - VectorType srcVecType = insertOp.getSourceType().cast(); + VectorType srcVecType = cast(insertOp.getSourceType()); SmallVector distrSrcShape(srcVecType.getShape().begin(), srcVecType.getShape().end()); // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> @@ -1248,7 +1248,7 @@ struct WarpOpInsert : public OpRewritePattern { int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); SmallVector newPos = llvm::to_vector( llvm::map_range(insertOp.getPosition(), [](Attribute attr) { - return attr.cast().getInt(); + return cast(attr).getInt(); })); // tid of inserting lane: pos / elementsPerLane Value insertingLane = rewriter.create( @@ -1337,7 +1337,7 @@ struct WarpOpScfForOp : public OpRewritePattern { if (!escapingValues.insert(operand->get())) return; Type distType = operand->get().getType(); - if (auto vecType = distType.cast()) { + if (auto vecType = cast(distType)) { AffineMap map = distributionMapFn(operand->get()); distType = getDistributedType(vecType, map, warpOp.getWarpSize()); } @@ -1359,7 +1359,7 @@ struct WarpOpScfForOp : public OpRewritePattern { for (OpOperand &yieldOperand : yield->getOpOperands()) { if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; - auto forResult = yieldOperand.get().cast(); + auto forResult = cast(yieldOperand.get()); newOperands.push_back( newWarpOp.getResult(yieldOperand.getOperandNumber())); yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); @@ -1463,7 +1463,7 @@ struct WarpOpReduction : public OpRewritePattern { auto reductionOp = cast(yieldOperand->get().getDefiningOp()); - auto vectorType = reductionOp.getVector().getType().cast(); + auto vectorType = cast(reductionOp.getVector().getType()); // Only rank 1 vectors supported. if (vectorType.getRank() != 1) return rewriter.notifyMatchFailure( @@ -1564,7 +1564,7 @@ void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { // operations from there. for (auto &op : body->without_terminator()) { bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { - return result.getType().isa(); + return isa(result.getType()); }); if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) opsToMove.insert(&op); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 6105e87..8b24441 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -136,10 +136,10 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { Type oldSrcType = insertOp.getSourceType(); Type newSrcType = oldSrcType; int64_t oldSrcRank = 0, newSrcRank = 0; - if (auto type = oldSrcType.dyn_cast()) { + if (auto type = dyn_cast(oldSrcType)) { newSrcType = trimLeadingOneDims(type); oldSrcRank = type.getRank(); - newSrcRank = newSrcType.cast().getRank(); + newSrcRank = cast(newSrcType).getRank(); } VectorType oldDstType = insertOp.getDestVectorType(); @@ -199,7 +199,7 @@ struct CastAwayTransferReadLeadingOneDim if (read.getMask()) return failure(); - auto shapedType = read.getSource().getType().cast(); + auto shapedType = cast(read.getSource().getType()); if (shapedType.getElementType() != read.getVectorType().getElementType()) return failure(); @@ -247,7 +247,7 @@ struct CastAwayTransferWriteLeadingOneDim if (write.getMask()) return failure(); - auto shapedType = write.getSource().getType().dyn_cast(); + auto shapedType = dyn_cast(write.getSource().getType()); if (shapedType.getElementType() != write.getVectorType().getElementType()) return failure(); @@ -284,7 +284,7 @@ struct CastAwayTransferWriteLeadingOneDim LogicalResult mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, RewriterBase &rewriter) { - VectorType oldAccType = contractOp.getAccType().dyn_cast(); + VectorType oldAccType = dyn_cast(contractOp.getAccType()); if (oldAccType == nullptr) return failure(); if (oldAccType.getRank() < 2) @@ -418,7 +418,7 @@ public: PatternRewriter &rewriter) const override { if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) return failure(); - auto vecType = op->getResultTypes()[0].dyn_cast(); + auto vecType = dyn_cast(op->getResultTypes()[0]); if (!vecType) return failure(); VectorType newVecType = trimLeadingOneDims(vecType); @@ -427,7 +427,7 @@ public: int64_t dropDim = vecType.getRank() - newVecType.getRank(); SmallVector newOperands; for (Value operand : op->getOperands()) { - if (auto opVecType = operand.getType().dyn_cast()) { + if (auto opVecType = dyn_cast(operand.getType())) { newOperands.push_back(rewriter.create( op->getLoc(), operand, splatZero(dropDim))); } else { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 313a3f9..37216ce 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -21,7 +21,7 @@ using namespace mlir::vector; // Helper that picks the proper sequence for inserting. static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset) { - auto vectorType = into.getType().cast(); + auto vectorType = cast(into.getType()); if (vectorType.getRank() > 1) return rewriter.create(loc, from, into, offset); return rewriter.create( @@ -32,7 +32,7 @@ static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, // Helper that picks the proper sequence for extracting. static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset) { - auto vectorType = vector.getType().cast(); + auto vectorType = cast(vector.getType()); if (vectorType.getRank() > 1) return rewriter.create(loc, vector, offset); return rewriter.create( @@ -134,10 +134,10 @@ public: } int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); int64_t size = srcType.getShape().front(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); auto loc = op.getLoc(); Value res = op.getDest(); @@ -174,7 +174,7 @@ public: off += stride, ++idx) { // 1. extract the proper subvector (or element) from source Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx); - if (extractedSource.getType().isa()) { + if (isa(extractedSource.getType())) { // 2. If we have a vector, extract the proper subvector from destination // Otherwise we are at the element level and no need to recurse. Value extractedDest = extractOne(rewriter, loc, op.getDest(), off); @@ -208,11 +208,10 @@ public: assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); - int64_t size = - op.getSizes().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); + int64_t size = cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); assert(dstType.getElementType().isSignlessIntOrIndexOrFloat()); @@ -254,11 +253,10 @@ public: return failure(); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); - int64_t size = - op.getSizes().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); + int64_t size = cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); Location loc = op.getLoc(); SmallVector elements; @@ -300,11 +298,10 @@ public: assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = - op.getOffsets().getValue().front().cast().getInt(); - int64_t size = - op.getSizes().getValue().front().cast().getInt(); + cast(op.getOffsets().getValue().front()).getInt(); + int64_t size = cast(op.getSizes().getValue().front()).getInt(); int64_t stride = - op.getStrides().getValue().front().cast().getInt(); + cast(op.getStrides().getValue().front()).getInt(); auto loc = op.getLoc(); auto elemType = dstType.getElementType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 3a06d9b..68d8c92 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -261,7 +261,7 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef offsets, llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( targetShape, inputType, offsets, sizes, strides); - return canonicalizeStridedLayout(rankReducedType.cast()); + return canonicalizeStridedLayout(cast(rankReducedType)); } /// Creates a rank-reducing memref.subview op that drops unit dims from its @@ -269,7 +269,7 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef offsets, static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, Value input) { - MemRefType inputType = input.getType().cast(); + MemRefType inputType = cast(input.getType()); assert(inputType.hasStaticShape()); SmallVector subViewOffsets(inputType.getRank(), 0); SmallVector subViewStrides(inputType.getRank(), 1); @@ -304,9 +304,9 @@ class TransferReadDropUnitDimsPattern PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferReadOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // TODO: support tensor types. if (!sourceType || !sourceType.hasStaticShape()) return failure(); @@ -347,9 +347,9 @@ class TransferWriteDropUnitDimsPattern PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); Value vector = transferWriteOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferWriteOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // TODO: support tensor type. if (!sourceType || !sourceType.hasStaticShape()) return failure(); @@ -406,7 +406,7 @@ static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, /// input starting at `firstDimToCollapse`. static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse) { - ShapedType inputType = input.getType().cast(); + ShapedType inputType = cast(input.getType()); if (inputType.getRank() == 1) return input; SmallVector reassociation; @@ -451,9 +451,9 @@ class FlattenContiguousRowMajorTransferReadPattern PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); Value vector = transferReadOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferReadOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // Contiguity check is valid on tensors only. if (!sourceType) return failure(); @@ -481,7 +481,7 @@ class FlattenContiguousRowMajorTransferReadPattern Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); MemRefType collapsedSourceType = - collapsedSource.getType().dyn_cast(); + dyn_cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstContiguousInnerDim + 1); SmallVector dimExprs{ @@ -494,7 +494,7 @@ class FlattenContiguousRowMajorTransferReadPattern loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.replaceOpWithNewOp( - transferReadOp, vector.getType().cast(), flatRead); + transferReadOp, cast(vector.getType()), flatRead); return success(); } }; @@ -511,9 +511,9 @@ class FlattenContiguousRowMajorTransferWritePattern PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); Value vector = transferWriteOp.getVector(); - VectorType vectorType = vector.getType().cast(); + VectorType vectorType = cast(vector.getType()); Value source = transferWriteOp.getSource(); - MemRefType sourceType = source.getType().dyn_cast(); + MemRefType sourceType = dyn_cast(source.getType()); // Contiguity check is valid on tensors only. if (!sourceType) return failure(); @@ -541,7 +541,7 @@ class FlattenContiguousRowMajorTransferWritePattern Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); MemRefType collapsedSourceType = - collapsedSource.getType().cast(); + cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstContiguousInnerDim + 1); SmallVector dimExprs{ @@ -610,7 +610,7 @@ class RewriteScalarExtractElementOfTransferRead *getConstantIntValue(ofr)); } } - if (xferOp.getSource().getType().isa()) { + if (isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), newIndices); } else { @@ -637,7 +637,7 @@ class RewriteScalarExtractOfTransferRead LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { // Only match scalar extracts. - if (extractOp.getType().isa()) + if (isa(extractOp.getType())) return failure(); auto xferOp = extractOp.getVector().getDefiningOp(); if (!xferOp) @@ -660,7 +660,7 @@ class RewriteScalarExtractOfTransferRead SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); for (const auto &it : llvm::enumerate(extractOp.getPosition())) { - int64_t offset = it.value().cast().getInt(); + int64_t offset = cast(it.value()).getInt(); int64_t idx = newIndices.size() - extractOp.getPosition().size() + it.index(); OpFoldResult ofr = affine::makeComposedFoldedAffineApply( @@ -673,7 +673,7 @@ class RewriteScalarExtractOfTransferRead extractOp.getLoc(), *getConstantIntValue(ofr)); } } - if (xferOp.getSource().getType().isa()) { + if (isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp(extractOp, xferOp.getSource(), newIndices); } else { @@ -714,7 +714,7 @@ class RewriteScalarWrite : public OpRewritePattern { xferOp.getVector(), pos); } // Construct a scalar store. - if (xferOp.getSource().getType().isa()) { + if (isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp( xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); } else { @@ -732,12 +732,12 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, // Run store to load forwarding first since it can expose more dead store // opportunity. rootOp->walk([&](vector::TransferReadOp read) { - if (read.getShapedType().isa()) + if (isa(read.getShapedType())) opt.storeToLoadForwarding(read); }); opt.removeDeadOp(); rootOp->walk([&](vector::TransferWriteOp write) { - if (write.getShapedType().isa()) + if (isa(write.getShapedType())) opt.deadStoreOp(write); }); opt.removeDeadOp(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index 34a7ce1..6dacb1e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -190,7 +190,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Location loc = xferOp.getLoc(); int64_t memrefRank = xferOp.getShapedType().getRank(); // TODO: relax this precondition, will require rank-reducing subviews. - assert(memrefRank == alloc.getType().cast().getRank() && + assert(memrefRank == cast(alloc.getType()).getRank() && "Expected memref rank to match the alloc rank"); ValueRange leadingIndices = xferOp.indices().take_front(xferOp.getLeadingShapedRank()); @@ -571,8 +571,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( } MemRefType compatibleMemRefType = - getCastCompatibleMemRefType(xferOp.getShapedType().cast(), - alloc.getType().cast()); + getCastCompatibleMemRefType(cast(xferOp.getShapedType()), + cast(alloc.getType())); if (!compatibleMemRefType) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 44f3a10..d634d6a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -93,9 +93,9 @@ struct ShapeCastOpFolder : public OpRewritePattern { PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = - shapeCastOp.getSource().getType().dyn_cast_or_null(); + dyn_cast_or_null(shapeCastOp.getSource().getType()); auto resultVectorType = - shapeCastOp.getResult().getType().dyn_cast_or_null(); + dyn_cast_or_null(shapeCastOp.getResult().getType()); if (!sourceVectorType || !resultVectorType) return failure(); @@ -105,7 +105,7 @@ struct ShapeCastOpFolder : public OpRewritePattern { if (!sourceShapeCastOp) return failure(); auto operandSourceVectorType = - sourceShapeCastOp.getSource().getType().cast(); + cast(sourceShapeCastOp.getSource().getType()); auto operandResultVectorType = sourceShapeCastOp.getType(); // Check if shape cast operations invert each other. @@ -342,7 +342,7 @@ struct CombineContractBroadcast if (!broadcast) continue; // contractionOp can only take vector as operands. - auto srcType = broadcast.getSourceType().dyn_cast(); + auto srcType = dyn_cast(broadcast.getSourceType()); if (!srcType || srcType.getRank() == broadcast.getResultVectorType().getRank()) continue; @@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast return failure(); Type castResTy = getElementTypeOrSelf(op->getResult(0)); - if (auto vecTy = bcastOp.getSourceType().dyn_cast()) + if (auto vecTy = dyn_cast(bcastOp.getSourceType())) castResTy = VectorType::get(vecTy.getShape(), castResTy); auto *castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), @@ -530,7 +530,7 @@ struct ReorderElementwiseOpsOnTranspose final // This is a constant. Create a reverse transpose op for it. auto vectorType = VectorType::get( srcType.getShape(), - operand.getType().cast().getElementType()); + cast(operand.getType()).getElementType()); srcValues.push_back(rewriter.create( operand.getLoc(), vectorType, operand, rewriter.getI64ArrayAttr(invOrder))); @@ -539,7 +539,7 @@ struct ReorderElementwiseOpsOnTranspose final auto vectorType = VectorType::get( srcType.getShape(), - op->getResultTypes()[0].cast().getElementType()); + cast(op->getResultTypes()[0]).getElementType()); Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, vectorType, op->getAttrs()); @@ -693,7 +693,7 @@ struct BubbleDownBitCastForStridedSliceExtract } SmallVector dims = - llvm::to_vector<4>(extractOp.getType().cast().getShape()); + llvm::to_vector<4>(cast(extractOp.getType()).getShape()); dims.back() = dims.back() / expandRatio; VectorType newExtractType = VectorType::get(dims, castSrcType.getElementType()); @@ -996,7 +996,7 @@ public: LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { auto dstType = op.getType(); - if (dstType.cast().isScalable()) + if (cast(dstType).isScalable()) return failure(); int64_t rank = dstType.getRank(); if (rank > 1) @@ -1026,7 +1026,7 @@ class DropInnerMostUnitDims : public OpRewritePattern { if (readOp.getMask()) return failure(); - auto srcType = readOp.getSource().getType().dyn_cast(); + auto srcType = dyn_cast(readOp.getSource().getType()); if (!srcType || !srcType.hasStaticShape()) return failure(); @@ -1060,13 +1060,13 @@ class DropInnerMostUnitDims : public OpRewritePattern { MemRefType resultMemrefType; MemRefLayoutAttrInterface layout = srcType.getLayout(); - if (layout.isa() && layout.isIdentity()) { + if (isa(layout) && layout.isIdentity()) { resultMemrefType = MemRefType::get( srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), nullptr, srcType.getMemorySpace()); } else { MemRefLayoutAttrInterface updatedLayout; - if (auto strided = layout.dyn_cast()) { + if (auto strided = dyn_cast(layout)) { auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); updatedLayout = StridedLayoutAttr::get(strided.getContext(), @@ -1099,7 +1099,7 @@ class DropInnerMostUnitDims : public OpRewritePattern { loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), strides); auto permMap = getTransferMinorIdentityMap( - rankedReducedView.getType().cast(), resultTargetVecType); + cast(rankedReducedView.getType()), resultTargetVecType); Value result = rewriter.create( loc, resultTargetVecType, rankedReducedView, readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index f56e7cf..5eee318 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -316,7 +316,7 @@ struct UnrollContractionPattern auto targetShape = getTargetShape(options, contractOp); if (!targetShape) return failure(); - auto dstVecType = contractOp.getResultType().cast(); + auto dstVecType = cast(contractOp.getResultType()); SmallVector originalSize = *contractOp.getShapeForUnroll(); Location loc = contractOp.getLoc(); @@ -491,7 +491,7 @@ struct UnrollElementwisePattern : public RewritePattern { auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); - auto dstVecType = op->getResult(0).getType().cast(); + auto dstVecType = cast(op->getResult(0).getType()); SmallVector originalSize = *cast(op).getShapeForUnroll(); SmallVector ratio = *computeShapeRatio(originalSize, *targetShape); @@ -512,7 +512,7 @@ struct UnrollElementwisePattern : public RewritePattern { getVectorOffset(ratioStrides, i, *targetShape); SmallVector extractOperands; for (OpOperand &operand : op->getOpOperands()) { - auto vecType = operand.get().getType().template dyn_cast(); + auto vecType = dyn_cast(operand.get().getType()); if (!vecType) { extractOperands.push_back(operand.get()); continue; diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index e77a13a..a1451fb 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -36,9 +36,9 @@ using namespace mlir; /// the type of `source`. Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - if (source.getType().isa()) + if (isa(source.getType())) return b.createOrFold(loc, source, dim); - if (source.getType().isa()) + if (isa(source.getType())) return b.createOrFold(loc, source, dim); llvm_unreachable("Expected MemRefType or TensorType"); } @@ -89,7 +89,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) { SmallVector transp; for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); + transp.push_back(cast(attr).getInt()); // Check whether the two source vector dimensions that are greater than one // must be transposed with each other so that we can apply one of the 2-D @@ -223,7 +223,7 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op, } return false; } else if (op.getNumResults() == 1) { - if (auto v = op.getResult(0).getType().dyn_cast()) { + if (auto v = dyn_cast(op.getResult(0).getType())) { superVectorType = v; } else { // Not a vector type. diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index e806db7..b36f297 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -22,11 +22,11 @@ using namespace mlir::x86vector; /// Extracts the "main" vector element type from the given X86Vector operation. template static Type getSrcVectorElementType(OpTy op) { - return op.getSrc().getType().template cast().getElementType(); + return cast(op.getSrc().getType()).getElementType(); } template <> Type getSrcVectorElementType(Vp2IntersectOp op) { - return op.getA().getType().template cast().getElementType(); + return cast(op.getA().getType()).getElementType(); } namespace { diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index 18405a9..9464ce8 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -288,30 +288,27 @@ template Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - auto resultType = mainFunction.getFunctionType() - .cast() - .getReturnType() - .dyn_cast(); + auto resultType = dyn_cast( + cast(mainFunction.getFunctionType()) + .getReturnType()); if (!resultType || resultType.getWidth() != 32) return makeStringError("only single i32 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - auto resultType = mainFunction.getFunctionType() - .cast() - .getReturnType() - .dyn_cast(); + auto resultType = dyn_cast( + cast(mainFunction.getFunctionType()) + .getReturnType()); if (!resultType || resultType.getWidth() != 64) return makeStringError("only single i64 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - if (!mainFunction.getFunctionType() - .cast() - .getReturnType() - .isa()) + if (!isa( + cast(mainFunction.getFunctionType()) + .getReturnType())) return makeStringError("only single f32 function result supported"); return Error::success(); } @@ -324,8 +321,7 @@ Error compileAndExecuteSingleReturnFunction( if (!mainFunction || mainFunction.isExternal()) return makeStringError("entry point not found"); - if (mainFunction.getFunctionType() - .cast() + if (cast(mainFunction.getFunctionType()) .getNumParams() != 0) return makeStringError("function inputs not supported"); diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index cc04fa5..e335e15 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -37,7 +37,7 @@ using namespace mlir; static unsigned getIndexBitwidth(DataLayoutEntryListRef params) { if (params.empty()) return 64; - auto attr = params.front().getValue().cast(); + auto attr = cast(params.front().getValue()); return attr.getValue().getZExtValue(); } @@ -51,10 +51,10 @@ mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout, unsigned mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout, DataLayoutEntryListRef params) { - if (type.isa()) + if (isa(type)) return type.getIntOrFloatBitWidth(); - if (auto ctype = type.dyn_cast()) { + if (auto ctype = dyn_cast(type)) { auto et = ctype.getElementType(); auto innerAlignment = getDefaultPreferredAlignment(et, dataLayout, params) * 8; @@ -66,7 +66,7 @@ unsigned mlir::detail::getDefaultTypeSizeInBits(Type type, } // Index is an integer of some bitwidth. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeSizeInBits( IntegerType::get(type.getContext(), getIndexBitwidth(params))); @@ -75,12 +75,12 @@ unsigned mlir::detail::getDefaultTypeSizeInBits(Type type, // there is no bit-packing at the moment element sizes are taken in bytes and // multiplied with 8 bits. // TODO: make this extensible. - if (auto vecType = type.dyn_cast()) + if (auto vecType = dyn_cast(type)) return vecType.getNumElements() / vecType.getShape().back() * llvm::PowerOf2Ceil(vecType.getShape().back()) * dataLayout.getTypeSize(vecType.getElementType()) * 8; - if (auto typeInterface = type.dyn_cast()) + if (auto typeInterface = dyn_cast(type)) return typeInterface.getTypeSizeInBits(dataLayout, params); reportMissingDataLayout(type); @@ -104,7 +104,7 @@ findEntryForIntegerType(IntegerType intType, static unsigned extractABIAlignment(DataLayoutEntryInterface entry) { auto values = - entry.getValue().cast().getValues(); + cast(entry.getValue()).getValues(); return *values.begin() / 8u; } @@ -134,24 +134,24 @@ unsigned mlir::detail::getDefaultABIAlignment( Type type, const DataLayout &dataLayout, ArrayRef params) { // Natural alignment is the closest power-of-two number above. - if (type.isa()) + if (isa(type)) return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type)); - if (auto fltType = type.dyn_cast()) + if (auto fltType = dyn_cast(type)) return getFloatTypeABIAlignment(fltType, dataLayout, params); // Index is an integer of some bitwidth. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeABIAlignment( IntegerType::get(type.getContext(), getIndexBitwidth(params))); - if (auto intType = type.dyn_cast()) + if (auto intType = dyn_cast(type)) return getIntegerTypeABIAlignment(intType, params); - if (auto ctype = type.dyn_cast()) + if (auto ctype = dyn_cast(type)) return getDefaultABIAlignment(ctype.getElementType(), dataLayout, params); - if (auto typeInterface = type.dyn_cast()) + if (auto typeInterface = dyn_cast(type)) return typeInterface.getABIAlignment(dataLayout, params); reportMissingDataLayout(type); @@ -159,7 +159,7 @@ unsigned mlir::detail::getDefaultABIAlignment( static unsigned extractPreferredAlignment(DataLayoutEntryInterface entry) { auto values = - entry.getValue().cast().getValues(); + cast(entry.getValue()).getValues(); return *std::next(values.begin(), values.size() - 1) / 8u; } @@ -187,27 +187,27 @@ unsigned mlir::detail::getDefaultPreferredAlignment( Type type, const DataLayout &dataLayout, ArrayRef params) { // Preferred alignment is same as natural for floats and vectors. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeABIAlignment(type); - if (auto fltType = type.dyn_cast()) + if (auto fltType = dyn_cast(type)) return getFloatTypePreferredAlignment(fltType, dataLayout, params); // Preferred alignment is the closest power-of-two number above for integers // (ABI alignment may be smaller). - if (auto intType = type.dyn_cast()) + if (auto intType = dyn_cast(type)) return getIntegerTypePreferredAlignment(intType, dataLayout, params); - if (type.isa()) { + if (isa(type)) { return dataLayout.getTypePreferredAlignment( IntegerType::get(type.getContext(), getIndexBitwidth(params))); } - if (auto ctype = type.dyn_cast()) + if (auto ctype = dyn_cast(type)) return getDefaultPreferredAlignment(ctype.getElementType(), dataLayout, params); - if (auto typeInterface = type.dyn_cast()) + if (auto typeInterface = dyn_cast(type)) return typeInterface.getPreferredAlignment(dataLayout, params); reportMissingDataLayout(type); @@ -232,7 +232,7 @@ mlir::detail::getDefaultStackAlignment(DataLayoutEntryInterface entry) { if (entry == DataLayoutEntryInterface()) return 0; - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); return value.getValue().getZExtValue(); } @@ -543,19 +543,19 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec, for (const auto &kvp : types) { auto sampleType = kvp.second.front().getKey().get(); - if (sampleType.isa()) { + if (isa(sampleType)) { assert(kvp.second.size() == 1 && "expected one data layout entry for non-parametric 'index' type"); - if (!kvp.second.front().getValue().isa()) + if (!isa(kvp.second.front().getValue())) return emitError(loc) << "expected integer attribute in the data layout entry for " << sampleType; continue; } - if (sampleType.isa()) { + if (isa(sampleType)) { for (DataLayoutEntryInterface entry : kvp.second) { - auto value = entry.getValue().dyn_cast(); + auto value = dyn_cast(entry.getValue()); if (!value || !value.getElementType().isSignlessInteger(32)) { emitError(loc) << "expected a dense i32 elements attribute in the " "data layout entry " @@ -587,7 +587,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec, if (isa(&sampleType.getDialect())) return emitError(loc) << "unexpected data layout for a built-in type"; - auto dlType = sampleType.dyn_cast(); + auto dlType = dyn_cast(sampleType); if (!dlType) return emitError(loc) << "data layout specified for a type that does not support it"; diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp index aff6a8f..a9bab23 100644 --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -29,9 +29,9 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { SmallVector outputBufferOperands, outputTensorOperands; for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) { Type type = operand->get().getType(); - if (type.isa()) { + if (isa(type)) { outputBufferOperands.push_back(operand); - } else if (type.isa()) { + } else if (isa(type)) { outputTensorOperands.push_back(operand); } else { return op->emitOpError("expected that operand #") diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp index 9248b11..cc31104 100644 --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -30,7 +30,7 @@ const APInt &ConstantIntRanges::smax() const { return smaxVal; } unsigned ConstantIntRanges::getStorageBitwidth(Type type) { if (type.isIndex()) return IndexType::kInternalStorageBitWidth; - if (auto integerType = type.dyn_cast()) + if (auto integerType = dyn_cast(type)) return integerType.getWidth(); // Non-integer types have their bounds stored in width 0 `APInt`s. return 0; diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index ebb10e0..80ed2cc 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -36,7 +36,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op, // a correct result. int64_t resultIdx = 0; for (OpResult result : op->getResults()) { - auto shapedType = result.getType().dyn_cast(); + auto shapedType = dyn_cast(result.getType()); if (!shapedType) continue; if (!shapedType.hasRank()) { @@ -69,7 +69,7 @@ bool ShapeAdaptor::hasRank() const { if (val.isNull()) return false; if (auto t = val.dyn_cast()) - return t.cast().hasRank(); + return cast(t).hasRank(); if (val.is()) return true; return val.get()->hasRank(); @@ -79,7 +79,7 @@ Type ShapeAdaptor::getElementType() const { if (val.isNull()) return nullptr; if (auto t = val.dyn_cast()) - return t.cast().getElementType(); + return cast(t).getElementType(); if (val.is()) return nullptr; return val.get()->getElementType(); @@ -88,10 +88,10 @@ Type ShapeAdaptor::getElementType() const { void ShapeAdaptor::getDims(SmallVectorImpl &res) const { assert(hasRank()); if (auto t = val.dyn_cast()) { - ArrayRef vals = t.cast().getShape(); + ArrayRef vals = cast(t).getShape(); res.assign(vals.begin(), vals.end()); } else if (auto attr = val.dyn_cast()) { - auto dattr = attr.cast(); + auto dattr = cast(attr); res.clear(); res.reserve(dattr.size()); for (auto it : dattr.getValues()) @@ -111,9 +111,9 @@ void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { int64_t ShapeAdaptor::getDimSize(int index) const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getDimSize(index); + return cast(t).getDimSize(index); if (auto attr = val.dyn_cast()) - return attr.cast() + return cast(attr) .getValues()[index] .getSExtValue(); auto *stc = val.get(); @@ -123,9 +123,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const { int64_t ShapeAdaptor::getRank() const { assert(hasRank()); if (auto t = val.dyn_cast()) - return t.cast().getRank(); + return cast(t).getRank(); if (auto attr = val.dyn_cast()) - return attr.cast().size(); + return cast(attr).size(); return val.get()->getDims().size(); } @@ -134,9 +134,9 @@ bool ShapeAdaptor::hasStaticShape() const { return false; if (auto t = val.dyn_cast()) - return t.cast().hasStaticShape(); + return cast(t).hasStaticShape(); if (auto attr = val.dyn_cast()) { - auto dattr = attr.cast(); + auto dattr = cast(attr); for (auto index : dattr.getValues()) if (ShapedType::isDynamic(index.getSExtValue())) return false; @@ -150,10 +150,10 @@ int64_t ShapeAdaptor::getNumElements() const { assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); if (auto t = val.dyn_cast()) - return t.cast().getNumElements(); + return cast(t).getNumElements(); if (auto attr = val.dyn_cast()) { - auto dattr = attr.cast(); + auto dattr = cast(attr); int64_t num = 1; for (auto index : dattr.getValues()) { num *= index.getZExtValue(); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index 95fb785..1fbe42c 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -34,7 +34,7 @@ static std::optional getConstantIntValue(OpFoldResult ofr) { } // Case 2: Check for IntegerAttr. Attribute attr = ofr.dyn_cast(); - if (auto intAttr = attr.dyn_cast_or_null()) + if (auto intAttr = dyn_cast_or_null(attr)) return intAttr.getValue().getSExtValue(); return std::nullopt; } @@ -137,8 +137,8 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, std::optional dim) const { #ifndef NDEBUG assertValidValueDim(value, dim); - assert((value.isa() || - value.cast().getOwner()->isEntryBlock()) && + assert((isa(value) || + cast(value).getOwner()->isEntryBlock()) && "unstructured control flow is not supported"); #endif // NDEBUG @@ -149,7 +149,7 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, } static Operation *getOwnerOfValue(Value value) { - if (auto bbArg = value.dyn_cast()) + if (auto bbArg = dyn_cast(value)) return bbArg.getOwner()->getParentOp(); return value.getDefiningOp(); } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index eca0297..c8c4428 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -402,7 +402,7 @@ struct ByteCodeWriter { .Case( [](Type) { return PDLValue::Kind::Operation; }) .Case([](pdl::RangeType rangeTy) { - if (rangeTy.getElementType().isa()) + if (isa(rangeTy.getElementType())) return PDLValue::Kind::TypeRange; return PDLValue::Kind::ValueRange; }) @@ -538,11 +538,11 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; auto processRewriterValue = [&](Value val) { valueToMemIndex.try_emplace(val, index++); - if (pdl::RangeType rangeType = val.getType().dyn_cast()) { + if (pdl::RangeType rangeType = dyn_cast(val.getType())) { Type elementTy = rangeType.getElementType(); - if (elementTy.isa()) + if (isa(elementTy)) valueToRangeIndex.try_emplace(val, typeRangeIndex++); - else if (elementTy.isa()) + else if (isa(elementTy)) valueToRangeIndex.try_emplace(val, valueRangeIndex++); } }; @@ -611,13 +611,13 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, /*dummyValue*/ 0); // Check to see if this value is a range type. - if (auto rangeTy = value.getType().dyn_cast()) { + if (auto rangeTy = dyn_cast(value.getType())) { Type eleType = rangeTy.getElementType(); - if (eleType.isa()) + if (isa(eleType)) defRangeIt->second.opRangeIndex = 0; - else if (eleType.isa()) + else if (isa(eleType)) defRangeIt->second.typeRangeIndex = 0; - else if (eleType.isa()) + else if (isa(eleType)) defRangeIt->second.valueRangeIndex = 0; } }; @@ -792,14 +792,14 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op, #endif // Range results also need to append the range storage index. - if (result.getType().isa()) + if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); writer.append(result); } } void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { Value lhs = op.getLhs(); - if (lhs.getType().isa()) { + if (isa(lhs.getType())) { writer.append(OpCode::AreRangesEqual); writer.appendPDLValueKind(lhs); writer.append(op.getLhs(), op.getRhs(), op.getSuccessors()); @@ -945,7 +945,7 @@ void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetOperands, index.value_or(std::numeric_limits::max()), op.getInputOp()); - if (result.getType().isa()) + if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); else writer.append(std::numeric_limits::max()); @@ -965,7 +965,7 @@ void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetResults, index.value_or(std::numeric_limits::max()), op.getInputOp()); - if (result.getType().isa()) + if (isa(result.getType())) writer.append(getRangeStorageIndex(result)); else writer.append(std::numeric_limits::max()); @@ -979,7 +979,7 @@ void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { } void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { - if (op.getType().isa()) { + if (isa(op.getType())) { Value result = op.getResult(); writer.append(OpCode::GetValueRangeTypes, result, getRangeStorageIndex(result), op.getValue()); @@ -1016,7 +1016,7 @@ void Generator::generate(pdl_interp::SwitchOperandCountOp op, void Generator::generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer) { auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) { - return OperationName(attr.cast().getValue(), ctx); + return OperationName(cast(attr).getValue(), ctx); }); writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases, op.getSuccessors()); @@ -1566,7 +1566,7 @@ void ByteCodeExecutor::executeCheckTypes() { Attribute rhs = read(); LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); - selectJump(*lhs == rhs.cast().getAsValueRange()); + selectJump(*lhs == cast(rhs).getAsValueRange()); } void ByteCodeExecutor::executeContinue() { @@ -1581,7 +1581,7 @@ void ByteCodeExecutor::executeCreateConstantTypeRange() { LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); unsigned memIndex = read(); unsigned rangeIndex = read(); - ArrayAttr typesAttr = read().cast(); + ArrayAttr typesAttr = cast(read()); LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); assignRangeToMemory(typesAttr.getAsValueRange(), memIndex, @@ -1743,7 +1743,7 @@ void ByteCodeExecutor::executeGetAttributeType() { unsigned memIndex = read(); Attribute attr = read(); Type type; - if (auto typedAttr = attr.dyn_cast()) + if (auto typedAttr = dyn_cast(attr)) type = typedAttr.getType(); LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index b0e278c..b7eb1f0 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -190,7 +190,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, // the FuncOp. if (emitter.shouldDeclareVariablesAtTop()) { // Skip the assignment if the emitc.constant has no value. - if (auto oAttr = value.dyn_cast()) { + if (auto oAttr = dyn_cast(value)) { if (oAttr.getValue().empty()) return success(); } @@ -201,7 +201,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, } // Emit a variable declaration for an emitc.constant op without value. - if (auto oAttr = value.dyn_cast()) { + if (auto oAttr = dyn_cast(value)) { if (oAttr.getValue().empty()) // The semicolon gets printed by the emitOperation function. return emitter.emitVariableDeclaration(result, @@ -333,7 +333,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { os << callOp.getCallee(); auto emitArgs = [&](Attribute attr) -> LogicalResult { - if (auto t = attr.dyn_cast()) { + if (auto t = dyn_cast(attr)) { // Index attributes are treated specially as operand index. if (t.getType().isIndex()) { int64_t idx = t.getInt(); @@ -759,11 +759,11 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { }; // Print floating point attributes. - if (auto fAttr = attr.dyn_cast()) { + if (auto fAttr = dyn_cast(attr)) { printFloat(fAttr.getValue()); return success(); } - if (auto dense = attr.dyn_cast()) { + if (auto dense = dyn_cast(attr)) { os << '{'; interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); }); os << '}'; @@ -771,21 +771,19 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } // Print integer attributes. - if (auto iAttr = attr.dyn_cast()) { - if (auto iType = iAttr.getType().dyn_cast()) { + if (auto iAttr = dyn_cast(attr)) { + if (auto iType = dyn_cast(iAttr.getType())) { printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); return success(); } - if (auto iType = iAttr.getType().dyn_cast()) { + if (auto iType = dyn_cast(iAttr.getType())) { printInt(iAttr.getValue(), false); return success(); } } - if (auto dense = attr.dyn_cast()) { - if (auto iType = dense.getType() - .cast() - .getElementType() - .dyn_cast()) { + if (auto dense = dyn_cast(attr)) { + if (auto iType = dyn_cast( + cast(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, shouldMapToUnsigned(iType.getSignedness())); @@ -793,10 +791,8 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { os << '}'; return success(); } - if (auto iType = dense.getType() - .cast() - .getElementType() - .dyn_cast()) { + if (auto iType = dyn_cast( + cast(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, false); }); @@ -806,13 +802,13 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } // Print opaque attributes. - if (auto oAttr = attr.dyn_cast()) { + if (auto oAttr = dyn_cast(attr)) { os << oAttr.getValue(); return success(); } // Print symbolic reference attributes. - if (auto sAttr = attr.dyn_cast()) { + if (auto sAttr = dyn_cast(attr)) { if (sAttr.getNestedReferences().size() > 1) return emitError(loc, "attribute has more than 1 nested reference"); os << sAttr.getRootReference().getValue(); @@ -820,7 +816,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } // Print type attributes. - if (auto type = attr.dyn_cast()) + if (auto type = dyn_cast(attr)) return emitType(loc, type.getValue()); return emitError(loc, "cannot emit attribute: ") << attr; @@ -957,7 +953,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { } LogicalResult CppEmitter::emitType(Location loc, Type type) { - if (auto iType = type.dyn_cast()) { + if (auto iType = dyn_cast(type)) { switch (iType.getWidth()) { case 1: return (os << "bool"), success(); @@ -973,7 +969,7 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { return emitError(loc, "cannot emit integer type ") << type; } } - if (auto fType = type.dyn_cast()) { + if (auto fType = dyn_cast(type)) { switch (fType.getWidth()) { case 32: return (os << "float"), success(); @@ -983,9 +979,9 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { return emitError(loc, "cannot emit float type ") << type; } } - if (auto iType = type.dyn_cast()) + if (auto iType = dyn_cast(type)) return (os << "size_t"), success(); - if (auto tType = type.dyn_cast()) { + if (auto tType = dyn_cast(type)) { if (!tType.hasRank()) return emitError(loc, "cannot emit unranked tensor type"); if (!tType.hasStaticShape()) @@ -1001,13 +997,13 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { os << ">"; return success(); } - if (auto tType = type.dyn_cast()) + if (auto tType = dyn_cast(type)) return emitTupleType(loc, tType.getTypes()); - if (auto oType = type.dyn_cast()) { + if (auto oType = dyn_cast(type)) { os << oType.getValue(); return success(); } - if (auto pType = type.dyn_cast()) { + if (auto pType = dyn_cast(type)) { if (failed(emitType(loc, pType.getPointee()))) return failure(); os << "*"; diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp index f409aa4..87d02f8 100644 --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -21,8 +21,8 @@ using namespace mlir::LLVM::detail; /// A utility walker that interrupts if the operation has valid debug /// information. static WalkResult interruptIfValidLocation(Operation *op) { - return op->getLoc().isa() ? WalkResult::advance() - : WalkResult::interrupt(); + return isa(op->getLoc()) ? WalkResult::advance() + : WalkResult::interrupt(); } DebugTranslation::DebugTranslation(Operation *module, llvm::Module &llvmModule) @@ -45,7 +45,7 @@ DebugTranslation::DebugTranslation(Operation *module, llvm::Module &llvmModule) if (auto targetTripleAttr = module->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) { auto targetTriple = - llvm::Triple(targetTripleAttr.cast().getValue()); + llvm::Triple(cast(targetTripleAttr).getValue()); if (targetTriple.isKnownWindowsMSVCEnvironment()) { // Dwarf debugging files will be generated by default, unless "CodeView" // is set explicitly. Windows/MSVC should use CodeView instead. @@ -68,8 +68,8 @@ void DebugTranslation::translate(LLVMFuncOp func, llvm::Function &llvmFunc) { const bool hasCallWithoutDebugInfo = func.walk([&](LLVM::CallOp call) { return call.getLoc()->walk([](Location l) { - return l.isa() ? WalkResult::interrupt() - : WalkResult::advance(); + return isa(l) ? WalkResult::interrupt() + : WalkResult::advance(); }); }) .wasInterrupted(); @@ -273,7 +273,7 @@ const llvm::DILocation * DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope, const llvm::DILocation *inlinedAt) { // LLVM doesn't have a representation for unknown. - if (!scope || loc.isa()) + if (!scope || isa(loc)) return nullptr; // Check for a cached instance. @@ -282,12 +282,12 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope, return existingIt->second; const llvm::DILocation *llvmLoc = nullptr; - if (auto callLoc = loc.dyn_cast()) { + if (auto callLoc = dyn_cast(loc)) { // For callsites, the caller is fed as the inlinedAt for the callee. const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt); llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc); - } else if (auto fileLoc = loc.dyn_cast()) { + } else if (auto fileLoc = dyn_cast(loc)) { llvm::DILocalScope *locationScope = scope; // Only construct a new DIFile when no local scope is present. This // prioritizes existing DI information when it's present. @@ -300,12 +300,12 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope, fileLoc.getColumn(), locationScope, const_cast(inlinedAt)); - } else if (auto fusedLoc = loc.dyn_cast()) { + } else if (auto fusedLoc = dyn_cast(loc)) { ArrayRef locations = fusedLoc.getLocations(); // Check for a scope encoded with the location. if (auto scopedAttr = - fusedLoc.getMetadata().dyn_cast_or_null()) + dyn_cast_or_null(fusedLoc.getMetadata())) scope = translate(scopedAttr); // For fused locations, merge each of the nodes. @@ -315,10 +315,10 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope, llvmLoc, translateLoc(locIt, scope, inlinedAt)); } - } else if (auto nameLoc = loc.dyn_cast()) { + } else if (auto nameLoc = dyn_cast(loc)) { llvmLoc = translateLoc(nameLoc.getChildLoc(), scope, inlinedAt); - } else if (auto opaqueLoc = loc.dyn_cast()) { + } else if (auto opaqueLoc = dyn_cast(loc)) { llvmLoc = translateLoc(opaqueLoc.getFallbackLocation(), scope, inlinedAt); } else { llvm_unreachable("unknown location kind"); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 5c98f92..c12d7f5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -231,9 +231,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, Attribute attr = it.value(); if (!attr) continue; - DictionaryAttr dAttr = attr.cast(); + DictionaryAttr dAttr = cast(attr); TypeAttr tAttr = - dAttr.get(InlineAsmOp::getElementTypeAttrName()).cast(); + cast(dAttr.get(InlineAsmOp::getElementTypeAttrName())); llvm::AttrBuilder b(moduleTranslation.getLLVMContext()); llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue()); b.addTypeAttr(llvm::Attribute::ElementType, ty); diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index d7f1bb6..eec8456 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -162,7 +162,7 @@ public: ->addOperand(llvmMetadataNode); }; if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { - if (!attribute.getValue().dyn_cast()) + if (!dyn_cast(attribute.getValue())) return failure(); SmallVector values = extractFromI64ArrayAttr(attribute.getValue()); @@ -172,7 +172,7 @@ public: if (values.size() > 2) generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { - if (!attribute.getValue().dyn_cast()) + if (!dyn_cast(attribute.getValue())) return failure(); SmallVector values = extractFromI64ArrayAttr(attribute.getValue()); @@ -183,10 +183,10 @@ public: generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); } else if (attribute.getName() == NVVM::NVVMDialect::getMinctasmAttrName()) { - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); generateMetadata(value.getInt(), "minctasm"); } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) { - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); generateMetadata(value.getInt(), "maxnreg"); } else if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp index 91ff174..392d34c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -99,7 +99,7 @@ processOperands(llvm::IRBuilderBase &builder, llvm::Value *dataPtr; llvm::Value *dataSize; - if (data.getType().isa()) { + if (isa(data.getType())) { dataPtrBase = dataValue; dataPtr = dataValue; dataSize = accBuilder->getSizeInBytes(dataValue); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 84c39c0..750f715 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -367,7 +367,7 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, if (criticalOp.getNameAttr()) { // The verifiers in OpenMP Dialect guarentee that all the pointers are // non-null - auto symbolRef = criticalOp.getNameAttr().cast(); + auto symbolRef = cast(criticalOp.getNameAttr()); auto criticalDeclareOp = SymbolTable::lookupNearestSymbolFrom(criticalOp, symbolRef); @@ -389,7 +389,7 @@ static omp::ReductionDeclareOp findReductionDecl(omp::WsLoopOp container, for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) { if (container.getReductionVars()[i] != reduction.getAccumulator()) continue; - reductionSymbol = (*container.getReductions())[i].cast(); + reductionSymbol = cast((*container.getReductions())[i]); break; } assert(reductionSymbol && @@ -705,7 +705,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) { llvm::omp::RTLDependenceKindTy type; switch ( - std::get<1>(dep).cast().getValue()) { + cast(std::get<1>(dep)).getValue()) { case mlir::omp::ClauseTaskDepend::taskdependin: type = llvm::omp::RTLDependenceKindTy::DepIn; break; @@ -1379,7 +1379,7 @@ static LogicalResult processMapOperand( llvm::Value *mapOpPtr; llvm::Value *mapOpSize; - if (mapOp.getType().isa()) { + if (isa(mapOp.getType())) { mapOpPtrBase = mapOpValue; mapOpPtr = mapOpValue; mapOpSize = ompBuilder->getSizeInBytes(mapOpValue); @@ -1410,7 +1410,7 @@ static LogicalResult processMapOperand( {builder.getInt32(0), builder.getInt32(index)}); builder.CreateStore(mapOpSize, sizeGEP); - mapTypeFlags.push_back(mapTypeOp.dyn_cast().getInt()); + mapTypeFlags.push_back(dyn_cast(mapTypeOp).getInt()); llvm::Constant *mapName = mlir::LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder); mapNames.push_back(mapName); @@ -1445,7 +1445,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = dataOp.getMapOperands().size(); @@ -1464,7 +1464,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = enterDataOp.getMapOperands().size(); @@ -1483,7 +1483,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, if (auto constOp = mlir::dyn_cast( devId.getDefiningOp())) if (auto intAttr = - constOp.getValue().dyn_cast()) + dyn_cast(constOp.getValue())) deviceID = intAttr.getInt(); numMapOperands = exitDataOp.getMapOperands().size(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp index fd739cf..2145b95 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp @@ -16,7 +16,7 @@ llvm::Constant * mlir::LLVM::createSourceLocStrFromLocation(Location loc, llvm::OpenMPIRBuilder &builder, StringRef name, uint32_t &strLen) { - if (auto fileLoc = loc.dyn_cast()) { + if (auto fileLoc = dyn_cast(loc)) { StringRef fileName = fileLoc.getFilename(); unsigned lineNo = fileLoc.getLine(); unsigned colNo = fileLoc.getColumn(); @@ -32,7 +32,7 @@ llvm::Constant * mlir::LLVM::createMappingInformation(Location loc, llvm::OpenMPIRBuilder &builder) { uint32_t strLen; - if (auto nameLoc = loc.dyn_cast()) { + if (auto nameLoc = dyn_cast(loc)) { StringRef name = nameLoc.getName(); return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name, strLen); diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp index 826bac9..5ab7028 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -109,7 +109,7 @@ public: auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); if (!value) return failure(); @@ -125,7 +125,7 @@ public: auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); if (!value) return failure(); @@ -142,7 +142,7 @@ public: auto func = dyn_cast(op); if (!func) return failure(); - auto value = attribute.getValue().dyn_cast(); + auto value = dyn_cast(attribute.getValue()); if (!value) return failure(); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp index 8e6906a..d3c5bb4 100644 --- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp @@ -190,7 +190,7 @@ void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) { void LoopAnnotationConversion::convertLocation(FusedLoc location) { auto localScopeAttr = - location.getMetadata().dyn_cast_or_null(); + dyn_cast_or_null(location.getMetadata()); if (!localScopeAttr) return; auto *localScope = dyn_cast( diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index ba37893..f4ea801 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -623,7 +623,7 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, /// Returns if `type` is a scalar integer or floating-point type. static bool isScalarType(Type type) { - return type.isa(); + return isa(type); } /// Returns `type` if it is a builtin integer or floating-point vector type that @@ -970,7 +970,7 @@ FailureOr ModuleImport::convertConstant(llvm::Constant *constant) { // Convert constants that can be represented as attributes. if (Attribute attr = getConstantAsAttr(constant)) { Type type = convertType(constant->getType()); - if (auto symbolRef = attr.dyn_cast()) { + if (auto symbolRef = dyn_cast(attr)) { return builder.create(loc, type, symbolRef.getValue()) .getResult(); } @@ -1047,7 +1047,7 @@ FailureOr ModuleImport::convertConstant(llvm::Constant *constant) { // Generate an UndefOp as root value and insert the aggregate elements. Type rootType = convertType(constant->getType()); - bool isArrayOrStruct = rootType.isa(); + bool isArrayOrStruct = isa(rootType); assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) && "unrecognized aggregate type"); Value root = builder.create(loc, rootType); @@ -1609,7 +1609,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { clearBlockAndValueMapping(); auto functionType = - convertType(func->getFunctionType()).dyn_cast(); + dyn_cast(convertType(func->getFunctionType())); if (func->isIntrinsic() && iface.isConvertibleIntrinsic(func->getIntrinsicID())) return success(); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index f8854d7..9d796c5 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -73,7 +73,7 @@ translateDataLayout(DataLayoutSpecInterface attribute, if (!key) continue; if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) { - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); bool isLittleEndian = value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle; layoutStream << "-" << (isLittleEndian ? "e" : "E"); @@ -81,7 +81,7 @@ translateDataLayout(DataLayoutSpecInterface attribute, continue; } if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) { - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); uint64_t space = value.getValue().getZExtValue(); // Skip the default address space. if (space == 0) @@ -91,7 +91,7 @@ translateDataLayout(DataLayoutSpecInterface attribute, continue; } if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) { - auto value = entry.getValue().cast(); + auto value = cast(entry.getValue()); uint64_t alignment = value.getValue().getZExtValue(); // Skip the default stack alignment. if (alignment == 0) @@ -112,14 +112,14 @@ translateDataLayout(DataLayoutSpecInterface attribute, if (!type) continue; // Data layout for the index type is irrelevant at this point. - if (type.isa()) + if (isa(type)) continue; layoutStream << "-"; LogicalResult result = llvm::TypeSwitch(type) .Case([&](Type type) -> LogicalResult { - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getSignedness() != IntegerType::Signless) return emitError(*loc) << "unsupported data layout for non-signless integer " @@ -250,7 +250,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, // Compute the shape of all dimensions but the innermost. Note that the // innermost dimension may be that of the vector element type. - bool hasVectorElementType = type.getElementType().isa(); + bool hasVectorElementType = isa(type.getElementType()); unsigned numAggregates = denseElementsAttr.getNumElements() / (hasVectorElementType ? 1 @@ -261,7 +261,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, // Handle the case of vector splat, LLVM has special support for it. if (denseElementsAttr.isSplat() && - (type.isa() || hasVectorElementType)) { + (isa(type) || hasVectorElementType)) { llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( innermostLLVMType, denseElementsAttr.getSplatValue(), loc, moduleTranslation); @@ -277,8 +277,8 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, // In case of non-splat, create a constructor for the innermost constant from // a piece of raw data. std::function buildCstData; - if (type.isa()) { - auto vectorElementType = type.getElementType().dyn_cast(); + if (isa(type)) { + auto vectorElementType = dyn_cast(type.getElementType()); if (vectorElementType && vectorElementType.getRank() == 1) { buildCstData = [&](StringRef data) { return llvm::ConstantDataVector::getRaw( @@ -290,7 +290,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, innermostLLVMType); }; } - } else if (type.isa()) { + } else if (isa(type)) { buildCstData = [&](StringRef data) { return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), innermostLLVMType); @@ -326,7 +326,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( if (!attr) return llvm::UndefValue::get(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = dyn_cast(attr); if (!arrayAttr || arrayAttr.size() != 2) { emitError(loc, "expected struct type to be a complex number"); return nullptr; @@ -344,11 +344,11 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. - if (auto intAttr = attr.dyn_cast()) + if (auto intAttr = dyn_cast(attr)) return llvm::ConstantInt::get( llvmType, intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); - if (auto floatAttr = attr.dyn_cast()) { + if (auto floatAttr = dyn_cast(attr)) { if (llvmType != llvm::Type::getFloatingPointTy(llvmType->getContext(), floatAttr.getValue().getSemantics())) { @@ -357,10 +357,10 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( } return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); } - if (auto funcAttr = attr.dyn_cast()) + if (auto funcAttr = dyn_cast(attr)) return llvm::ConstantExpr::getBitCast( moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType); - if (auto splatAttr = attr.dyn_cast()) { + if (auto splatAttr = dyn_cast(attr)) { llvm::Type *elementType; uint64_t numElements; bool isScalable = false; @@ -401,13 +401,13 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( // Try using raw elements data if possible. if (llvm::Constant *result = - convertDenseElementsAttr(loc, attr.dyn_cast(), + convertDenseElementsAttr(loc, dyn_cast(attr), llvmType, moduleTranslation)) { return result; } // Fall back to element-by-element construction otherwise. - if (auto elementsAttr = attr.dyn_cast()) { + if (auto elementsAttr = dyn_cast(attr)) { assert(elementsAttr.getShapedType().hasStaticShape()); assert(!elementsAttr.getShapedType().getShape().empty() && "unexpected empty elements attribute shape"); @@ -428,7 +428,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( return result; } - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = dyn_cast(attr)) { return llvm::ConstantDataArray::get( moduleTranslation.getLLVMContext(), ArrayRef{stringAttr.getValue().data(), @@ -685,7 +685,7 @@ LogicalResult ModuleTranslation::convertGlobals() { if (op.getValueOrNull()) { // String attributes are treated separately because they cannot appear as // in-function constants and are thus not supported by getLLVMConstant. - if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { + if (auto strAttr = dyn_cast_or_null(op.getValueOrNull())) { cst = llvm::ConstantDataArray::getString( llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); type = cst->getType(); @@ -763,11 +763,10 @@ LogicalResult ModuleTranslation::convertGlobals() { ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors; for (auto symbolAndPriority : range) { llvm::Function *f = lookupFunction( - std::get<0>(symbolAndPriority).cast().getValue()); - appendGlobalFn( - *llvmModule, f, - std::get<1>(symbolAndPriority).cast().getInt(), - /*Data=*/nullptr); + cast(std::get<0>(symbolAndPriority)).getValue()); + appendGlobalFn(*llvmModule, f, + cast(std::get<1>(symbolAndPriority)).getInt(), + /*Data=*/nullptr); } } @@ -830,20 +829,20 @@ forwardPassthroughAttributes(Location loc, std::optional attributes, return success(); for (Attribute attr : *attributes) { - if (auto stringAttr = attr.dyn_cast()) { + if (auto stringAttr = dyn_cast(attr)) { if (failed( checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) return failure(); continue; } - auto arrayAttr = attr.dyn_cast(); + auto arrayAttr = dyn_cast(attr); if (!arrayAttr || arrayAttr.size() != 2) return emitError(loc) << "expected 'passthrough' to contain string or array attributes"; - auto keyAttr = arrayAttr[0].dyn_cast(); - auto valueAttr = arrayAttr[1].dyn_cast(); + auto keyAttr = dyn_cast(arrayAttr[0]); + auto valueAttr = dyn_cast(arrayAttr[1]); if (!keyAttr || !valueAttr) return emitError(loc) << "expected arrays within 'passthrough' to contain two strings"; @@ -985,7 +984,7 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() { // Convert result attributes. if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) { - DictionaryAttr resultAttrs = allResultAttrs[0].cast(); + DictionaryAttr resultAttrs = cast(allResultAttrs[0]); llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs)); } @@ -1133,7 +1132,7 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, return; } - SymbolRefAttr tagRef = tagRefs[0].cast(); + SymbolRefAttr tagRef = cast(tagRefs[0]); llvm::MDNode *node = getTBAANode(op, tagRef); inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); } @@ -1192,7 +1191,7 @@ LogicalResult ModuleTranslation::createTBAAMetadata() { // The type references are in 1, 3, 5, etc. positions. unsigned opNum = 1; for (Attribute typeAttr : tdOp.getMembers()) { - refNames.push_back(typeAttr.cast().getValue()); + refNames.push_back(cast(typeAttr).getValue()); operandIndices.push_back(opNum); opNum += 2; } @@ -1299,7 +1298,7 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, auto llvmModule = std::make_unique(name, llvmContext); if (auto dataLayoutAttr = m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { - llvmModule->setDataLayout(dataLayoutAttr.cast().getValue()); + llvmModule->setDataLayout(cast(dataLayoutAttr).getValue()); } else { FailureOr llvmDataLayout(llvm::DataLayout("")); if (auto iface = dyn_cast(m)) { @@ -1319,7 +1318,7 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, } if (auto targetTripleAttr = m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) - llvmModule->setTargetTriple(targetTripleAttr.cast().getValue()); + llvmModule->setTargetTriple(cast(targetTripleAttr).getValue()); // Inject declarations for `malloc` and `free` functions that can be used in // memref allocation/deallocation coming from standard ops lowering. diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 4c3713f..1724808 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -364,11 +364,11 @@ spirv::Deserializer::processFunction(ArrayRef operands) { } Type fnType = getType(operands[3]); - if (!fnType || !fnType.isa()) { + if (!fnType || !isa(fnType)) { return emitError(unknownLoc, "unknown function type from ") << operands[3]; } - auto functionType = fnType.cast(); + auto functionType = cast(fnType); if ((isVoidType(resultType) && functionType.getNumResults() != 0) || (functionType.getNumResults() == 1 && @@ -562,7 +562,7 @@ spirv::Deserializer::processGlobalVariable(ArrayRef operands) { return emitError(unknownLoc, "unknown result type : ") << operands[wordIndex]; } - auto ptrType = type.dyn_cast(); + auto ptrType = dyn_cast(type); if (!ptrType) { return emitError(unknownLoc, "expected a result type to be a spirv.ptr, found : ") @@ -623,7 +623,7 @@ IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) { if (!constInfo) { return nullptr; } - return constInfo->first.dyn_cast(); + return dyn_cast(constInfo->first); } LogicalResult spirv::Deserializer::processName(ArrayRef operands) { @@ -825,7 +825,7 @@ spirv::Deserializer::processArrayType(ArrayRef operands) { << operands[2] << "can only come from normal constant right now"; } - if (auto intVal = countInfo->first.dyn_cast()) { + if (auto intVal = dyn_cast(countInfo->first)) { count = intVal.getValue().getZExtValue(); } else { return emitError(unknownLoc, "OpTypeArray count must come from a " @@ -1172,7 +1172,7 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef operands, auto resultID = operands[1]; - if (auto intType = resultType.dyn_cast()) { + if (auto intType = dyn_cast(resultType)) { auto bitwidth = intType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); @@ -1205,7 +1205,7 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef operands, return success(); } - if (auto floatType = resultType.dyn_cast()) { + if (auto floatType = dyn_cast(resultType)) { auto bitwidth = floatType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); @@ -1295,12 +1295,12 @@ spirv::Deserializer::processConstantComposite(ArrayRef operands) { } auto resultID = operands[1]; - if (auto vectorType = resultType.dyn_cast()) { + if (auto vectorType = dyn_cast(resultType)) { auto attr = DenseElementsAttr::get(vectorType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); - } else if (auto arrayType = resultType.dyn_cast()) { + } else if (auto arrayType = dyn_cast(resultType)) { auto attr = opBuilder.getArrayAttr(elements); constantMap.try_emplace(resultID, attr, resultType); } else { @@ -1444,7 +1444,7 @@ spirv::Deserializer::processConstantNull(ArrayRef operands) { } auto resultID = operands[1]; - if (resultType.isIntOrFloat() || resultType.isa()) { + if (resultType.isIntOrFloat() || isa(resultType)) { auto attr = opBuilder.getZeroAttr(resultType); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 487b667..613e4f6 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -244,7 +244,7 @@ private: Type getUndefType(uint32_t id) { return undefMap.lookup(id); } /// Returns true if the given `type` is for SPIR-V void type. - bool isVoidType(Type type) const { return type.isa(); } + bool isVoidType(Type type) const { return isa(type); } /// Processes a SPIR-V type instruction with given `opcode` and `operands` and /// registers the type into `module`. diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index d863ab4..f3e8a4b 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -98,7 +98,7 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { auto constituents = op.getConstituents(); for (auto index : llvm::seq(0, constituents.size())) { - auto constituent = constituents[index].dyn_cast(); + auto constituent = dyn_cast(constituents[index]); auto constituentName = constituent.getValue(); auto constituentID = getSpecConstID(constituentName); @@ -280,7 +280,7 @@ LogicalResult Serializer::processVariableOp(spirv::VariableOp op) { auto attr = op->getAttr(spirv::attributeName()); if (attr) { operands.push_back( - static_cast(attr.cast().getValue())); + static_cast(cast(attr).getValue())); } elidedAttrs.push_back(spirv::attributeName()); for (auto arg : op.getODSOperands(0)) { @@ -491,7 +491,7 @@ LogicalResult Serializer::processBranchConditionalOp( if (auto weights = condBranchOp.getBranchWeights()) { for (auto val : weights->getValue()) - arguments.push_back(val.cast().getInt()); + arguments.push_back(cast(val).getInt()); } if (failed(emitDebugLine(functionBody, condBranchOp.getLoc()))) @@ -554,7 +554,7 @@ Serializer::processOp(spirv::EntryPointOp op) { // Add the interface values. if (auto interface = op.getInterface()) { for (auto var : interface.getValue()) { - auto id = getVariableID(var.cast().getValue()); + auto id = getVariableID(cast(var).getValue()); if (!id) { return op.emitError( "referencing undefined global variable." @@ -617,7 +617,7 @@ Serializer::processOp(spirv::FunctionCallOp op) { operands.push_back(valueID); } - if (!resultTy.isa()) + if (!isa(resultTy)) valueIDMap[op.getResult(0)] = funcCallID; encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); @@ -638,28 +638,28 @@ Serializer::processOp(spirv::CopyMemoryOp op) { if (auto attr = op->getAttr("memory_access")) { operands.push_back( - static_cast(attr.cast().getValue())); + static_cast(cast(attr).getValue())); } elidedAttrs.push_back("memory_access"); if (auto attr = op->getAttr("alignment")) { operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); + cast(attr).getValue().getZExtValue())); } elidedAttrs.push_back("alignment"); if (auto attr = op->getAttr("source_memory_access")) { operands.push_back( - static_cast(attr.cast().getValue())); + static_cast(cast(attr).getValue())); } elidedAttrs.push_back("source_memory_access"); if (auto attr = op->getAttr("source_alignment")) { operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); + cast(attr).getValue().getZExtValue())); } elidedAttrs.push_back("source_alignment"); @@ -689,7 +689,7 @@ LogicalResult Serializer::processOp( for (Value operand : op->getOperands()) operands.push_back(getValueID(operand)); spirv::StorageClass resultStorage = - resultTy.cast().getStorageClass(); + cast(resultTy).getStorageClass(); operands.push_back(static_cast(resultStorage)); encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit, operands); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 292ed97..b1c5dfd 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -144,7 +144,7 @@ void Serializer::printValueIDMap(raw_ostream &os) { << "id = " << valueIDPair.second << ' '; if (auto *op = val.getDefiningOp()) { os << "from op '" << op->getName() << "'"; - } else if (auto arg = val.dyn_cast()) { + } else if (auto arg = dyn_cast(val)) { Block *block = arg.getOwner(); os << "from argument of block " << block << ' '; os << " in op '" << block->getParentOp()->getName() << "'"; @@ -176,7 +176,7 @@ void Serializer::processCapability() { void Serializer::processDebugInfo() { if (!options.emitDebugInfo) return; - auto fileLoc = module.getLoc().dyn_cast(); + auto fileLoc = dyn_cast(module.getLoc()); auto fileName = fileLoc ? fileLoc.getFilename().strref() : ""; fileID = getNextID(); SmallVector operands; @@ -221,13 +221,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: - if (auto intAttr = attr.getValue().dyn_cast()) { + if (auto intAttr = dyn_cast(attr.getValue())) { args.push_back(intAttr.getValue().getZExtValue()); break; } return emitError(loc, "expected integer attribute for ") << attrName; case spirv::Decoration::BuiltIn: - if (auto strAttr = attr.getValue().dyn_cast()) { + if (auto strAttr = dyn_cast(attr.getValue())) { auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); if (enumVal) { args.push_back(static_cast(*enumVal)); @@ -245,7 +245,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, case spirv::Decoration::Restrict: case spirv::Decoration::RelaxedPrecision: // For unit attributes, the args list has no values so we do nothing - if (auto unitAttr = attr.getValue().dyn_cast()) + if (auto unitAttr = dyn_cast(attr.getValue())) break; return emitError(loc, "expected unit attribute for ") << attrName; default: @@ -307,13 +307,13 @@ LogicalResult Serializer::processMemberDecoration( // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and // PushConstant Storage Classes must be explicitly laid out." bool Serializer::isInterfaceStructPtrType(Type type) const { - if (auto ptrType = type.dyn_cast()) { + if (auto ptrType = dyn_cast(type)) { switch (ptrType.getStorageClass()) { case spirv::StorageClass::PhysicalStorageBuffer: case spirv::StorageClass::PushConstant: case spirv::StorageClass::StorageBuffer: case spirv::StorageClass::Uniform: - return ptrType.getPointeeType().isa(); + return isa(ptrType.getPointeeType()); default: break; } @@ -343,8 +343,8 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, auto typeEnum = spirv::Opcode::OpTypeVoid; bool deferSerialization = false; - if ((type.isa() && - succeeded(prepareFunctionType(loc, type.cast(), typeEnum, + if ((isa(type) && + succeeded(prepareFunctionType(loc, cast(type), typeEnum, operands))) || succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, deferSerialization, serializationCtx))) { @@ -390,7 +390,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getWidth() == 1) { typeEnum = spirv::Opcode::OpTypeBool; return success(); @@ -406,13 +406,13 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto floatType = type.dyn_cast()) { + if (auto floatType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeFloat; operands.push_back(floatType.getWidth()); return success(); } - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, serializationCtx))) { @@ -424,7 +424,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto imageType = type.dyn_cast()) { + if (auto imageType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeImage; uint32_t sampledTypeID = 0; if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) @@ -440,7 +440,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto arrayType = type.dyn_cast()) { + if (auto arrayType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, @@ -455,10 +455,10 @@ LogicalResult Serializer::prepareBasicType( return processTypeDecoration(loc, arrayType, resultID); } - if (auto ptrType = type.dyn_cast()) { + if (auto ptrType = dyn_cast(type)) { uint32_t pointeeTypeID = 0; spirv::StructType pointeeStruct = - ptrType.getPointeeType().dyn_cast(); + dyn_cast(ptrType.getPointeeType()); if (pointeeStruct && pointeeStruct.isIdentified() && serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { @@ -510,7 +510,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto runtimeArrayType = type.dyn_cast()) { + if (auto runtimeArrayType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), elementTypeID, serializationCtx))) { @@ -521,7 +521,7 @@ LogicalResult Serializer::prepareBasicType( return processTypeDecoration(loc, runtimeArrayType, resultID); } - if (auto sampledImageType = type.dyn_cast()) { + if (auto sampledImageType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeSampledImage; uint32_t imageTypeID = 0; if (failed( @@ -532,7 +532,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto structType = type.dyn_cast()) { + if (auto structType = dyn_cast(type)) { if (structType.isIdentified()) { if (failed(processName(resultID, structType.getIdentifier()))) return failure(); @@ -581,7 +581,7 @@ LogicalResult Serializer::prepareBasicType( } if (auto cooperativeMatrixType = - type.dyn_cast()) { + dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), elementTypeID, serializationCtx))) { @@ -600,7 +600,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto jointMatrixType = type.dyn_cast()) { + if (auto jointMatrixType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, jointMatrixType.getElementType(), elementTypeID, serializationCtx))) { @@ -621,7 +621,7 @@ LogicalResult Serializer::prepareBasicType( return success(); } - if (auto matrixType = type.dyn_cast()) { + if (auto matrixType = dyn_cast(type)) { uint32_t elementTypeID = 0; if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, serializationCtx))) { @@ -684,12 +684,12 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType, } uint32_t resultID = 0; - if (auto attr = valueAttr.dyn_cast()) { - int rank = attr.getType().dyn_cast().getRank(); + if (auto attr = dyn_cast(valueAttr)) { + int rank = dyn_cast(attr.getType()).getRank(); SmallVector index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); - } else if (auto arrayAttr = valueAttr.dyn_cast()) { + } else if (auto arrayAttr = dyn_cast(valueAttr)) { resultID = prepareArrayConstant(loc, constType, arrayAttr); } @@ -712,7 +712,7 @@ uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(attr.size() + 2); - auto elementType = constType.cast().getElementType(); + auto elementType = cast(constType).getElementType(); for (Attribute elementAttr : attr) { if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { operands.push_back(elementID); @@ -732,16 +732,16 @@ uint32_t Serializer::prepareDenseElementsConstant(Location loc, Type constType, DenseElementsAttr valueAttr, int dim, MutableArrayRef index) { - auto shapedType = valueAttr.getType().dyn_cast(); + auto shapedType = dyn_cast(valueAttr.getType()); assert(dim <= shapedType.getRank()); if (shapedType.getRank() == dim) { - if (auto attr = valueAttr.dyn_cast()) { + if (auto attr = dyn_cast(valueAttr)) { return attr.getType().getElementType().isInteger(1) ? prepareConstantBool(loc, attr.getValues()[index]) : prepareConstantInt(loc, attr.getValues()[index]); } - if (auto attr = valueAttr.dyn_cast()) { + if (auto attr = dyn_cast(valueAttr)) { return prepareConstantFp(loc, attr.getValues()[index]); } return 0; @@ -755,7 +755,7 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; operands.reserve(shapedType.getDimSize(dim) + 2); - auto elementType = constType.cast().getElementType(0); + auto elementType = cast(constType).getElementType(0); for (int i = 0; i < shapedType.getDimSize(dim); ++i) { index[dim] = i; if (auto elementID = prepareDenseElementsConstant( @@ -773,13 +773,13 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, bool isSpec) { - if (auto floatAttr = valueAttr.dyn_cast()) { + if (auto floatAttr = dyn_cast(valueAttr)) { return prepareConstantFp(loc, floatAttr, isSpec); } - if (auto boolAttr = valueAttr.dyn_cast()) { + if (auto boolAttr = dyn_cast(valueAttr)) { return prepareConstantBool(loc, boolAttr, isSpec); } - if (auto intAttr = valueAttr.dyn_cast()) { + if (auto intAttr = dyn_cast(valueAttr)) { return prepareConstantInt(loc, intAttr, isSpec); } @@ -797,8 +797,7 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, // Process the type for this bool literal uint32_t typeID = 0; - if (failed( - processType(loc, boolAttr.cast().getType(), typeID))) { + if (failed(processType(loc, cast(boolAttr).getType(), typeID))) { return 0; } @@ -1246,7 +1245,7 @@ LogicalResult Serializer::emitDebugLine(SmallVectorImpl &binary, return success(); } - auto fileLoc = loc.dyn_cast(); + auto fileLoc = dyn_cast(loc); if (fileLoc) encodeInstructionInto(binary, spirv::Opcode::OpLine, {fileID, fileLoc.getLine(), fileLoc.getColumn()}); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index ab9b901..4b2ebf6 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -156,7 +156,7 @@ private: Type getVoidType() { return mlirBuilder.getNoneType(); } - bool isVoidType(Type type) const { return type.isa(); } + bool isVoidType(Type type) const { return isa(type); } /// Returns true if the given type is a pointer type to a struct in some /// interface storage class. diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp index 8225680..00b6816 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -239,7 +239,7 @@ void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { // replacement values. bool usesReplOperation = replValues.size() == 1 && - replValues.front().getType().isa(); + isa(replValues.front().getType()); builder.create( loc, rootExpr, usesReplOperation ? replValues[0] : Value(), usesReplOperation ? ValueRange() : ValueRange(replValues)); @@ -441,7 +441,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { if (ast::OperationType opType = parentType.dyn_cast()) { if (isa(expr)) { Type mlirType = genType(expr->getType()); - if (mlirType.isa()) + if (isa(mlirType)) return builder.create(loc, mlirType, parentExprs[0], builder.getI32IntegerAttr(0)); return builder.create(loc, mlirType, parentExprs[0]); diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp index 5e4dc07..7278aba 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -58,7 +58,7 @@ getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, StringRef uriScheme, const lsp::URIForFile *uri = nullptr) { std::optional location; loc->walk([&](Location nestedLoc) { - FileLineColLoc fileLoc = nestedLoc.dyn_cast(); + FileLineColLoc fileLoc = dyn_cast(nestedLoc); if (!fileLoc) return WalkResult::advance(); @@ -91,7 +91,7 @@ static void collectLocationsFromLoc(Location loc, const lsp::URIForFile &uri) { SetVector visitedLocs; loc->walk([&](Location nestedLoc) { - FileLineColLoc fileLoc = nestedLoc.dyn_cast(); + FileLineColLoc fileLoc = dyn_cast(nestedLoc); if (!fileLoc || !visitedLocs.insert(nestedLoc)) return WalkResult::advance(); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index e98cccc..2884295 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -136,7 +136,7 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, // If the existing operation has an unknown location and the current // operation doesn't, then set the existing op's location to that of the // current op. - if (existing->getLoc().isa() && !op->getLoc().isa()) + if (isa(existing->getLoc()) && !isa(op->getLoc())) existing->setLoc(op->getLoc()); ++numCSE; diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 5d6eaad..57ccb3b 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -345,7 +345,7 @@ static void collectCallOps(iterator_range blocks, // TODO: Support inlining nested call references. CallInterfaceCallable callable = call.getCallableForCallee(); if (SymbolRefAttr symRef = dyn_cast(callable)) { - if (!symRef.isa()) + if (!isa(symRef)) continue; } diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index a4bf97c..45d6f7d 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -99,7 +99,7 @@ MemorySlotPromoter::MemorySlotPromoter( info(std::move(info)) { #ifndef NDEBUG auto isResultOrNewBlockArgument = [&]() { - if (BlockArgument arg = slot.ptr.dyn_cast()) + if (BlockArgument arg = dyn_cast(slot.ptr)) return arg.getOwner()->getParentOp() == allocator; return slot.ptr.getDefiningOp() == allocator; }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f812d4d..615c8e4 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -401,7 +401,7 @@ static Value buildUnresolvedTargetMaterialization( SmallVectorImpl &unresolvedMaterializations) { Block *insertBlock = input.getParentBlock(); Block::iterator insertPt = insertBlock->begin(); - if (OpResult inputRes = input.dyn_cast()) + if (OpResult inputRes = dyn_cast(input)) insertPt = ++inputRes.getOwner()->getIterator(); return buildUnresolvedMaterialization( @@ -1033,7 +1033,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { if (!repl) continue; - if (repl.isa()) { + if (isa(repl)) { arg.replaceAllUsesWith(repl); continue; } @@ -1041,7 +1041,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { // If the replacement value is an operation, we check to make sure that we // don't replace uses that are within the parent operation of the // replacement value. - Operation *replOp = repl.cast().getOwner(); + Operation *replOp = cast(repl).getOwner(); Block *replBlock = replOp->getBlock(); arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); @@ -2615,7 +2615,7 @@ static void computeNecessaryMaterializations( } // Check to see if this is an argument materialization. - auto isBlockArg = [](Value v) { return v.isa(); }; + auto isBlockArg = [](Value v) { return isa(v); }; if (llvm::any_of(op->getOperands(), isBlockArg) || llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { mat->setKind(UnresolvedMaterialization::Argument); diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index 13e7fa8..4490851 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -384,7 +384,7 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, assert(castKind == getCastKindName(CastKind::Argument) && "unexpected value of cast kind attribute"); assert(llvm::all_of(operands, - [&](Value v) { return v.isa(); })); + [&](Value v) { return isa(v); })); maybeResult = typeConverter.materializeArgumentConversion( rewriter, castOp->getLoc(), resultTypes.front(), castOp.getOperands()); diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 2933758..b95af9c 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -244,17 +244,17 @@ public: bool wasProvenLive(Value value) { // TODO: For results that are removable, e.g. for region based control flow, // we could allow for these values to be tracked independently. - if (OpResult result = value.dyn_cast()) + if (OpResult result = dyn_cast(value)) return wasProvenLive(result.getOwner()); - return wasProvenLive(value.cast()); + return wasProvenLive(cast(value)); } bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); } void setProvedLive(Value value) { // TODO: For results that are removable, e.g. for region based control flow, // we could allow for these values to be tracked independently. - if (OpResult result = value.dyn_cast()) + if (OpResult result = dyn_cast(value)) return setProvedLive(result.getOwner()); - setProvedLive(value.cast()); + setProvedLive(cast(value)); } void setProvedLive(BlockArgument arg) { changed |= liveValues.insert(arg).second; @@ -538,11 +538,11 @@ unsigned BlockEquivalenceData::getOrderOf(Value value) const { assert(value.getParentBlock() == block && "expected value of this block"); // Arguments use the argument number as the order index. - if (BlockArgument arg = value.dyn_cast()) + if (BlockArgument arg = dyn_cast(value)) return arg.getArgNumber(); // Otherwise, the result order is offset from the parent op's order. - OpResult result = value.cast(); + OpResult result = cast(value); auto opOrderIt = opOrderIndex.find(result.getDefiningOp()); assert(opOrderIt != opOrderIndex.end() && "expected op to have an order"); return opOrderIt->second + result.getResultNumber(); diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 4598b56..def8a14 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -145,13 +145,13 @@ private: int64_t largeAttrLimit = getLargeAttributeSizeLimit(); // Always emit splat attributes. - if (attr.isa()) { + if (isa(attr)) { attr.print(os); return; } // Elide "big" elements attributes. - auto elements = attr.dyn_cast(); + auto elements = dyn_cast(attr); if (elements && elements.getNumElements() > largeAttrLimit) { os << std::string(elements.getShapedType().getRank(), '[') << "..." << std::string(elements.getShapedType().getRank(), ']') << " : " @@ -159,7 +159,7 @@ private: return; } - auto array = attr.dyn_cast(); + auto array = dyn_cast(attr); if (array && static_cast(array.size()) > largeAttrLimit) { os << "[...]"; return; diff --git a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp index b563be4..95c16a6 100644 --- a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp +++ b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp @@ -24,7 +24,7 @@ static void printAliasOperand(Operation *op) { llvm::errs() << op->getAttrOfType("test.ptr").getValue(); } static void printAliasOperand(Value value) { - if (BlockArgument arg = value.dyn_cast()) { + if (BlockArgument arg = dyn_cast(value)) { Region *region = arg.getParentRegion(); unsigned parentBlockNumber = std::distance(region->begin(), arg.getOwner()->getIterator()); @@ -37,7 +37,7 @@ static void printAliasOperand(Value value) { llvm::errs() << "#" << arg.getArgNumber(); return; } - OpResult result = value.cast(); + OpResult result = cast(value); printAliasOperand(result.getOwner()); llvm::errs() << "#" << result.getResultNumber(); } @@ -156,7 +156,7 @@ struct TestAliasAnalysisModRefPass /// Check if value is function argument. static bool isFuncArg(Value val) { - auto blockArg = val.dyn_cast(); + auto blockArg = dyn_cast(val); if (!blockArg) return false; @@ -166,7 +166,7 @@ static bool isFuncArg(Value val) { /// Check if value has "restrict" attribute. Value must be a function argument. static bool isRestrict(Value val) { - auto blockArg = val.cast(); + auto blockArg = cast(val); auto func = mlir::cast(blockArg.getOwner()->getParentOp()); return !!func.getArgAttr(blockArg.getArgNumber(), diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp index c9e72f8..968e10b 100644 --- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp @@ -32,7 +32,7 @@ struct TestMemRefStrideCalculation void TestMemRefStrideCalculation::runOnOperation() { llvm::outs() << "Testing: " << getOperation().getName() << "\n"; getOperation().walk([&](memref::AllocOp allocOp) { - auto memrefType = allocOp.getResult().getType().cast(); + auto memrefType = cast(allocOp.getResult().getType()); int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(memrefType, strides, offset))) { diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp index c3f2098..e1ccc1b 100644 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -102,7 +102,7 @@ public: matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { // Construct mapping for tuple element types. - auto stateType = op->getOperand(0).getType().cast(); + auto stateType = cast(op->getOperand(0).getType()); TypeRange originalElementTypes = stateType.getTypes(); OneToNTypeMapping elementMapping(originalElementTypes); if (failed(typeConverter->convertSignatureArgs(originalElementTypes, @@ -148,7 +148,7 @@ static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter, static std::optional> buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, Location loc) { - TupleType inputType = input.getType().dyn_cast(); + TupleType inputType = dyn_cast(input.getType()); if (!inputType) return {}; @@ -156,7 +156,7 @@ buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { Value element = builder.create<::test::GetTupleElementOp>( loc, elementType, input, builder.getI32IntegerAttr(idx)); - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Recurse if the current element is also a tuple. SmallVector flatRecursiveTypes; nestedTupleType.getFlattenedTypes(flatRecursiveTypes); @@ -186,7 +186,7 @@ static std::optional buildMakeTupleOp(OpBuilder &builder, elements.reserve(resultType.getTypes().size()); ValueRange::iterator inputIt = inputs.begin(); for (Type elementType : resultType.getTypes()) { - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Determine how many input values are needed for the nested elements of // the nested TupleType and advance inputIt by that number. // TODO: We only need the *number* of nested types, not the types itself. diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 1bf3ce4..dff619e 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -81,7 +81,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, return WalkResult::skip(); } Value value = op->getOperand(0); - if (value.getType().isa() != + if (isa(value.getType()) != !op->hasAttrOfType("dim")) { // Op should have "dim" attribute if and only if the operand is an // index-typed value. @@ -119,7 +119,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, if (reifyToFuncArgs) { // Reify in terms of function block arguments. stopCondition = stopCondition = [](Value v, std::optional d) { - auto bbArg = v.dyn_cast(); + auto bbArg = dyn_cast(v); if (!bbArg) return false; return isa( @@ -166,7 +166,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, return WalkResult::skip(); } Value constOp = rewriter.create( - op->getLoc(), reified->get().cast().getInt()); + op->getLoc(), cast(reified->get()).getInt()); rewriter.replaceOp(op, constOp); return WalkResult::skip(); } diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp index 85dd071..f8588fa 100644 --- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp +++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp @@ -127,7 +127,7 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) { // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the // future we can always extend. - auto superVectorType = opInst->getResult(0).getType().cast(); + auto superVectorType = cast(opInst->getResult(0).getType()); auto ratio = computeShapeRatio(superVectorType.getShape(), subVectorType.getShape()); if (!ratio) { @@ -211,8 +211,8 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { maps.reserve(matches.size()); for (auto m : llvm::reverse(matches)) { auto *opInst = m.getMatchedOperation(); - auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) - .cast() + auto map = cast( + opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)) .getValue(); maps.push_back(map); } diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index 41e1666..10aba73 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -27,7 +27,7 @@ static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc, Type elementType = resultType.getType(i); Value element = builder.create( loc, elementType, value, builder.getI32IntegerAttr(i)); - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Recurse if the current element is also a tuple. if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element, values))) @@ -50,7 +50,7 @@ static std::optional buildMakeTupleOp(OpBuilder &builder, elements.reserve(resultType.getTypes().size()); ValueRange::iterator inputIt = inputs.begin(); for (Type elementType : resultType.getTypes()) { - if (auto nestedTupleType = elementType.dyn_cast()) { + if (auto nestedTupleType = dyn_cast(elementType)) { // Determine how many input values are needed for the nested elements of // the nested TupleType and advance inputIt by that number. // TODO: We only need the *number* of nested types, not the types itself. diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index 5050498..2231e42 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -38,9 +38,9 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { for (OpOperand &opOperand : linalgOp->getOpOperands()) { - if (opOperand.get().getType().isa()) + if (isa(opOperand.get().getType())) continue; - if (opOperand.get().getType().isa()) { + if (isa(opOperand.get().getType())) { // Tile and Fuse tensor input. if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs()) continue; diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp index 449a3e9..0f08758 100644 --- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp +++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp @@ -61,9 +61,9 @@ void ReportShapeFnPass::runOnOperation() { if (attr) { auto lookup = [&](Attribute attr) { return cast( - SymbolTable::lookupSymbolIn(module, attr.cast())); + SymbolTable::lookupSymbolIn(module, cast(attr))); }; - if (auto arrayAttr = attr.dyn_cast()) { + if (auto arrayAttr = dyn_cast(attr)) { libraries.reserve(arrayAttr.size()); for (auto attr : arrayAttr) libraries.push_back(lookup(attr)); diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 6dc8b4a..46fe865 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -113,7 +113,7 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { if (!op.getSource().hasOneUse()) return false; - auto resultType = op.getResult().getType().cast(); + auto resultType = cast(op.getResult().getType()); constexpr int64_t kConstantFoldingMaxNumElements = 1024; return resultType.getNumElements() <= kConstantFoldingMaxNumElements; }; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 4660d9a..715c77b 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -49,7 +49,7 @@ Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const { } LogicalResult MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr, InFlightDiagnostic *diag) { - StringAttr strAttr = attr.dyn_cast(); + StringAttr strAttr = dyn_cast(attr); if (!strAttr) { if (diag) *diag << "Expect StringAttr but got " << attr; @@ -221,7 +221,7 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { //===------------------------------------------------------------------===// AliasResult getAlias(Attribute attr, raw_ostream &os) const final { - StringAttr strAttr = attr.dyn_cast(); + StringAttr strAttr = dyn_cast(attr); if (!strAttr) return AliasResult::NoAlias; @@ -246,16 +246,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { } AliasResult getAlias(Type type, raw_ostream &os) const final { - if (auto tupleType = type.dyn_cast()) { + if (auto tupleType = dyn_cast(type)) { if (tupleType.size() > 0 && llvm::all_of(tupleType.getTypes(), [](Type elemType) { - return elemType.isa(); + return isa(elemType); })) { os << "test_tuple"; return AliasResult::FinalAlias; } } - if (auto intType = type.dyn_cast()) { + if (auto intType = dyn_cast(type)) { if (intType.getSignedness() == TestIntegerType::SignednessSemantics::Unsigned && intType.getWidth() == 8) { @@ -263,7 +263,7 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { return AliasResult::FinalAlias; } } - if (auto recType = type.dyn_cast()) { + if (auto recType = dyn_cast(type)) { if (recType.getName() == "type_to_alias") { // We only make alias for a specific recursive type. os << "testrec"; @@ -1230,7 +1230,7 @@ void PolyForOp::getAsmBlockArgumentNames(Region ®ion, auto args = getRegion().front().getArguments(); auto e = std::min(arrayAttr.size(), args.size()); for (unsigned i = 0; i < e; ++i) { - if (auto strAttr = arrayAttr[i].dyn_cast()) + if (auto strAttr = dyn_cast(arrayAttr[i])) setNameFn(args[i], strAttr.getValue()); } } @@ -1252,7 +1252,7 @@ static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) { } static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) { - p.printOptionalLocationSpecifier(loc.cast()); + p.printOptionalLocationSpecifier(cast(loc)); } //===----------------------------------------------------------------------===// @@ -1376,7 +1376,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = operands.front().getType(); - auto sval = operandType.dyn_cast(); + auto sval = dyn_cast(operandType); if (!sval) { return emitOptionalError(location, "only shaped type operands allowed"); } @@ -1384,7 +1384,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( auto type = IntegerType::get(context, 17); Attribute encoding; - if (auto rankedTy = sval.dyn_cast()) + if (auto rankedTy = dyn_cast(sval)) encoding = rankedTy.getEncoding(); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); return success(); @@ -1404,7 +1404,7 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( Location loc = getLoc(); shapes.reserve(operands.size()); for (Value operand : llvm::reverse(operands)) { - auto rank = operand.getType().cast().getRank(); + auto rank = cast(operand.getType()).getRank(); auto currShape = llvm::to_vector<4>( llvm::map_range(llvm::seq(0, rank), [&](int64_t dim) -> Value { return builder.createOrFold(loc, operand, dim); @@ -1421,7 +1421,7 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( Location loc = getLoc(); shapes.reserve(getNumOperands()); for (Value operand : llvm::reverse(getOperands())) { - auto tensorType = operand.getType().cast(); + auto tensorType = cast(operand.getType()); auto currShape = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, tensorType.getRank()), [&](int64_t dim) -> OpFoldResult { @@ -1471,12 +1471,12 @@ void SideEffectOp::getEffects( // If there is one, it is an array of dictionary attributes that hold // information on the effects of this operation. for (Attribute element : effectsAttr) { - DictionaryAttr effectElement = element.cast(); + DictionaryAttr effectElement = cast(element); // Get the specific memory effect. MemoryEffects::Effect *effect = StringSwitch( - effectElement.get("effect").cast().getValue()) + cast(effectElement.get("effect")).getValue()) .Case("allocate", MemoryEffects::Allocate::get()) .Case("free", MemoryEffects::Free::get()) .Case("read", MemoryEffects::Read::get()) @@ -1491,7 +1491,7 @@ void SideEffectOp::getEffects( if (effectElement.get("on_result")) effects.emplace_back(effect, getResult(), resource); else if (Attribute ref = effectElement.get("on_reference")) - effects.emplace_back(effect, ref.cast(), resource); + effects.emplace_back(effect, cast(ref), resource); else effects.emplace_back(effect, resource); } @@ -1556,7 +1556,7 @@ void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { llvm::raw_svector_ostream tmpStream(resultNameStr); p.printOperand(getResult(i), tmpStream); - auto expectedName = getNames()[i].dyn_cast(); + auto expectedName = dyn_cast(getNames()[i]); if (!expectedName || tmpStream.str().drop_front() != expectedName.getValue()) { namesDisagree = true; @@ -1576,7 +1576,7 @@ void StringAttrPrettyNameOp::getAsmResultNames( auto value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = value[i].dyn_cast()) + if (auto str = dyn_cast(value[i])) if (!str.getValue().empty()) setNameFn(getResult(i), str.getValue()); } @@ -1585,7 +1585,7 @@ void CustomResultsNameOp::getAsmResultNames( function_ref setNameFn) { ArrayAttr value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = value[i].dyn_cast()) + if (auto str = dyn_cast(value[i])) if (!str.getValue().empty()) setNameFn(getResult(i), str.getValue()); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index adaa6e1..a61ba8e 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -153,7 +153,7 @@ struct IncrementIntAttribute : public OpRewritePattern { LogicalResult matchAndRewrite(AnyAttrOfOp op, PatternRewriter &rewriter) const override { - auto intAttr = op.getAttr().dyn_cast(); + auto intAttr = dyn_cast(op.getAttr()); if (!intAttr) return failure(); int64_t val = intAttr.getInt(); @@ -1271,11 +1271,11 @@ struct TestTypeConversionProducer Type convertedType = getTypeConverter() ? getTypeConverter()->convertType(resultType) : resultType; - if (resultType.isa()) + if (isa(resultType)) resultType = rewriter.getF64Type(); else if (resultType.isInteger(16)) resultType = rewriter.getIntegerType(64); - else if (resultType.isa() && + else if (isa(resultType) && convertedType != resultType) resultType = convertedType; else @@ -1430,8 +1430,8 @@ struct TestTypeConversionDriver inputs.empty()) return builder.create(loc, resultType); // Allow producing an i64 from an integer. - if (resultType.isa() && inputs.size() == 1 && - inputs[0].getType().isa()) + if (isa(resultType) && inputs.size() == 1 && + isa(inputs[0].getType())) return builder.create(loc, resultType, inputs).getResult(); // Otherwise, fail. return nullptr; @@ -1440,7 +1440,7 @@ struct TestTypeConversionDriver // Initialize the conversion target. mlir::ConversionTarget target(getContext()); target.addDynamicallyLegalOp([](TestTypeProducerOp op) { - auto recursiveType = op.getType().dyn_cast(); + auto recursiveType = dyn_cast(op.getType()); return op.getType().isF64() || op.getType().isInteger(64) || (recursiveType && recursiveType.getName() == "outer_converted_type"); diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp index c147ff4..9642301 100644 --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -42,20 +42,20 @@ ConvertTosaNegateOp::matchAndRewrite(Operation *op, auto tosaNegateOp = cast(op); auto inputType = - tosaNegateOp.getInput1().getType().dyn_cast(); + dyn_cast(tosaNegateOp.getInput1().getType()); // skip if input is not ranked tensor type if (!inputType) return failure(); // skip if it's not ranked tensor type. auto outputType = - tosaNegateOp.getResult().getType().dyn_cast(); + dyn_cast(tosaNegateOp.getResult().getType()); if (!outputType) return failure(); // skip if output is not per-tensor quantized type. auto outputElementType = - outputType.getElementType().dyn_cast(); + dyn_cast(outputType.getElementType()); if (!outputElementType) return failure(); @@ -112,14 +112,14 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, auto tosaConv2DOp = cast(op); auto inputType = - tosaConv2DOp.getInput().getType().dyn_cast(); + dyn_cast(tosaConv2DOp.getInput().getType()); // skip if input is not ranked tensor type if (!inputType) return failure(); auto weightType = - tosaConv2DOp.getWeight().getType().dyn_cast(); + dyn_cast(tosaConv2DOp.getWeight().getType()); // skip if wt is not ranked tensor type if (!weightType) @@ -127,16 +127,16 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, // skip if it's not ranked tensor type. auto outputType = - tosaConv2DOp.getResult().getType().dyn_cast(); + dyn_cast(tosaConv2DOp.getResult().getType()); if (!outputType) return failure(); auto inputQType = - inputType.getElementType().dyn_cast(); + dyn_cast(inputType.getElementType()); auto weightQType = - weightType.getElementType().dyn_cast(); + dyn_cast(weightType.getElementType()); auto outputQType = - outputType.getElementType().dyn_cast(); + dyn_cast(outputType.getElementType()); // Works on quantized type only. if (!(inputQType && weightQType && outputQType)) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index dd853aa..d0c79ab 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -89,7 +89,7 @@ private: auto extract = dyn_cast(users); if (!extract) return std::nullopt; - auto vecType = extract.getResult().getType().cast(); + auto vecType = cast(extract.getResult().getType()); if (dstVec && dstVec != vecType) return std::nullopt; dstVec = vecType; @@ -430,7 +430,7 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, static constexpr int64_t kSharedMemorySpace = 3; // Compute type of shared memory buffer. MemRefType memrefType; - if (auto vectorType = type.dyn_cast()) { + if (auto vectorType = dyn_cast(type)) { memrefType = MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, kSharedMemorySpace); @@ -535,7 +535,7 @@ struct TestVectorDistribution // Create a map (d0, d1) -> (d1) to distribute along the inner // dimension. Once we support n-d distribution we can add more // complex cases. - VectorType vecType = val.getType().dyn_cast(); + VectorType vecType = dyn_cast(val.getType()); int64_t vecRank = vecType ? vecType.getRank() : 0; OpBuilder builder(val.getContext()); if (vecRank == 0) @@ -642,9 +642,9 @@ struct TestCreateVectorBroadcast if (op->getName().getStringRef() != "test_create_broadcast") return; auto targetShape = - op->getResult(0).getType().cast().getShape(); + cast(op->getResult(0).getType()).getShape(); auto arrayAttr = - op->getAttr("broadcast_dims").cast().asArrayRef(); + cast(op->getAttr("broadcast_dims")).asArrayRef(); llvm::SetVector broadcastedDims; broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end()); OpBuilder b(op); diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp index 9313f40..498de3d 100644 --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -34,7 +34,7 @@ struct TestElementsAttrInterface void runOnOperation() override { getOperation().walk([&](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - auto elementsAttr = attr.getValue().dyn_cast(); + auto elementsAttr = dyn_cast(attr.getValue()); if (!elementsAttr) continue; testElementsAttrIteration(op, elementsAttr, "int64_t"); diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp index 1f5b29d..578486c 100644 --- a/mlir/test/lib/IR/TestDiagnostics.cpp +++ b/mlir/test/lib/IR/TestDiagnostics.cpp @@ -36,7 +36,7 @@ struct TestDiagnosticFilterPass // Build a diagnostic handler that has filtering capabilities. auto filterFn = [&](Location loc) { // Ignore non-file locations. - FileLineColLoc fileLoc = loc.dyn_cast(); + FileLineColLoc fileLoc = dyn_cast(loc); if (!fileLoc) return true; diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp index 171d46a..4589788 100644 --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -35,13 +35,13 @@ struct TestFuncInsertArg SmallVector locsToInsert; for (auto insert : inserts.getAsRange()) { indicesToInsert.push_back( - insert[0].cast().getValue().getZExtValue()); - typesToInsert.push_back(insert[1].cast().getValue()); + cast(insert[0]).getValue().getZExtValue()); + typesToInsert.push_back(cast(insert[1]).getValue()); attrsToInsert.push_back(insert.size() > 2 - ? insert[2].cast() + ? cast(insert[2]) : DictionaryAttr::get(&getContext())); locsToInsert.push_back(insert.size() > 3 - ? Location(insert[3].cast()) + ? Location(cast(insert[3])) : unknownLoc); } func->removeAttr("test.insert_args"); @@ -72,10 +72,10 @@ struct TestFuncInsertResult SmallVector attrsToInsert; for (auto insert : inserts.getAsRange()) { indicesToInsert.push_back( - insert[0].cast().getValue().getZExtValue()); - typesToInsert.push_back(insert[1].cast().getValue()); + cast(insert[0]).getValue().getZExtValue()); + typesToInsert.push_back(cast(insert[1]).getValue()); attrsToInsert.push_back(insert.size() > 2 - ? insert[2].cast() + ? cast(insert[2]) : DictionaryAttr::get(&getContext())); } func->removeAttr("test.insert_results"); diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp index 633d530..2dd3fe2 100644 --- a/mlir/test/lib/IR/TestInterfaces.cpp +++ b/mlir/test/lib/IR/TestInterfaces.cpp @@ -27,7 +27,7 @@ struct TestTypeInterfaces void runOnOperation() override { getOperation().walk([](Operation *op) { for (Type type : op->getResultTypes()) { - if (auto testInterface = type.dyn_cast()) { + if (auto testInterface = dyn_cast(type)) { testInterface.printTypeA(op->getLoc()); testInterface.printTypeB(op->getLoc()); testInterface.printTypeC(op->getLoc()); @@ -37,7 +37,7 @@ struct TestTypeInterfaces TestTypeInterface result = testInterface.printTypeRet(op->getLoc()); (void)result; } - if (auto testType = type.dyn_cast()) + if (auto testType = dyn_cast(type)) testType.printTypeE(op->getLoc()); } }); diff --git a/mlir/test/lib/IR/TestOpaqueLoc.cpp b/mlir/test/lib/IR/TestOpaqueLoc.cpp index 977d2b0..c0ce896 100644 --- a/mlir/test/lib/IR/TestOpaqueLoc.cpp +++ b/mlir/test/lib/IR/TestOpaqueLoc.cpp @@ -74,7 +74,7 @@ struct TestOpaqueLoc ScopedDiagnosticHandler diagHandler(&getContext(), [](Diagnostic &diag) { auto &os = llvm::outs(); - if (diag.getLocation().isa()) { + if (isa(diag.getLocation())) { MyLocation *loc = OpaqueLoc::getUnderlyingLocationOrNull( diag.getLocation()); if (loc) diff --git a/mlir/test/lib/IR/TestPrintDefUse.cpp b/mlir/test/lib/IR/TestPrintDefUse.cpp index 0656036..5d489a3 100644 --- a/mlir/test/lib/IR/TestPrintDefUse.cpp +++ b/mlir/test/lib/IR/TestPrintDefUse.cpp @@ -34,7 +34,7 @@ struct TestPrintDefUsePass } else { // If there is no defining op, the Value is necessarily a Block // argument. - auto blockArg = operand.cast(); + auto blockArg = cast(operand); llvm::outs() << " - Operand produced by Block argument, number " << blockArg.getArgNumber() << "\n"; } diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Transforms/TestTopologicalSort.cpp index 4ad5b5c..a8cc7a5 100644 --- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp +++ b/mlir/test/lib/Transforms/TestTopologicalSort.cpp @@ -42,7 +42,7 @@ struct TestTopologicalSortAnalysisPass // If the root has an "ordered" attribute, we fill the selectedOps // vector in a certain order. int64_t pos = - selected->getAttr("selected").cast().getInt(); + cast(selected->getAttr("selected")).getInt(); if (pos >= static_cast(selectedOps.size())) selectedOps.append(pos + 1 - selectedOps.size(), nullptr); selectedOps[pos] = selected; diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 52ea148..35f9015 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -317,10 +317,10 @@ struct ScalarTraits { SerializedAffineMap &value) { assert(rawYamlContext); auto *yamlContext = static_cast(rawYamlContext); - if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext) - .dyn_cast_or_null()) + if (auto attr = dyn_cast_or_null( + mlir::parseAttribute(scalar, yamlContext->mlirContext))) value.affineMapAttr = attr; - else if (!value.affineMapAttr || !value.affineMapAttr.isa()) + else if (!value.affineMapAttr || !isa(value.affineMapAttr)) return "could not parse as an affine map attribute"; return StringRef(); } diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp index f390934..aa19b5c 100644 --- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp +++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -36,18 +36,18 @@ TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) { ASSERT_EQ(subElementTypes.size(), 4U); // !llvm.ptr> - ASSERT_TRUE(subElementTypes[0].isa()); + ASSERT_TRUE(isa(subElementTypes[0])); // !llvm.struct<"bar",...> - auto structType = subElementTypes[1].dyn_cast(); + auto structType = dyn_cast(subElementTypes[1]); ASSERT_TRUE(bool(structType)); ASSERT_TRUE(structType.getName().equals("bar")); // !llvm.ptr> - ASSERT_TRUE(subElementTypes[2].isa()); + ASSERT_TRUE(isa(subElementTypes[2])); // !llvm.struct<"foo",...> - structType = subElementTypes[3].dyn_cast(); + structType = dyn_cast(subElementTypes[3]); ASSERT_TRUE(bool(structType)); ASSERT_TRUE(structType.getName().equals("foo")); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 94345d00..f01cc02 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -278,7 +278,7 @@ static void checkNativeAccess(MLIRContext *ctx, ArrayRef data, // Check that we cast to this attribute when possible. Attribute genericAttr = attr; - EXPECT_TRUE(genericAttr.template isa()); + EXPECT_TRUE(isa(genericAttr)); } template static void checkNativeIntAccess(Builder &builder, size_t intWidth) { @@ -330,9 +330,9 @@ TEST(DenseResourceElementsAttrTest, CheckNoCast) { Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data)); - EXPECT_TRUE(i32ResourceAttr.isa()); - EXPECT_FALSE(i32ResourceAttr.isa()); - EXPECT_FALSE(i32ResourceAttr.isa()); + EXPECT_TRUE(isa(i32ResourceAttr)); + EXPECT_FALSE(isa(i32ResourceAttr)); + EXPECT_FALSE(isa(i32ResourceAttr)); } TEST(DenseResourceElementsAttrTest, CheckInvalidData) { @@ -407,17 +407,17 @@ TEST(SparseElementsAttrTest, GetZero) { // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. auto zeroIntValue = - sparseInt.getValues()[{1, 1}].cast(); + cast(sparseInt.getValues()[{1, 1}]); EXPECT_EQ(zeroIntValue.getInt(), 0); EXPECT_TRUE(zeroIntValue.getType() == intTy); auto zeroFloatValue = - sparseFloat.getValues()[{1, 1}].cast(); + cast(sparseFloat.getValues()[{1, 1}]); EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); auto zeroStringValue = - sparseString.getValues()[{1, 1}].cast(); + cast(sparseString.getValues()[{1, 1}]); EXPECT_TRUE(zeroStringValue.getValue().empty()); EXPECT_TRUE(zeroStringValue.getType() == stringTy); } diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp index d5e19d2..fe85516 100644 --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -61,11 +61,11 @@ TEST(InterfaceAttachment, Type) { // Check that the type has no interface. IntegerType i8 = IntegerType::get(&context, 8); - ASSERT_FALSE(i8.isa()); + ASSERT_FALSE(isa(i8)); // Attach an interface and check that the type now has the interface. IntegerType::attachInterface(context); - TestExternalTypeInterface iface = i8.dyn_cast(); + TestExternalTypeInterface iface = dyn_cast(i8); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); @@ -74,9 +74,9 @@ TEST(InterfaceAttachment, Type) { // Same, but with the default implementation overridden. FloatType flt = Float32Type::get(&context); - ASSERT_FALSE(flt.isa()); + ASSERT_FALSE(isa(flt)); Float32Type::attachInterface(context); - iface = flt.dyn_cast(); + iface = dyn_cast(flt); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); @@ -86,7 +86,7 @@ TEST(InterfaceAttachment, Type) { // Other contexts shouldn't have the attribute attached. MLIRContext other; IntegerType i8other = IntegerType::get(&other, 8); - EXPECT_FALSE(i8other.isa()); + EXPECT_FALSE(isa(i8other)); } /// External interface model for the test type from the test dialect. @@ -111,7 +111,7 @@ TEST(InterfaceAttachment, TypeDelayedContextConstruct) { MLIRContext context(registry); context.loadDialect(); test::TestType testType = test::TestType::get(&context); - auto iface = testType.dyn_cast(); + auto iface = dyn_cast(testType); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); @@ -130,9 +130,9 @@ TEST(InterfaceAttachment, TypeDelayedContextAppend) { MLIRContext context; context.loadDialect(); test::TestType testType = test::TestType::get(&context); - EXPECT_FALSE(testType.isa()); + EXPECT_FALSE(isa(testType)); context.appendDialectRegistry(registry); - EXPECT_TRUE(testType.isa()); + EXPECT_TRUE(isa(testType)); } TEST(InterfaceAttachment, RepeatedRegistration) { @@ -156,13 +156,13 @@ TEST(InterfaceAttachment, TypeBuiltinDelayed) { MLIRContext context(registry); IntegerType i16 = IntegerType::get(&context, 16); - EXPECT_TRUE(i16.isa()); + EXPECT_TRUE(isa(i16)); MLIRContext initiallyEmpty; IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); - EXPECT_FALSE(i32.isa()); + EXPECT_FALSE(isa(i32)); initiallyEmpty.appendDialectRegistry(registry); - EXPECT_TRUE(i32.isa()); + EXPECT_TRUE(isa(i32)); } /// The interface provides a default implementation that expects @@ -181,9 +181,8 @@ struct TestExternalFallbackTypeVectorModel : public TestExternalFallbackTypeInterface::FallbackModel< TestExternalFallbackTypeVectorModel> { unsigned getBitwidth(Type type) const { - IntegerType elementType = type.cast() - .getElementType() - .dyn_cast_or_null(); + IntegerType elementType = + dyn_cast_or_null(cast(type).getElementType()); return elementType ? elementType.getWidth() : 0; } }; @@ -193,16 +192,16 @@ TEST(InterfaceAttachment, Fallback) { // Just check that we can attach the interface. IntegerType i8 = IntegerType::get(&context, 8); - ASSERT_FALSE(i8.isa()); + ASSERT_FALSE(isa(i8)); IntegerType::attachInterface(context); - ASSERT_TRUE(i8.isa()); + ASSERT_TRUE(isa(i8)); // Call the method so it is guaranteed not to be instantiated. VectorType vec = VectorType::get({42}, i8); - ASSERT_FALSE(vec.isa()); + ASSERT_FALSE(isa(vec)); VectorType::attachInterface(context); - ASSERT_TRUE(vec.isa()); - EXPECT_EQ(vec.cast().getBitwidth(), 8u); + ASSERT_TRUE(isa(vec)); + EXPECT_EQ(cast(vec).getBitwidth(), 8u); } /// External model for attribute interfaces. @@ -210,7 +209,7 @@ struct TestExternalIntegerAttrModel : public TestExternalAttrInterface::ExternalModel< TestExternalIntegerAttrModel, IntegerAttr> { const Dialect *getDialectPtr(Attribute attr) const { - return &attr.cast().getDialect(); + return &cast(attr).getDialect(); } static int getSomeNumber() { return 42; } @@ -222,9 +221,9 @@ TEST(InterfaceAttachment, Attribute) { // Attribute interfaces use the exact same mechanism as types, so just check // that the basics work for attributes. IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); - ASSERT_FALSE(attr.isa()); + ASSERT_FALSE(isa(attr)); IntegerAttr::attachInterface(context); - auto iface = attr.dyn_cast(); + auto iface = dyn_cast(attr); ASSERT_TRUE(iface != nullptr); EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); EXPECT_EQ(iface.getSomeNumber(), 42); @@ -253,14 +252,14 @@ TEST(InterfaceAttachmentTest, AttributeDelayed) { MLIRContext context(registry); context.loadDialect(); auto attr = test::SimpleAAttr::get(&context); - EXPECT_TRUE(attr.isa()); + EXPECT_TRUE(isa(attr)); MLIRContext initiallyEmpty; initiallyEmpty.loadDialect(); attr = test::SimpleAAttr::get(&initiallyEmpty); - EXPECT_FALSE(attr.isa()); + EXPECT_FALSE(isa(attr)); initiallyEmpty.appendDialectRegistry(registry); - EXPECT_TRUE(attr.isa()); + EXPECT_TRUE(isa(attr)); } /// External interface model for the module operation. Only provides non-default diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp index 7e0a8f5..6601f32 100644 --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -152,16 +152,16 @@ struct OpWithLayout : public Op { static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout, DataLayoutEntryListRef params) { // Make a recursive query. - if (type.isa()) + if (isa(type)) return dataLayout.getTypeSizeInBits( IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth())); // Handle built-in types that are not handled by the default process. - if (auto iType = type.dyn_cast()) { + if (auto iType = dyn_cast(type)) { for (DataLayoutEntryInterface entry : params) if (entry.getKey().dyn_cast() == type) return 8 * - entry.getValue().cast().getValue().getZExtValue(); + cast(entry.getValue()).getValue().getZExtValue(); return 8 * iType.getIntOrFloatBitWidth(); } @@ -217,7 +217,7 @@ struct DLTestDialect : Dialect { void printAttribute(Attribute attr, DialectAsmPrinter &printer) const override { printer << "spec<"; - llvm::interleaveComma(attr.cast().getEntries(), + llvm::interleaveComma(cast(attr).getEntries(), printer); printer << ">"; } @@ -244,7 +244,7 @@ struct DLTestDialect : Dialect { } void printType(Type type, DialectAsmPrinter &printer) const override { - if (type.isa()) + if (isa(type)) printer << "single_query"; else printer << "no_layout"; diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp index 24e8702..97349d6 100644 --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -75,12 +75,12 @@ TEST(PassManagerTest, OpSpecificAnalysis) { // Verify that each function got annotated with expected attributes. for (func::FuncOp func : module->getOps()) { - ASSERT_TRUE(func->getAttr("isFunc").isa()); - EXPECT_TRUE(func->getAttr("isFunc").cast().getValue()); + ASSERT_TRUE(isa(func->getAttr("isFunc"))); + EXPECT_TRUE(cast(func->getAttr("isFunc")).getValue()); bool isSecret = func.getName() == "secret"; - ASSERT_TRUE(func->getAttr("isSecret").isa()); - EXPECT_EQ(func->getAttr("isSecret").cast().getValue(), isSecret); + ASSERT_TRUE(isa(func->getAttr("isSecret"))); + EXPECT_EQ(cast(func->getAttr("isSecret")).getValue(), isSecret); } } -- 2.7.4