From: Andy Davis Date: Fri, 20 Dec 2019 00:04:59 +0000 (-0800) Subject: [VectorOps] Update vector transfer_read/write ops to operatate on memrefs with vector... X-Git-Tag: llvmorg-11-init~1466^2~20 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8020ad3e396bcca8dba94cea397cece81b76b119;p=platform%2Fupstream%2Fllvm.git [VectorOps] Update vector transfer_read/write ops to operatate on memrefs with vector element type. Update vector transfer_read/write ops to operatate on memrefs with vector element type. This handle cases where the memref vector element type represents the minimal memory transfer unit (or multiple of the minimal memory transfer unit). PiperOrigin-RevId: 286482115 --- diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index 7dcac62..d5e8431 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -746,10 +746,15 @@ def Vector_TransferReadOp : let description = [{ The `vector.transfer_read` op performs a blocking read from a slice within - a scalar [MemRef](../LangRef.md#memref-type) supplied as its first operand - into a [vector](../LangRef.md#vector-type) of the same elemental type. The - slice is further defined by a full-rank index within the MemRef, supplied as - the operands `2 .. 1 + rank(memref)`. The permutation_map + a [MemRef](../LangRef.md#memref-type) supplied as its first operand + into a [vector](../LangRef.md#vector-type) of the same base elemental type. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The size of the slice is specified by the @@ -854,6 +859,11 @@ def Vector_TransferReadOp : memref, vector<128xf32> } } + + // Read from a memref with vector element type. + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 + {permutation_map = (d0, d1)->(d0, d1)} + : memref>, vector<1x1x4x3xf32> ``` }]; @@ -878,10 +888,15 @@ def Vector_TransferWriteOp : let description = [{ The `vector.transfer_write` performs a blocking write from a [vector](../LangRef.md#vector-type), supplied as its first operand, into a - slice within a scalar [MemRef](../LangRef.md#memref-type) of the same - elemental type, supplied as its second operand. The slice is further defined - by a full-rank index within the MemRef, supplied as the operands - `3 .. 2 + rank(memref)`. + slice within a [MemRef](../LangRef.md#memref-type) of the same base + elemental type, supplied as its second operand. + + A vector memref operand must have its vector element type match a suffix + (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, + vector<1x1x4x3xf32>). + + The slice is further defined by a full-rank index within the MemRef, + supplied as the operands `3 .. 2 + rank(memref)`. The permutation_map [attribute](../LangRef.md#attributes) is an [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The size of the slice is specified by the @@ -915,6 +930,11 @@ def Vector_TransferWriteOp : {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : vector<16x32x64xf32>, memref }}}} + + // write to a memref with vector element type. + vector.transfer_write %4, %arg1[%c3, %c3] + {permutation_map = (d0, d1)->(d0, d1)} + : vector<1x1x4x3xf32>, memref> ``` }]; @@ -1048,7 +1068,7 @@ def Vector_TupleOp : Note that this operation is used during the vector op unrolling transformation and should be removed before lowering to lower-level dialects. - + Examples: ``` diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 541b542..8a69467 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1420,6 +1420,59 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap, return success(); } +static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType, + VectorType vectorType, + AffineMap permutationMap) { + auto memrefElementType = memrefType.getElementType(); + if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { + // Memref has vector element type. + + // Check that 'memrefVectorElementType' and vector element types match. + if (memrefVectorElementType.getElementType() != vectorType.getElementType()) + return op->emitOpError( + "requires memref and vector types of the same elemental type"); + + // Check that memref vector type is a suffix of 'vectorType. + unsigned memrefVecEltRank = memrefVectorElementType.getRank(); + unsigned resultVecRank = vectorType.getRank(); + if (memrefVecEltRank > resultVecRank) + return op->emitOpError( + "requires memref vector element and vector result ranks to match."); + // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h. + unsigned rankOffset = resultVecRank - memrefVecEltRank; + auto memrefVecEltShape = memrefVectorElementType.getShape(); + auto resultVecShape = vectorType.getShape(); + for (unsigned i = 0; i < memrefVecEltRank; ++i) + if (memrefVecEltShape[i] != resultVecShape[rankOffset + i]) + return op->emitOpError( + "requires memref vector element shape to match suffix of " + "vector result shape."); + // Check that permutation map results match 'rankOffset' of vector type. + if (permutationMap.getNumResults() != rankOffset) + return op->emitOpError("requires a permutation_map with result dims of " + "the same rank as the vector type"); + } else { + // Memref has scalar element type. + + // Check that memref and vector element types match. + if (memrefType.getElementType() != vectorType.getElementType()) + return op->emitOpError( + "requires memref and vector types of the same elemental type"); + + // Check that permutation map results match rank of vector type. + if (permutationMap.getNumResults() != vectorType.getRank()) + return op->emitOpError("requires a permutation_map with result dims of " + "the same rank as the vector type"); + } + + if (permutationMap.getNumSymbols() != 0) + return op->emitOpError("requires permutation_map without symbols"); + if (permutationMap.getNumInputs() != memrefType.getRank()) + return op->emitOpError("requires a permutation_map with input dims of the " + "same rank as the memref type"); + return success(); +} + static void print(OpAsmPrinter &p, TransferReadOp op) { p << op.getOperationName() << " " << op.memref() << "[" << op.indices() << "], " << op.padding() << " "; @@ -1459,26 +1512,35 @@ static LogicalResult verify(TransferReadOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); - if (memrefType.getElementType() != vectorType.getElementType()) - return op.emitOpError( - "requires memref and vector types of the same elemental type"); - auto elementalType = op.padding()->getType(); - if (!VectorType::isValidElementType(elementalType)) - return op.emitOpError("requires valid padding vector elemental type"); - if (elementalType != vectorType.getElementType()) - return op.emitOpError( - "requires formal padding and vector of the same elemental type"); - if (llvm::size(op.indices()) != memrefType.getRank()) - return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + auto paddingType = op.padding()->getType(); auto permutationMap = op.permutation_map(); - if (permutationMap.getNumSymbols() != 0) - return op.emitOpError("requires permutation_map without symbols"); - if (permutationMap.getNumInputs() != memrefType.getRank()) - return op.emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type"); - if (permutationMap.getNumResults() != vectorType.getRank()) - return op.emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type"); + auto memrefElementType = memrefType.getElementType(); + + if (static_cast(op.indices().size()) != memrefType.getRank()) + return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + + if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + permutationMap))) + return failure(); + + if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { + // Memref has vector element type. + // Check that 'memrefVectorElementType' and 'paddingType' types match. + if (memrefVectorElementType != paddingType) + return op.emitOpError( + "requires memref element type and padding type to match."); + + } else { + // Check that 'paddingType' is valid to store in a vector type. + if (!VectorType::isValidElementType(paddingType)) + return op.emitOpError("requires valid padding vector elemental type"); + + // Check that padding type and vector element types match. + if (paddingType != vectorType.getElementType()) + return op.emitOpError( + "requires formal padding and vector of the same elemental type"); + } + return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); } @@ -1519,24 +1581,15 @@ static LogicalResult verify(TransferWriteOp op) { // Consistency of elemental types in memref and vector. MemRefType memrefType = op.getMemRefType(); VectorType vectorType = op.getVectorType(); - if (memrefType.getElementType() != vectorType.getElementType()) - return op.emitOpError( - "requires memref and vector types of the same elemental type"); + auto permutationMap = op.permutation_map(); + if (llvm::size(op.indices()) != memrefType.getRank()) return op.emitOpError("requires ") << memrefType.getRank() << " indices"; - // Consistency of AffineMap attribute. - auto permutationMap = op.permutation_map(); - if (permutationMap.getNumSymbols() != 0) - return op.emitOpError("requires a symbol-less permutation_map"); - if (permutationMap.getNumInputs() != memrefType.getRank()) - return op.emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type: ") - << permutationMap.getNumInputs() << " vs " << memrefType; - if (permutationMap.getNumResults() != vectorType.getRank()) - return op.emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type.") - << permutationMap.getNumResults() << " vs " << vectorType; + if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType, + permutationMap))) + return failure(); + return verifyPermutationMap(permutationMap, [&op](Twine t) { return op.emitOpError(t); }); } diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index c208c92..9ef39e2 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -308,6 +308,36 @@ func @test_vector.transfer_read(%arg0: memref) { // ----- +func @test_vector.transfer_read(%arg0: memref>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{requires memref and vector types of the same elemental type}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref>, vector<1x1x4x3xi32> +} + +// ----- + +func @test_vector.transfer_read(%arg0: memref>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{requires memref vector element and vector result ranks to match}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref>, vector<3xf32> +} + +// ----- + +func @test_vector.transfer_read(%arg0: memref>) { + %c3 = constant 3 : index + %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}} + %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref>, vector<1x1x2x3xf32> +} + +// ----- + func @test_vector.transfer_write(%arg0: memref) { %c3 = constant 3 : index %cst = constant dense<3.0> : vector<128 x f32> diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index e160799..d99a7df 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -1,24 +1,35 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s +// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1) + // CHECK-LABEL: func @vector_transfer_ops( -func @vector_transfer_ops(%arg0: memref) { +func @vector_transfer_ops(%arg0: memref, + %arg1 : memref>) { + // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 + %vf0 = splat %f0 : vector<4x3xf32> + // - // CHECK: %0 = vector.transfer_read + // CHECK: vector.transfer_read %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d0)} : memref, vector<128xf32> - // CHECK: %1 = vector.transfer_read + // CHECK: vector.transfer_read %1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d1, d0)} : memref, vector<3x7xf32> // CHECK: vector.transfer_read %2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d0)} : memref, vector<128xf32> // CHECK: vector.transfer_read %3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d1)} : memref, vector<128xf32> - // + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref>, vector<1x1x4x3xf32> + %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref>, vector<1x1x4x3xf32> + // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0)} : vector<128xf32>, memref // CHECK: vector.transfer_write vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d1, d0)} : vector<3x7xf32>, memref + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {permutation_map = #[[MAP0]]} : vector<1x1x4x3xf32>, memref> + vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, memref> + return }