From adca3c2edcdd1375d8c421816ec53044537ccd64 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 11 May 2019 17:57:32 -0700 Subject: [PATCH] Replace Operation::cast with llvm::cast. -- PiperOrigin-RevId: 247785983 --- mlir/examples/Linalg/Linalg1/lib/Analysis.cpp | 6 +-- mlir/examples/Linalg/Linalg1/lib/Common.cpp | 4 +- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 6 +-- mlir/examples/Linalg/Linalg1/lib/Utils.cpp | 2 +- mlir/examples/Linalg/Linalg2/lib/Transforms.cpp | 13 +++--- .../Linalg/Linalg3/include/linalg3/TensorOps-inl.h | 2 +- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 2 +- mlir/examples/Linalg/Linalg3/lib/Transforms.cpp | 16 +++---- mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch4/mlir/ToyCombine.cpp | 8 ++-- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 2 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 10 ++--- mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 2 +- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 10 ++--- mlir/include/mlir/IR/OpDefinition.h | 12 ++--- mlir/include/mlir/IR/Operation.h | 8 ---- mlir/lib/AffineOps/AffineOps.cpp | 4 +- mlir/lib/Analysis/LoopAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 2 +- mlir/lib/Analysis/VectorAnalysis.cpp | 4 +- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 6 +-- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 18 ++++---- mlir/lib/Linalg/IR/LinalgOps.cpp | 5 +-- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 12 ++--- mlir/lib/Linalg/Transforms/Tiling.cpp | 8 ++-- mlir/lib/Linalg/Utils/Utils.cpp | 6 +-- mlir/lib/Quantization/IR/QuantOps.cpp | 8 ++-- mlir/lib/Quantization/Transforms/ConvertConst.cpp | 2 +- .../Quantization/Transforms/ConvertSimQuant.cpp | 2 +- mlir/lib/StandardOps/Ops.cpp | 16 +++---- mlir/lib/Transforms/LoopFusion.cpp | 52 +++++++++++----------- mlir/lib/Transforms/LoopUnroll.cpp | 2 +- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 2 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/LowerVectorTransfers.cpp | 4 +- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 6 +-- mlir/lib/Transforms/Utils/LoopUtils.cpp | 4 +- .../Vectorization/VectorizerTestPass.cpp | 2 +- mlir/lib/Transforms/Vectorize.cpp | 8 ++-- 41 files changed, 139 insertions(+), 147 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp index a7fba17..092b83a 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Analysis.cpp @@ -35,7 +35,7 @@ ViewOp linalg::getViewBaseViewOp(Value *view) { view = slice.getParentView(); assert(viewType.isa() && "expected a ViewType"); } - return view->getDefiningOp()->cast(); + return cast(view->getDefiningOp()); } Value *linalg::getViewSupportingMemRef(Value *view) { @@ -51,12 +51,12 @@ std::pair linalg::getViewRootIndexing(Value *view, if (auto viewOp = dyn_cast(view->getDefiningOp())) return std::make_pair(viewOp.getIndexing(dim), dim); - auto sliceOp = view->getDefiningOp()->cast(); + auto sliceOp = cast(view->getDefiningOp()); auto *parentView = sliceOp.getParentView(); unsigned sliceDim = sliceOp.getSlicingDim(); auto *indexing = sliceOp.getIndexing(); if (indexing->getDefiningOp()) { - if (auto rangeOp = indexing->getDefiningOp()->cast()) { + if (auto rangeOp = cast(indexing->getDefiningOp())) { // If I sliced with a range and I sliced at this dim, then I'm it. if (dim == sliceDim) { return std::make_pair(rangeOp.getResult(), dim); diff --git a/mlir/examples/Linalg/Linalg1/lib/Common.cpp b/mlir/examples/Linalg/Linalg1/lib/Common.cpp index 278f9c5..1e211bf 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Common.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Common.cpp @@ -47,8 +47,8 @@ linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( auto lb = rangeOp.getMin(); auto ub = rangeOp.getMax(); // This must be a constexpr index until we relax the affine.for constraint - auto step = - rangeOp.getStep()->getDefiningOp()->cast().getValue(); + auto step = llvm::cast(rangeOp.getStep()->getDefiningOp()) + .getValue(); loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step); } } diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 6097240..48884b1 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -155,7 +155,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto rangeOp = op->cast(); + auto rangeOp = cast(op); auto rangeDescriptorType = linalg::convertLinalgType(rangeOp.getResult()->getType()); @@ -187,7 +187,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto viewOp = op->cast(); + auto viewOp = cast(op); auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType()); auto memrefType = viewOp.getSupportingMemRef()->getType().cast(); @@ -319,7 +319,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto sliceOp = op->cast(); + auto sliceOp = cast(op); auto newViewDescriptorType = linalg::convertLinalgType(sliceOp.getViewType()); auto elementType = rewriter.getType( diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index 5bcebc7..05070a9 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -35,7 +35,7 @@ unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa() && "expected a ViewType"); if (auto viewOp = dyn_cast(view->getDefiningOp())) return viewOp.getRank(); - return view->getDefiningOp()->cast().getRank(); + return cast(view->getDefiningOp()).getRank(); } ViewOp linalg::emitAndReturnViewOpFromMemRef(Value *memRef) { diff --git a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp index 83fd9ad..9df0af8 100644 --- a/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/Transforms.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/StandardTypes.h" using llvm::ArrayRef; +using llvm::cast; using llvm::SmallVector; using mlir::FuncBuilder; using mlir::MemRefType; @@ -49,7 +50,7 @@ static SmallVector getViewChain(mlir::Value *v) { SmallVector tmp; do { - auto sliceOp = v->getDefiningOp()->cast(); // must be a slice op + auto sliceOp = cast(v->getDefiningOp()); // must be a slice op tmp.push_back(v); v = sliceOp.getParentView(); } while (!v->getType().isa()); @@ -62,15 +63,15 @@ static mlir::Value *createFullyComposedIndexing(unsigned dim, ArrayRef chain) { using namespace mlir::edsc::op; assert(chain.front()->getType().isa() && "must be a ViewType"); - auto viewOp = chain.front()->getDefiningOp()->cast(); + auto viewOp = cast(chain.front()->getDefiningOp()); auto *indexing = viewOp.getIndexing(dim); if (!indexing->getType().isa()) return indexing; - auto rangeOp = indexing->getDefiningOp()->cast(); + auto rangeOp = cast(indexing->getDefiningOp()); Value *min = rangeOp.getMin(), *max = rangeOp.getMax(), *step = rangeOp.getStep(); for (auto *v : chain.drop_front(1)) { - auto slice = v->getDefiningOp()->cast(); + auto slice = cast(v->getDefiningOp()); if (slice.getRank() != slice.getParentRank()) { // Rank-reducing slice. if (slice.getSlicingDim() == dim) { @@ -82,7 +83,7 @@ static mlir::Value *createFullyComposedIndexing(unsigned dim, dim = (slice.getSlicingDim() < dim) ? dim - 1 : dim; } else { // not a rank-reducing slice. if (slice.getSlicingDim() == dim) { - auto range = slice.getIndexing()->getDefiningOp()->cast(); + auto range = cast(slice.getIndexing()->getDefiningOp()); auto oldMin = min; min = ValueHandle(min) + ValueHandle(range.getMin()); // ideally: max = min(oldMin + ValueHandle(range.getMax()), oldMax); @@ -110,5 +111,5 @@ ViewOp linalg::emitAndReturnFullyComposedView(Value *v) { for (unsigned idx = 0; idx < rank; ++idx) { ranges.push_back(createFullyComposedIndexing(idx, chain)); } - return view(memRef, ranges).getOperation()->cast(); + return cast(view(memRef, ranges).getOperation()); } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h index 3090f29..2c47541 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/TensorOps-inl.h @@ -94,7 +94,7 @@ extractRangesFromViewOrSliceOp(mlir::Value *view) { if (auto viewOp = llvm::dyn_cast(view->getDefiningOp())) return viewOp.getRanges(); - auto sliceOp = view->getDefiningOp()->cast(); + auto sliceOp = llvm::cast(view->getDefiningOp()); unsigned slicingDim = sliceOp.getSlicingDim(); auto *indexing = *(sliceOp.getIndexings().begin()); bool isRankReducing = indexing->getType().isa(); diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index f1bb90d..22feb66 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -71,7 +71,7 @@ public: // a getelementptr. Value *obtainDataPtr(Operation *op, Value *viewDescriptor, ArrayRef indices, FuncBuilder &rewriter) const { - auto loadOp = op->cast(); + auto loadOp = cast(op); auto elementType = loadOp.getViewType().template cast().getElementType(); auto *llvmPtrType = linalg::convertLinalgType(elementType) diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index bce7f58..6309300 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -95,7 +95,7 @@ extractFromRanges(ArrayRef ranges, SmallVector res; res.reserve(ranges.size()); for (auto *v : ranges) { - auto r = v->getDefiningOp()->cast(); + auto r = cast(v->getDefiningOp()); res.push_back(extract(r)); } return res; @@ -149,9 +149,9 @@ linalg::makeGenericLoopRanges(AffineMap operandRangesToLoopMaps, for (auto z : llvm::zip(res.steps, tileSizes)) { auto *step = std::get<0>(z); auto tileSize = std::get<1>(z); - auto stepValue = step->getDefiningOp()->cast().getValue(); + auto stepValue = cast(step->getDefiningOp()).getValue(); auto tileSizeValue = - tileSize->getDefiningOp()->cast().getValue(); + cast(tileSize->getDefiningOp()).getValue(); assert(stepValue > 0); tiledSteps.push_back(constant_index(stepValue * tileSizeValue)); } @@ -236,7 +236,7 @@ emitAndReturnLoadStoreOperands(LoadOrStoreOp loadOrStoreOp, ViewOp viewOp) { operands.push_back(indexing); continue; } - RangeOp range = indexing->getDefiningOp()->cast(); + RangeOp range = cast(indexing->getDefiningOp()); ValueHandle min(range.getMin()); Value *storeIndex = *(loadOrStoreOp.getIndices().begin() + storeDim++); using edsc::op::operator+; @@ -275,10 +275,10 @@ template <> PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto load = op->cast(); + auto load = cast(op); SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) - : load.getView()->getDefiningOp()->cast(); + : cast(load.getView()->getDefiningOp()); ScopedContext scope(FuncBuilder(load), load.getLoc()); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(load, view); @@ -290,10 +290,10 @@ template <> PatternMatchResult Rewriter::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto store = op->cast(); + auto store = cast(op); SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) - : store.getView()->getDefiningOp()->cast(); + : cast(store.getView()->getDefiningOp()); ScopedContext scope(FuncBuilder(store), store.getLoc()); auto *valueToStore = store.getValueToStore(); auto *memRef = view.getSupportingMemRef(); diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index c9f98e7..5f024ea 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -350,7 +350,7 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. for (auto &block : f->getBlocks()) { - auto ret = block.getTerminator()->cast(); + auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 942ce86..4175fc2 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -49,7 +49,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { // We can directly cast the current operation as this will only get invoked // on TransposeOp. - TransposeOp transpose = op->cast(); + TransposeOp transpose = llvm::cast(op); // Look through the input of the current transpose. mlir::Value *transposeInput = transpose.getOperand(); TransposeOp transposeInputOp = @@ -73,7 +73,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // Look through the input of the current reshape. ConstantOp constantOp = llvm::dyn_cast_or_null( reshape.getOperand()->getDefiningOp()); @@ -120,7 +120,7 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // Look through the input of the current reshape. mlir::Value *reshapeInput = reshape.getOperand(); // If the input is defined by another reshape, bingo! @@ -142,7 +142,7 @@ struct SimplifyNullReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); if (reshape.getOperand()->getType() != reshape.getResult()->getType()) return matchFailure(); rewriter.replaceOp(reshape, {reshape.getOperand()}); diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index db6ba73..3e640de 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -92,7 +92,7 @@ public: using intrinsics::constant_index; using linalg::intrinsics::range; using linalg::intrinsics::view; - toy::MulOp mul = op->cast(); + toy::MulOp mul = cast(op); auto loc = mul.getLoc(); Value *result = memRefTypeCast( rewriter, rewriter.create(loc, mul.getResult()->getType()) diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 534b5cb..0a2ff1d 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -93,7 +93,7 @@ public: /// number must match the number of result of `op`. SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto add = op->cast(); + auto add = cast(op); auto loc = add.getLoc(); // Create a `toy.alloc` operation to allocate the output buffer for this op. Value *result = memRefTypeCast( @@ -135,7 +135,7 @@ public: // Get or create the declaration of the printf function in the module. Function *printfFunc = getPrintf(*op->getFunction()->getModule()); - auto print = op->cast(); + auto print = cast(op); auto loc = print.getLoc(); // We will operate on a MemRef abstraction, we use a type.cast to get one // if our operand is still a Toy array. @@ -234,7 +234,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - toy::ConstantOp cstOp = op->cast(); + toy::ConstantOp cstOp = cast(op); auto loc = cstOp.getLoc(); auto retTy = cstOp.getResult()->getType().cast(); auto shape = retTy.getShape(); @@ -277,7 +277,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto transpose = op->cast(); + auto transpose = cast(op); auto loc = transpose.getLoc(); Value *result = memRefTypeCast( rewriter, @@ -309,7 +309,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto retOp = op->cast(); + auto retOp = cast(op); using namespace edsc; auto loc = retOp.getLoc(); // Argument is optional, handle both cases. diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 4e17b23..ab99019 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -357,7 +357,7 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. for (auto &block : f->getBlocks()) { - auto ret = block.getTerminator()->cast(); + auto ret = llvm::cast(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 8d6aed6..260f6a6 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -49,7 +49,7 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { mlir::PatternRewriter &rewriter) const override { // We can directly cast the current operation as this will only get invoked // on TransposeOp. - TransposeOp transpose = op->cast(); + TransposeOp transpose = llvm::cast(op); // look through the input to the current transpose mlir::Value *transposeInput = transpose.getOperand(); mlir::Operation *transposeInputInst = transposeInput->getDefiningOp(); @@ -74,7 +74,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // look through the input to the current reshape mlir::Value *reshapeInput = reshape.getOperand(); mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); @@ -125,7 +125,7 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); // look through the input to the current reshape mlir::Value *reshapeInput = reshape.getOperand(); mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); @@ -150,7 +150,7 @@ struct SimplifyNullReshape : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = op->cast(); + ReshapeOp reshape = llvm::cast(op); if (reshape.getOperand()->getType() != reshape.getResult()->getType()) return matchFailure(); rewriter.replaceOp(reshape, {reshape.getOperand()}); @@ -185,7 +185,7 @@ struct SimplifyIdentityTypeCast : public mlir::RewritePattern { mlir::PatternMatchResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - TypeCastOp typeCast = op->cast(); + TypeCastOp typeCast = llvm::cast(op); auto resTy = typeCast.getResult()->getType(); auto *candidateOp = op; while (candidateOp && candidateOp->isa()) { diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 2eff412..250fb94 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -183,8 +183,8 @@ public: static LogicalResult constantFoldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - return op->cast().constantFold(operands, results, - op->getContext()); + return cast(op).constantFold(operands, results, + op->getContext()); } /// Op implementations can implement this hook. It should attempt to constant @@ -205,7 +205,7 @@ public: /// This is an implementation detail of the folder hook for AbstractOperation. static LogicalResult foldHook(Operation *op, SmallVectorImpl &results) { - return op->cast().fold(results); + return cast(op).fold(results); } /// This hook implements a generalized folder for this operation. Operations @@ -253,7 +253,7 @@ public: ArrayRef operands, SmallVectorImpl &results) { auto result = - op->cast().constantFold(operands, op->getContext()); + cast(op).constantFold(operands, op->getContext()); if (!result) return failure(); @@ -277,7 +277,7 @@ public: /// This is an implementation detail of the folder hook for AbstractOperation. static LogicalResult foldHook(Operation *op, SmallVectorImpl &results) { - auto *result = op->cast().fold(); + auto *result = cast(op).fold(); if (!result) return failure(); if (result != op->getResult(0)) @@ -808,7 +808,7 @@ public: static LogicalResult verifyInvariants(Operation *op) { return failure( failed(BaseVerifier...>::verifyTrait(op)) || - failed(op->cast().verify())); + failed(cast(op).verify())); } // Returns the properties of an operation by combining the properties of the diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 31ec8ea..088a4e4 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -389,14 +389,6 @@ public: // Conversions to declared operations like DimOp //===--------------------------------------------------------------------===// - /// The cast methods perform a cast from an Operation to a typed Op like - /// DimOp. This aborts if the parameter to the template isn't an instance of - /// the template type argument. - template OpClass cast() { - assert(isa() && "cast() argument of incompatible type!"); - return OpClass(this); - } - /// The is methods return true if the operation is a typed op (like DimOp) of /// of the given class. template bool isa() { return OpClass::classof(this); } diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 2dfed93..f551afb 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -661,7 +661,7 @@ struct SimplifyAffineApply : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto apply = op->cast(); + auto apply = cast(op); auto map = apply.getAffineMap(); AffineMap oldMap = map; @@ -1010,7 +1010,7 @@ struct AffineForLoopBoundFolder : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto forOp = op->cast(); + auto forOp = cast(op); auto foldLowerOrUpperBound = [&forOp](bool lower) { // Check to see if each of the operands is the result of a constant. If // so, get the value. If not, ignore it. diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 60f2b14..3d984c5 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -192,7 +192,7 @@ bool mlir::isAccessInvariant(Value *iv, Value *index) { return false; } - auto composeOp = affineApplyOps[0]->cast(); + auto composeOp = cast(affineApplyOps[0]); // We need yet another level of indirection because the `dim` index of the // access may not correspond to the `dim` index of composeOp. return !(AffineValueMap(composeOp).isFunctionOf(0, iv)); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 8d963e4..cc46d65 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -603,7 +603,7 @@ mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; FuncBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); auto sliceLoopNest = - b.clone(*srcLoopIVs[0].getOperation())->cast(); + cast(b.clone(*srcLoopIVs[0].getOperation())); Operation *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 8fecf05..627ca7a 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -123,7 +123,7 @@ static AffineMap makePermutationMap( for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); auto invariants = getInvariantAccesses( - kvp.first->cast().getInductionVar(), indices); + cast(kvp.first).getInductionVar(), indices); unsigned numIndices = indices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { @@ -181,7 +181,7 @@ AffineMap mlir::makePermutationMap( return ::makePermutationMap(load.getIndices(), enclosingLoopToVectorDim); } - auto store = op->cast(); + auto store = cast(op); return ::makePermutationMap(store.getIndices(), enclosingLoopToVectorDim); } diff --git a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp index afd8152..8bfcfa5 100644 --- a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -124,7 +124,7 @@ struct UniformDequantizePattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto dcastOp = op->cast(); + auto dcastOp = cast(op); Type inputType = dcastOp.arg()->getType(); Type outputType = dcastOp.getResult()->getType(); @@ -328,7 +328,7 @@ struct UniformRealAddEwPattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto addOp = op->cast(); + auto addOp = cast(op); const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(), addOp.clamp_min(), addOp.clamp_max()); if (!info.isValid()) { @@ -350,7 +350,7 @@ struct UniformRealMulEwPattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto mulOp = op->cast(); + auto mulOp = cast(op); const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(), mulOp.clamp_min(), mulOp.clamp_max()); if (!info.isValid()) { diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 0d9025b..e9aee95 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -414,14 +414,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return matchFailure(); - auto allocOp = op->cast(); + auto allocOp = cast(op); MemRefType type = allocOp.getType(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto allocOp = op->cast(); + auto allocOp = cast(op); MemRefType type = allocOp.getType(); // Get actual sizes of the memref as values: static sizes are constant @@ -557,7 +557,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return matchFailure(); - auto memRefCastOp = op->cast(); + auto memRefCastOp = cast(op); MemRefType sourceType = memRefCastOp.getOperand()->getType().cast(); MemRefType targetType = memRefCastOp.getType(); @@ -569,7 +569,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto memRefCastOp = op->cast(); + auto memRefCastOp = cast(op); auto targetType = memRefCastOp.getType(); auto sourceType = memRefCastOp.getOperand()->getType().cast(); @@ -636,7 +636,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return this->matchFailure(); - auto dimOp = op->cast(); + auto dimOp = cast(op); MemRefType type = dimOp.getOperand()->getType().cast(); return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); } @@ -644,7 +644,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { assert(operands.size() == 1 && "expected exactly one operand"); - auto dimOp = op->cast(); + auto dimOp = cast(op); MemRefType type = dimOp.getOperand()->getType().cast(); SmallVector results; @@ -683,7 +683,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { PatternMatchResult match(Operation *op) const override { if (!LLVMLegalizationPattern::match(op)) return this->matchFailure(); - auto loadOp = op->cast(); + auto loadOp = cast(op); MemRefType type = loadOp.getMemRefType(); return isSupportedMemRefType(type) ? this->matchSuccess() : this->matchFailure(); @@ -794,7 +794,7 @@ struct LoadOpLowering : public LoadStoreOpLowering { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto loadOp = op->cast(); + auto loadOp = cast(op); auto type = loadOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, operands.front(), @@ -815,7 +815,7 @@ struct StoreOpLowering : public LoadStoreOpLowering { SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto storeOp = op->cast(); + auto storeOp = cast(op); auto type = storeOp.getMemRefType(); Value *dataPtr = getDataPtr(op->getLoc(), type, operands[1], diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 6998da5..8ea45df 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -320,7 +320,7 @@ void mlir::linalg::SliceOp::print(OpAsmPrinter *p) { } ViewOp mlir::linalg::SliceOp::getBaseViewOp() { - return getOperand(0)->getDefiningOp()->cast(); + return cast(getOperand(0)->getDefiningOp()); } ViewType mlir::linalg::SliceOp::getBaseViewType() { @@ -505,8 +505,7 @@ ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result); /// ``` void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); - *p << op->cast().getOperationName() << " " - << *op->getOperand(0); + *p << cast(op).getOperationName() << " " << *op->getOperand(0); p->printOptionalAttrDict(op->getAttrs()); *p << " : " << op->getOperand(0)->getType(); } diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 90111a8..2d1f5f2 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -181,7 +181,7 @@ public: } // Get MLIR types for injecting element pointer. - auto allocOp = op->cast(); + auto allocOp = cast(op); auto elementType = allocOp.getElementType(); uint64_t elementSize = 0; if (auto vectorType = elementType.dyn_cast()) @@ -239,7 +239,7 @@ public: } // Get MLIR types for extracting element pointer. - auto deallocOp = op->cast(); + auto deallocOp = cast(op); auto elementPtrTy = rewriter.getType(getPtrToElementType( deallocOp.getOperand()->getType().cast(), lowering)); @@ -283,7 +283,7 @@ public: // a getelementptr. This must be called under an edsc::ScopedContext. Value *obtainDataPtr(Operation *op, Value *viewDescriptor, ArrayRef indices, FuncBuilder &rewriter) const { - auto loadOp = op->cast(); + auto loadOp = cast(op); auto elementTy = rewriter.getType( getPtrToElementType(loadOp.getViewType(), lowering)); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -329,7 +329,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto rangeOp = op->cast(); + auto rangeOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeOp.getResult()->getType(), lowering); @@ -355,7 +355,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto sliceOp = op->cast(); + auto sliceOp = cast(op); auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -453,7 +453,7 @@ public: SmallVector rewrite(Operation *op, ArrayRef operands, FuncBuilder &rewriter) const override { - auto viewOp = op->cast(); + auto viewOp = cast(op); auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); auto elementTy = rewriter.getType( getPtrToElementType(viewOp.getViewType(), lowering)); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index 6e20542a..e1fa74d 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -115,8 +115,8 @@ static SmallVector applyMapToRangePart(FuncBuilder *b, Location loc, } static bool isZero(Value *v) { - return v->getDefiningOp() && v->getDefiningOp()->isa() && - v->getDefiningOp()->cast().getValue() == 0; + return isa_and_nonnull(v->getDefiningOp()) && + cast(v->getDefiningOp()).getValue() == 0; } /// Returns a map that can be used to filter the zero values out of tileSizes. @@ -176,8 +176,8 @@ makeTiledLoopRanges(FuncBuilder *b, Location loc, AffineMap map, // Steps must be constant for now to abide by affine.for semantics. auto *newStep = state.getOrCreate( - step->getDefiningOp()->cast().getValue() * - tileSize->getDefiningOp()->cast().getValue()); + cast(step->getDefiningOp()).getValue() * + cast(tileSize->getDefiningOp()).getValue()); res.push_back(b->create(loc, mins[idx], maxes[idx], newStep)); // clang-format on } diff --git a/mlir/lib/Linalg/Utils/Utils.cpp b/mlir/lib/Linalg/Utils/Utils.cpp index 98cf4b7..6732fa1 100644 --- a/mlir/lib/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Linalg/Utils/Utils.cpp @@ -42,12 +42,12 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( assert(ranges[i].getType() && "expected !linalg.range type"); assert(ranges[i].getValue()->getDefiningOp() && "need operations to extract range parts"); - auto rangeOp = ranges[i].getValue()->getDefiningOp()->cast(); + auto rangeOp = cast(ranges[i].getValue()->getDefiningOp()); auto lb = rangeOp.min(); auto ub = rangeOp.max(); // This must be a constexpr index until we relax the affine.for constraint auto step = - rangeOp.step()->getDefiningOp()->cast().getValue(); + cast(rangeOp.step()->getDefiningOp()).getValue(); loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step); } assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); @@ -106,7 +106,7 @@ Value *mlir::createOrReturnView(FuncBuilder *b, Location loc, return view.getResult(); return b->create(loc, view.getResult(), ranges); } - auto slice = viewDefiningOp->cast(); + auto slice = cast(viewDefiningOp); unsigned idxRange = 0; SmallVector newIndexings; bool elide = true; diff --git a/mlir/lib/Quantization/IR/QuantOps.cpp b/mlir/lib/Quantization/IR/QuantOps.cpp index 046ad85..ab6d97f 100644 --- a/mlir/lib/Quantization/IR/QuantOps.cpp +++ b/mlir/lib/Quantization/IR/QuantOps.cpp @@ -43,9 +43,9 @@ public: : RewritePattern(StorageCastOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override { - auto scastOp = op->cast(); + auto scastOp = cast(op); if (matchPattern(scastOp.arg(), m_Op())) { - auto srcScastOp = scastOp.arg()->getDefiningOp()->cast(); + auto srcScastOp = cast(scastOp.arg()->getDefiningOp()); if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) { return matchSuccess(); } @@ -54,8 +54,8 @@ public: } void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto scastOp = op->cast(); - auto srcScastOp = scastOp.arg()->getDefiningOp()->cast(); + auto scastOp = cast(op); + auto srcScastOp = cast(scastOp.arg()->getDefiningOp()); rewriter.replaceOp(op, srcScastOp.arg()); } }; diff --git a/mlir/lib/Quantization/Transforms/ConvertConst.cpp b/mlir/lib/Quantization/Transforms/ConvertConst.cpp index 21a0de2..ad41f8f 100644 --- a/mlir/lib/Quantization/Transforms/ConvertConst.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertConst.cpp @@ -59,7 +59,7 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { State state; // Is the operand a constant? - auto qbarrier = op->cast(); + auto qbarrier = cast(op); if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) { return matchFailure(); } diff --git a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp index 4df7b88..c62adc8 100644 --- a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp @@ -59,7 +59,7 @@ public: } bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { - auto fqOp = op->cast(); + auto fqOp = cast(op); auto converter = ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType()); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index bc68a78..59c1400 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -283,7 +283,7 @@ struct SimplifyAllocConst : public RewritePattern { : RewritePattern(AllocOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override { - auto alloc = op->cast(); + auto alloc = cast(op); // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. @@ -294,7 +294,7 @@ struct SimplifyAllocConst : public RewritePattern { } void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto allocOp = op->cast(); + auto allocOp = cast(op); auto memrefType = allocOp.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones @@ -352,7 +352,7 @@ struct SimplifyDeadAlloc : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Check if the alloc'ed value has any uses. - auto alloc = op->cast(); + auto alloc = cast(op); if (!alloc.use_empty()) return matchFailure(); @@ -468,7 +468,7 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto indirectCall = op->cast(); + auto indirectCall = cast(op); // Check that the callee is a constant operation. Attribute callee; @@ -978,7 +978,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto condbr = op->cast(); + auto condbr = cast(op); // Check that the condition is a constant. if (!matchPattern(condbr.getCondition(), m_Op())) @@ -1222,7 +1222,7 @@ struct SimplifyDeadDealloc : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto dealloc = op->cast(); + auto dealloc = cast(op); // Check that the memref operand's defining operation is an AllocOp. Value *memref = dealloc.memref(); @@ -2107,7 +2107,7 @@ struct SimplifyXMinusX : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto subi = op->cast(); + auto subi = cast(op); if (subi.getOperand(0) != subi.getOperand(1)) return matchFailure(); @@ -2192,7 +2192,7 @@ struct SimplifyXXOrX : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto xorOp = op->cast(); + auto xorOp = cast(op); if (xorOp.lhs() != xorOp.rhs()) return matchFailure(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 1c4a4d1..d430c5d 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -128,7 +128,7 @@ struct LoopNestStateCollector { void collect(Operation *opToWalk) { opToWalk->walk([&](Operation *op) { if (op->isa()) - forOps.push_back(op->cast()); + forOps.push_back(cast(op)); else if (op->getNumRegions() != 0) hasNonForRegion = true; else if (op->isa()) @@ -172,7 +172,7 @@ public: unsigned getLoadOpCount(Value *memref) { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { - if (memref == loadOpInst->cast().getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) ++loadOpCount; } return loadOpCount; @@ -182,7 +182,7 @@ public: unsigned getStoreOpCount(Value *memref) { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { - if (memref == storeOpInst->cast().getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) ++storeOpCount; } return storeOpCount; @@ -192,7 +192,7 @@ public: void getStoreOpsForMemref(Value *memref, SmallVectorImpl *storeOps) { for (auto *storeOpInst : stores) { - if (memref == storeOpInst->cast().getMemRef()) + if (memref == cast(storeOpInst).getMemRef()) storeOps->push_back(storeOpInst); } } @@ -201,7 +201,7 @@ public: void getLoadOpsForMemref(Value *memref, SmallVectorImpl *loadOps) { for (auto *loadOpInst : loads) { - if (memref == loadOpInst->cast().getMemRef()) + if (memref == cast(loadOpInst).getMemRef()) loadOps->push_back(loadOpInst); } } @@ -211,10 +211,10 @@ public: void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { - loadMemrefs.insert(loadOpInst->cast().getMemRef()); + loadMemrefs.insert(cast(loadOpInst).getMemRef()); } for (auto *storeOpInst : stores) { - auto *memref = storeOpInst->cast().getMemRef(); + auto *memref = cast(storeOpInst).getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } @@ -306,7 +306,7 @@ public: bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { - auto *memref = storeOpInst->cast().getMemRef(); + auto *memref = cast(storeOpInst).getMemRef(); auto *op = memref->getDefiningOp(); // Return true if 'memref' is a block argument. if (!op) @@ -331,7 +331,7 @@ public: Node *node = getNode(id); for (auto *storeOpInst : node->stores) { // Return false if there exist out edges from 'id' on 'memref'. - if (getOutEdgeCount(id, storeOpInst->cast().getMemRef()) > 0) + if (getOutEdgeCount(id, cast(storeOpInst).getMemRef()) > 0) return false; } return true; @@ -656,12 +656,12 @@ bool MemRefDependenceGraph::init(Function &f) { Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); - auto *memref = opInst->cast().getMemRef(); + auto *memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); - auto *memref = opInst->cast().getMemRef(); + auto *memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&op] = node.id; @@ -670,14 +670,14 @@ bool MemRefDependenceGraph::init(Function &f) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); - auto *memref = op.cast().getMemRef(); + auto *memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (auto storeOp = dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); - auto *memref = op.cast().getMemRef(); + auto *memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (op.getNumRegions() != 0) { @@ -887,7 +887,7 @@ static void moveLoadsAccessingMemrefTo(Value *memref, dstLoads->clear(); SmallVector srcLoadsToKeep; for (auto *load : *srcLoads) { - if (load->cast().getMemRef() == memref) + if (cast(load).getMemRef() == memref) dstLoads->push_back(load); else srcLoadsToKeep.push_back(load); @@ -1051,7 +1051,7 @@ computeLoopInterchangePermutation(ArrayRef loops, static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { assert(node->op->isa()); SmallVector loops; - AffineForOp curr = node->op->cast(); + AffineForOp curr = cast(node->op); getPerfectlyNestedLoops(loops, curr); if (loops.size() < 2) return; @@ -1107,7 +1107,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Builder to create constants at the top level. FuncBuilder top(forInst->getFunction()); // Create new memref type based on slice bounds. - auto *oldMemRef = srcStoreOpInst->cast().getMemRef(); + auto *oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast(); unsigned rank = oldMemRefType.getRank(); @@ -1233,7 +1233,7 @@ static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // Gather all memrefs from 'srcNode' store ops. DenseSet storeMemrefs; for (auto *storeOpInst : srcNode->stores) { - storeMemrefs.insert(storeOpInst->cast().getMemRef()); + storeMemrefs.insert(cast(storeOpInst).getMemRef()); } // Return false if any of the following are true: // *) 'srcNode' writes to a live in/out memref other than 'memref'. @@ -1842,7 +1842,7 @@ public: DenseSet visitedMemrefs; while (!loads.empty()) { // Get memref of load on top of the stack. - auto *memref = loads.back()->cast().getMemRef(); + auto *memref = cast(loads.back()).getMemRef(); if (visitedMemrefs.count(memref) > 0) continue; visitedMemrefs.insert(memref); @@ -1898,7 +1898,7 @@ public: // Gather 'dstNode' store ops to 'memref'. SmallVector dstStoreOpInsts; for (auto *storeOpInst : dstNode->stores) - if (storeOpInst->cast().getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == memref) dstStoreOpInsts.push_back(storeOpInst); unsigned bestDstLoopDepth; @@ -1916,7 +1916,7 @@ public: LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" << *sliceLoopNest.getOperation() << "\n"); // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = dstNode->op->cast(); + auto dstAffineForOp = cast(dstNode->op); if (insertPointInst != dstAffineForOp.getOperation()) { dstAffineForOp.getOperation()->moveBefore(insertPointInst); } @@ -1934,7 +1934,7 @@ public: // Create private memref for 'memref' in 'dstAffineForOp'. SmallVector storesForMemref; for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (storeOpInst->cast().getMemRef() == memref) + if (cast(storeOpInst).getMemRef() == memref) storesForMemref.push_back(storeOpInst); } assert(storesForMemref.size() == 1); @@ -1956,7 +1956,7 @@ public: // Add new load ops to current Node load op list 'loads' to // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - auto *loadMemRef = loadOpInst->cast().getMemRef(); + auto *loadMemRef = cast(loadOpInst).getMemRef(); if (visitedMemrefs.count(loadMemRef) == 0) loads.push_back(loadOpInst); } @@ -2072,7 +2072,7 @@ public: auto sliceLoopNest = mlir::insertBackwardComputationSlice( sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); if (sliceLoopNest != nullptr) { - auto dstForInst = dstNode->op->cast(); + auto dstForInst = cast(dstNode->op); // Update operation position of fused loop nest (if needed). if (insertPointInst != dstForInst.getOperation()) { dstForInst.getOperation()->moveBefore(insertPointInst); @@ -2114,7 +2114,7 @@ public: // Check that all stores are to the same memref. DenseSet storeMemrefs; for (auto *storeOpInst : sibNode->stores) { - storeMemrefs.insert(storeOpInst->cast().getMemRef()); + storeMemrefs.insert(cast(storeOpInst).getMemRef()); } if (storeMemrefs.size() != 1) return false; @@ -2214,7 +2214,7 @@ public: } // Collect dst loop stats after memref privatizaton transformation. - auto dstForInst = dstNode->op->cast(); + auto dstForInst = cast(dstNode->op); LoopNestStateCollector dstLoopCollector; dstLoopCollector.collect(dstForInst.getOperation()); // Clear and add back loads and stores @@ -2226,7 +2226,7 @@ public: // function. if (mdg->getOutEdgeCount(sibNode->id) == 0) { mdg->removeNode(sibNode->id); - sibNode->op->cast().erase(); + sibNode->op->erase(); } } diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 236ef81..1707f78 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -113,7 +113,7 @@ void LoopUnroll::runOnFunction() { hasInnerLoops |= walkPostOrder(block.begin(), block.end()); if (opInst->isa()) { if (!hasInnerLoops) - loops.push_back(opInst->cast()); + loops.push_back(cast(opInst)); return true; } return hasInnerLoops; diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 0a23295..43e8f4a 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -187,7 +187,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, // Insert the cleanup loop right after 'forOp'. FuncBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst))); - auto cleanupAffineForOp = builder.clone(*forInst)->cast(); + auto cleanupAffineForOp = cast(builder.clone(*forInst)); // Adjust the lower bound of the cleanup loop; its upper bound is the same // as the original loop's upper bound. AffineMap cleanupMap; diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 1ffe5e3..6f0162e 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -626,7 +626,7 @@ void LowerAffinePass::runOnFunction() { } else if (auto forOp = dyn_cast(op)) { if (lowerAffineFor(forOp)) return signalPassFailure(); - } else if (lowerAffineApply(op->cast())) { + } else if (lowerAffineApply(cast(op))) { return signalPassFailure(); } } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index f7352d6..657169a 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -264,7 +264,7 @@ VectorTransferRewriter::matchAndRewrite( using namespace mlir::edsc::op; using namespace mlir::edsc::intrinsics; - VectorTransferReadOp transfer = op->cast(); + VectorTransferReadOp transfer = cast(op); // 1. Setup all the captures. ScopedContext scope(FuncBuilder(op), transfer.getLoc()); @@ -323,7 +323,7 @@ VectorTransferRewriter::matchAndRewrite( using namespace mlir::edsc::op; using namespace mlir::edsc::intrinsics; - VectorTransferWriteOp transfer = op->cast(); + VectorTransferWriteOp transfer = cast(op); // 1. Setup all the captures. ScopedContext scope(FuncBuilder(op), transfer.getLoc()); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 28dfb22..206ae53b 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -679,7 +679,7 @@ static bool materialize(Function *f, const SetVector &terminators, continue; } - auto terminator = term->cast(); + auto terminator = cast(term); LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term); // Get the transitive use-defs starting from terminator, limited to the diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 94df936..118efe5 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -201,7 +201,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { return; // Perform the actual store to load forwarding. - Value *storeVal = lastWriteStoreOp->cast().getValueToStore(); + Value *storeVal = cast(lastWriteStoreOp).getValueToStore(); loadOp.getResult()->replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 0da97f7..272972d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -234,8 +234,8 @@ static void findMatchingStartFinishInsts( // For each start operation, we look for a matching finish operation. for (auto *dmaStartInst : dmaStartInsts) { for (auto *dmaFinishInst : dmaFinishInsts) { - if (checkTagMatch(dmaStartInst->cast(), - dmaFinishInst->cast())) { + if (checkTagMatch(cast(dmaStartInst), + cast(dmaFinishInst))) { startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); break; } @@ -273,7 +273,7 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { for (auto &pair : startWaitPairs) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( - dmaStartInst->cast().getFasterMemPos()); + cast(dmaStartInst).getFasterMemPos()); if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 7fbb48e..1ae75b4 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -426,7 +426,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, Operation *op = forOp.getOperation(); if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) { FuncBuilder builder(op->getBlock(), ++Block::iterator(op)); - auto cleanupForInst = builder.clone(*op)->cast(); + auto cleanupForInst = cast(builder.clone(*op)); AffineMap cleanupMap; SmallVector cleanupOperands; getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands, @@ -512,7 +512,7 @@ void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) { void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { for (unsigned i = 0; i < loopDepth; ++i) { assert(forOp.getBody()->front().isa()); - AffineForOp nextForOp = forOp.getBody()->front().cast(); + AffineForOp nextForOp = cast(forOp.getBody()->front()); interchangeLoops(forOp, nextForOp); } } diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index b64dc53..20138d5 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -253,7 +253,7 @@ void VectorizerTestPass::testNormalizeMaps() { SmallVector matches; pattern.match(f, &matches); for (auto m : matches) { - auto app = m.getMatchedOperation()->cast(); + auto app = cast(m.getMatchedOperation()); FuncBuilder b(m.getMatchedOperation()); SmallVector operands(app.getOperands()); makeComposedAffineApply(&b, app.getLoc(), app.getAffineMap(), operands); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 9b8768a..4a58b15 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -859,7 +859,7 @@ static FilterFunctionType isVectorizableLoopPtrFactory(const llvm::DenseSet ¶llelLoops, int fastestVaryingMemRefDimension) { return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) { - auto loop = forOp.cast(); + auto loop = cast(forOp); auto parallelIt = parallelLoops.find(loop); if (parallelIt == parallelLoops.end()) return false; @@ -879,7 +879,7 @@ static LogicalResult vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, VectorizationState *state) { auto *loopInst = oneMatch.getMatchedOperation(); - auto loop = loopInst->cast(); + auto loop = cast(loopInst); auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -1118,7 +1118,7 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) { /// anything below it fails. static LogicalResult vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { - auto loop = m.getMatchedOperation()->cast(); + auto loop = cast(m.getMatchedOperation()); VectorizationState state; state.strategy = strategy; @@ -1139,7 +1139,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// RAII. auto *loopInst = loop.getOperation(); FuncBuilder builder(loopInst); - auto clonedLoop = builder.clone(*loopInst)->cast(); + auto clonedLoop = cast(builder.clone(*loopInst)); struct Guard { LogicalResult failure() { loop.getInductionVar()->replaceAllUsesWith(clonedLoop.getInductionVar()); -- 2.7.4