def Vector_ReductionOp :
Vector_Op<"reduction", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
- TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ TCresVTEtIsSameAsOpBase<0, 0>>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface,
+ ["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector,
Optional<AnyType>:$acc)>,
Results<(outs AnyType:$dest)> {
return nullptr;
}
+Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getVectorType().getShape());
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
}
};
+struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
+ UnrollReductionPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options)
+ : OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
+ PatternRewriter &rewriter) const override {
+ Optional<SmallVector<int64_t, 4>> targetShape =
+ getTargetShape(options, reductionOp);
+ if (!targetShape)
+ return failure();
+ SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
+ int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
+
+ // Create unrolled vector reduction.
+ Location loc = reductionOp.getLoc();
+ Value accumulator = nullptr;
+ for (int64_t i = 0; i < ratio; ++i) {
+ SmallVector<int64_t> offsets =
+ getVectorOffset(originalSize, *targetShape, i);
+ SmallVector<int64_t> strides(offsets.size(), 1);
+ Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.vector(), offsets, *targetShape, strides);
+ Operation *newOp = cloneOpWithOperandsAndTypes(
+ rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
+ Value result = newOp->getResult(0);
+
+ if (!accumulator) {
+ // This is the first reduction.
+ accumulator = result;
+ } else {
+ // On subsequent reduction, combine with the accumulator.
+ accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(),
+ accumulator, result);
+ }
+ }
+
+ rewriter.replaceOp(reductionOp, accumulator);
+ return success();
+ }
+
+private:
+ const vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
- UnrollMultiReductionPattern>(patterns.getContext(), options);
+ UnrollReductionPattern, UnrollMultiReductionPattern>(
+ patterns.getContext(), options);
}
void mlir::vector::populatePropagateVectorDistributionPatterns(
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: return %[[V2]] : vector<4xf32>
+
+// CHECK-LABEL: func @vector_reduction(
+// CHECK-SAME: %[[v:.*]]: vector<8xf32>
+// CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
+// CHECK: %[[r0:.*]] = vector.reduction <add>, %[[s0]]
+// CHECK: %[[s1:.*]] = vector.extract_strided_slice %[[v]] {offsets = [2], sizes = [2]
+// CHECK: %[[r1:.*]] = vector.reduction <add>, %[[s1]]
+// CHECK: %[[add1:.*]] = arith.addf %[[r0]], %[[r1]]
+// CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[v]] {offsets = [4], sizes = [2]
+// CHECK: %[[r2:.*]] = vector.reduction <add>, %[[s2]]
+// CHECK: %[[add2:.*]] = arith.addf %[[add1]], %[[r2]]
+// CHECK: %[[s3:.*]] = vector.extract_strided_slice %[[v]] {offsets = [6], sizes = [2]
+// CHECK: %[[r3:.*]] = vector.reduction <add>, %[[s3]]
+// CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
+// CHECK: return %[[add3]]
+func @vector_reduction(%v : vector<8xf32>) -> f32 {
+ %0 = vector.reduction <add>, %v : vector<8xf32> into f32
+ return %0 : f32
+}
+
return success(isa<arith::AddFOp, vector::FMAOp,
vector::MultiDimReductionOp>(op));
}));
+ populateVectorUnrollPatterns(
+ patterns, UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{2})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::ReductionOp>(op));
+ }));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =