From 397336dcab81dd0bb95e50e95c737c3e77ee7985 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 10 Feb 2021 15:57:02 -0800 Subject: [PATCH] [mlir][vector] Add missing support for contract of integer lowering. Some of the lowering of vector.contract didn't support integer case. Since reduction of integer cannot accumulate we always break up the reduction op, it should be merged by a separate canonicalization if possible. Differential Revision: https://reviews.llvm.org/D96461 --- mlir/lib/Dialect/Vector/VectorTransforms.cpp | 33 +++++++++-- .../Dialect/Vector/vector-contract-transforms.mlir | 65 ++++++++++++++++++---- 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 0a6c88d..200eb55 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1716,6 +1716,24 @@ private: } // namespace +/// Creates an AddIOp if `isInt` is true otherwise create an AddFOp using +/// operands `x` and `y`. +static Value createAdd(Location loc, Value x, Value y, bool isInt, + PatternRewriter &rewriter) { + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); +} + +/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using +/// operands `x and `y`. +static Value createMul(Location loc, Value x, Value y, bool isInt, + PatternRewriter &rewriter) { + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); +} + namespace mlir { /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul @@ -2003,13 +2021,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, // ExtractOp does not allow dynamic indexing, we must unroll explicitly. Value res = rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); + bool isInt = dstType.getElementType().isa(); for (unsigned r = 0; r < dstRows; ++r) { Value a = rewriter.create(op.getLoc(), lhs, r); for (unsigned c = 0; c < dstColumns; ++c) { Value b = rank == 1 ? rhs : rewriter.create(op.getLoc(), rhs, c); - Value m = rewriter.create(op.getLoc(), a, b); + Value m = createMul(op.getLoc(), a, b, isInt, rewriter); Value reduced = rewriter.create( op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"), m, ValueRange{}); @@ -2020,7 +2039,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, } } if (auto acc = op.acc()) - res = rewriter.create(op.getLoc(), res, acc); + res = createAdd(op.getLoc(), res, acc, isInt, rewriter); rewriter.replaceOp(op, res); return success(); } @@ -2176,6 +2195,7 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, VectorType rhsType = op.getRhsType(); Type resType = op.getResultType(); assert(!resType.isa()); + bool isInt = resType.isa(); // Use iterator index 0. int64_t iterIndex = 0; SmallVector iMap = op.getIndexingMaps(); @@ -2190,10 +2210,13 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value m = rewriter.create(loc, op.lhs(), op.rhs()); + Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); StringAttr kind = rewriter.getStringAttr("add"); - return rewriter.create(loc, resType, kind, m, - op.acc()); + Value res = rewriter.create(loc, resType, kind, m, + ValueRange{}); + if (auto acc = op.acc()) + res = createAdd(op.getLoc(), res, acc, isInt, rewriter); + return res; } // Construct new iterator types and affine map array attribute. std::array lowIndexingMaps = { diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir index 2c3ac0f..3adb18c 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -18,8 +18,9 @@ // CHECK-SAME: %[[B:.*1]]: vector<4xf32>, // CHECK-SAME: %[[C:.*2]]: f32 // CHECK: %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32 -// CHECK: return %[[R]] : f32 +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xf32> into f32 +// CHECK: %[[ACC:.*]] = addf %[[R]], %[[C]] : f32 +// CHECK: return %[[ACC]] : f32 func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 @@ -27,6 +28,21 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) return %0 : f32 } +// CHECK-LABEL: func @extract_contract1_int +// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, +// CHECK-SAME: %[[C:.*2]]: i32 +// CHECK: %[[F:.*]] = muli %[[A]], %[[B]] : vector<4xi32> +// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]] : vector<4xi32> into i32 +// CHECK: %[[ACC:.*]] = addi %[[R]], %[[C]] : i32 +// CHECK: return %[[ACC]] : i32 + +func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xi32>, vector<4xi32> into i32 + return %0 : i32 +} + #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (j)>, @@ -61,6 +77,29 @@ func @extract_contract2(%arg0: vector<2x3xf32>, return %0 : vector<2xf32> } +// CHECK-LABEL: func @extract_contract2_int +// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xi32> +// CHECK: %[[R:.*]] = constant dense<0> : vector<2xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> +// CHECK: %[[T2:.*]] = muli %[[T0]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xi32> into i32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> +// CHECK: %[[T7:.*]] = muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xi32> into i32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> +// CHECK: %[[T10:.*]] = addi %[[T9]], %[[C]] : vector<2xi32> +// CHECK: return %[[T10]] : vector<2xi32> +func @extract_contract2_int(%arg0: vector<2x3xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2xi32>) -> vector<2xi32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xi32>, vector<3xi32> into vector<2xi32> + return %0 : vector<2xi32> +} + #vecmat_accesses = [ affine_map<(i, j) -> (j)>, affine_map<(i, j) -> (i, j)>, @@ -162,12 +201,14 @@ func @extract_contract4(%arg0: vector<2x2xf32>, // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> // CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> // CHECK: %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32 -// CHECK: return %[[T7]] : f32 +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 +// CHECK: %[[T4:.*]] = addf %[[T3]], %[[C]] : f32 +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[T6]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 +// CHECK: %[[T9:.*]] = addf %[[T8]], %[[T4]] : f32 +// CHECK: return %[[T9]] : f32 func @full_contract1(%arg0: vector<2x3xf32>, %arg1: vector<2x3xf32>, @@ -200,7 +241,8 @@ func @full_contract1(%arg0: vector<2x3xf32>, // CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> // CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32 +// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]] : vector<3xf32> into f32 +// CHECK: %[[ACC0:.*]] = addf %[[T11]], %[[C]] : f32 // // CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> // CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf @@ -210,8 +252,9 @@ func @full_contract1(%arg0: vector<2x3xf32>, // CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> // CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> // CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32 -// CHECK: return %[[T23]] : f32 +// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]] : vector<3xf32> into f32 +// CHECK: %[[ACC1:.*]] = addf %[[T23]], %[[ACC0]] : f32 +// CHECK: return %[[ACC1]] : f32 func @full_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3x2xf32>, -- 2.7.4