From caf89c0db679f79ca6c9a75c5acc6151dd380f26 Mon Sep 17 00:00:00 2001 From: Michal Terepeta Date: Mon, 6 Dec 2021 07:59:49 +0000 Subject: [PATCH] [mlir][Vector] Support 0-D vectors in `ConstantMaskOp` To support creating both a mask with just a single `true` and `false` values, I had to relax the restriction in the verifier that the rank is always equal to the length of the attribute array, in other words, we now allow: - `vector.constant_mask [0] : vector` which gets lowered to `arith.constant dense : vector` - `vector.constant_mask [1] : vector` which gets lowered to `arith.constant dense : vector` (the attribute list for the 0-D case must be a singleton containing either `0` or `1`) Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D115023 --- mlir/include/mlir/Dialect/Vector/VectorOps.td | 2 +- mlir/lib/Dialect/Vector/VectorOps.cpp | 13 ++++++++++++- mlir/lib/Dialect/Vector/VectorTransforms.cpp | 15 ++++++++++++++- .../test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 20 ++++++++++++++++++++ mlir/test/Dialect/Vector/invalid.mlir | 14 ++++++++++++++ mlir/test/Dialect/Vector/ops.mlir | 9 +++++++++ .../Dialect/Vector/CPU/test-0-d-vectors.mlir | 15 ++++++++++++++- 7 files changed, 84 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 74edc5f..14afb47 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2111,7 +2111,7 @@ def Vector_TypeCastOp : def Vector_ConstantMaskOp : Vector_Op<"constant_mask", [NoSideEffect]>, Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>, - Results<(outs VectorOf<[I1]>)> { + Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a constant vector mask"; let description = [{ Creates and returns a vector mask where elements of the result vector diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 76fcb97..1b18b19 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3924,8 +3924,19 @@ void vector::TransposeOp::getTransp(SmallVectorImpl &results) { //===----------------------------------------------------------------------===// static LogicalResult verify(ConstantMaskOp &op) { - // Verify that array attr size matches the rank of the vector result. auto resultType = op.getResult().getType().cast(); + // Check the corner case of 0-D vectors first. + if (resultType.getRank() == 0) { + if (op.mask_dim_sizes().size() != 1) + return op->emitError("array attr must have length 1 for 0-D vectors"); + auto dim = op.mask_dim_sizes()[0].cast().getInt(); + if (dim != 0 && dim != 1) + return op->emitError( + "mask dim size must be either 0 or 1 for 0-D vectors"); + return success(); + } + + // Verify that array attr size matches the rank of the vector result. if (static_cast(op.mask_dim_sizes().size()) != resultType.getRank()) return op.emitOpError( "must specify array attr of size equal vector result rank"); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 876f8ae..6d50838 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -960,7 +960,20 @@ public: auto dstType = op.getType(); auto eltType = dstType.getElementType(); auto dimSizes = op.mask_dim_sizes(); - int64_t rank = dimSizes.size(); + int64_t rank = dstType.getRank(); + + if (rank == 0) { + assert(dimSizes.size() == 1 && + "Expected exactly one dim size for a 0-D vector"); + bool value = dimSizes[0].cast().getInt() == 1; + rewriter.replaceOpWithNewOp( + op, dstType, + DenseIntElementsAttr::get( + VectorType::get(ArrayRef{}, rewriter.getI1Type()), + ArrayRef{value})); + return success(); + } + int64_t trueDim = std::min(dstType.getDimSize(0), dimSizes[0].cast().getInt()); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 0c21cf9e..ee3c756 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1396,6 +1396,26 @@ func @transfer_read_1d_mask(%A : memref, %base : index) -> vector<5xf32> // ----- +func @genbool_0d_f() -> vector { + %0 = vector.constant_mask [0] : vector + return %0 : vector +} +// CHECK-LABEL: func @genbool_0d_f +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector +// CHECK: return %[[VAL_0]] : vector + +// ----- + +func @genbool_0d_t() -> vector { + %0 = vector.constant_mask [1] : vector + return %0 : vector +} +// CHECK-LABEL: func @genbool_0d_t +// CHECK: %[[VAL_0:.*]] = arith.constant dense : vector +// CHECK: return %[[VAL_0]] : vector + +// ----- + func @genbool_1d() -> vector<8xi1> { %0 = vector.constant_mask [4] : vector<8xi1> return %0 : vector<8xi1> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index fb69798..63e30cf 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -884,6 +884,20 @@ func @create_mask() { // ----- +func @constant_mask_0d_no_attr() { + // expected-error@+1 {{array attr must have length 1 for 0-D vectors}} + %0 = vector.constant_mask [] : vector +} + +// ----- + +func @constant_mask_0d_bad_attr() { + // expected-error@+1 {{mask dim size must be either 0 or 1 for 0-D vectors}} + %0 = vector.constant_mask [2] : vector +} + +// ----- + func @constant_mask() { // expected-error@+1 {{must specify array attr of size equal vector result rank}} %0 = vector.constant_mask [3, 2, 7] : vector<4x3xi1> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 2bd0e13..43c5abd 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -376,6 +376,15 @@ func @create_vector_mask() { return } +// CHECK-LABEL: @constant_vector_mask_0d +func @constant_vector_mask_0d() { + // CHECK: vector.constant_mask [0] : vector + %0 = vector.constant_mask [0] : vector + // CHECK: vector.constant_mask [1] : vector + %1 = vector.constant_mask [1] : vector + return +} + // CHECK-LABEL: @constant_vector_mask func @constant_vector_mask() { // CHECK: vector.constant_mask [3, 2] : vector<4x3xi1> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir index 74bbada..a0d4c3d 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -68,6 +68,16 @@ func @bitcast_0d() { } +func @constant_mask_0d() { + %1 = vector.constant_mask [0] : vector + // CHECK: ( 0 ) + vector.print %1: vector + %2 = vector.constant_mask [1] : vector + // CHECK: ( 1 ) + vector.print %2: vector + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -78,10 +88,13 @@ func @entry() { call @print_vector_0d(%3) : (vector) -> () %4 = arith.constant 42.0 : f32 + + // Warning: these must be called in their textual order of definition in the + // file to not mess up FileCheck. call @splat_0d(%4) : (f32) -> () call @broadcast_0d(%4) : (f32) -> () - call @bitcast_0d() : () -> () + call @constant_mask_0d() : () -> () return } -- 2.7.4