[mlir][vector] Add fold pattern for InsertOp(Constant into Constant)
authorJakub Kuderski <kubak@google.com>
Sat, 26 Nov 2022 04:00:48 +0000 (23:00 -0500)
committerJakub Kuderski <kubak@google.com>
Sat, 26 Nov 2022 04:01:29 +0000 (23:01 -0500)
This pattern comes with vector size threshold to make sure we do not
introduce too many large constants.

This help clean up code created by the Wide Integer Emulation pass.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D138733

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index eebf590..2f9bca6 100644 (file)
@@ -1635,11 +1635,10 @@ public:
     if (!dense || dense.isSplat())
       return failure();
 
-    // Calculate the linearized position of the continous chunk of elements to
+    // Calculate the linearized position of the continuous chunk of elements to
     // extract.
     llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
-    llvm::copy(getI64SubArray(extractOp.getPosition()),
-               completePositions.begin());
+    copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
     int64_t elemBeginPosition =
         linearize(completePositions, computeStrides(vecTy.getShape()));
     auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
@@ -2084,11 +2083,68 @@ public:
   }
 };
 
+// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
+class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+  // unless the source vector constant has a single use.
+  static constexpr int64_t vectorSizeFoldThreshold = 256;
+
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    // Return if 'InsertOp' operand is not defined by a compatible vector
+    // ConstantOp.
+    TypedValue<VectorType> destVector = op.getDest();
+    Attribute vectorDestCst;
+    if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
+      return failure();
+
+    VectorType destTy = destVector.getType();
+    if (destTy.isScalable())
+      return failure();
+
+    // Make sure we do not create too many large constants.
+    if (destTy.getNumElements() > vectorSizeFoldThreshold &&
+        !destVector.hasOneUse())
+      return failure();
+
+    auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
+
+    Value sourceValue = op.getSource();
+    Attribute sourceCst;
+    if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
+      return failure();
+
+    // Calculate the linearized position of the continuous chunk of elements to
+    // insert.
+    llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+    copy(getI64SubArray(op.getPosition()), completePositions.begin());
+    int64_t insertBeginPosition =
+        linearize(completePositions, computeStrides(destTy.getShape()));
+
+    SmallVector<Attribute> insertedValues;
+    if (auto denseSource = sourceCst.dyn_cast<DenseElementsAttr>())
+      llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
+    else
+      insertedValues.push_back(sourceCst);
+
+    auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
+    copy(insertedValues, allValues.begin() + insertBeginPosition);
+    auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
+    return success();
+  }
+};
+
 } // namespace
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+              InsertOpConstantFolder>(context);
 }
 
 // Eliminates insert operations that produce values identical to their source
@@ -2744,13 +2800,12 @@ public:
 
     // Expand offsets and sizes to match the vector rank.
     SmallVector<int64_t, 4> offsets(sliceRank, 0);
-    llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()),
-               offsets.begin());
+    copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
 
     SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
-    llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
+    copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
 
-    // Calcualte the slice elements by enumerating all slice positions and
+    // Calculate the slice elements by enumerating all slice positions and
     // linearizing them. The enumeration order is lexicographic which yields a
     // sequence of monotonically increasing linearized position indices.
     auto denseValuesBegin = dense.value_begin<Attribute>();
index 19a06af..7aabcec 100644 (file)
@@ -1795,6 +1795,64 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func.func @insert_1d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[0, 1, 9]> : vector<3xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<3xi32>, vector<3xi32>
+func.func @insert_1d_constant() -> (vector<3xi32>, vector<3xi32>, vector<3xi32>) {
+  %vcst = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+  %icst = arith.constant 9 : i32
+  %a = vector.insert %icst, %vcst[0] : i32 into vector<3xi32>
+  %b = vector.insert %icst, %vcst[1] : i32 into vector<3xi32>
+  %c = vector.insert %icst, %vcst[2] : i32 into vector<3xi32>
+  return %a, %b, %c : vector<3xi32>, vector<3xi32>, vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insert_2d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[99, 1, 2\], \[3, 4, 5\]\]}}> : vector<2x3xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[3, 4, 99\]\]}}> : vector<2x3xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[90, 91, 92\], \[3, 4, 5\]\]}}> : vector<2x3xi32>
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[90, 91, 92\]\]}}> : vector<2x3xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]]
+func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) {
+  %vcst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+  %cst_scalar = arith.constant 99 : i32
+  %cst_1d = arith.constant dense<[90, 91, 92]> : vector<3xi32>
+  %a = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32>
+  %b = vector.insert %cst_scalar, %vcst[1, 2] : i32 into vector<2x3xi32>
+  %c = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32>
+  %d = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32>
+  return %a, %b, %c, %d : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insert_2d_splat_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[0, 99, 0\]\]}}> : vector<2x3xi32>
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[33, 33, 33\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
+//   CHECK-DAG: %[[ECST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[33, 33, 33\]\]}}> : vector<2x3xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]]
+func.func @insert_2d_splat_constant()
+  -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) {
+  %vcst = arith.constant dense<0> : vector<2x3xi32>
+  %cst_zero = arith.constant 0 : i32
+  %cst_scalar = arith.constant 99 : i32
+  %cst_1d = arith.constant dense<33> : vector<3xi32>
+  %a = vector.insert %cst_zero, %vcst[0, 0] : i32 into vector<2x3xi32>
+  %b = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32>
+  %c = vector.insert %cst_scalar, %vcst[1, 1] : i32 into vector<2x3xi32>
+  %d = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32>
+  %e = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32>
+  return %a, %b, %c, %d, %e : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @insert_element_fold
 //       CHECK:   %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32>
 //       CHECK:   return %[[V]]