From d3ddbe153e4ce1377653c8fb2936334bf9d105cf Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 31 Jan 2023 00:05:19 +0000 Subject: [PATCH] [mlir][vector] Clarify vector.contract promotion behavior This commit updates vector.contract documentation to clarify the promotion behavior if operands and the result have different bitwidths. It also adds a check to disable signed/unsigned integer types and only allow signless integers. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D142915 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 5 +++++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 13 +++++++++---- mlir/test/Dialect/Vector/invalid.mlir | 11 +++++++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index e0711a4..3dc3123 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -110,6 +110,11 @@ def Vector_ContractionOp : num_batch_dims (see dimension type descriptions below)). For K = 0 (no free or batch dimensions), the accumulator and output are a scalar. + If operands and the result have types of different bitwidths, operands are + promoted to have the same bitwidth as the result before performing the + contraction. For integer types, only signless integer types are supported, + and the promotion happens via sign extension. + Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp) specify the dynamic dimension sizes of valid data within the lhs/rhs vector arguments. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 974854e..cefd629 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -803,10 +803,15 @@ static LogicalResult verifyOutputShape( } LogicalResult ContractionOp::verify() { - auto lhsType = getLhsType(); - auto rhsType = getRhsType(); - auto accType = getAccType(); - auto resType = getResultType(); + VectorType lhsType = getLhsType(); + VectorType rhsType = getRhsType(); + Type accType = getAccType(); + Type resType = getResultType(); + + if (lhsType.getElementType().isa()) { + if (!lhsType.getElementType().isSignlessInteger()) + return emitOpError("only supports signless integer types"); + } // Verify that an indexing map was specified for each vector operand. if (getIndexingMapsArray().size() != 3) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index e0e8ed3..5132fa8 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1652,3 +1652,14 @@ func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) { // expected-error@+1 {{op failed to verify that position is a multiple of the result length.}} %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32> } + +// ----- + +func.func @integer_vector_contract(%arg0: vector<16x32xsi8>, %arg1: vector<32x16xsi8>, %arg2: vector<16x16xsi32>) -> vector<16x16xsi32> { + // expected-error@+1 {{op only supports signless integer types}} + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<16x32xsi8>, vector<32x16xsi8> into vector<16x16xsi32> + return %0: vector<16x16xsi32> +} -- 2.7.4