From 1423e8bf5dda75877c0414dd26d024fd770d71fb Mon Sep 17 00:00:00 2001 From: Michal Terepeta Date: Fri, 3 Dec 2021 08:55:52 +0000 Subject: [PATCH] [mlir][Vector] Support 0-D vectors in `BitCastOp` The implementation only allows to bit-cast between two 0-D vectors. We could probably support casting from/to vectors like `vector<1xf32>`, but I wasn't convinced that this would be important and it would require breaking the invariant that `BitCastOp` works only on vectors with equal rank. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114854 --- mlir/include/mlir/Dialect/Vector/VectorOps.td | 12 ++++++++---- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 6 +++--- mlir/lib/Dialect/Vector/VectorOps.cpp | 16 ++++++++++++---- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 14 ++++++++++++++ mlir/test/Dialect/Vector/invalid.mlir | 14 ++++++++++++++ mlir/test/Dialect/Vector/ops.mlir | 10 +++++++--- .../Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir | 15 +++++++++++++++ 7 files changed, 73 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 8eaf785..74edc5f 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -675,7 +675,7 @@ def Vector_InsertElementOp : position and inserts the source into the destination at the proper position. Note that this instruction resembles vector.insert, but is restricted to 0-D - and 1-D vectors and relaxed to dynamic indices. + and 1-D vectors and relaxed to dynamic indices. It is meant to be closer to LLVM's version: https://llvm.org/docs/LangRef.html#insertelement-instruction @@ -2025,13 +2025,14 @@ def Vector_ShapeCastOp : def Vector_BitCastOp : Vector_Op<"bitcast", [NoSideEffect, AllRanksMatch<["source", "result"]>]>, - Arguments<(ins AnyVector:$source)>, - Results<(outs AnyVector:$result)>{ + Arguments<(ins AnyVectorOfAnyRank:$source)>, + Results<(outs AnyVectorOfAnyRank:$result)>{ let summary = "bitcast casts between vectors"; let description = [{ The bitcast operation casts between vectors of the same rank, the minor 1-D vector size is casted to a vector with a different element type but same - bitwidth. + bitwidth. In case of 0-D vectors, the bitwidth of element types must be + equal. Example: @@ -2044,6 +2045,9 @@ def Vector_BitCastOp : // Example casting to an element type of the same size. %5 = vector.bitcast %4 : vector<5x1x4x3xf32> to vector<5x1x4x3xi32> + + // Example casting of 0-D vectors. + %7 = vector.bitcast %6 : vector to vector ``` }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index bc42922..9b4dce4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -121,9 +121,9 @@ public: LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Only 1-D vectors can be lowered to LLVM. - VectorType resultTy = bitCastOp.getType(); - if (resultTy.getRank() != 1) + // Only 0-D and 1-D vectors can be lowered to LLVM. + VectorType resultTy = bitCastOp.getResultVectorType(); + if (resultTy.getRank() > 1) return failure(); Type newResultTy = typeConverter->convertType(resultTy); rewriter.replaceOpWithNewOp(bitCastOp, newResultTy, diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 859067b..76fcb97 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3702,12 +3702,20 @@ static LogicalResult verify(BitCastOp op) { } DataLayout dataLayout = DataLayout::closest(op); - if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) * - sourceVectorType.getShape().back() != - dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) * - resultVectorType.getShape().back()) + auto sourceElementBits = + dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); + auto resultElementBits = + dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); + + if (sourceVectorType.getRank() == 0) { + if (sourceElementBits != resultElementBits) + return op.emitOpError("source/result bitwidth of the 0-D vector element " + "types must be equal"); + } else if (sourceElementBits * sourceVectorType.getShape().back() != + resultElementBits * resultVectorType.getShape().back()) { return op.emitOpError( "source/result bitwidth of the minor 1-D vectors must be equal"); + } return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index ce81c4e..0c21cf9e 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,6 +1,20 @@ // RUN: mlir-opt %s -convert-vector-to-llvm -split-input-file | FileCheck %s +func @bitcast_f32_to_i32_vector_0d(%input: vector) -> vector { + %0 = vector.bitcast %input : vector to vector + return %0 : vector +} + +// CHECK-LABEL: @bitcast_f32_to_i32_vector_0d +// CHECK-SAME: %[[input:.*]]: vector +// CHECK: %[[vec_f32_1d:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector to vector<1xf32> +// CHECK: %[[vec_i32_1d:.*]] = llvm.bitcast %[[vec_f32_1d]] : vector<1xf32> to vector<1xi32> +// CHECK: %[[vec_i32_0d:.*]] = builtin.unrealized_conversion_cast %[[vec_i32_1d]] : vector<1xi32> to vector +// CHECK: return %[[vec_i32_0d]] : vector + +// ----- + func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> { %0 = vector.bitcast %input : vector<16xf32> to vector<16xi32> return %0 : vector<16xi32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 7902976..fb69798 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1014,6 +1014,20 @@ func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) { // ----- +func @bitcast_rank_mismatch_to_0d(%arg0 : vector<1xf32>) { + // expected-error@+1 {{op failed to verify that all of {source, result} have same rank}} + %0 = vector.bitcast %arg0 : vector<1xf32> to vector +} + +// ----- + +func @bitcast_rank_mismatch_from_0d(%arg0 : vector) { + // expected-error@+1 {{op failed to verify that all of {source, result} have same rank}} + %0 = vector.bitcast %arg0 : vector to vector<1xf32> +} + +// ----- + func @bitcast_rank_mismatch(%arg0 : vector<5x1x3x2xf32>) { // expected-error@+1 {{op failed to verify that all of {source, result} have same rank}} %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x3x2xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 11bc141..2bd0e13 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -432,8 +432,9 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>, func @bitcast(%arg0 : vector<5x1x3x2xf32>, %arg1 : vector<8x1xi32>, %arg2 : vector<16x1x8xi8>, - %arg3 : vector<8x2x1xindex>) - -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>) { + %arg3 : vector<8x2x1xindex>, + %arg4 : vector) + -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector) { // CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> @@ -459,7 +460,10 @@ func @bitcast(%arg0 : vector<5x1x3x2xf32>, // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x2x1xindex> to vector<8x2x2xf32> %7 = vector.bitcast %arg3 : vector<8x2x1xindex> to vector<8x2x2xf32> - return %0, %1, %2, %3, %4, %5, %6, %7 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32> + // CHECK: vector.bitcast %{{.*}} : vector to vector + %8 = vector.bitcast %arg4 : vector to vector + + return %0, %1, %2, %3, %4, %5, %6, %7, %8 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>, vector } // CHECK-LABEL: @vector_fma diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir index 8e69d65..74bbada 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -55,6 +55,19 @@ func @broadcast_0d(%a: f32) { return } +func @bitcast_0d() { + %0 = arith.constant 42 : i32 + %1 = arith.constant dense<0> : vector + %2 = vector.insertelement %0, %1[] : vector + %3 = vector.bitcast %2 : vector to vector + %4 = vector.extractelement %3[] : vector + %5 = arith.bitcast %4 : f32 to i32 + // CHECK: 42 + vector.print %5: i32 + return +} + + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -68,5 +81,7 @@ func @entry() { call @splat_0d(%4) : (f32) -> () call @broadcast_0d(%4) : (f32) -> () + call @bitcast_0d() : () -> () + return } -- 2.7.4