[mlir][NFC] Remove a few op builders that simply swap parameter order
authorRiver Riddle <riddleriver@gmail.com>
Sun, 6 Feb 2022 20:32:47 +0000 (12:32 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 8 Feb 2022 03:03:57 +0000 (19:03 -0800)
Differential Revision: https://reviews.llvm.org/D119093

12 files changed:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

index bef9178..b278de5 100644 (file)
@@ -73,12 +73,6 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
     DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins From:$in)>,
     Results<(outs To:$out)> {
-  let builders = [
-    OpBuilder<(ins "Value":$source, "Type":$destType), [{
-      impl::buildCastOp($_builder, $_state, source, destType);
-    }]>
-  ];
-
   let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
 }
 
index 81839df..79ad1ed 100644 (file)
@@ -374,11 +374,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
   let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
   let results = (outs AnyRankedOrUnrankedMemRef:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
-  let builders = [
-    OpBuilder<(ins "Value":$source, "Type":$destType), [{
-       impl::buildCastOp($_builder, $_state, source, destType);
-    }]>
-  ];
 
   let extraClassDeclaration = [{
     /// Fold the given CastOp into consumer op.
index 80f50a3..56ee70e 100644 (file)
@@ -1003,11 +1003,11 @@ private:
       switch (conversion) {
       case PrintConversion::ZeroExt64:
         value = rewriter.create<arith::ExtUIOp>(
-            loc, value, IntegerType::get(rewriter.getContext(), 64));
+            loc, IntegerType::get(rewriter.getContext(), 64), value);
         break;
       case PrintConversion::SignExt64:
         value = rewriter.create<arith::ExtSIOp>(
-            loc, value, IntegerType::get(rewriter.getContext(), 64));
+            loc, IntegerType::get(rewriter.getContext(), 64), value);
         break;
       case PrintConversion::None:
         break;
index c4e1632..0252626 100644 (file)
@@ -94,8 +94,8 @@ struct IndexCastOpInterface
         getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
                       layout, sourceType.getMemorySpace());
 
-    replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
-                                                     resultType);
+    replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
+                                                     source);
     return success();
   }
 };
index 74d6d42..fd98f31 100644 (file)
@@ -835,15 +835,15 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
                                                 scalingFactor);
     }
     Value numWorkersIndex =
-        b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
+        b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
     Value numWorkersFloat =
-        b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
+        b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
     Value scaledNumWorkers =
         b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
     Value scaledNumInt =
-        b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
+        b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
     Value scaledWorkers =
-        b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
+        b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);
 
     Value maxComputeBlocks = b.create<arith::MaxSIOp>(
         b.create<arith::ConstantIndexOp>(1), scaledWorkers);
index d47d6ea..ed9170c 100644 (file)
@@ -887,7 +887,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
   auto i32Vec = broadcast(builder.getI32Type(), shape);
 
   // exp2(k)
-  Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
+  Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
   Value exp2KValue = exp2I32(builder, k);
 
   // exp(x) = exp(y) * exp2(k)
@@ -1042,7 +1042,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
 
   auto i32Vec = broadcast(builder.getI32Type(), shape);
   auto fPToSingedInteger = [&](Value a) -> Value {
-    return builder.create<arith::FPToSIOp>(a, i32Vec);
+    return builder.create<arith::FPToSIOp>(i32Vec, a);
   };
 
   auto modulo4 = [&](Value a) -> Value {
index 21cb8d6..da672dc 100644 (file)
@@ -165,7 +165,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
         alloc.alignmentAttr());
     // Insert a cast so we have the same type as the old alloc.
     auto resultCast =
-        rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
+        rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
 
     rewriter.replaceOp(alloc, {resultCast});
     return success();
@@ -2156,8 +2156,8 @@ public:
       rewriter.replaceOp(subViewOp, subViewOp.source());
       return success();
     }
-    rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
-                                        subViewOp.getType());
+    rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
+                                        subViewOp.source());
     return success();
   }
 };
@@ -2177,7 +2177,7 @@ struct SubViewReturnTypeCanonicalizer {
 /// A canonicalizer wrapper to replace SubViewOps.
 struct SubViewCanonicalizer {
   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
-    rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
+    rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
   }
 };
 
@@ -2422,7 +2422,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
                                              viewOp.getOperand(0),
                                              viewOp.byte_shift(), newOperands);
     // Insert a cast so we have the same type as the old memref type.
-    rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
+    rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
     return success();
   }
 };
index 2a83977..ca490a2 100644 (file)
@@ -101,8 +101,8 @@ public:
         Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
         size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
         if (!size.getType().isa<IndexType>())
-          size = rewriter.create<arith::IndexCastOp>(loc, size,
-                                                     rewriter.getIndexType());
+          size = rewriter.create<arith::IndexCastOp>(
+              loc, rewriter.getIndexType(), size);
         sizes[i] = size;
       } else {
         sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
index 94e87b3..07875fc 100644 (file)
@@ -309,7 +309,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
     Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
                                                    ValueRange{ivs[0], idx});
     val =
-        rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
+        rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val);
     rewriter.create<memref::StoreOp>(loc, val, ind, idx);
   }
   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
index 72e70dd..be707b1 100644 (file)
@@ -831,11 +831,11 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
     if (!etp.isa<IndexType>()) {
       if (etp.getIntOrFloatBitWidth() < 32)
         vload = rewriter.create<arith::ExtUIOp>(
-            loc, vload, vectorType(codegen, rewriter.getI32Type()));
+            loc, vectorType(codegen, rewriter.getI32Type()), vload);
       else if (etp.getIntOrFloatBitWidth() < 64 &&
                !codegen.options.enableSIMDIndex32)
         vload = rewriter.create<arith::ExtUIOp>(
-            loc, vload, vectorType(codegen, rewriter.getI64Type()));
+            loc, vectorType(codegen, rewriter.getI64Type()), vload);
     }
     return vload;
   }
@@ -846,9 +846,9 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
   Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
   if (!load.getType().isa<IndexType>()) {
     if (load.getType().getIntOrFloatBitWidth() < 64)
-      load = rewriter.create<arith::ExtUIOp>(loc, load, rewriter.getI64Type());
+      load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
     load =
-        rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
+        rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
   }
   return load;
 }
@@ -868,7 +868,7 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
   Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
   if (auto vtp = i.getType().dyn_cast<VectorType>()) {
     Value inv =
-        rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType());
+        rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
     mul = genVectorInvariantValue(codegen, rewriter, inv);
   }
   return rewriter.create<arith::AddIOp>(loc, mul, i);
index 31e7fb5..37e077a 100644 (file)
@@ -671,25 +671,25 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
                                            rewriter.getZeroAttr(v0.getType())),
         v0);
   case kTruncF:
-    return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
   case kExtF:
-    return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
   case kCastFS:
-    return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
   case kCastFU:
-    return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
   case kCastSF:
-    return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
   case kCastUF:
-    return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
   case kCastS:
-    return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
   case kCastU:
-    return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
   case kTruncI:
-    return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
   case kBitCast:
-    return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0));
+    return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
   // Binary ops.
   case kMulF:
     return rewriter.create<arith::MulFOp>(loc, v0, v1);
index 1ceebf2..f574713 100644 (file)
@@ -255,7 +255,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
       [&](OpBuilder &b, Location loc) {
         Value res = memref;
         if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+          res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
                               xferOp.indices().end());
@@ -271,7 +271,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
             alloc);
         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
         Value casted =
-            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+            b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
         scf::ValueVector viewAndIndices{casted};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
                               zero);
@@ -309,7 +309,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
       [&](OpBuilder &b, Location loc) {
         Value res = memref;
         if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+          res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
                               xferOp.indices().end());
@@ -324,7 +324,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
                 loc, MemRefType::get({}, vector.getType()), alloc));
 
         Value casted =
-            b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+            b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
         scf::ValueVector viewAndIndices{casted};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
                               zero);
@@ -360,7 +360,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
           [&](OpBuilder &b, Location loc) {
             Value res = memref;
             if (compatibleMemRefType != xferOp.getShapedType())
-              res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
+              res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
             scf::ValueVector viewAndIndices{res};
             viewAndIndices.insert(viewAndIndices.end(),
                                   xferOp.indices().begin(),
@@ -369,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
           },
           [&](OpBuilder &b, Location loc) {
             Value casted =
-                b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
+                b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
             scf::ValueVector viewAndIndices{casted};
             viewAndIndices.insert(viewAndIndices.end(),
                                   xferOp.getTransferRank(), zero);