Type i64Type = rewriter.getIntegerType(64);
MemRefType memRefType = xferOp.getMemRefType();
+ if (auto memrefVectorElementType =
+ memRefType.getElementType().dyn_cast<VectorType>()) {
+ // Memref has vector element type.
+ if (memrefVectorElementType.getElementType() !=
+ xferOp.getVectorType().getElementType())
+ return failure();
+ // Check that memref vector type is a suffix of 'vectorType.
+ unsigned memrefVecEltRank = memrefVectorElementType.getRank();
+ unsigned resultVecRank = xferOp.getVectorType().getRank();
+ assert(memrefVecEltRank <= resultVecRank);
+ // TODO: Move this to isSuffix in Vector/Utils.h.
+ unsigned rankOffset = resultVecRank - memrefVecEltRank;
+ auto memrefVecEltShape = memrefVectorElementType.getShape();
+ auto resultVecShape = xferOp.getVectorType().getShape();
+ for (unsigned i = 0; i < memrefVecEltRank; ++i)
+ assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] &&
+ "memref vector element shape should match suffix of vector "
+ "result shape.");
+ }
+
// 1. Get the source/dst address as an LLVM vector pointer.
// The vector pointer would always be on address space 0, therefore
// addrspacecast shall be used when source/dst memrefs are not on
if (auto memrefVectorElementType = memrefElementType.dyn_cast<VectorType>()) {
// Memref has vector element type.
- // Check that 'memrefVectorElementType' and vector element types match.
- if (memrefVectorElementType.getElementType() != vectorType.getElementType())
+ unsigned memrefVecSize = memrefVectorElementType.getElementTypeBitWidth() *
+ memrefVectorElementType.getShape().back();
+ unsigned resultVecSize =
+ vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
+ if (resultVecSize % memrefVecSize != 0)
return op->emitOpError(
- "requires memref and vector types of the same elemental type");
+ "requires the bitwidth of the minor 1-D vector to be an integral "
+ "multiple of the bitwidth of the minor 1-D vector of the memref");
- // Check that memref vector type is a suffix of 'vectorType.
unsigned memrefVecEltRank = memrefVectorElementType.getRank();
unsigned resultVecRank = vectorType.getRank();
if (memrefVecEltRank > resultVecRank)
return op->emitOpError(
"requires memref vector element and vector result ranks to match.");
- // TODO: Move this to isSuffix in Vector/Utils.h.
unsigned rankOffset = resultVecRank - memrefVecEltRank;
- auto memrefVecEltShape = memrefVectorElementType.getShape();
- auto resultVecShape = vectorType.getShape();
- for (unsigned i = 0; i < memrefVecEltRank; ++i)
- if (memrefVecEltShape[i] != resultVecShape[rankOffset + i])
- return op->emitOpError(
- "requires memref vector element shape to match suffix of "
- "vector result shape.");
// Check that permutation map results match 'rankOffset' of vector type.
if (permutationMap.getNumResults() != rankOffset)
return op->emitOpError("requires a permutation_map with result dims of "
"the same rank as the vector type");
} else {
// Memref has scalar element type.
-
- // Check that memref and vector element types match.
- if (memrefType.getElementType() != vectorType.getElementType())
+ unsigned resultVecSize =
+ vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
+ if (resultVecSize % memrefElementType.getIntOrFloatBitWidth() != 0)
return op->emitOpError(
- "requires memref and vector types of the same elemental type");
+ "requires the bitwidth of the minor 1-D vector to be an integral "
+ "multiple of the bitwidth of the memref element type");
// Check that permutation map results match rank of vector type.
if (permutationMap.getNumResults() != vectorType.getRank())
VectorType vector, Value memref, ValueRange indices,
AffineMap permutationMap,
ArrayRef<bool> maybeMasked) {
- Type elemType = vector.cast<VectorType>().getElementType();
+ Type elemType = memref.getType().cast<MemRefType>().getElementType();
Value padding = builder.create<ConstantOp>(result.location, elemType,
builder.getZeroAttr(elemType));
if (maybeMasked.empty())
return op.emitOpError("requires valid padding vector elemental type");
// Check that padding type and vector element types match.
- if (paddingType != vectorType.getElementType())
+ if (paddingType != memrefElementType)
return op.emitOpError(
- "requires formal padding and vector of the same elemental type");
+ "requires formal padding and memref of the same elemental type");
}
return verifyPermutationMap(permutationMap,
// 2. Rewrite as a load.
// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vec<17 x float>>
+func @transfer_read_1d_cast(%A : memref<?xi32>, %base: index) -> vector<12xi8> {
+ %c0 = constant 0: i32
+ %v = vector.transfer_read %A[%base], %c0 {masked = [false]} :
+ memref<?xi32>, vector<12xi8>
+ return %v: vector<12xi8>
+}
+// CHECK-LABEL: func @transfer_read_1d_cast
+// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm.vec<12 x i8>
+//
+// 1. Bitcast to vector form.
+// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+// CHECK-SAME: (!llvm.ptr<i32>, !llvm.i64) -> !llvm.ptr<i32>
+// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
+// CHECK-SAME: !llvm.ptr<i32> to !llvm.ptr<vec<12 x i8>>
+//
+// 2. Rewrite as a load.
+// CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm.ptr<vec<12 x i8>>
+
func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1>
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
%vf0 = splat %f0 : vector<4x3xf32>
- // expected-error@+1 {{requires memref and vector types of the same elemental type}}
- %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xi32>
-}
-
-// -----
-
-func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
- %c3 = constant 3 : index
- %f0 = constant 0.0 : f32
- %vf0 = splat %f0 : vector<4x3xf32>
// expected-error@+1 {{requires memref vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}
// -----
-func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
+func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = constant 3 : index
%f0 = constant 0.0 : f32
- %vf0 = splat %f0 : vector<4x3xf32>
- // expected-error@+1 {{ requires memref vector element shape to match suffix of vector result shape}}
- %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x2x3xf32>
+ %vf0 = splat %f0 : vector<6xf32>
+ // expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref}}
+ %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
// -----
// CHECK-LABEL: func @vector_transfer_ops(
func @vector_transfer_ops(%arg0: memref<?x?xf32>,
- %arg1 : memref<?x?xvector<4x3xf32>>) {
+ %arg1 : memref<?x?xvector<4x3xf32>>,
+ %arg2 : memref<?x?xvector<4x3xi32>>) {
// CHECK: %[[C3:.*]] = constant 3 : index
%c3 = constant 3 : index
%cst = constant 3.0 : f32
%f0 = constant 0.0 : f32
+ %c0 = constant 0 : i32
%vf0 = splat %f0 : vector<4x3xf32>
+ %v0 = splat %c0 : vector<4x3xi32>
//
// CHECK: vector.transfer_read
%4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
// CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
%5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {masked = [true, false]} : memref<?x?xvector<4x3xf32>>, vector<1x1x4x3xf32>
+ // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
+ %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref<?x?xvector<4x3xi32>>, vector<5x24xi8>
+
// CHECK: vector.transfer_write
vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref<?x?xf32>
vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
vector.transfer_write %5, %arg1[%c3, %c3] {masked = [true, true]} : vector<1x1x4x3xf32>, memref<?x?xvector<4x3xf32>>
+ // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
+ vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref<?x?xvector<4x3xi32>>
return
}