[mlir] Move casting calls from methods to function calls
authorTres Popp <tpopp@google.com>
Fri, 26 May 2023 08:17:47 +0000 (10:17 +0200)
committerTres Popp <tpopp@google.com>
Fri, 26 May 2023 08:29:55 +0000 (10:29 +0200)
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This patch updates all remaining uses of the deprecated functionality in
mlir/. This was done with clang-tidy as described below and further
modifications to GPUBase.td and OpenMPOpsInterfaces.td.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc
```

Differential Revision: https://reviews.llvm.org/D151542

117 files changed:
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/TypeRange.h
mlir/include/mlir/Interfaces/SideEffectInterfaces.h
mlir/include/mlir/Pass/AnalysisManager.h
mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/DataFlowFramework.cpp
mlir/lib/AsmParser/Parser.cpp
mlir/lib/CAPI/Interfaces/Interfaces.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Debug/DebuggerExecutionContextHook.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/lib/Dialect/DLTI/DLTI.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Split.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Block.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/IR/Region.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/IR/TypeRange.cpp
mlir/lib/IR/Types.cpp
mlir/lib/IR/Unit.cpp
mlir/lib/IR/Value.cpp
mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/TableGen/Operator.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/lib/Transforms/Utils/InliningUtils.cpp
mlir/test/lib/Analysis/TestDataFlowFramework.cpp
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp

index ef07af2..df9105c 100644 (file)
@@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
 
   // If the type is a function type, it contains the input and result types of
   // this operation.
-  if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+  if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
                                result.operands))
       return mlir::failure();
@@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
 mlir::LogicalResult ConstantOp::verify() {
   // If the return type of the constant is not an unranked tensor, the shape
   // must match the shape of the attribute holding the data.
-  auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+  auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
   if (!resultType)
     return success();
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+  auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
@@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
   auto resultType = results.front();
 
   // Check that the result type of the function matches the operand type.
-  if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
-      resultType.isa<mlir::UnrankedTensorType>())
+  if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+      llvm::isa<mlir::UnrankedTensorType>(resultType))
     return mlir::success();
 
   return emitError() << "type of return operand (" << inputType
@@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
 }
 
 mlir::LogicalResult TransposeOp::verify() {
-  auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
-  auto resultType = getType().dyn_cast<RankedTensorType>();
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
   if (!inputType || !resultType)
     return mlir::success();
 
index 43f8d5b..ca076f2 100644 (file)
@@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
 
   // If the type is a function type, it contains the input and result types of
   // this operation.
-  if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+  if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
                                result.operands))
       return mlir::failure();
@@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
 mlir::LogicalResult ConstantOp::verify() {
   // If the return type of the constant is not an unranked tensor, the shape
   // must match the shape of the attribute holding the data.
-  auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+  auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
   if (!resultType)
     return success();
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+  auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
@@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
   auto resultType = results.front();
 
   // Check that the result type of the function matches the operand type.
-  if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
-      resultType.isa<mlir::UnrankedTensorType>())
+  if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+      llvm::isa<mlir::UnrankedTensorType>(resultType))
     return mlir::success();
 
   return emitError() << "type of return operand (" << inputType
@@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
 }
 
 mlir::LogicalResult TransposeOp::verify() {
-  auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
-  auto resultType = getType().dyn_cast<RankedTensorType>();
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
   if (!inputType || !resultType)
     return mlir::success();
 
index d533e58..e841518 100644 (file)
@@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
 
   // If the type is a function type, it contains the input and result types of
   // this operation.
-  if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+  if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
                                result.operands))
       return mlir::failure();
@@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
 mlir::LogicalResult ConstantOp::verify() {
   // If the return type of the constant is not an unranked tensor, the shape
   // must match the shape of the attribute holding the data.
-  auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+  auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
   if (!resultType)
     return success();
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+  auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
@@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;
   // The inputs must be Tensors with the same element type.
-  TensorType input = inputs.front().dyn_cast<TensorType>();
-  TensorType output = outputs.front().dyn_cast<TensorType>();
+  TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+  TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
   if (!input || !output || input.getElementType() != output.getElementType())
     return false;
   // The shape is required to match if both types are ranked.
@@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
   auto resultType = results.front();
 
   // Check that the result type of the function matches the operand type.
-  if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
-      resultType.isa<mlir::UnrankedTensorType>())
+  if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+      llvm::isa<mlir::UnrankedTensorType>(resultType))
     return mlir::success();
 
   return emitError() << "type of return operand (" << inputType
@@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
 }
 
 void TransposeOp::inferShapes() {
-  auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+  auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
 }
 
 mlir::LogicalResult TransposeOp::verify() {
-  auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
-  auto resultType = getType().dyn_cast<RankedTensorType>();
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
   if (!inputType || !resultType)
     return mlir::success();
 
index cf3e492..d45baa1 100644 (file)
@@ -94,7 +94,7 @@ struct ShapeInferencePass
   /// operands inferred.
   static bool allOperandsInferred(Operation *op) {
     return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
-      return operandType.isa<RankedTensorType>();
+      return llvm::isa<RankedTensorType>(operandType);
     });
   }
 
@@ -102,7 +102,7 @@ struct ShapeInferencePass
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {
     return llvm::any_of(op->getResultTypes(), [](Type resultType) {
-      return !resultType.isa<RankedTensorType>();
+      return !llvm::isa<RankedTensorType>(resultType);
     });
   }
 };
index 4f03266..c2a99aa 100644 (file)
@@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
 
   // If the type is a function type, it contains the input and result types of
   // this operation.
-  if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+  if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
                                result.operands))
       return mlir::failure();
@@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
 mlir::LogicalResult ConstantOp::verify() {
   // If the return type of the constant is not an unranked tensor, the shape
   // must match the shape of the attribute holding the data.
-  auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+  auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
   if (!resultType)
     return success();
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+  auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
@@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;
   // The inputs must be Tensors with the same element type.
-  TensorType input = inputs.front().dyn_cast<TensorType>();
-  TensorType output = outputs.front().dyn_cast<TensorType>();
+  TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+  TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
   if (!input || !output || input.getElementType() != output.getElementType())
     return false;
   // The shape is required to match if both types are ranked.
@@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
   auto resultType = results.front();
 
   // Check that the result type of the function matches the operand type.
-  if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
-      resultType.isa<mlir::UnrankedTensorType>())
+  if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+      llvm::isa<mlir::UnrankedTensorType>(resultType))
     return mlir::success();
 
   return emitError() << "type of return operand (" << inputType
@@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
 }
 
 void TransposeOp::inferShapes() {
-  auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+  auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
 }
 
 mlir::LogicalResult TransposeOp::verify() {
-  auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
-  auto resultType = getType().dyn_cast<RankedTensorType>();
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
   if (!inputType || !resultType)
     return mlir::success();
 
index 9881755..fd589dd 100644 (file)
@@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
 static void lowerOpToLoops(Operation *op, ValueRange operands,
                            PatternRewriter &rewriter,
                            LoopIterationFn processIteration) {
-  auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+  auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
   auto loc = op->getLoc();
 
   // Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
 
     // When lowering the constant operation, we allocate and assign the constant
     // values to a corresponding memref allocation.
-    auto tensorType = op.getType().cast<RankedTensorType>();
+    auto tensorType = llvm::cast<RankedTensorType>(op.getType());
     auto memRefType = convertTensorToMemRef(tensorType);
     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
 
@@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
   target.addIllegalDialect<toy::ToyDialect>();
   target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
     return llvm::none_of(op->getOperandTypes(),
-                         [](Type type) { return type.isa<TensorType>(); });
+                         [](Type type) { return llvm::isa<TensorType>(type); });
   });
 
   // Now that the conversion target has been defined, we just need to provide
index cf3e492..d45baa1 100644 (file)
@@ -94,7 +94,7 @@ struct ShapeInferencePass
   /// operands inferred.
   static bool allOperandsInferred(Operation *op) {
     return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
-      return operandType.isa<RankedTensorType>();
+      return llvm::isa<RankedTensorType>(operandType);
     });
   }
 
@@ -102,7 +102,7 @@ struct ShapeInferencePass
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {
     return llvm::any_of(op->getResultTypes(), [](Type resultType) {
-      return !resultType.isa<RankedTensorType>();
+      return !llvm::isa<RankedTensorType>(resultType);
     });
   }
 };
index 4f03266..c2a99aa 100644 (file)
@@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
 
   // If the type is a function type, it contains the input and result types of
   // this operation.
-  if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+  if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
                                result.operands))
       return mlir::failure();
@@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
 mlir::LogicalResult ConstantOp::verify() {
   // If the return type of the constant is not an unranked tensor, the shape
   // must match the shape of the attribute holding the data.
-  auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+  auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
   if (!resultType)
     return success();
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+  auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
@@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;
   // The inputs must be Tensors with the same element type.
-  TensorType input = inputs.front().dyn_cast<TensorType>();
-  TensorType output = outputs.front().dyn_cast<TensorType>();
+  TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+  TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
   if (!input || !output || input.getElementType() != output.getElementType())
     return false;
   // The shape is required to match if both types are ranked.
@@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
   auto resultType = results.front();
 
   // Check that the result type of the function matches the operand type.
-  if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
-      resultType.isa<mlir::UnrankedTensorType>())
+  if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+      llvm::isa<mlir::UnrankedTensorType>(resultType))
     return mlir::success();
 
   return emitError() << "type of return operand (" << inputType
@@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
 }
 
 void TransposeOp::inferShapes() {
-  auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+  auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
 }
 
 mlir::LogicalResult TransposeOp::verify() {
-  auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
-  auto resultType = getType().dyn_cast<RankedTensorType>();
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
   if (!inputType || !resultType)
     return mlir::success();
 
index 9881755..fd589dd 100644 (file)
@@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
 static void lowerOpToLoops(Operation *op, ValueRange operands,
                            PatternRewriter &rewriter,
                            LoopIterationFn processIteration) {
-  auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+  auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
   auto loc = op->getLoc();
 
   // Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
 
     // When lowering the constant operation, we allocate and assign the constant
     // values to a corresponding memref allocation.
-    auto tensorType = op.getType().cast<RankedTensorType>();
+    auto tensorType = llvm::cast<RankedTensorType>(op.getType());
     auto memRefType = convertTensorToMemRef(tensorType);
     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
 
@@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
   target.addIllegalDialect<toy::ToyDialect>();
   target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
     return llvm::none_of(op->getOperandTypes(),
-                         [](Type type) { return type.isa<TensorType>(); });
+                         [](Type type) { return llvm::isa<TensorType>(type); });
   });
 
   // Now that the conversion target has been defined, we just need to provide
index 06e5096..a10588e 100644 (file)
@@ -61,7 +61,7 @@ public:
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+    auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
     auto memRefShape = memRefType.getShape();
     auto loc = op->getLoc();
 
index cf3e492..d45baa1 100644 (file)
@@ -94,7 +94,7 @@ struct ShapeInferencePass
   /// operands inferred.
   static bool allOperandsInferred(Operation *op) {
     return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
-      return operandType.isa<RankedTensorType>();
+      return llvm::isa<RankedTensorType>(operandType);
     });
   }
 
@@ -102,7 +102,7 @@ struct ShapeInferencePass
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {
     return llvm::any_of(op->getResultTypes(), [](Type resultType) {
-      return !resultType.isa<RankedTensorType>();
+      return !llvm::isa<RankedTensorType>(resultType);
     });
   }
 };
index 6432403..1b77f8c 100644 (file)
@@ -101,7 +101,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
 
   // If the type is a function type, it contains the input and result types of
   // this operation.
-  if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+  if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
     if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
                                result.operands))
       return mlir::failure();
@@ -179,9 +179,9 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
 static mlir::LogicalResult verifyConstantForType(mlir::Type type,
                                                  mlir::Attribute opaqueValue,
                                                  mlir::Operation *op) {
-  if (type.isa<mlir::TensorType>()) {
+  if (llvm::isa<mlir::TensorType>(type)) {
     // Check that the value is an elements attribute.
-    auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
+    auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue);
     if (!attrValue)
       return op->emitError("constant of TensorType must be initialized by "
                            "a DenseFPElementsAttr, got ")
@@ -189,13 +189,13 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
 
     // If the return type of the constant is not an unranked tensor, the shape
     // must match the shape of the attribute holding the data.
-    auto resultType = type.dyn_cast<mlir::RankedTensorType>();
+    auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type);
     if (!resultType)
       return success();
 
     // Check that the rank of the attribute type matches the rank of the
     // constant result type.
-    auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
+    auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType());
     if (attrType.getRank() != resultType.getRank()) {
       return op->emitOpError("return type must match the one of the attached "
                              "value attribute: ")
@@ -213,11 +213,11 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
     }
     return mlir::success();
   }
-  auto resultType = type.cast<StructType>();
+  auto resultType = llvm::cast<StructType>(type);
   llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
 
   // Verify that the initializer is an Array.
-  auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
+  auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue);
   if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
     return op->emitError("constant of StructType must be initialized by an "
                          "ArrayAttr with the same number of elements, got ")
@@ -283,8 +283,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (inputs.size() != 1 || outputs.size() != 1)
     return false;
   // The inputs must be Tensors with the same element type.
-  TensorType input = inputs.front().dyn_cast<TensorType>();
-  TensorType output = outputs.front().dyn_cast<TensorType>();
+  TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+  TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
   if (!input || !output || input.getElementType() != output.getElementType())
     return false;
   // The shape is required to match if both types are ranked.
@@ -426,8 +426,8 @@ mlir::LogicalResult ReturnOp::verify() {
   auto resultType = results.front();
 
   // Check that the result type of the function matches the operand type.
-  if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
-      resultType.isa<mlir::UnrankedTensorType>())
+  if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+      llvm::isa<mlir::UnrankedTensorType>(resultType))
     return mlir::success();
 
   return emitError() << "type of return operand (" << inputType
@@ -442,7 +442,7 @@ mlir::LogicalResult ReturnOp::verify() {
 void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
                            mlir::Value input, size_t index) {
   // Extract the result type from the input type.
-  StructType structTy = input.getType().cast<StructType>();
+  StructType structTy = llvm::cast<StructType>(input.getType());
   assert(index < structTy.getNumElementTypes());
   mlir::Type resultType = structTy.getElementTypes()[index];
 
@@ -451,7 +451,7 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
 }
 
 mlir::LogicalResult StructAccessOp::verify() {
-  StructType structTy = getInput().getType().cast<StructType>();
+  StructType structTy = llvm::cast<StructType>(getInput().getType());
   size_t indexValue = getIndex();
   if (indexValue >= structTy.getNumElementTypes())
     return emitOpError()
@@ -474,14 +474,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
 }
 
 void TransposeOp::inferShapes() {
-  auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+  auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
   SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
   getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
 }
 
 mlir::LogicalResult TransposeOp::verify() {
-  auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
-  auto resultType = getType().dyn_cast<RankedTensorType>();
+  auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+  auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
   if (!inputType || !resultType)
     return mlir::success();
 
@@ -598,7 +598,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
       return nullptr;
 
     // Check that the type is either a TensorType or another StructType.
-    if (!elementType.isa<mlir::TensorType, StructType>()) {
+    if (!llvm::isa<mlir::TensorType, StructType>(elementType)) {
       parser.emitError(typeLoc, "element type for a struct must either "
                                 "be a TensorType or a StructType, got: ")
           << elementType;
@@ -619,7 +619,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
 void ToyDialect::printType(mlir::Type type,
                            mlir::DialectAsmPrinter &printer) const {
   // Currently the only toy type is a struct type.
-  StructType structType = type.cast<StructType>();
+  StructType structType = llvm::cast<StructType>(type);
 
   // Print the struct type according to the parser format.
   printer << "struct<";
@@ -653,9 +653,9 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
                                                  mlir::Attribute value,
                                                  mlir::Type type,
                                                  mlir::Location loc) {
-  if (type.isa<StructType>())
+  if (llvm::isa<StructType>(type))
     return builder.create<StructConstantOp>(loc, type,
-                                            value.cast<mlir::ArrayAttr>());
+                                            llvm::cast<mlir::ArrayAttr>(value));
   return builder.create<ConstantOp>(loc, type,
-                                    value.cast<mlir::DenseElementsAttr>());
+                                    llvm::cast<mlir::DenseElementsAttr>(value));
 }
index 9881755..fd589dd 100644 (file)
@@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
 static void lowerOpToLoops(Operation *op, ValueRange operands,
                            PatternRewriter &rewriter,
                            LoopIterationFn processIteration) {
-  auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+  auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
   auto loc = op->getLoc();
 
   // Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
 
     // When lowering the constant operation, we allocate and assign the constant
     // values to a corresponding memref allocation.
-    auto tensorType = op.getType().cast<RankedTensorType>();
+    auto tensorType = llvm::cast<RankedTensorType>(op.getType());
     auto memRefType = convertTensorToMemRef(tensorType);
     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
 
@@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
   target.addIllegalDialect<toy::ToyDialect>();
   target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
     return llvm::none_of(op->getOperandTypes(),
-                         [](Type type) { return type.isa<TensorType>(); });
+                         [](Type type) { return llvm::isa<TensorType>(type); });
   });
 
   // Now that the conversion target has been defined, we just need to provide
index 06e5096..a10588e 100644 (file)
@@ -61,7 +61,7 @@ public:
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+    auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
     auto memRefShape = memRefType.getShape();
     auto loc = op->getLoc();
 
index cf3e492..d45baa1 100644 (file)
@@ -94,7 +94,7 @@ struct ShapeInferencePass
   /// operands inferred.
   static bool allOperandsInferred(Operation *op) {
     return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
-      return operandType.isa<RankedTensorType>();
+      return llvm::isa<RankedTensorType>(operandType);
     });
   }
 
@@ -102,7 +102,7 @@ struct ShapeInferencePass
   /// shaped result.
   static bool returnsDynamicShape(Operation *op) {
     return llvm::any_of(op->getResultTypes(), [](Type resultType) {
-      return !resultType.isa<RankedTensorType>();
+      return !llvm::isa<RankedTensorType>(resultType);
     });
   }
 };
index 62b00d9..09be97e 100644 (file)
@@ -31,7 +31,8 @@ OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 /// Fold simple struct access operations that access into a constant.
 OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
-  auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
+  auto structAttr =
+      llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
   if (!structAttr)
     return nullptr;
 
index 7def9e2..e62b9c0 100644 (file)
@@ -62,19 +62,19 @@ class FileLineColLocBreakpointManager
 public:
   Breakpoint *match(const Action &action) const override {
     for (const IRUnit &unit : action.getContextIRUnits()) {
-      if (auto *op = unit.dyn_cast<Operation *>()) {
+      if (auto *op = llvm::dyn_cast_if_present<Operation *>(unit)) {
         if (auto match = matchFromLocation(op->getLoc()))
           return *match;
         continue;
       }
-      if (auto *block = unit.dyn_cast<Block *>()) {
+      if (auto *block = llvm::dyn_cast_if_present<Block *>(unit)) {
         for (auto &op : block->getOperations()) {
           if (auto match = matchFromLocation(op.getLoc()))
             return *match;
         }
         continue;
       }
-      if (Region *region = unit.dyn_cast<Region *>()) {
+      if (Region *region = llvm::dyn_cast_if_present<Region *>(unit)) {
         if (auto match = matchFromLocation(region->getLoc()))
           return *match;
         continue;
index ddef020..63f18eb 100644 (file)
@@ -110,27 +110,27 @@ class MMAMatrixOf<list<Type> allowedTypes> :
   "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
 
 // Types for all sparse handles.
-def GPU_SparseEnvHandle : 
-  DialectType<GPU_Dialect, 
-    CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">, 
-    "sparse environment handle type">, 
+def GPU_SparseEnvHandle :
+  DialectType<GPU_Dialect,
+    CPred<"llvm::isa<::mlir::gpu::SparseEnvHandleType>($_self)">,
+    "sparse environment handle type">,
   BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
 
-def GPU_SparseDnVecHandle : 
-  DialectType<GPU_Dialect, 
-    CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">, 
+def GPU_SparseDnVecHandle :
+  DialectType<GPU_Dialect,
+    CPred<"llvm::isa<::mlir::gpu::SparseDnVecHandleType>($_self)">,
     "dense vector handle type">,
   BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
 
-def GPU_SparseDnMatHandle : 
-  DialectType<GPU_Dialect, 
-    CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">, 
+def GPU_SparseDnMatHandle :
+  DialectType<GPU_Dialect,
+    CPred<"llvm::isa<::mlir::gpu::SparseDnMatHandleType>($_self)">,
     "dense matrix handle type">,
   BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
 
-def GPU_SparseSpMatHandle : 
-  DialectType<GPU_Dialect, 
-    CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">, 
+def GPU_SparseSpMatHandle :
+  DialectType<GPU_Dialect,
+    CPred<"llvm::isa<::mlir::gpu::SparseSpMatHandleType>($_self)">,
     "sparse matrix handle type">,
   BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
 
index 6f83c05..0331f9f 100644 (file)
@@ -95,7 +95,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
       /*methodName=*/"getDeclareTargetDeviceType",
       (ins), [{}], [{
         if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
-          if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
+          if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
             return dAttr.getDeviceType().getValue();
         return {};
       }]>,
@@ -108,7 +108,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
       /*methodName=*/"getDeclareTargetCaptureClause",
       (ins), [{}], [{
         if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
-          if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
+          if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
             return dAttr.getCaptureClause().getValue();
         return {};
       }]>
index 79313b6..acb3556 100644 (file)
@@ -115,7 +115,7 @@ public:
   static bool classof(Type type);
 
   /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return cast<ShapedType>(); }
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 };
 
 //===----------------------------------------------------------------------===//
@@ -169,7 +169,7 @@ public:
   unsigned getMemorySpaceAsInt() const;
 
   /// Allow implicit conversion to ShapedType.
-  operator ShapedType() const { return cast<ShapedType>(); }
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 };
 
 } // namespace mlir
index 37a7c24..99fabab 100644 (file)
@@ -217,13 +217,15 @@ private:
   }
 
   static bool isEmptyKey(mlir::TypeRange range) {
-    if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+    if (const auto *type =
+            llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
       return type == getEmptyKeyPointer();
     return false;
   }
 
   static bool isTombstoneKey(mlir::TypeRange range) {
-    if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+    if (const auto *type =
+            llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
       return type == getTombstoneKeyPointer();
     return false;
   }
index 306f4cf..ac42f38 100644 (file)
@@ -163,12 +163,12 @@ public:
 
   /// Return the value the effect is applied on, or nullptr if there isn't a
   /// known value being affected.
-  Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
+  Value getValue() const { return value ? llvm::dyn_cast_if_present<Value>(value) : Value(); }
 
   /// Return the symbol reference the effect is applied on, or nullptr if there
   /// isn't a known smbol being affected.
   SymbolRefAttr getSymbolRef() const {
-    return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
+    return value ? llvm::dyn_cast_if_present<SymbolRefAttr>(value) : SymbolRefAttr();
   }
 
   /// Return the resource that the effect applies to.
index 9821a68..f9db261 100644 (file)
@@ -254,7 +254,7 @@ struct NestedAnalysisMap {
   /// Returns the parent analysis map for this analysis map, or null if this is
   /// the top-level map.
   const NestedAnalysisMap *getParent() const {
-    return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
+    return llvm::dyn_cast_if_present<NestedAnalysisMap *>(parentOrInstrumentor);
   }
 
   /// Returns a pass instrumentation object for the current operation. This
index db25c23..a6d53b2 100644 (file)
@@ -89,7 +89,7 @@ void SparseConstantPropagation::visitOperation(
 
     // Merge in the result of the fold, either a constant or a value.
     OpFoldResult foldResult = std::get<1>(it);
-    if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
+    if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
       LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
       propagateIfChanged(lattice,
                          lattice->join(ConstantValue(attr, op->getDialect())));
index 8ff71b5..d681604 100644 (file)
@@ -31,7 +31,7 @@ void Executable::print(raw_ostream &os) const {
 }
 
 void Executable::onUpdate(DataFlowSolver *solver) const {
-  if (auto *block = point.dyn_cast<Block *>()) {
+  if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
     // Re-invoke the analyses on the block itself.
     for (DataFlowAnalysis *analysis : subscribers)
       solver->enqueue({block, analysis});
@@ -39,7 +39,7 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
     for (DataFlowAnalysis *analysis : subscribers)
       for (Operation &op : *block)
         solver->enqueue({&op, analysis});
-  } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
+  } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
     // Re-invoke the analysis on the successor block.
     if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
       for (DataFlowAnalysis *analysis : subscribers)
@@ -219,7 +219,7 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
 LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
   if (point.is<Block *>())
     return success();
-  auto *op = point.dyn_cast<Operation *>();
+  auto *op = llvm::dyn_cast_if_present<Operation *>(point);
   if (!op)
     return emitError(point.getLoc(), "unknown program point kind");
 
index 6450891..77ef87d 100644 (file)
@@ -33,9 +33,9 @@ LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) {
 }
 
 LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
-  if (auto *op = point.dyn_cast<Operation *>())
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
     processOperation(op);
-  else if (auto *block = point.dyn_cast<Block *>())
+  else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
     visitBlock(block);
   else
     return failure();
index c866fc6..c832405 100644 (file)
@@ -181,7 +181,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
         if (auto bound =
                 dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
           return bound.getValue();
-      } else if (auto value = loopBound->dyn_cast<Value>()) {
+      } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
         const IntegerValueRangeLattice *lattice =
             getLatticeElementFor(op, value);
         if (lattice != nullptr)
index 629c482..f5cf866 100644 (file)
@@ -66,9 +66,9 @@ AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
 }
 
 LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
-  if (Operation *op = point.dyn_cast<Operation *>())
+  if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
     visitOperation(op);
-  else if (Block *block = point.dyn_cast<Block *>())
+  else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
     visitBlock(block);
   else
     return failure();
@@ -238,7 +238,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
 
     unsigned firstIndex = 0;
     if (inputs.size() != lattices.size()) {
-      if (point.dyn_cast<Operation *>()) {
+      if (llvm::dyn_cast_if_present<Operation *>(point)) {
         if (!inputs.empty())
           firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
         visitNonControlFlowArgumentsImpl(
@@ -316,9 +316,9 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
 
 LogicalResult
 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
-  if (Operation *op = point.dyn_cast<Operation *>())
+  if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
     visitOperation(op);
-  else if (point.dyn_cast<Block *>())
+  else if (llvm::dyn_cast_if_present<Block *>(point))
     // For backward dataflow, we don't have to do any work for the blocks
     // themselves. CFG edges between blocks are processed by the BranchOp
     // logic in `visitOperation`, and entry blocks for functions are tied
index 9c8a889..47caf26 100644 (file)
@@ -39,21 +39,21 @@ void ProgramPoint::print(raw_ostream &os) const {
     os << "<NULL POINT>";
     return;
   }
-  if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+  if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
     return programPoint->print(os);
-  if (auto *op = dyn_cast<Operation *>())
+  if (auto *op = llvm::dyn_cast<Operation *>(*this))
     return op->print(os);
-  if (auto value = dyn_cast<Value>())
+  if (auto value = llvm::dyn_cast<Value>(*this))
     return value.print(os);
   return get<Block *>()->print(os);
 }
 
 Location ProgramPoint::getLoc() const {
-  if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+  if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
     return programPoint->getLoc();
-  if (auto *op = dyn_cast<Operation *>())
+  if (auto *op = llvm::dyn_cast<Operation *>(*this))
     return op->getLoc();
-  if (auto value = dyn_cast<Value>())
+  if (auto value = llvm::dyn_cast<Value>(*this))
     return value.getLoc();
   return get<Block *>()->getParent()->getLoc();
 }
index 75f4d4d..3b562e0 100644 (file)
@@ -2060,7 +2060,7 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
   if (parseToken(Token::r_paren, "expected ')' in location"))
     return failure();
 
-  if (auto *op = opOrArgument.dyn_cast<Operation *>())
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
     op->setLoc(directLoc);
   else
     opOrArgument.get<BlockArgument>().setLoc(directLoc);
index 3144a33..d3fd6b4 100644 (file)
@@ -47,7 +47,7 @@ SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
 DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
   DictionaryAttr attributeDict;
   if (!mlirAttributeIsNull(attributes))
-    attributeDict = unwrap(attributes).cast<DictionaryAttr>();
+    attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
   return attributeDict;
 }
 
index 1d1923d..8defd89 100644 (file)
@@ -1190,9 +1190,9 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
 // TODO: safer and more flexible to store data type in actual op instead?
 static Type getSpMatElemType(Value spMat) {
   if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
-    return op.getValues().getType().cast<MemRefType>().getElementType();
+    return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
   if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
-    return op.getValues().getType().cast<MemRefType>().getElementType();
+    return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
   llvm_unreachable("cannot find spmat def");
 }
 
@@ -1235,7 +1235,7 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
-  Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
+  Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
   auto handle =
@@ -1271,7 +1271,7 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
-  Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
+  Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
   auto handle =
@@ -1315,8 +1315,8 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
     pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
     pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
   }
-  Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
-  Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
+  Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+  Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
   auto iw = rewriter.create<LLVM::ConstantOp>(
       loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
@@ -1350,9 +1350,9 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
     pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
     pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
   }
-  Type pType = op.getRowPos().getType().cast<MemRefType>().getElementType();
-  Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
-  Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
+  Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
+  Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+  Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
   auto pw = rewriter.create<LLVM::ConstantOp>(
       loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
   auto iw = rewriter.create<LLVM::ConstantOp>(
index cf0d506..aac6e60 100644 (file)
@@ -405,7 +405,7 @@ LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
     return failure();
   if (!(*converted)) // Conversion to default is 0.
     return 0;
-  if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
+  if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
     return explicitSpace.getInt();
   return failure();
 }
index 24ea1a6..f82a86c 100644 (file)
@@ -671,7 +671,7 @@ struct GlobalMemrefOpLowering
 
     Attribute initialValue = nullptr;
     if (!global.isExternal() && !global.isUninitialized()) {
-      auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
+      auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
       initialValue = elementsAttr;
 
       // For scalar memrefs, the global variable created is of the element type,
index 8cd180d..b9a1cc9 100644 (file)
@@ -412,10 +412,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
     auto *ans = cast<TypeAnswer>(answer);
     if (isa<pdl::RangeType>(val.getType()))
       builder.create<pdl_interp::CheckTypesOp>(
-          loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
+          loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
     else
       builder.create<pdl_interp::CheckTypeOp>(
-          loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
+          loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
     break;
   }
   case Predicates::AttributeQuestion: {
index 2faf7f1..9e0cccf 100644 (file)
@@ -300,7 +300,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
 
   // tosa::ErfOp
-  if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
+  if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
     return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
 
   // tosa::GreaterOp
@@ -1885,7 +1885,7 @@ public:
 
     auto addDynamicDimension = [&](Value source, int64_t dim) {
       auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
-      if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
+      if (auto dimValue = llvm::dyn_cast_if_present<Value>(dynamicDim.value()))
         results.push_back(dimValue);
     };
 
index 6cbd7f3..744a038 100644 (file)
@@ -121,11 +121,11 @@ void mlirDebuggerCursorSelectParentIRUnit() {
     return;
   }
   IRUnit *unit = &state.cursor;
-  if (auto *op = unit->dyn_cast<Operation *>()) {
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
     state.cursor = op->getBlock();
-  } else if (auto *region = unit->dyn_cast<Region *>()) {
+  } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
     state.cursor = region->getParentOp();
-  } else if (auto *block = unit->dyn_cast<Block *>()) {
+  } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
     state.cursor = block->getParent();
   } else {
     llvm::outs() << "Current cursor is not a valid IRUnit";
@@ -142,14 +142,14 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
     return;
   }
   IRUnit *unit = &state.cursor;
-  if (auto *op = unit->dyn_cast<Operation *>()) {
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
     if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
       llvm::outs() << "Index invalid, op has " << op->getNumRegions()
                    << " but got " << index << "\n";
       return;
     }
     state.cursor = &op->getRegion(index);
-  } else if (auto *region = unit->dyn_cast<Region *>()) {
+  } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
     auto block = region->begin();
     int count = 0;
     while (block != region->end() && count != index) {
@@ -163,7 +163,7 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
       return;
     }
     state.cursor = &*block;
-  } else if (auto *block = unit->dyn_cast<Block *>()) {
+  } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
     auto op = block->begin();
     int count = 0;
     while (op != block->end() && count != index) {
@@ -192,14 +192,14 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
     return;
   }
   IRUnit *unit = &state.cursor;
-  if (auto *op = unit->dyn_cast<Operation *>()) {
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
     Operation *previous = op->getPrevNode();
     if (!previous) {
       llvm::outs() << "No previous operation in the current block\n";
       return;
     }
     state.cursor = previous;
-  } else if (auto *region = unit->dyn_cast<Region *>()) {
+  } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
     llvm::outs() << "Has region\n";
     Operation *parent = region->getParentOp();
     if (!parent) {
@@ -212,7 +212,7 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
     }
     state.cursor =
         &region->getParentOp()->getRegion(region->getRegionNumber() - 1);
-  } else if (auto *block = unit->dyn_cast<Block *>()) {
+  } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
     Block *previous = block->getPrevNode();
     if (!previous) {
       llvm::outs() << "No previous block in the current region\n";
@@ -234,14 +234,14 @@ void mlirDebuggerCursorSelectNextIRUnit() {
     return;
   }
   IRUnit *unit = &state.cursor;
-  if (auto *op = unit->dyn_cast<Operation *>()) {
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
     Operation *next = op->getNextNode();
     if (!next) {
       llvm::outs() << "No next operation in the current block\n";
       return;
     }
     state.cursor = next;
-  } else if (auto *region = unit->dyn_cast<Region *>()) {
+  } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
     Operation *parent = region->getParentOp();
     if (!parent) {
       llvm::outs() << "No parent operation for the current region\n";
@@ -253,7 +253,7 @@ void mlirDebuggerCursorSelectNextIRUnit() {
     }
     state.cursor =
         &region->getParentOp()->getRegion(region->getRegionNumber() + 1);
-  } else if (auto *block = unit->dyn_cast<Block *>()) {
+  } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
     Block *next = block->getNextNode();
     if (!next) {
       llvm::outs() << "No next block in the current region\n";
index 2009c39..9153686 100644 (file)
@@ -1212,7 +1212,7 @@ static void materializeConstants(OpBuilder &b, Location loc,
   actualValues.reserve(values.size());
   auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
   for (OpFoldResult ofr : values) {
-    if (auto value = ofr.dyn_cast<Value>()) {
+    if (auto value = llvm::dyn_cast_if_present<Value>(ofr)) {
       actualValues.push_back(value);
       continue;
     }
@@ -4599,7 +4599,7 @@ void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
         if (staticDim.has_value())
           return builder.create<arith::ConstantIndexOp>(result.location,
                                                         *staticDim);
-        return ofr.dyn_cast<Value>();
+        return llvm::dyn_cast_if_present<Value>(ofr);
       });
   result.addOperands(basisValues);
 }
index d0d83c2..e0dd2d6 100644 (file)
@@ -808,7 +808,7 @@ OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(getRhs(), m_Zero()))
     return getLhs();
   /// or(x, <all ones>) -> <all ones>
-  if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
+  if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
     if (rhsAttr.getValue().isAllOnes())
       return rhsAttr;
 
@@ -1249,7 +1249,7 @@ LogicalResult arith::ExtSIOp::verify() {
 
 /// Always fold extension of FP constants.
 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = adaptor.getIn().dyn_cast_or_null<FloatAttr>();
+  auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
   if (!constOperand)
     return {};
 
@@ -1702,7 +1702,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
 
   // We are moving constants to the right side; So if lhs is constant rhs is
   // guaranteed to be a constant.
-  if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
+  if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
     return constFoldBinaryOp<IntegerAttr>(
         adaptor.getOperands(), getI1SameShape(lhs.getType()),
         [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
@@ -1772,8 +1772,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
 }
 
 OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
-  auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
-  auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
+  auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
+  auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
 
   // If one operand is NaN, making them both NaN does not change the result.
   if (lhs && lhs.getValue().isNaN())
@@ -2193,11 +2193,11 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
   // Constant-fold constant operands over non-splat constant condition.
   // select %cst_vec, %cst0, %cst1 => %cst2
   if (auto cond =
-          adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
     if (auto lhs =
-            adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
+            llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
       if (auto rhs =
-              adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
+              llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
         SmallVector<Attribute> results;
         results.reserve(static_cast<size_t>(cond.getNumElements()));
         auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
index 85e0725..61ec365 100644 (file)
@@ -184,7 +184,7 @@ struct SelectOpInterface
 
     // If the buffers have different types, they differ only in their layout
     // map.
-    auto memrefType = trueType->cast<MemRefType>();
+    auto memrefType = llvm::cast<MemRefType>(*trueType);
     return getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
index fb363c8..965ef11 100644 (file)
@@ -33,8 +33,8 @@ LogicalResult mlir::foldDynamicIndexList(Builder &b,
     if (ofr.is<Attribute>())
       continue;
     // Newly static, move from Value to constant.
-    if (auto cstOp =
-            ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
+    if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
+                         .getDefiningOp<arith::ConstantIndexOp>()) {
       ofr = b.getIndexAttr(cstOp.value());
       valuesChanged = true;
     }
@@ -56,9 +56,9 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
 
 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
                                             OpFoldResult ofr) {
-  if (auto value = ofr.dyn_cast<Value>())
+  if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
     return value;
-  auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
+  auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
   assert(attr && "expect the op fold result casts to an integer attribute");
   return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
 }
index eeba571..fdd6f3d 100644 (file)
@@ -179,7 +179,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
     populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
   }
   FailureOr<Value> alloc = options.createAlloc(
-      rewriter, loc, allocType->cast<MemRefType>(), dynamicDims);
+      rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
   if (failed(alloc))
     return failure();
 
index be80f30..016ec2b 100644 (file)
@@ -59,7 +59,8 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
 
 /// Return the func::FuncOp called by `callOp`.
 static func::FuncOp getCalledFunction(CallOpInterface callOp) {
-  SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+  SymbolRefAttr sym =
+      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<func::FuncOp>(
index f73efc1..89904db 100644 (file)
@@ -80,7 +80,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 
 /// Return the FuncOp called by `callOp`.
 static FuncOp getCalledFunction(CallOpInterface callOp) {
-  SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+  SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
index 5e35196..34959aa 100644 (file)
@@ -995,7 +995,7 @@ static void annotateOpsWithAliasSets(Operation *op,
   op->walk([&](Operation *op) {
     SmallVector<Attribute> aliasSets;
     for (OpResult opResult : op->getOpResults()) {
-      if (opResult.getType().isa<TensorType>()) {
+      if (llvm::isa<TensorType>(opResult.getType())) {
         SmallVector<Attribute> aliases;
         state.applyOnAliases(opResult, [&](Value alias) {
           std::string buffer;
index d0af1c2..417f457 100644 (file)
@@ -238,7 +238,7 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
 
 /// Return the func::FuncOp called by `callOp`.
 static func::FuncOp getCalledFunction(func::CallOp callOp) {
-  SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+  SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<func::FuncOp>(
index 02c0e16..f2d1a96 100644 (file)
@@ -90,7 +90,8 @@ OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
-  ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
+  ArrayAttr arrayAttr =
+      llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
   if (arrayAttr && arrayAttr.size() == 2)
     return arrayAttr[1];
   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
@@ -103,7 +104,8 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
-  ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
+  ArrayAttr arrayAttr =
+      llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
   if (arrayAttr && arrayAttr.size() == 2)
     return arrayAttr[0];
   if (auto createOp = getOperand().getDefiningOp<CreateOp>())
index 3970c9c..aba9e7d 100644 (file)
@@ -94,7 +94,7 @@ DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
 
 void DataLayoutEntryAttr::print(AsmPrinter &os) const {
   os << DataLayoutEntryAttr::kAttrKeyword << "<";
-  if (auto type = getKey().dyn_cast<Type>())
+  if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
     os << type;
   else
     os << "\"" << getKey().get<StringAttr>().strref() << "\"";
@@ -151,7 +151,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   DenseSet<Type> types;
   DenseSet<StringAttr> ids;
   for (DataLayoutEntryInterface entry : entries) {
-    if (auto type = entry.getKey().dyn_cast<Type>()) {
+    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
       if (!types.insert(type).second)
         return emitError() << "repeated layout entry key: " << type;
     } else {
index 65a20a0..06e8d79 100644 (file)
@@ -493,7 +493,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
     // error. All other canonicalization is done in the fold method.
     bool requiresConst = !rawConstantIndices.empty() &&
                          currType.isa_and_nonnull<LLVMStructType>();
-    if (Value val = iter.dyn_cast<Value>()) {
+    if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
       APInt intC;
       if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
           intC.isSignedIntN(kGEPConstantBitWidth)) {
@@ -598,7 +598,7 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
   llvm::interleaveComma(
       GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
       [&](PointerUnion<IntegerAttr, Value> cst) {
-        if (Value val = cst.dyn_cast<Value>())
+        if (Value val = llvm::dyn_cast_if_present<Value>(cst))
           printer.printOperand(val);
         else
           printer << cst.get<IntegerAttr>().getInt();
@@ -2495,7 +2495,7 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
         !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
 
       PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
-      if (Value val = existing.dyn_cast<Value>())
+      if (Value val = llvm::dyn_cast_if_present<Value>(existing))
         gepArgs.emplace_back(val);
       else
         gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
index 89693ec..b36ed78 100644 (file)
@@ -261,7 +261,7 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
 
 static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
   return llvm::all_of(gepOp.getIndices(), [](auto index) {
-    auto indexAttr = index.template dyn_cast<IntegerAttr>();
+    auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
     return indexAttr && indexAttr.getValue() == 0;
   });
 }
@@ -289,7 +289,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
   // Ensures all indices are static and fetches them.
   SmallVector<IntegerAttr> indices;
   for (auto index : gep.getIndices()) {
-    IntegerAttr indexInt = index.dyn_cast<IntegerAttr>();
+    IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
     if (!indexInt)
       return {};
     indices.push_back(indexInt);
@@ -310,7 +310,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
   for (IntegerAttr index : llvm::drop_begin(indices)) {
     // Ensure the structure of the type being indexed can be reasoned about.
     // This includes rejecting any potential typed pointer.
-    auto destructurable = selectedType.dyn_cast<DestructurableTypeInterface>();
+    auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
     if (!destructurable)
       return {};
 
@@ -343,7 +343,7 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
                             SmallPtrSetImpl<Attribute> &usedIndices,
                             SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
-  auto basePtrType = getBase().getType().dyn_cast<LLVM::LLVMPointerType>();
+  auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
   if (!basePtrType)
     return false;
 
@@ -359,7 +359,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
     return false;
   auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
   assert(slot.elementPtrs.contains(firstLevelIndex));
-  if (!slot.elementPtrs.at(firstLevelIndex).isa<LLVM::LLVMPointerType>())
+  if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
     return false;
   mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
   usedIndices.insert(firstLevelIndex);
@@ -369,7 +369,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
                                  DenseMap<Attribute, MemorySlot> &subslots,
                                  RewriterBase &rewriter) {
-  IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast<IntegerAttr>();
+  IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
   const MemorySlot &newSlot = subslots.at(firstLevelIndex);
 
   ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
@@ -414,7 +414,7 @@ LLVM::LLVMStructType::getSubelementIndexMap() {
 }
 
 Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
-  auto indexAttr = index.dyn_cast<IntegerAttr>();
+  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
   if (!indexAttr || !indexAttr.getType().isInteger(32))
     return {};
   int32_t indexInt = indexAttr.getInt();
@@ -439,7 +439,7 @@ LLVM::LLVMArrayType::getSubelementIndexMap() const {
 }
 
 Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
-  auto indexAttr = index.dyn_cast<IntegerAttr>();
+  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
   if (!indexAttr || !indexAttr.getType().isInteger(32))
     return {};
   int32_t indexInt = indexAttr.getInt();
index dcbfbf3..be129ff 100644 (file)
@@ -354,7 +354,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
     auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
     const auto *it =
         llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
-          if (auto type = entry.getKey().dyn_cast<Type>()) {
+          if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
             return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
                    newType.getAddressSpace();
           }
@@ -362,7 +362,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
         });
     if (it == oldLayout.end()) {
       llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
-        if (auto type = entry.getKey().dyn_cast<Type>()) {
+        if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
           return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
         }
         return false;
index a63f664..52699db 100644 (file)
@@ -2368,7 +2368,7 @@ transform::TileOp::apply(TransformResults &transformResults,
         sizes.reserve(tileSizes.size());
         unsigned dynamicIdx = 0;
         for (OpFoldResult ofr : getMixedSizes()) {
-          if (auto attr = ofr.dyn_cast<Attribute>()) {
+          if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
             sizes.push_back(b.create<arith::ConstantIndexOp>(
                 getLoc(), cast<IntegerAttr>(attr).getInt()));
             continue;
@@ -2794,7 +2794,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
             sizes.reserve(tileSizes.size());
             unsigned dynamicIdx = 0;
             for (OpFoldResult ofr : getMixedSizes()) {
-              if (auto attr = ofr.dyn_cast<Attribute>()) {
+              if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
                 sizes.push_back(b.create<arith::ConstantIndexOp>(
                     getLoc(), cast<IntegerAttr>(attr).getInt()));
               } else {
index 33ff4a3..6a80057 100644 (file)
@@ -1447,7 +1447,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
       cast<LinalgOp>(genericOp.getOperation())
           .createLoopRanges(rewriter, genericOp.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
-    if (auto attr = ofr.dyn_cast<Attribute>())
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
     llvm::APInt actual;
     return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
index d39cd0e..e952f94 100644 (file)
@@ -229,7 +229,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     // to look for the bound.
     LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
     Value size;
-    if (auto attr = rangeValue.size.dyn_cast<Attribute>()) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
     } else {
       Value materializedSize =
index 203ae43..bbe3a54 100644 (file)
@@ -92,7 +92,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
       rewriter, op.getLoc(), d0 + d1 - d2,
       {iterationSpace[dimension].offset, iterationSpace[dimension].size,
        minSplitPoint});
-  if (auto attr = remainingSize.dyn_cast<Attribute>()) {
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
     if (cast<IntegerAttr>(attr).getValue().isZero())
       return {op, TilingInterface()};
   }
index 1293d03..5ef34b1 100644 (file)
@@ -48,7 +48,7 @@ using namespace mlir::scf;
 static bool isZero(OpFoldResult v) {
   if (!v)
     return false;
-  if (auto attr = v.dyn_cast<Attribute>()) {
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
     IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
     return intAttr && intAttr.getValue().isZero();
   }
@@ -104,7 +104,7 @@ void mlir::linalg::transformIndexOps(
 /// checked at runtime.
 static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
                                          OpFoldResult value) {
-  if (auto attr = value.dyn_cast<Attribute>()) {
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
     assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
            "expected strictly positive tile size and divisor");
     return;
index 6f932bd..4dceab3 100644 (file)
@@ -1135,7 +1135,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
                                         PatternRewriter &rewriter) const {
   // Given an OpFoldResult, return an index-typed value.
   auto getIdxValue = [&](OpFoldResult ofr) {
-    if (auto val = ofr.dyn_cast<Value>())
+    if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
       return val;
     return rewriter
         .create<arith::ConstantIndexOp>(
index aae3603..d081e1a 100644 (file)
@@ -1646,7 +1646,7 @@ static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
                                            ArrayRef<OpFoldResult> ofrs) {
   SmallVector<Value> result;
   for (auto o : ofrs) {
-    if (auto val = o.template dyn_cast<Value>()) {
+    if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
       result.push_back(val);
     } else {
       result.push_back(rewriter.create<arith::ConstantIndexOp>(
@@ -1954,8 +1954,8 @@ struct PadOpVectorizationWithTransferWritePattern
         continue;
 
       // Other cases: Take a deeper look at defining ops of values.
-      auto v1 = size1.dyn_cast<Value>();
-      auto v2 = size2.dyn_cast<Value>();
+      auto v1 = llvm::dyn_cast_if_present<Value>(size1);
+      auto v2 = llvm::dyn_cast_if_present<Value>(size2);
       if (!v1 || !v2)
         return false;
 
index ef31668..d5eea24 100644 (file)
@@ -970,7 +970,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
     auto dim = it.index();
     auto size = it.value();
     curr.push_back(dim);
-    auto attr = size.dyn_cast<Attribute>();
+    auto attr = llvm::dyn_cast_if_present<Attribute>(size);
     if (attr && cast<IntegerAttr>(attr).getInt() == 1)
       continue;
     reassociation.emplace_back(ReassociationIndices{});
index 0474d19..faffe9a 100644 (file)
@@ -64,7 +64,7 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
 //===----------------------------------------------------------------------===//
 
 static bool isSupportedElementType(Type type) {
-  return type.isa<MemRefType>() ||
+  return llvm::isa<MemRefType>(type) ||
          OpBuilder(type.getContext()).getZeroAttr(type);
 }
 
@@ -110,7 +110,7 @@ void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
 SmallVector<DestructurableMemorySlot>
 memref::AllocaOp::getDestructurableSlots() {
   MemRefType memrefType = getType();
-  auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
+  auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
   if (!destructurable)
     return {};
 
@@ -134,7 +134,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
 
   DenseMap<Attribute, MemorySlot> slotMap;
 
-  auto memrefType = getType().cast<DestructurableTypeInterface>();
+  auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
   for (Attribute usedIndex : usedIndices) {
     Type elemType = memrefType.getTypeAtIndex(usedIndex);
     MemRefType elemPtr = MemRefType::get({}, elemType);
@@ -281,7 +281,7 @@ struct MemRefDestructurableTypeExternalModel
           MemRefDestructurableTypeExternalModel, MemRefType> {
   std::optional<DenseMap<Attribute, Type>>
   getSubelementIndexMap(Type type) const {
-    auto memrefType = type.cast<MemRefType>();
+    auto memrefType = llvm::cast<MemRefType>(type);
     constexpr int64_t maxMemrefSizeForDestructuring = 16;
     if (!memrefType.hasStaticShape() ||
         memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
@@ -298,15 +298,15 @@ struct MemRefDestructurableTypeExternalModel
   }
 
   Type getTypeAtIndex(Type type, Attribute index) const {
-    auto memrefType = type.cast<MemRefType>();
-    auto coordArrAttr = index.dyn_cast<ArrayAttr>();
+    auto memrefType = llvm::cast<MemRefType>(type);
+    auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
     if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
       return {};
 
     Type indexType = IndexType::get(memrefType.getContext());
     for (const auto &[coordAttr, dimSize] :
          llvm::zip(coordArrAttr, memrefType.getShape())) {
-      auto coord = coordAttr.dyn_cast<IntegerAttr>();
+      auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
       if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
           coord.getInt() >= dimSize)
         return {};
index 3beda2c..8ed790c 100644 (file)
@@ -970,7 +970,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
     return unusedDims;
 
   for (const auto &dim : llvm::enumerate(sizes))
-    if (auto attr = dim.value().dyn_cast<Attribute>())
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
       if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
         unusedDims.set(dim.index());
 
@@ -1042,7 +1042,7 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
 
 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
-  auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
+  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
   if (!index)
     return {};
 
index 9b1d85b..431d270 100644 (file)
@@ -56,7 +56,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
     // Because we only support input strides of 1, the output stride is also
     // always 1.
     if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
-          Attribute attr = valueOrAttr.dyn_cast<Attribute>();
+          Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
           return attr && cast<IntegerAttr>(attr).getInt() == 1;
         })) {
       strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
@@ -86,8 +86,9 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
       }
 
       sizes.push_back(opSize);
-      Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
-                sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
+      Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
+                sourceOffsetAttr =
+                    llvm::dyn_cast_if_present<Attribute>(sourceOffset);
 
       if (opOffsetAttr && sourceOffsetAttr) {
         // If both offsets are static we can simply calculate the combined
@@ -101,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
         AffineExpr expr = rewriter.getAffineConstantExpr(0);
         SmallVector<Value> affineApplyOperands;
         for (auto valueOrAttr : {opOffset, sourceOffset}) {
-          if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+          if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
             expr = expr + cast<IntegerAttr>(attr).getInt();
           } else {
             expr =
index 15be4d5..fab270a 100644 (file)
@@ -520,7 +520,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
              << operandName << " operand appears more than once";
 
     mlir::Type varType = operand.getType();
-    auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+    auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
     auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
     if (!decl)
       return op->emitOpError()
index 65cca0e..4e3d899 100644 (file)
@@ -802,10 +802,10 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands,
   for (const auto &mapTypeOp : *map_types) {
     int64_t mapTypeBits = 0x00;
 
-    if (!mapTypeOp.isa<mlir::IntegerAttr>())
+    if (!llvm::isa<mlir::IntegerAttr>(mapTypeOp))
       return failure();
 
-    mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
+    mapTypeBits = llvm::cast<mlir::IntegerAttr>(mapTypeOp).getInt();
 
     bool to =
         bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
index 4b0d0e4..b6cb8c7 100644 (file)
@@ -381,7 +381,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   // map.
   auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
 #ifndef NDEBUG
-  auto iterRanked = initArgBufferType->cast<MemRefType>();
+  auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
   assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
          "expected same shape");
   assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
@@ -802,7 +802,7 @@ struct WhileOpInterface
           if (!isa<TensorType>(bbArg.getType()))
             return bbArg.getType();
           // TODO: error handling
-          return bufferization::getBufferType(bbArg, options)->cast<Type>();
+          return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
         }));
 
     // Construct a new scf.while op with memref instead of tensor values.
index 11df319..89dfb61 100644 (file)
@@ -88,10 +88,10 @@ LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr,
     return failure();
 
   unsigned dimIv = cstr.appendDimVar(iv);
-  auto lbv = lb.dyn_cast<Value>();
+  auto lbv = llvm::dyn_cast_if_present<Value>(lb);
   unsigned symLb =
       lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
-  auto ubv = ub.dyn_cast<Value>();
+  auto ubv = llvm::dyn_cast_if_present<Value>(ub);
   unsigned symUb =
       ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);
 
index 164e2e0..cb4ae4e 100644 (file)
@@ -152,7 +152,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
     auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
     if (getIndices().size() == 1 &&
         constructOp.getConstituents().size() == type.getNumElements()) {
-      auto i = getIndices().begin()->cast<IntegerAttr>();
+      auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
       return constructOp.getConstituents()[i.getValue().getSExtValue()];
     }
   }
index 1673756..6747f75 100644 (file)
@@ -1562,8 +1562,8 @@ LogicalResult spirv::BitcastOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::ConvertPtrToUOp::verify() {
-  auto operandType = getPointer().getType().cast<spirv::PointerType>();
-  auto resultType = getResult().getType().cast<spirv::ScalarType>();
+  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+  auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
   if (!resultType || !resultType.isSignlessInteger())
     return emitError("result must be a scalar type of unsigned integer");
   auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
@@ -1583,8 +1583,8 @@ LogicalResult spirv::ConvertPtrToUOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::ConvertUToPtrOp::verify() {
-  auto operandType = getOperand().getType().cast<spirv::ScalarType>();
-  auto resultType = getResult().getType().cast<spirv::PointerType>();
+  auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
+  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
   if (!operandType || !operandType.isSignlessInteger())
     return emitError("result must be a scalar type of unsigned integer");
   auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
index fbcc5d8..30fc3e1 100644 (file)
@@ -125,23 +125,23 @@ Type CompositeType::getElementType(unsigned index) const {
 }
 
 unsigned CompositeType::getNumElements() const {
-  if (auto arrayType = dyn_cast<ArrayType>())
+  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
     return arrayType.getNumElements();
-  if (auto matrixType = dyn_cast<MatrixType>())
+  if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
     return matrixType.getNumColumns();
-  if (auto structType = dyn_cast<StructType>())
+  if (auto structType = llvm::dyn_cast<StructType>(*this))
     return structType.getNumElements();
-  if (auto vectorType = dyn_cast<VectorType>())
+  if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
-  if (isa<CooperativeMatrixNVType>()) {
+  if (llvm::isa<CooperativeMatrixNVType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv::CooperativeMatrix type");
   }
-  if (isa<JointMatrixINTELType>()) {
+  if (llvm::isa<JointMatrixINTELType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv::JointMatrix type");
   }
-  if (isa<RuntimeArrayType>()) {
+  if (llvm::isa<RuntimeArrayType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv::RuntimeArray type");
   }
@@ -149,8 +149,8 @@ unsigned CompositeType::getNumElements() const {
 }
 
 bool CompositeType::hasCompileTimeKnownNumElements() const {
-  return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
-              RuntimeArrayType>();
+  return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
+              RuntimeArrayType>(*this);
 }
 
 void CompositeType::getExtensions(
@@ -188,11 +188,11 @@ void CompositeType::getCapabilities(
 }
 
 std::optional<int64_t> CompositeType::getSizeInBytes() {
-  if (auto arrayType = dyn_cast<ArrayType>())
+  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
     return arrayType.getSizeInBytes();
-  if (auto structType = dyn_cast<StructType>())
+  if (auto structType = llvm::dyn_cast<StructType>(*this))
     return structType.getSizeInBytes();
-  if (auto vectorType = dyn_cast<VectorType>()) {
+  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
     std::optional<int64_t> elementSize =
         llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
     if (!elementSize)
@@ -680,7 +680,7 @@ void ScalarType::getCapabilities(
     capabilities.push_back(ref);                                               \
   } break
 
-  if (auto intType = dyn_cast<IntegerType>()) {
+  if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
     switch (bitwidth) {
       WIDTH_CASE(Int, 8);
       WIDTH_CASE(Int, 16);
@@ -692,7 +692,7 @@ void ScalarType::getCapabilities(
       llvm_unreachable("invalid bitwidth to getCapabilities");
     }
   } else {
-    assert(isa<FloatType>());
+    assert(llvm::isa<FloatType>(*this));
     switch (bitwidth) {
       WIDTH_CASE(Float, 16);
       WIDTH_CASE(Float, 64);
@@ -735,22 +735,22 @@ bool SPIRVType::classof(Type type) {
 }
 
 bool SPIRVType::isScalarOrVector() {
-  return isIntOrFloat() || isa<VectorType>();
+  return isIntOrFloat() || llvm::isa<VectorType>(*this);
 }
 
 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
                               std::optional<StorageClass> storage) {
-  if (auto scalarType = dyn_cast<ScalarType>()) {
+  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
     scalarType.getExtensions(extensions, storage);
-  } else if (auto compositeType = dyn_cast<CompositeType>()) {
+  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
     compositeType.getExtensions(extensions, storage);
-  } else if (auto imageType = dyn_cast<ImageType>()) {
+  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
     imageType.getExtensions(extensions, storage);
-  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
+  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
     sampledImageType.getExtensions(extensions, storage);
-  } else if (auto matrixType = dyn_cast<MatrixType>()) {
+  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
     matrixType.getExtensions(extensions, storage);
-  } else if (auto ptrType = dyn_cast<PointerType>()) {
+  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getExtensions(extensions, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getExtensions");
@@ -760,17 +760,17 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
 void SPIRVType::getCapabilities(
     SPIRVType::CapabilityArrayRefVector &capabilities,
     std::optional<StorageClass> storage) {
-  if (auto scalarType = dyn_cast<ScalarType>()) {
+  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
     scalarType.getCapabilities(capabilities, storage);
-  } else if (auto compositeType = dyn_cast<CompositeType>()) {
+  } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
     compositeType.getCapabilities(capabilities, storage);
-  } else if (auto imageType = dyn_cast<ImageType>()) {
+  } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
     imageType.getCapabilities(capabilities, storage);
-  } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
+  } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
     sampledImageType.getCapabilities(capabilities, storage);
-  } else if (auto matrixType = dyn_cast<MatrixType>()) {
+  } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
     matrixType.getCapabilities(capabilities, storage);
-  } else if (auto ptrType = dyn_cast<PointerType>()) {
+  } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getCapabilities(capabilities, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
@@ -778,9 +778,9 @@ void SPIRVType::getCapabilities(
 }
 
 std::optional<int64_t> SPIRVType::getSizeInBytes() {
-  if (auto scalarType = dyn_cast<ScalarType>())
+  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
     return scalarType.getSizeInBytes();
-  if (auto compositeType = dyn_cast<CompositeType>())
+  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
     return compositeType.getSizeInBytes();
   return std::nullopt;
 }
index 58d0e6a..b1dffbf 100644 (file)
@@ -856,9 +856,9 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getLhs() || !adaptor.getRhs())
     return nullptr;
   auto lhsShape = llvm::to_vector<6>(
-      adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
+      llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
   auto rhsShape = llvm::to_vector<6>(
-      adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
+      llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
   SmallVector<int64_t, 6> resultShape;
   resultShape.append(lhsShape.begin(), lhsShape.end());
   resultShape.append(rhsShape.begin(), rhsShape.end());
@@ -989,7 +989,7 @@ OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
           if (!operand)
             return false;
           extents.push_back(llvm::to_vector<6>(
-              operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
+              llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
         }
         return OpTrait::util::staticallyKnownBroadcastable(extents);
       }())
@@ -1132,10 +1132,10 @@ LogicalResult mlir::shape::DimOp::verify() {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
-  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
+  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
   if (!lhs)
     return nullptr;
-  auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
+  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
   if (!rhs)
     return nullptr;
 
@@ -1346,7 +1346,7 @@ std::optional<int64_t> GetExtentOp::getConstantDim() {
 }
 
 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
-  auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
+  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
   if (!elements)
     return nullptr;
   std::optional<int64_t> dim = getConstantDim();
@@ -1490,7 +1490,7 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
-  auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
+  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
   if (!shape)
     return {};
   int64_t rank = shape.getNumElements();
@@ -1671,10 +1671,10 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
-  auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
+  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
   if (!lhs)
     return nullptr;
-  auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
+  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
   if (!rhs)
     return nullptr;
   APInt folded = lhs.getValue() * rhs.getValue();
@@ -1864,9 +1864,9 @@ LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
   if (!adaptor.getOperand() || !adaptor.getIndex())
     return failure();
   auto shapeVec = llvm::to_vector<6>(
-      adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
+      llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
   auto shape = llvm::ArrayRef(shapeVec);
-  auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
+  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
   // Verify that the split point is in the correct range.
   // TODO: Constant fold to an "error".
   int64_t rank = shape.size();
@@ -1889,7 +1889,7 @@ OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
     return OpFoldResult();
   Builder builder(getContext());
   auto shape = llvm::to_vector<6>(
-      adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
+      llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
   auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
                                     builder.getIndexType());
   return DenseIntElementsAttr::get(type, shape);
index 0ecc77f..3175e95 100644 (file)
@@ -815,7 +815,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
   Level cooStartLvl = getCOOStart(stt.getEncoding());
   if (cooStartLvl < stt.getLvlRank()) {
     // We only supports trailing COO for now, must be the last input.
-    auto cooTp = lvlTps.back().cast<ShapedType>();
+    auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
     // The coordinates should be in shape of <? x rank>
     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
@@ -844,7 +844,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
       inputTp = lvlTps[idx++];
     }
     // The input element type and expected element type should match.
-    Type inpElemTp = inputTp.cast<TensorType>().getElementType();
+    Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
     Type expElemTp = getFieldElemType(stt, fKind);
     if (inpElemTp != expElemTp) {
       misMatch = true;
index 5a1615e..246e5d9 100644 (file)
@@ -188,7 +188,7 @@ static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
 /// Generates a memref from tensor operation.
 static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
                                Value tensor) {
-  auto tensorType = tensor.getType().cast<ShapedType>();
+  auto tensorType = llvm::cast<ShapedType>(tensor.getType());
   auto memrefType =
       MemRefType::get(tensorType.getShape(), tensorType.getElementType());
   return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
index f6405d2..20d0c5e 100644 (file)
@@ -414,7 +414,7 @@ public:
   /// TODO: better unord/not-unique; also generalize, optimize, specialize!
   SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
                                        OpBuilder &builder, Location loc) {
-    const SparseTensorType stt(rtp.cast<RankedTensorType>());
+    const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
     const Level lvlRank = stt.getLvlRank();
     // Extract fields and coordinates from args.
     SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
@@ -466,7 +466,7 @@ public:
     // The mangled name of the function has this format:
     //   <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
     constexpr const char kInsertFuncNamePrefix[] = "_insert_";
-    const SparseTensorType stt(rtp.cast<RankedTensorType>());
+    const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
 
     SmallString<32> nameBuffer;
     llvm::raw_svector_ostream nameOstream(nameBuffer);
@@ -541,14 +541,14 @@ static void genEndInsert(OpBuilder &builder, Location loc,
 
 static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
                                               Value tensor) {
-  auto tTp = tensor.getType().cast<TensorType>();
+  auto tTp = llvm::cast<TensorType>(tensor.getType());
   auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
   return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
       .getResult();
 }
 
 Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
-  auto elemTp = mem.getType().cast<MemRefType>().getElementType();
+  auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
   return builder
       .create<memref::SubViewOp>(
           loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
index 8a52082..a964d91 100644 (file)
@@ -180,7 +180,7 @@ struct ReifyPadOp
       AffineExpr expr = b.getAffineDimExpr(0);
       unsigned numSymbols = 0;
       auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
-        if (Value v = valueOrAttr.dyn_cast<Value>()) {
+        if (Value v = llvm::dyn_cast_if_present<Value>(valueOrAttr)) {
           expr = expr + b.getAffineSymbolExpr(numSymbols++);
           mapOperands.push_back(v);
           return;
index eab64b5..1adb9c7 100644 (file)
@@ -501,7 +501,7 @@ Speculation::Speculatability DimOp::getSpeculatability() {
 
 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
   // All forms of folding require a known index.
-  auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
+  auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
   if (!index)
     return {};
 
@@ -764,7 +764,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
       OpFoldResult currDim = std::get<1>(it);
       // Case 1: The empty tensor dim is static. Check that the tensor cast
       // result dim matches.
-      if (auto attr = currDim.dyn_cast<Attribute>()) {
+      if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
         if (ShapedType::isDynamic(newDim) ||
             newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
           // Something is off, the cast result shape cannot be more dynamic
@@ -2106,7 +2106,7 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
 }
 
 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
-  if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
+  if (auto splat = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
     auto resultType = llvm::cast<ShapedType>(getResult().getType());
     if (resultType.hasStaticShape())
       return splat.resizeSplat(resultType);
@@ -3558,7 +3558,7 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
   SmallVector<int64_t> result;
   for (auto o : ofrs) {
     // Have to do this first, as getConstantIntValue special-cases constants.
-    if (o.dyn_cast<Value>())
+    if (llvm::dyn_cast_if_present<Value>(o))
       result.push_back(ShapedType::kDynamic);
     else
       result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
index 935a1b9..545a9d0 100644 (file)
@@ -76,7 +76,7 @@ struct CastOpInterface
     auto rankedResultType = cast<RankedTensorType>(castOp.getType());
     return MemRefType::get(
         rankedResultType.getShape(), rankedResultType.getElementType(),
-        maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
+        llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -139,7 +139,7 @@ struct CollapseShapeOpInterface
         collapseShapeOp.getSrc(), options, fixedTypes);
     if (failed(maybeSrcBufferType))
       return failure();
-    auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+    auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
     bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
         srcBufferType, collapseShapeOp.getReassociationIndices());
 
@@ -303,7 +303,7 @@ struct ExpandShapeOpInterface
         expandShapeOp.getSrc(), options, fixedTypes);
     if (failed(maybeSrcBufferType))
       return failure();
-    auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+    auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
     auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
         srcBufferType, expandShapeOp.getResultType().getShape(),
         expandShapeOp.getReassociationIndices());
@@ -369,7 +369,7 @@ struct ExtractSliceOpInterface
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, resultMemrefType->cast<MemRefType>(), *srcMemref, mixedOffsets,
+        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
         mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
@@ -389,7 +389,7 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
+        extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
         mixedOffsets, mixedSizes, mixedStrides));
   }
 };
index e95e628..fb3b934 100644 (file)
@@ -548,8 +548,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
@@ -573,8 +573,8 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
   if (lhsAttr && lhsAttr.isSplat()) {
     if (llvm::isa<IntegerType>(resultETy) &&
         lhsAttr.getSplatValue<APInt>().isZero())
@@ -642,8 +642,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
   if (rhsTy == resultTy) {
@@ -670,8 +670,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto resultETy = resultTy.getElementType();
-  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
     return getInput1();
@@ -713,8 +713,8 @@ struct APIntFoldGreaterEqual {
 
 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
-  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (!lhsAttr || !rhsAttr)
     return {};
@@ -725,8 +725,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
-  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
   if (!lhsAttr || !rhsAttr)
     return {};
@@ -738,8 +738,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
-  auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
-  auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+  auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+  auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
   Value lhs = getInput1();
   Value rhs = getInput2();
   auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
@@ -763,7 +763,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
   if (getInput().getType() == getType())
     return getInput();
 
-  auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
+  auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
   if (!operand)
     return {};
 
@@ -852,7 +852,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   if (inputTy == outputTy)
     return getInput1();
 
-  auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+  auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
   }
@@ -863,7 +863,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
 OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
   // If the pad is all zeros we can fold this operation away.
   if (adaptor.getPadding()) {
-    auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
+    auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
     if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
       return getInput1();
     }
@@ -907,7 +907,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
   auto operand = getInput();
   auto operandTy = llvm::cast<ShapedType>(operand.getType());
   auto axis = getAxis();
-  auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
+  auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
   if (operandAttr)
     return operandAttr;
 
@@ -936,7 +936,7 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
       !outputTy.getElementType().isIntOrIndexOrFloat())
     return {};
 
-  auto operand = adaptor.getInput().cast<ElementsAttr>();
+  auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
   if (operand.isSplat() && outputTy.hasStaticShape()) {
     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
   }
@@ -955,7 +955,7 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
   if (getOnTrue() == getOnFalse())
     return getOnTrue();
 
-  auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
+  auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
   if (!predicate)
     return {};
 
@@ -977,7 +977,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::cast<ShapedType>(getType());
 
   // Transposing splat values just means reshaping.
-  if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
+  if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     if (input.isSplat() && resultTy.hasStaticShape() &&
         inputTy.getElementType() == resultTy.getElementType())
       return input.reshape(resultTy);
index 829db2a..94cbb0a 100644 (file)
@@ -63,9 +63,9 @@ LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
   // Verify the rank agrees with the output type if the output type is ranked.
   if (outputType) {
     if (outputType.getRank() !=
-            input1_copy.getType().cast<RankedTensorType>().getRank() ||
+            llvm::cast<RankedTensorType>(input1_copy.getType()).getRank() ||
         outputType.getRank() !=
-            input2_copy.getType().cast<RankedTensorType>().getRank())
+            llvm::cast<RankedTensorType>(input2_copy.getType()).getRank())
       return rewriter.notifyMatchFailure(
           loc, "the reshaped type doesn't agrees with the ranked output type");
   }
index 8f84a06..d260c93 100644 (file)
@@ -103,8 +103,8 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
 
 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
                                         Value &input1, Value &input2) {
-  auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
-  auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
+  auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
+  auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
 
   if (!input1Ty || !input2Ty) {
     return failure();
@@ -126,9 +126,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
   }
 
   ArrayRef<int64_t> higherRankShape =
-      higherTensorValue.getType().cast<RankedTensorType>().getShape();
+      llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
   ArrayRef<int64_t> lowerRankShape =
-      lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+      llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
 
   SmallVector<int64_t, 4> reshapeOutputShape;
 
@@ -136,7 +136,8 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
           .failed())
     return failure();
 
-  auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+  auto reshapeInputType =
+      llvm::cast<RankedTensorType>(lowerTensorValue.getType());
   auto reshapeOutputType = RankedTensorType::get(
       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
 
index b3176b1..695b4a3 100644 (file)
@@ -118,7 +118,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
     SmallVector<Operation *> operations;
     operations.reserve(values.size());
     for (transform::MappedValue value : values) {
-      if (auto *op = value.dyn_cast<Operation *>()) {
+      if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
         operations.push_back(op);
         continue;
       }
@@ -135,7 +135,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
     SmallVector<Value> payloadValues;
     payloadValues.reserve(values.size());
     for (transform::MappedValue value : values) {
-      if (auto v = value.dyn_cast<Value>()) {
+      if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
         payloadValues.push_back(v);
         continue;
       }
@@ -152,7 +152,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
   SmallVector<transform::Param> parameters;
   parameters.reserve(values.size());
   for (transform::MappedValue value : values) {
-    if (auto attr = value.dyn_cast<Attribute>()) {
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
       parameters.push_back(attr);
       continue;
     }
index 09137d3..75d2dce 100644 (file)
@@ -18,7 +18,7 @@ namespace mlir {
 bool isZeroIndex(OpFoldResult v) {
   if (!v)
     return false;
-  if (auto attr = v.dyn_cast<Attribute>()) {
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
     IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
     return intAttr && intAttr.getValue().isZero();
   }
@@ -51,7 +51,7 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
 void dispatchIndexOpFoldResult(OpFoldResult ofr,
                                SmallVectorImpl<Value> &dynamicVec,
                                SmallVectorImpl<int64_t> &staticVec) {
-  auto v = ofr.dyn_cast<Value>();
+  auto v = llvm::dyn_cast_if_present<Value>(ofr);
   if (!v) {
     APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
     staticVec.push_back(apInt.getSExtValue());
@@ -116,14 +116,14 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.
-  if (auto val = ofr.dyn_cast<Value>()) {
+  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
     APSInt intVal;
     if (matchPattern(val, m_ConstantInt(&intVal)))
       return intVal.getSExtValue();
     return std::nullopt;
   }
   // Case 2: Check for IntegerAttr.
-  Attribute attr = ofr.dyn_cast<Attribute>();
+  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
   if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
     return intAttr.getValue().getSExtValue();
   return std::nullopt;
@@ -143,7 +143,8 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
   auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
   if (cst1 && cst2 && *cst1 == *cst2)
     return true;
-  auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
+  auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
+       v2 = llvm::dyn_cast_if_present<Value>(ofr2);
   return v1 && v1 == v2;
 }
 
index aac6777..20c088c 100644 (file)
@@ -1154,7 +1154,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
                             OpaqueProperties properties, RegionRange,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
   ExtractOp::Adaptor op(operands, attributes, properties);
-  auto vectorType = op.getVector().getType().cast<VectorType>();
+  auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
   if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
     inferredReturnTypes.push_back(vectorType.getElementType());
   } else {
@@ -2003,9 +2003,9 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
   if (!adaptor.getSource())
     return {};
   auto vectorType = getResultVectorType();
-  if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
+  if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
     return DenseElementsAttr::get(vectorType, adaptor.getSource());
-  if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
+  if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
     return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
   return {};
 }
@@ -2090,7 +2090,7 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
                             OpaqueProperties properties, RegionRange,
                             SmallVectorImpl<Type> &inferredReturnTypes) {
   ShuffleOp::Adaptor op(operands, attributes, properties);
-  auto v1Type = op.getV1().getType().cast<VectorType>();
+  auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
   auto v1Rank = v1Type.getRank();
   // Construct resulting type: leading dimension matches mask
   // length, all trailing dimensions match the operands.
@@ -4951,7 +4951,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
-  if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
+  if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
     if (attr.isSplat())
       return attr.reshape(getResultVectorType());
 
index a75bc58..ef0bf75 100644 (file)
@@ -3642,7 +3642,7 @@ void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
   if (auto *op = getDefiningOp())
     return op->print(os, flags);
   // TODO: Improve BlockArgument print'ing.
-  BlockArgument arg = this->cast<BlockArgument>();
+  BlockArgument arg = llvm::cast<BlockArgument>(*this);
   os << "<block argument> of type '" << arg.getType()
      << "' at index: " << arg.getArgNumber();
 }
@@ -3656,7 +3656,7 @@ void Value::print(raw_ostream &os, AsmState &state) {
     return op->print(os, state);
 
   // TODO: Improve BlockArgument print'ing.
-  BlockArgument arg = this->cast<BlockArgument>();
+  BlockArgument arg = llvm::cast<BlockArgument>(*this);
   os << "<block argument> of type '" << arg.getType()
      << "' at index: " << arg.getArgNumber();
 }
@@ -3693,10 +3693,10 @@ static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
 
 void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
   Operation *op;
-  if (auto result = dyn_cast<OpResult>()) {
+  if (auto result = llvm::dyn_cast<OpResult>(*this)) {
     op = result.getOwner();
   } else {
-    op = cast<BlockArgument>().getOwner()->getParentOp();
+    op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
     if (!op) {
       os << "<<UNKNOWN SSA VALUE>>";
       return;
index dc4ec14..069be73 100644 (file)
@@ -347,14 +347,14 @@ BlockRange::BlockRange(SuccessorRange successors)
 
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) {
-  if (auto *operand = object.dyn_cast<BlockOperand *>())
+  if (auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
     return {operand + index};
-  return {object.dyn_cast<Block *const *>() + index};
+  return {llvm::dyn_cast_if_present<Block *const *>(object) + index};
 }
 
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
-  if (const auto *operand = object.dyn_cast<BlockOperand *>())
+  if (const auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
     return operand[index].get();
-  return object.dyn_cast<Block *const *>()[index];
+  return llvm::dyn_cast_if_present<Block *const *>(object)[index];
 }
index c4fad9c..22abd52 100644 (file)
@@ -483,7 +483,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
     Type expectedType = std::get<1>(it);
 
     // Normal values get pushed back directly.
-    if (auto value = std::get<0>(it).dyn_cast<Value>()) {
+    if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
       if (value.getType() != expectedType)
         return cleanupFailure();
 
index de26c8e..9a18f07 100644 (file)
@@ -1247,12 +1247,12 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
 DenseElementsAttr
 DenseElementsAttr::mapValues(Type newElementType,
                              function_ref<APInt(const APInt &)> mapping) const {
-  return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+  return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType, mapping);
 }
 
 DenseElementsAttr DenseElementsAttr::mapValues(
     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
-  return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+  return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType, mapping);
 }
 
 ShapedType DenseElementsAttr::getType() const {
index c816e4a..eea07ed 100644 (file)
@@ -88,45 +88,45 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
-  if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
-          Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
+  if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
     return 8;
-  if (isa<Float16Type, BFloat16Type>())
+  if (llvm::isa<Float16Type, BFloat16Type>(*this))
     return 16;
-  if (isa<Float32Type>())
+  if (llvm::isa<Float32Type>(*this))
     return 32;
-  if (isa<Float64Type>())
+  if (llvm::isa<Float64Type>(*this))
     return 64;
-  if (isa<Float80Type>())
+  if (llvm::isa<Float80Type>(*this))
     return 80;
-  if (isa<Float128Type>())
+  if (llvm::isa<Float128Type>(*this))
     return 128;
   llvm_unreachable("unexpected float type");
 }
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
-  if (isa<Float8E5M2Type>())
+  if (llvm::isa<Float8E5M2Type>(*this))
     return APFloat::Float8E5M2();
-  if (isa<Float8E4M3FNType>())
+  if (llvm::isa<Float8E4M3FNType>(*this))
     return APFloat::Float8E4M3FN();
-  if (isa<Float8E5M2FNUZType>())
+  if (llvm::isa<Float8E5M2FNUZType>(*this))
     return APFloat::Float8E5M2FNUZ();
-  if (isa<Float8E4M3FNUZType>())
+  if (llvm::isa<Float8E4M3FNUZType>(*this))
     return APFloat::Float8E4M3FNUZ();
-  if (isa<Float8E4M3B11FNUZType>())
+  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
     return APFloat::Float8E4M3B11FNUZ();
-  if (isa<BFloat16Type>())
+  if (llvm::isa<BFloat16Type>(*this))
     return APFloat::BFloat();
-  if (isa<Float16Type>())
+  if (llvm::isa<Float16Type>(*this))
     return APFloat::IEEEhalf();
-  if (isa<Float32Type>())
+  if (llvm::isa<Float32Type>(*this))
     return APFloat::IEEEsingle();
-  if (isa<Float64Type>())
+  if (llvm::isa<Float64Type>(*this))
     return APFloat::IEEEdouble();
-  if (isa<Float80Type>())
+  if (llvm::isa<Float80Type>(*this))
     return APFloat::x87DoubleExtended();
-  if (isa<Float128Type>())
+  if (llvm::isa<Float128Type>(*this))
     return APFloat::IEEEquad();
   llvm_unreachable("non-floating point type used");
 }
@@ -269,21 +269,21 @@ Type TensorType::getElementType() const {
           [](auto type) { return type.getElementType(); });
 }
 
-bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
+bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
 
 ArrayRef<int64_t> TensorType::getShape() const {
-  return cast<RankedTensorType>().getShape();
+  return llvm::cast<RankedTensorType>(*this).getShape();
 }
 
 TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
-  if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
+  if (auto unrankedTy = llvm::dyn_cast<UnrankedTensorType>(*this)) {
     if (shape)
       return RankedTensorType::get(*shape, elementType);
     return UnrankedTensorType::get(elementType);
   }
 
-  auto rankedTy = cast<RankedTensorType>();
+  auto rankedTy = llvm::cast<RankedTensorType>(*this);
   if (!shape)
     return RankedTensorType::get(rankedTy.getShape(), elementType,
                                  rankedTy.getEncoding());
@@ -356,15 +356,15 @@ Type BaseMemRefType::getElementType() const {
           [](auto type) { return type.getElementType(); });
 }
 
-bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
+bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
 
 ArrayRef<int64_t> BaseMemRefType::getShape() const {
-  return cast<MemRefType>().getShape();
+  return llvm::cast<MemRefType>(*this).getShape();
 }
 
 BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                          Type elementType) const {
-  if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
+  if (auto unrankedTy = llvm::dyn_cast<UnrankedMemRefType>(*this)) {
     if (!shape)
       return UnrankedMemRefType::get(elementType, getMemorySpace());
     MemRefType::Builder builder(*shape, elementType);
@@ -372,7 +372,7 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
     return builder;
   }
 
-  MemRefType::Builder builder(cast<MemRefType>());
+  MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
   if (shape)
     builder.setShape(*shape);
   builder.setElementType(elementType);
@@ -389,15 +389,15 @@ MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
 }
 
 Attribute BaseMemRefType::getMemorySpace() const {
-  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
+  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
     return rankedMemRefTy.getMemorySpace();
-  return cast<UnrankedMemRefType>().getMemorySpace();
+  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
 }
 
 unsigned BaseMemRefType::getMemorySpaceAsInt() const {
-  if (auto rankedMemRefTy = dyn_cast<MemRefType>())
+  if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
     return rankedMemRefTy.getMemorySpaceAsInt();
-  return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
+  return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
 }
 
 //===----------------------------------------------------------------------===//
index 0a4a19d..c353188 100644 (file)
@@ -626,17 +626,17 @@ ValueRange::ValueRange(ResultRange values)
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
                                            ptrdiff_t index) {
-  if (const auto *value = owner.dyn_cast<const Value *>())
+  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
     return {value + index};
-  if (auto *operand = owner.dyn_cast<OpOperand *>())
+  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
     return {operand + index};
   return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
 }
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
-  if (const auto *value = owner.dyn_cast<const Value *>())
+  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
     return value[index];
-  if (auto *operand = owner.dyn_cast<OpOperand *>())
+  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
     return operand[index].get();
   return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
 }
index 2b84a93..e1caa89 100644 (file)
@@ -267,18 +267,18 @@ RegionRange::RegionRange(ArrayRef<Region *> regions)
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
                                              ptrdiff_t index) {
-  if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
+  if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
     return region + index;
-  if (auto **region = owner.dyn_cast<Region **>())
+  if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
     return region + index;
   return &owner.get<Region *>()[index];
 }
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 Region *RegionRange::dereference_iterator(const OwnerT &owner,
                                           ptrdiff_t index) {
-  if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
+  if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
     return region[index].get();
-  if (auto **region = owner.dyn_cast<Region **>())
+  if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
     return region[index];
   return &owner.get<Region *>()[index];
 }
index c03f4dd..2494cb7 100644 (file)
@@ -551,7 +551,7 @@ struct SymbolScope {
                 typename llvm::function_traits<CallbackT>::result_t,
                 void>::value> * = nullptr>
   std::optional<WalkResult> walk(CallbackT cback) {
-    if (Region *region = limit.dyn_cast<Region *>())
+    if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
       return walkSymbolUses(*region, cback);
     return walkSymbolUses(limit.get<Operation *>(), cback);
   }
@@ -571,7 +571,7 @@ struct SymbolScope {
   /// traversing into any nested symbol tables.
   template <typename CallbackT>
   std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
-    if (Region *region = limit.dyn_cast<Region *>())
+    if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
       return ::walkSymbolTable(*region, cback);
     return ::walkSymbolTable(limit.get<Operation *>(), cback);
   }
index 2e2121a..c05c0ce 100644 (file)
@@ -27,9 +27,9 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
   if (count == 0)
     return;
   ValueRange::OwnerT owner = values.begin().getBase();
-  if (auto *result = owner.dyn_cast<detail::OpResultImpl *>())
+  if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(owner))
     this->base = result;
-  else if (auto *operand = owner.dyn_cast<OpOperand *>())
+  else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
     this->base = operand;
   else
     this->base = owner.get<const Value *>();
@@ -37,22 +37,22 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
 
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
-  if (const auto *value = object.dyn_cast<const Value *>())
+  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
     return {value + index};
-  if (auto *operand = object.dyn_cast<OpOperand *>())
+  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
     return {operand + index};
-  if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
+  if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
     return {result->getNextResultAtOffset(index)};
-  return {object.dyn_cast<const Type *>() + index};
+  return {llvm::dyn_cast_if_present<const Type *>(object) + index};
 }
 
 /// See `llvm::detail::indexed_accessor_range_base` for details.
 Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
-  if (const auto *value = object.dyn_cast<const Value *>())
+  if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
     return (value + index)->getType();
-  if (auto *operand = object.dyn_cast<OpOperand *>())
+  if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
     return (operand + index)->get().getType();
-  if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
+  if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
     return result->getNextResultAtOffset(index)->getType();
-  return object.dyn_cast<const Type *>()[index];
+  return llvm::dyn_cast_if_present<const Type *>(object)[index];
 }
index d3d1d86..e376a5f 100644 (file)
@@ -34,84 +34,94 @@ Type AbstractType::replaceImmediateSubElements(Type type,
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
-bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
-bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
-bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
-bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
-bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
-bool Type::isBF16() const { return isa<BFloat16Type>(); }
-bool Type::isF16() const { return isa<Float16Type>(); }
-bool Type::isF32() const { return isa<Float32Type>(); }
-bool Type::isF64() const { return isa<Float64Type>(); }
-bool Type::isF80() const { return isa<Float80Type>(); }
-bool Type::isF128() const { return isa<Float128Type>(); }
-
-bool Type::isIndex() const { return isa<IndexType>(); }
+bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
+bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
+bool Type::isFloat8E5M2FNUZ() const {
+  return llvm::isa<Float8E5M2FNUZType>(*this);
+}
+bool Type::isFloat8E4M3FNUZ() const {
+  return llvm::isa<Float8E4M3FNUZType>(*this);
+}
+bool Type::isFloat8E4M3B11FNUZ() const {
+  return llvm::isa<Float8E4M3B11FNUZType>(*this);
+}
+bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
+bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
+bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
+bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
+bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
+bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+
+bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
 
 /// Return true if this is an integer type with the specified width.
 bool Type::isInteger(unsigned width) const {
-  if (auto intTy = dyn_cast<IntegerType>())
+  if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
     return intTy.getWidth() == width;
   return false;
 }
 
 bool Type::isSignlessInteger() const {
-  if (auto intTy = dyn_cast<IntegerType>())
+  if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
     return intTy.isSignless();
   return false;
 }
 
 bool Type::isSignlessInteger(unsigned width) const {
-  if (auto intTy = dyn_cast<IntegerType>())
+  if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
     return intTy.isSignless() && intTy.getWidth() == width;
   return false;
 }
 
 bool Type::isSignedInteger() const {
-  if (auto intTy = dyn_cast<IntegerType>())
+  if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
     return intTy.isSigned();
   return false;
 }
 
 bool Type::isSignedInteger(unsigned width) const {
-  if (auto intTy = dyn_cast<IntegerType>())
+  if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
     return intTy.isSigned() && intTy.getWidth() == width;
   return false;
 }
 
 bool Type::isUnsignedInteger() const {
-  if (auto intTy = dyn_cast<IntegerType>())
+  if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
     return intTy.isUnsigned();
   return false;
 }
 
 bool Type::isUnsignedInteger(unsigned width) const {
-  if (auto intTy = dyn_cast<IntegerType>())
+  if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
     return intTy.isUnsigned() && intTy.getWidth() == width;
   return false;
 }
 
 bool Type::isSignlessIntOrIndex() const {
-  return isSignlessInteger() || isa<IndexType>();
+  return isSignlessInteger() || llvm::isa<IndexType>(*this);
 }
 
 bool Type::isSignlessIntOrIndexOrFloat() const {
-  return isSignlessInteger() || isa<IndexType, FloatType>();
+  return isSignlessInteger() || llvm::isa<IndexType, FloatType>(*this);
 }
 
 bool Type::isSignlessIntOrFloat() const {
-  return isSignlessInteger() || isa<FloatType>();
+  return isSignlessInteger() || llvm::isa<FloatType>(*this);
 }
 
-bool Type::isIntOrIndex() const { return isa<IntegerType>() || isIndex(); }
+bool Type::isIntOrIndex() const {
+  return llvm::isa<IntegerType>(*this) || isIndex();
+}
 
-bool Type::isIntOrFloat() const { return isa<IntegerType, FloatType>(); }
+bool Type::isIntOrFloat() const {
+  return llvm::isa<IntegerType, FloatType>(*this);
+}
 
 bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); }
 
 unsigned Type::getIntOrFloatBitWidth() const {
   assert(isIntOrFloat() && "only integers and floats have a bitwidth");
-  if (auto intType = dyn_cast<IntegerType>())
+  if (auto intType = llvm::dyn_cast<IntegerType>(*this))
     return intType.getWidth();
-  return cast<FloatType>().getWidth();
+  return llvm::cast<FloatType>(*this).getWidth();
 }
index 7da714f..c109d38 100644 (file)
@@ -48,11 +48,11 @@ static void printBlock(llvm::raw_ostream &os, Block *block,
 }
 
 void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
-  if (auto *op = this->dyn_cast<Operation *>())
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(*this))
     return printOp(os, op, flags);
-  if (auto *region = this->dyn_cast<Region *>())
+  if (auto *region = llvm::dyn_cast_if_present<Region *>(*this))
     return printRegion(os, region, flags);
-  if (auto *block = this->dyn_cast<Block *>())
+  if (auto *block = llvm::dyn_cast_if_present<Block *>(*this))
     return printBlock(os, block, flags);
   llvm_unreachable("unknown IRUnit");
 }
index 86b9cde..6b5195d 100644 (file)
@@ -18,7 +18,7 @@ using namespace mlir::detail;
 /// If this value is the result of an Operation, return the operation that
 /// defines it.
 Operation *Value::getDefiningOp() const {
-  if (auto result = dyn_cast<OpResult>())
+  if (auto result = llvm::dyn_cast<OpResult>(*this))
     return result.getOwner();
   return nullptr;
 }
@@ -27,28 +27,28 @@ Location Value::getLoc() const {
   if (auto *op = getDefiningOp())
     return op->getLoc();
 
-  return cast<BlockArgument>().getLoc();
+  return llvm::cast<BlockArgument>(*this).getLoc();
 }
 
 void Value::setLoc(Location loc) {
   if (auto *op = getDefiningOp())
     return op->setLoc(loc);
 
-  return cast<BlockArgument>().setLoc(loc);
+  return llvm::cast<BlockArgument>(*this).setLoc(loc);
 }
 
 /// Return the Region in which this Value is defined.
 Region *Value::getParentRegion() {
   if (auto *op = getDefiningOp())
     return op->getParentRegion();
-  return cast<BlockArgument>().getOwner()->getParent();
+  return llvm::cast<BlockArgument>(*this).getOwner()->getParent();
 }
 
 /// Return the Block in which this Value is defined.
 Block *Value::getParentBlock() {
   if (Operation *op = getDefiningOp())
     return op->getBlock();
-  return cast<BlockArgument>().getOwner();
+  return llvm::cast<BlockArgument>(*this).getOwner();
 }
 
 //===----------------------------------------------------------------------===//
index e335e15..e460fe1 100644 (file)
@@ -241,7 +241,7 @@ mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
                                    TypeID typeID) {
   return llvm::to_vector<4>(llvm::make_filter_range(
       entries, [typeID](DataLayoutEntryInterface entry) {
-        auto type = entry.getKey().dyn_cast<Type>();
+        auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
         return type && type.getTypeID() == typeID;
       }));
 }
@@ -521,7 +521,7 @@ void DataLayoutSpecInterface::bucketEntriesByType(
     DenseMap<TypeID, DataLayoutEntryList> &types,
     DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
   for (DataLayoutEntryInterface entry : getEntries()) {
-    if (auto type = entry.getKey().dyn_cast<Type>())
+    if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
       types[type.getTypeID()].push_back(entry);
     else
       ids[entry.getKey().get<StringAttr>()] = entry;
index aaa1e1b..00d1c51 100644 (file)
@@ -68,7 +68,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
 bool ShapeAdaptor::hasRank() const {
   if (val.isNull())
     return false;
-  if (auto t = val.dyn_cast<Type>())
+  if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).hasRank();
   if (val.is<Attribute>())
     return true;
@@ -78,7 +78,7 @@ bool ShapeAdaptor::hasRank() const {
 Type ShapeAdaptor::getElementType() const {
   if (val.isNull())
     return nullptr;
-  if (auto t = val.dyn_cast<Type>())
+  if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).getElementType();
   if (val.is<Attribute>())
     return nullptr;
@@ -87,10 +87,10 @@ Type ShapeAdaptor::getElementType() const {
 
 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
   assert(hasRank());
-  if (auto t = val.dyn_cast<Type>()) {
+  if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
     ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
     res.assign(vals.begin(), vals.end());
-  } else if (auto attr = val.dyn_cast<Attribute>()) {
+  } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
     auto dattr = cast<DenseIntElementsAttr>(attr);
     res.clear();
     res.reserve(dattr.size());
@@ -110,9 +110,9 @@ void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
 
 int64_t ShapeAdaptor::getDimSize(int index) const {
   assert(hasRank());
-  if (auto t = val.dyn_cast<Type>())
+  if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).getDimSize(index);
-  if (auto attr = val.dyn_cast<Attribute>())
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
     return cast<DenseIntElementsAttr>(attr)
         .getValues<APInt>()[index]
         .getSExtValue();
@@ -122,9 +122,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
 
 int64_t ShapeAdaptor::getRank() const {
   assert(hasRank());
-  if (auto t = val.dyn_cast<Type>())
+  if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).getRank();
-  if (auto attr = val.dyn_cast<Attribute>())
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
     return cast<DenseIntElementsAttr>(attr).size();
   return val.get<ShapedTypeComponents *>()->getDims().size();
 }
@@ -133,9 +133,9 @@ bool ShapeAdaptor::hasStaticShape() const {
   if (!hasRank())
     return false;
 
-  if (auto t = val.dyn_cast<Type>())
+  if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).hasStaticShape();
-  if (auto attr = val.dyn_cast<Attribute>()) {
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
     auto dattr = cast<DenseIntElementsAttr>(attr);
     for (auto index : dattr.getValues<APInt>())
       if (ShapedType::isDynamic(index.getSExtValue()))
@@ -149,10 +149,10 @@ bool ShapeAdaptor::hasStaticShape() const {
 int64_t ShapeAdaptor::getNumElements() const {
   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
 
-  if (auto t = val.dyn_cast<Type>())
+  if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).getNumElements();
 
-  if (auto attr = val.dyn_cast<Attribute>()) {
+  if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
     auto dattr = cast<DenseIntElementsAttr>(attr);
     int64_t num = 1;
     for (auto index : dattr.getValues<APInt>()) {
index bc7d6b4..3fab2a3 100644 (file)
@@ -26,14 +26,14 @@ namespace mlir {
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.
-  if (auto val = ofr.dyn_cast<Value>()) {
+  if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
     APSInt intVal;
     if (matchPattern(val, m_ConstantInt(&intVal)))
       return intVal.getSExtValue();
     return std::nullopt;
   }
   // Case 2: Check for IntegerAttr.
-  Attribute attr = ofr.dyn_cast<Attribute>();
+  Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
   if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
     return intAttr.getValue().getSExtValue();
   return std::nullopt;
@@ -99,7 +99,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
 }
 
 AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
-  if (Value value = ofr.dyn_cast<Value>())
+  if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
     return getExpr(value, /*dim=*/std::nullopt);
   auto constInt = getConstantIntValue(ofr);
   assert(constInt.has_value() && "expected Integer constant");
index 35f1edd..7276071 100644 (file)
@@ -26,7 +26,8 @@ struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
   const Pass &getPass() const { return pass; }
   Operation *getOp() const {
     ArrayRef<IRUnit> irUnits = getContextIRUnits();
-    return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
+    return irUnits.empty() ? nullptr
+                           : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
   }
 
 public:
index 03557d9..fe9cae3 100644 (file)
@@ -384,7 +384,7 @@ void Operator::populateTypeInferenceInfo(
   if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
     // Check for a non-variable length operand to use as the type anchor.
     auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
-      NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
+      NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
       return operand && !operand->isVariableLength();
     });
     if (operandI == arguments.end())
@@ -824,7 +824,7 @@ StringRef Operator::getAssemblyFormat() const {
 void Operator::print(llvm::raw_ostream &os) const {
   os << "op '" << getOperationName() << "'\n";
   for (Argument arg : arguments) {
-    if (auto *attr = arg.dyn_cast<NamedAttribute *>())
+    if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
       os << "[attribute] " << attr->name << '\n';
     else
       os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
index 8783be7..96c727e 100644 (file)
@@ -131,7 +131,7 @@ convertBranchWeights(std::optional<ElementsAttr> weights,
     return nullptr;
   SmallVector<uint32_t> weightValues;
   weightValues.reserve(weights->size());
-  for (APInt weight : weights->cast<DenseIntElementsAttr>())
+  for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
     weightValues.push_back(weight.getLimitedValue());
   return llvm::MDBuilder(moduleTranslation.getLLVMContext())
       .createBranchWeights(weightValues);
@@ -330,7 +330,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     auto *ty = llvm::cast<llvm::IntegerType>(
         moduleTranslation.convertType(switchOp.getValue().getType()));
     for (auto i :
-         llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
+         llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
                    switchOp.getCaseDestinations()))
       switchInst->addCase(
           llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
index 2e67db5..05d6b78 100644 (file)
@@ -730,8 +730,8 @@ Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
 
   // Returns the static shape of the provided type if possible.
   auto getConstantShape = [&](llvm::Type *type) {
-    return getBuiltinTypeForAttr(convertType(type))
-        .dyn_cast_or_null<ShapedType>();
+    return llvm::dyn_cast_if_present<ShapedType>(getBuiltinTypeForAttr(convertType(type))
+        );
   };
 
   // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
@@ -798,8 +798,8 @@ Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
 
   // Convert zero aggregates.
   if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
-    auto shape = getBuiltinTypeForAttr(convertType(constZero->getType()))
-                     .dyn_cast_or_null<ShapedType>();
+    auto shape = llvm::dyn_cast_if_present<ShapedType>(getBuiltinTypeForAttr(convertType(constZero->getType()))
+                     );
     if (!shape)
       return {};
     // Convert zero aggregates with a static shape to splat elements attributes.
index d1d23b6..772721e 100644 (file)
@@ -69,7 +69,7 @@ translateDataLayout(DataLayoutSpecInterface attribute,
   std::string llvmDataLayout;
   llvm::raw_string_ostream layoutStream(llvmDataLayout);
   for (DataLayoutEntryInterface entry : attribute.getEntries()) {
-    auto key = entry.getKey().dyn_cast<StringAttr>();
+    auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
     if (!key)
       continue;
     if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
@@ -108,7 +108,7 @@ translateDataLayout(DataLayoutSpecInterface attribute,
   // specified in entries. Where possible, data layout queries are used instead
   // of directly inspecting the entries.
   for (DataLayoutEntryInterface entry : attribute.getEntries()) {
-    auto type = entry.getKey().dyn_cast<Type>();
+    auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
     if (!type)
       continue;
     // Data layout for the index type is irrelevant at this point.
index a0b8faa..b84d1d9 100644 (file)
@@ -285,7 +285,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
         static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
     auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
         linkageName, linkageTypeAttr);
-    decorations[words[0]].set(symbol, linkageAttr.dyn_cast<Attribute>());
+    decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
     break;
   }
   case spirv::Decoration::Aliased:
index 582f02f..44538c3 100644 (file)
@@ -639,7 +639,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
   if (values) {
     for (auto &intVal : values.getValue()) {
       operands.push_back(static_cast<uint32_t>(
-          intVal.cast<IntegerAttr>().getValue().getZExtValue()));
+          llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
     }
   }
   encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
index 9089ebc..f32f6e8 100644 (file)
@@ -222,7 +222,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   case spirv::Decoration::LinkageAttributes: {
     // Get the value of the Linkage Attributes
     // e.g., LinkageAttributes=["linkageName", linkageType].
-    auto linkageAttr = attr.getValue().dyn_cast<spirv::LinkageAttributesAttr>();
+    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
     auto linkageName = linkageAttr.getLinkageName();
     auto linkageType = linkageAttr.getLinkageType().getValue();
     // Encode the Linkage Name (string literal to uint32_t).
index 4c6aaf3..8efc614 100644 (file)
@@ -136,7 +136,7 @@ struct PDLIndexSymbol {
 
   /// Return the location of the definition of this symbol.
   SMRange getDefLoc() const {
-    if (const ast::Decl *decl = definition.dyn_cast<const ast::Decl *>()) {
+    if (const ast::Decl *decl = llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
       const ast::Name *declName = decl->getName();
       return declName ? declName->getLoc() : decl->getLoc();
     }
@@ -465,7 +465,7 @@ PDLDocument::findHover(const lsp::URIForFile &uri,
     return std::nullopt;
 
   // Add hover for operation names.
-  if (const auto *op = symbol->definition.dyn_cast<const ods::Operation *>())
+  if (const auto *op = llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
     return buildHoverForOpName(op, hoverRange);
   const auto *decl = symbol->definition.get<const ast::Decl *>();
   return findHover(decl, hoverRange);
index 57ccb3b..a7dcd2b 100644 (file)
@@ -373,7 +373,7 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
 
 #ifndef NDEBUG
 static std::string getNodeName(CallOpInterface op) {
-  if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
+  if (auto sym = llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
     return debugString(op);
   return "_unnamed_callee_";
 }
index 827c0ad..e9e59cf 100644 (file)
@@ -272,7 +272,7 @@ OperationFolder::processFoldResults(Operation *op,
     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
 
     // Check if the result was an SSA value.
-    if (auto repl = foldResults[i].dyn_cast<Value>()) {
+    if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
       if (repl.getType() != op->getResult(i).getType()) {
         results.clear();
         return failure();
index 46126b0..f076869 100644 (file)
@@ -266,7 +266,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
 
   // Remap the locations of the inlined operations if a valid source location
   // was provided.
-  if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
+  if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
     remapInlinedLocations(newBlocks, *inlineLoc);
 
   // If the blocks were moved in-place, make sure to remap any necessary
index 5a66f19..ed361b5 100644 (file)
@@ -115,11 +115,11 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
 }
 
 LogicalResult FooAnalysis::visit(ProgramPoint point) {
-  if (auto *op = point.dyn_cast<Operation *>()) {
+  if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
     visitOperation(op);
     return success();
   }
-  if (auto *block = point.dyn_cast<Block *>()) {
+  if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
     visitBlock(block);
     return success();
   }
index db3b9a1..ad017ce 100644 (file)
@@ -161,7 +161,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
       }
 
       // Replace the op with the reified bound.
-      if (auto val = reified->dyn_cast<Value>()) {
+      if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
         rewriter.replaceOp(op, val);
         return WalkResult::skip();
       }
index 31b7504..98e88fe 100644 (file)
@@ -1134,7 +1134,7 @@ void OpEmitter::genPropertiesSupport() {
 )decl";
   for (const auto &attrOrProp : attrOrProperties) {
     if (const auto *namedProperty =
-            attrOrProp.dyn_cast<const NamedProperty *>()) {
+            llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
       StringRef name = namedProperty->name;
       auto &prop = namedProperty->prop;
       FmtContext fctx;
@@ -1145,7 +1145,7 @@ void OpEmitter::genPropertiesSupport() {
                                           .addSubst("_diag", propertyDiag)),
                                name);
     } else {
-      const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+      const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
       StringRef name = namedAttr->attrName;
       setPropMethod << formatv(R"decl(
   {{
@@ -1187,7 +1187,7 @@ void OpEmitter::genPropertiesSupport() {
 )decl";
   for (const auto &attrOrProp : attrOrProperties) {
     if (const auto *namedProperty =
-            attrOrProp.dyn_cast<const NamedProperty *>()) {
+            llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
       StringRef name = namedProperty->name;
       auto &prop = namedProperty->prop;
       FmtContext fctx;
@@ -1198,7 +1198,7 @@ void OpEmitter::genPropertiesSupport() {
                      .addSubst("_storage", propertyStorage)));
       continue;
     }
-    const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+    const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
     StringRef name = namedAttr->attrName;
     getPropMethod << formatv(R"decl(
     {{
@@ -1225,7 +1225,7 @@ void OpEmitter::genPropertiesSupport() {
 )decl";
   for (const auto &attrOrProp : attrOrProperties) {
     if (const auto *namedProperty =
-            attrOrProp.dyn_cast<const NamedProperty *>()) {
+            llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
       StringRef name = namedProperty->name;
       auto &prop = namedProperty->prop;
       FmtContext fctx;
@@ -1238,13 +1238,13 @@ void OpEmitter::genPropertiesSupport() {
   llvm::interleaveComma(
       attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
         if (const auto *namedProperty =
-                attrOrProp.dyn_cast<const NamedProperty *>()) {
+                llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
           hashMethod << "\n    hash_" << namedProperty->name << "(prop."
                      << namedProperty->name << ")";
           return;
         }
         const auto *namedAttr =
-            attrOrProp.dyn_cast<const AttributeMetadata *>();
+            llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
         StringRef name = namedAttr->attrName;
         hashMethod << "\n    llvm::hash_value(prop." << name
                    << ".getAsOpaquePointer())";
@@ -1266,7 +1266,7 @@ void OpEmitter::genPropertiesSupport() {
 )decl";
   for (const auto &attrOrProp : attrOrProperties) {
     if (const auto *namedAttr =
-            attrOrProp.dyn_cast<const AttributeMetadata *>()) {
+            llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp)) {
       StringRef name = namedAttr->attrName;
       getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
       setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
@@ -1281,7 +1281,7 @@ void OpEmitter::genPropertiesSupport() {
   // syntax. This method verifies the constraint on the properties attributes
   // before they are set, since dyn_cast<> will silently omit failures.
   for (const auto &attrOrProp : attrOrProperties) {
-    const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+    const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
     if (!namedAttr || !namedAttr->constraint)
       continue;
     Attribute attr = *namedAttr->constraint;
@@ -2472,7 +2472,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
     // Calculate the start index from which we can attach default values in the
     // builder declaration.
     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
-      auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
+      auto *namedAttr = llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(op.getArg(i));
       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
         break;
 
@@ -2502,7 +2502,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
 
   for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
     Argument arg = op.getArg(i);
-    if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
+    if (const auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
       StringRef type;
       if (operand->isVariadicOfVariadic())
         type = "::llvm::ArrayRef<::mlir::ValueRange>";
@@ -2515,7 +2515,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
                              operand->isOptional());
       continue;
     }
-    if (const auto *operand = arg.dyn_cast<NamedProperty *>()) {
+    if (const auto *operand = llvm::dyn_cast_if_present<NamedProperty *>(arg)) {
       // TODO
       continue;
     }
@@ -3442,7 +3442,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
     llvm::raw_string_ostream comparatorOs(comparator);
     for (const auto &attrOrProp : attrOrProperties) {
       if (const auto *namedProperty =
-              attrOrProp.dyn_cast<const NamedProperty *>()) {
+              llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
         StringRef name = namedProperty->name;
         if (name.empty())
           report_fatal_error("missing name for property");
@@ -3476,7 +3476,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
                                  .addSubst("_storage", propertyStorage)));
         continue;
       }
-      const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+      const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
       const Attribute *attr = nullptr;
       if (namedAttr->constraint)
         attr = &*namedAttr->constraint;
index 9f8fa43..3e6db51 100644 (file)
@@ -265,11 +265,11 @@ struct OperationFormat {
 
     /// Get the variable this type is resolved to, or nullptr.
     const NamedTypeConstraint *getVariable() const {
-      return resolver.dyn_cast<const NamedTypeConstraint *>();
+      return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
     }
     /// Get the attribute this type is resolved to, or nullptr.
     const NamedAttribute *getAttribute() const {
-      return resolver.dyn_cast<const NamedAttribute *>();
+      return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
     }
     /// Get the transformer for the type of the variable, or std::nullopt.
     std::optional<StringRef> getVarTransformer() const {
index 330268d..6e4a9e3 100644 (file)
@@ -674,7 +674,7 @@ populateBuilderLinesAttr(const Operator &op,
   builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     Argument arg = op.getArg(i);
-    auto *attribute = arg.dyn_cast<NamedAttribute *>();
+    auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
     if (!attribute)
       continue;
 
@@ -914,9 +914,9 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
     // - default-valued named attributes
     // - optional operands
     Argument a = op.getArg(builderArgIndex - numResultArgs);
-    if (auto *nattr = a.dyn_cast<NamedAttribute *>())
+    if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
       return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
-    if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
+    if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
       return ntype->isOptional();
     return false;
   };
index 8a04cc9..9463c4f 100644 (file)
@@ -595,7 +595,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
         ++opArgIdx;
         continue;
       }
-      if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+      if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
         if (operand->isVariableLength()) {
           auto error = formatv("use nested DAG construct to match op {0}'s "
                                "variadic operand #{1} unsupported now",
@@ -1524,7 +1524,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
   int valueIndex = 0; // An index for uniquing local variable names.
   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
     const auto *operand =
-        resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
+        llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
     // We do not need special handling for attributes.
     if (!operand)
       continue;
@@ -1579,7 +1579,7 @@ void PatternEmitter::supplyValuesForOpArgs(
 
     Argument opArg = resultOp.getArg(argIndex);
     // Handle the case of operand first.
-    if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+    if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
       if (!operand->name.empty())
         os << "/*" << operand->name << "=*/";
       os << childNodeNames.lookup(argIndex);
index 4caaf1d..7bf7755 100644 (file)
@@ -926,7 +926,7 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
   // Process operands/attributes
   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
-    if (auto *valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
+    if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
       if (valueArg->isVariableLength()) {
         if (i != e - 1) {
           PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
index 6601f32..e9ba28e 100644 (file)
@@ -159,7 +159,7 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
     // Handle built-in types that are not handled by the default process.
     if (auto iType = dyn_cast<IntegerType>(type)) {
       for (DataLayoutEntryInterface entry : params)
-        if (entry.getKey().dyn_cast<Type>() == type)
+        if (llvm::dyn_cast_if_present<Type>(entry.getKey()) == type)
           return 8 *
                  cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
       return 8 * iType.getIntOrFloatBitWidth();