if (!dense || dense.isSplat())
return failure();
- // Calculate the linearized position of the continous chunk of elements to
+ // Calculate the linearized position of the continuous chunk of elements to
// extract.
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- llvm::copy(getI64SubArray(extractOp.getPosition()),
- completePositions.begin());
+ copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
int64_t elemBeginPosition =
linearize(completePositions, computeStrides(vecTy.getShape()));
auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
}
};
+// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
+class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+ // unless the source vector constant has a single use.
+ static constexpr int64_t vectorSizeFoldThreshold = 256;
+
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+ // Return if 'InsertOp' operand is not defined by a compatible vector
+ // ConstantOp.
+ TypedValue<VectorType> destVector = op.getDest();
+ Attribute vectorDestCst;
+ if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
+ return failure();
+
+ VectorType destTy = destVector.getType();
+ if (destTy.isScalable())
+ return failure();
+
+ // Make sure we do not create too many large constants.
+ if (destTy.getNumElements() > vectorSizeFoldThreshold &&
+ !destVector.hasOneUse())
+ return failure();
+
+ auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
+
+ Value sourceValue = op.getSource();
+ Attribute sourceCst;
+ if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
+ return failure();
+
+ // Calculate the linearized position of the continuous chunk of elements to
+ // insert.
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ copy(getI64SubArray(op.getPosition()), completePositions.begin());
+ int64_t insertBeginPosition =
+ linearize(completePositions, computeStrides(destTy.getShape()));
+
+ SmallVector<Attribute> insertedValues;
+ if (auto denseSource = sourceCst.dyn_cast<DenseElementsAttr>())
+ llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
+ else
+ insertedValues.push_back(sourceCst);
+
+ auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
+ copy(insertedValues, allValues.begin() + insertBeginPosition);
+ auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
+ return success();
+ }
+};
+
} // namespace
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+ InsertOpConstantFolder>(context);
}
// Eliminates insert operations that produce values identical to their source
// Expand offsets and sizes to match the vector rank.
SmallVector<int64_t, 4> offsets(sliceRank, 0);
- llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()),
- offsets.begin());
+ copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
- llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
+ copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
- // Calcualte the slice elements by enumerating all slice positions and
+ // Calculate the slice elements by enumerating all slice positions and
// linearizing them. The enumeration order is lexicographic which yields a
// sequence of monotonically increasing linearized position indices.
auto denseValuesBegin = dense.value_begin<Attribute>();
// -----
+// CHECK-LABEL: func.func @insert_1d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[0, 1, 9]> : vector<3xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<3xi32>, vector<3xi32>
+func.func @insert_1d_constant() -> (vector<3xi32>, vector<3xi32>, vector<3xi32>) {
+ %vcst = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+ %icst = arith.constant 9 : i32
+ %a = vector.insert %icst, %vcst[0] : i32 into vector<3xi32>
+ %b = vector.insert %icst, %vcst[1] : i32 into vector<3xi32>
+ %c = vector.insert %icst, %vcst[2] : i32 into vector<3xi32>
+ return %a, %b, %c : vector<3xi32>, vector<3xi32>, vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insert_2d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[99, 1, 2\], \[3, 4, 5\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[3, 4, 99\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[90, 91, 92\], \[3, 4, 5\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[90, 91, 92\]\]}}> : vector<2x3xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]]
+func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) {
+ %vcst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+ %cst_scalar = arith.constant 99 : i32
+ %cst_1d = arith.constant dense<[90, 91, 92]> : vector<3xi32>
+ %a = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32>
+ %b = vector.insert %cst_scalar, %vcst[1, 2] : i32 into vector<2x3xi32>
+ %c = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32>
+ %d = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32>
+ return %a, %b, %c, %d : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insert_2d_splat_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[0, 99, 0\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[33, 33, 33\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[ECST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[33, 33, 33\]\]}}> : vector<2x3xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]]
+func.func @insert_2d_splat_constant()
+ -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) {
+ %vcst = arith.constant dense<0> : vector<2x3xi32>
+ %cst_zero = arith.constant 0 : i32
+ %cst_scalar = arith.constant 99 : i32
+ %cst_1d = arith.constant dense<33> : vector<3xi32>
+ %a = vector.insert %cst_zero, %vcst[0, 0] : i32 into vector<2x3xi32>
+ %b = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32>
+ %c = vector.insert %cst_scalar, %vcst[1, 1] : i32 into vector<2x3xi32>
+ %d = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32>
+ %e = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32>
+ return %a, %b, %c, %d, %e : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @insert_element_fold
// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32>
// CHECK: return %[[V]]