let description = [{
The `vector.transfer_read` op performs a blocking read from a slice within
- a scalar [MemRef](../LangRef.md#memref-type) supplied as its first operand
- into a [vector](../LangRef.md#vector-type) of the same elemental type. The
- slice is further defined by a full-rank index within the MemRef, supplied as
- the operands `2 .. 1 + rank(memref)`. The permutation_map
+ a [MemRef](../LangRef.md#memref-type) supplied as its first operand
+ into a [vector](../LangRef.md#vector-type) of the same base elemental type.
+
+ A vector memref operand must have its vector element type match a suffix
+ (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
+ vector<1x1x4x3xf32>).
+
+ The slice is further defined by a full-rank index within the MemRef,
+ supplied as the operands `2 .. 1 + rank(memref)`. The permutation_map
[attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The size of the slice is specified by the
memref<?x?xf32>, vector<128xf32>
}
}
+
+ // Read from a memref with vector element type.
+ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0
+ {permutation_map = (d0, d1)->(d0, d1)}
+ : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
```
}];
let description = [{
The `vector.transfer_write` performs a blocking write from a
[vector](../LangRef.md#vector-type), supplied as its first operand, into a
- slice within a scalar [MemRef](../LangRef.md#memref-type) of the same
- elemental type, supplied as its second operand. The slice is further defined
- by a full-rank index within the MemRef, supplied as the operands
- `3 .. 2 + rank(memref)`.
+ slice within a [MemRef](../LangRef.md#memref-type) of the same base
+ elemental type, supplied as its second operand.
+
+ A vector memref operand must have its vector element type match a suffix
+ (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>,
+ vector<1x1x4x3xf32>).
+
+ The slice is further defined by a full-rank index within the MemRef,
+ supplied as the operands `3 .. 2 + rank(memref)`.
The permutation_map [attribute](../LangRef.md#attributes) is an
[affine-map](Affine.md#affine-maps) which specifies the transposition on the
slice to match the vector shape. The size of the slice is specified by the
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
vector<16x32x64xf32>, memref<?x?x?x?xf32>
}}}}
+
+ // write to a memref with vector element type.
+ vector.transfer_write %4, %arg1[%c3, %c3]
+ {permutation_map = (d0, d1)->(d0, d1)}
+ : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
```
}];
Note that this operation is used during the vector op unrolling
transformation and should be removed before lowering to lower-level
dialects.
-
+
Examples:
```
return success();
}
+static LogicalResult verifyTransferOp(Operation *op, MemRefType memrefType,
+ VectorType vectorType,
+ AffineMap permutationMap) {
+ auto memrefElementType = memrefType.getElementType();
+ if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+
+ // Check that 'memrefVectorElementType' and vector element types match.
+ if (memrefVectorElementType.getElementType() != vectorType.getElementType())
+ return op->emitOpError(
+ "requires memref and vector types of the same elemental type");
+
+ // Check that memref vector type is a suffix of 'vectorType.
+ unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+ unsigned resultVecRank = vectorType.getRank();
+ if (memrefVecEltRank > resultVecRank)
+ return op->emitOpError(
+ "requires memref vector element and vector result ranks to match.");
+ // TODO(b/146516564) Move this to isSuffix in VectorOps/Utils.h.
+ unsigned rankOffset = resultVecRank - memrefVecEltRank;
+ auto memrefVecEltShape = memrefVectorElementType.getShape();
+ auto resultVecShape = vectorType.getShape();
+ for (unsigned i = 0; i < memrefVecEltRank; ++i)
+ if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
+ return op->emitOpError(
+ "requires memref vector element shape to match suffix of "
+ "vector result shape.");
+ // Check that permutation map results match 'rankOffset' of vector type.
+ if (permutationMap.getNumResults() != rankOffset)
+ return op->emitOpError("requires a permutation_map with result dims of "
+ "the same rank as the vector type");
+ } else {
+ // Memref has scalar element type.
+
+ // Check that memref and vector element types match.
+ if (memrefType.getElementType() != vectorType.getElementType())
+ return op->emitOpError(
+ "requires memref and vector types of the same elemental type");
+
+ // Check that permutation map results match rank of vector type.
+ if (permutationMap.getNumResults() != vectorType.getRank())
+ return op->emitOpError("requires a permutation_map with result dims of "
+ "the same rank as the vector type");
+ }
+
+ if (permutationMap.getNumSymbols() != 0)
+ return op->emitOpError("requires permutation_map without symbols");
+ if (permutationMap.getNumInputs() != memrefType.getRank())
+ return op->emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ return success();
+}
+
static void print(OpAsmPrinter &p, TransferReadOp op) {
p << op.getOperationName() << " " << op.memref() << "[" << op.indices()
<< "], " << op.padding() << " ";
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return op.emitOpError(
- "requires memref and vector types of the same elemental type");
- auto elementalType = op.padding()->getType();
- if (!VectorType::isValidElementType(elementalType))
- return op.emitOpError("requires valid padding vector elemental type");
- if (elementalType != vectorType.getElementType())
- return op.emitOpError(
- "requires formal padding and vector of the same elemental type");
- if (llvm::size(op.indices()) != memrefType.getRank())
- return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+ auto paddingType = op.padding()->getType();
auto permutationMap = op.permutation_map();
- if (permutationMap.getNumSymbols() != 0)
- return op.emitOpError("requires permutation_map without symbols");
- if (permutationMap.getNumInputs() != memrefType.getRank())
- return op.emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
- if (permutationMap.getNumResults() != vectorType.getRank())
- return op.emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type");
+ auto memrefElementType = memrefType.getElementType();
+
+ if (static_cast<int64_t>(op.indices().size()) != memrefType.getRank())
+ return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+
+ if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ permutationMap)))
+ return failure();
+
+ if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+ // Check that 'memrefVectorElementType' and 'paddingType' types match.
+ if (memrefVectorElementType != paddingType)
+ return op.emitOpError(
+ "requires memref element type and padding type to match.");
+
+ } else {
+ // Check that 'paddingType' is valid to store in a vector type.
+ if (!VectorType::isValidElementType(paddingType))
+ return op.emitOpError("requires valid padding vector elemental type");
+
+ // Check that padding type and vector element types match.
+ if (paddingType != vectorType.getElementType())
+ return op.emitOpError(
+ "requires formal padding and vector of the same elemental type");
+ }
+
return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
// Consistency of elemental types in memref and vector.
MemRefType memrefType = op.getMemRefType();
VectorType vectorType = op.getVectorType();
- if (memrefType.getElementType() != vectorType.getElementType())
- return op.emitOpError(
- "requires memref and vector types of the same elemental type");
+ auto permutationMap = op.permutation_map();
+
if (llvm::size(op.indices()) != memrefType.getRank())
return op.emitOpError("requires ") << memrefType.getRank() << " indices";
- // Consistency of AffineMap attribute.
- auto permutationMap = op.permutation_map();
- if (permutationMap.getNumSymbols() != 0)
- return op.emitOpError("requires a symbol-less permutation_map");
- if (permutationMap.getNumInputs() != memrefType.getRank())
- return op.emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type: ")
- << permutationMap.getNumInputs() << " vs " << memrefType;
- if (permutationMap.getNumResults() != vectorType.getRank())
- return op.emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type.")
- << permutationMap.getNumResults() << " vs " << vectorType;
+ if (failed(verifyTransferOp(op.getOperation(), memrefType, vectorType,
+ permutationMap)))
+ return failure();
+
return verifyPermutationMap(permutationMap,
[&op](Twine t) { return op.emitOpError(t); });
}
// -----
+func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
+ %c3 = constant 3 : index
+ %f0 = constant 0.0 : f32
+ %vf0 = splat %f0 : vector<4x3xf32>
+ // expected-error@+1 {{requires memref and vector types of the same elemental type}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xi32>
+}
+
+// -----
+
+func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
+ %c3 = constant 3 : index
+ %f0 = constant 0.0 : f32
+ %vf0 = splat %f0 : vector<4x3xf32>
+ // expected-error@+1 {{requires memref vector element and vector result ranks to match}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
+}
+
+// -----
+
+func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
+ %c3 = constant 3 : index
+ %f0 = constant 0.0 : f32
+ %vf0 = splat %f0 : vector<4x3xf32>
+ // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x2x3xf32>
+}
+
+// -----
+
func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = constant 3 : index
%cst = constant dense<3.0> : vector<128 x f32>
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0, d1) -> (d0, d1)
+
// CHECK-LABEL: func @vector_transfer_ops(
-func @vector_transfer_ops(%arg0: memref<?x?xf32>) {
+func @vector_transfer_ops(%arg0: memref<?x?xf32>,
+ %arg1 : memref<?x?xvector<4x3xf32>>) {
+ // CHECK: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
%cst = constant 3.0 : f32
%f0 = constant 0.0 : f32
+ %vf0 = splat %f0 : vector<4x3xf32>
+
//
- // CHECK: %0 = vector.transfer_read
+ // CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32>
- // CHECK: %1 = vector.transfer_read
+ // CHECK: vector.transfer_read
%1 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = (d0, d1)->(d1, d0)} : memref<?x?xf32>, vector<3x7xf32>
// CHECK: vector.transfer_read
%2 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d0)} : memref<?x?xf32>, vector<128xf32>
// CHECK: vector.transfer_read
%3 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = (d0, d1)->(d1)} : memref<?x?xf32>, vector<128xf32>
- //
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {permutation_map = #[[MAP0]]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>
// CHECK: vector.transfer_write
vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = (d0, d1)->(d1, d0)} : vector<3x7xf32>, memref<?x?xf32>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] {permutation_map = #[[MAP0]]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+ vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+
return
}