def Vector_ExtractMapOp :
Vector_Op<"extract_map", [NoSideEffect]>,
- Arguments<(ins AnyVector:$vector, Index:$id)>,
+ Arguments<(ins AnyVector:$vector, Variadic<Index>:$ids)>,
Results<(outs AnyVector)> {
let summary = "vector extract map operation";
let description = [{
- Takes an 1-D vector and extracts a sub-part of the vector starting at id.
+ Takes an N-D vector and extracts a sub-part of the vector starting at id
+ along each dimension.
+
+ The dimension associated to each element of `ids` used to extract are
+ implicitly deduced from the the destination type. For each dimension the
+ multiplicity is the destination dimension size divided by the source
+ dimension size, each dimension with a multiplicity greater than 1 is
+ associated to the next id, following ids order.
+ For example if the source type is `vector<64x4x32xf32>` and the destination
+ type is `vector<4x4x2xf32>`, the first id maps to dimension 0 and the second
+ id to dimension 2.
Similarly to vector.tuple_get, this operation is used for progressive
lowering and should be folded away before converting to LLVM.
```mlir
%ev = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
+ %ev1 = vector.extract_map %v1[%id1, %id2] : vector<64x4x32xf32>
+ to vector<4x4x2xf32>
```
}];
let builders = [
- OpBuilderDAG<(ins "Value":$vector, "Value":$id, "int64_t":$multiplicity)>];
+ OpBuilderDAG<(ins "Value":$vector, "ValueRange":$ids,
+ "ArrayRef<int64_t>":$multiplicity,
+ "AffineMap":$map)>];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return vector().getType().cast<VectorType>();
VectorType getResultType() {
return getResult().getType().cast<VectorType>();
}
- int64_t multiplicity() {
- return getSourceVectorType().getNumElements() /
- getResultType().getNumElements();
- }
+ void getMultiplicity(SmallVectorImpl<int64_t> &multiplicity);
+ AffineMap map();
}];
let assemblyFormat = [{
- $vector `[` $id `]` attr-dict `:` type($vector) `to` type(results)
+ $vector `[` $ids `]` attr-dict `:` type($vector) `to` type(results)
}];
let hasFolder = 1;
def Vector_InsertMapOp :
Vector_Op<"insert_map", [NoSideEffect, AllTypesMatch<["dest", "result"]>]>,
- Arguments<(ins AnyVector:$vector, AnyVector:$dest, Index:$id)>,
+ Arguments<(ins AnyVector:$vector, AnyVector:$dest, Variadic<Index>:$ids)>,
Results<(outs AnyVector:$result)> {
let summary = "vector insert map operation";
let description = [{
- Inserts a 1-D vector and within a larger vector starting at id. The new
+ Inserts a N-D vector and within a larger vector starting at id. The new
vector created will have the same size as the destination operand vector.
+ The dimension associated to each element of `ids` used to insert is
+ implicitly deduced from the source type (see `ExtractMapOp` for details).
+ For example if source type is `vector<4x4x2xf32>` and the destination type
+ is `vector<64x4x32xf32>`, the first id maps to dimension 0 and the second id
+ to dimension 2.
+
Similarly to vector.tuple_get, this operation is used for progressive
lowering and should be folded away before converting to LLVM.
```mlir
%v = vector.insert_map %ev %v[%id] : vector<1xf32> into vector<32xf32>
+ %v1 = vector.insert_map %ev1, %v1[%arg0, %arg1] : vector<2x4x1xf32>
+ into vector<64x4x32xf32>
```
}];
let builders = [OpBuilderDAG<(ins "Value":$vector, "Value":$dest,
- "Value":$id, "int64_t":$multiplicity)>];
+ "ValueRange":$ids)>];
let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return vector().getType().cast<VectorType>();
VectorType getResultType() {
return getResult().getType().cast<VectorType>();
}
- int64_t multiplicity() {
- return getResultType().getNumElements() /
- getSourceVectorType().getNumElements();
- }
+ // Return a map indicating the dimension mapping to the given ids.
+ AffineMap map();
}];
let assemblyFormat = [{
- $vector `,` $dest `[` $id `]` attr-dict
+ $vector `,` $dest `[` $ids `]` attr-dict
`:` type($vector) `into` type($result)
}];
}
InsertMapOp insert;
};
-/// Distribute a 1D vector pointwise operation over a range of given IDs taking
+/// Distribute a N-D vector pointwise operation over a range of given ids taking
/// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or
/// SPMD id). This transformation only inserts
/// vector.extract_map/vector.insert_map. It is meant to be used with
/// %v = addf %a, %b : vector<32xf32>
/// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
/// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32>
-Optional<DistributeOps> distributPointwiseVectorOp(OpBuilder &builder,
- Operation *op, Value id,
- int64_t multiplicity);
+Optional<DistributeOps>
+distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
+ ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
+ const AffineMap &map);
/// Canonicalize an extra element using the result of a pointwise operation.
/// Transforms:
/// %v = addf %a, %b : vector32xf32>
//===----------------------------------------------------------------------===//
void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value id, int64_t multiplicity) {
+ Value vector, ValueRange ids,
+ ArrayRef<int64_t> multiplicity,
+ AffineMap permutationMap) {
+ assert(ids.size() == multiplicity.size() &&
+ ids.size() == permutationMap.getNumResults());
+ assert(permutationMap.isProjectedPermutation());
VectorType type = vector.getType().cast<VectorType>();
- VectorType resultType = VectorType::get(type.getNumElements() / multiplicity,
- type.getElementType());
- ExtractMapOp::build(builder, result, resultType, vector, id);
+ SmallVector<int64_t, 4> newShape(type.getShape().begin(),
+ type.getShape().end());
+ for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
+ AffineExpr expr = permutationMap.getResult(i);
+ auto dim = expr.cast<AffineDimExpr>();
+ newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
+ }
+ VectorType resultType = VectorType::get(newShape, type.getElementType());
+ ExtractMapOp::build(builder, result, resultType, vector, ids);
}
static LogicalResult verify(ExtractMapOp op) {
- if (op.getSourceVectorType().getShape().size() != 1 ||
- op.getResultType().getShape().size() != 1)
- return op.emitOpError("expects source and destination vectors of rank 1");
- if (op.getSourceVectorType().getNumElements() %
- op.getResultType().getNumElements() !=
- 0)
+ if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
return op.emitOpError(
- "source vector size must be a multiple of destination vector size");
+ "expected source and destination vectors of same rank");
+ unsigned numId = 0;
+ for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) {
+ if (op.getSourceVectorType().getDimSize(i) %
+ op.getResultType().getDimSize(i) !=
+ 0)
+ return op.emitOpError("source vector dimensions must be a multiple of "
+ "destination vector dimensions");
+ if (op.getSourceVectorType().getDimSize(i) !=
+ op.getResultType().getDimSize(i))
+ numId++;
+ }
+ if (numId != op.ids().size())
+ return op.emitOpError("expected number of ids must match the number of "
+ "dimensions distributed");
return success();
}
OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
auto insert = vector().getDefiningOp<vector::InsertMapOp>();
- if (insert == nullptr || multiplicity() != insert.multiplicity() ||
- id() != insert.id())
+ if (insert == nullptr || getType() != insert.vector().getType() ||
+ ids() != insert.ids())
return {};
return insert.vector();
}
+void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
+ assert(multiplicity.empty());
+ for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
+ if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
+ multiplicity.push_back(getSourceVectorType().getDimSize(i) /
+ getResultType().getDimSize(i));
+ }
+}
+
+template <typename MapOp>
+AffineMap calculateImplicitMap(MapOp op) {
+ SmallVector<AffineExpr, 4> perm;
+ // Check which dimension have a multiplicity greater than 1 and associated
+ // them to the IDs in order.
+ for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
+ if (op.getSourceVectorType().getDimSize(i) !=
+ op.getResultType().getDimSize(i))
+ perm.push_back(getAffineDimExpr(i, op.getContext()));
+ }
+ auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
+ op.getContext());
+ return map;
+}
+
+AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
void InsertMapOp::build(OpBuilder &builder, OperationState &result,
- Value vector, Value dest, Value id,
- int64_t multiplicity) {
- VectorType type = vector.getType().cast<VectorType>();
- VectorType resultType = VectorType::get(type.getNumElements() * multiplicity,
- type.getElementType());
- InsertMapOp::build(builder, result, resultType, vector, dest, id);
+ Value vector, Value dest, ValueRange ids) {
+ InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids);
}
static LogicalResult verify(InsertMapOp op) {
- if (op.getSourceVectorType().getShape().size() != 1 ||
- op.getResultType().getShape().size() != 1)
- return op.emitOpError("expected source and destination vectors of rank 1");
- if (op.getResultType().getNumElements() %
- op.getSourceVectorType().getNumElements() !=
- 0)
+ if (op.getSourceVectorType().getRank() != op.getResultType().getRank())
return op.emitOpError(
- "destination vector size must be a multiple of source vector size");
+ "expected source and destination vectors of same rank");
+ unsigned numId = 0;
+ for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) {
+ if (op.getResultType().getDimSize(i) %
+ op.getSourceVectorType().getDimSize(i) !=
+ 0)
+ return op.emitOpError(
+ "destination vector size must be a multiple of source vector size");
+ if (op.getResultType().getDimSize(i) !=
+ op.getSourceVectorType().getDimSize(i))
+ numId++;
+ }
+ if (numId != op.ids().size())
+ return op.emitOpError("expected number of ids must match the number of "
+ "dimensions distributed");
return success();
}
+AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
+
//===----------------------------------------------------------------------===//
// InsertStridedSliceOp
//===----------------------------------------------------------------------===//
SmallVector<Value, 4> extractOperands;
for (OpOperand &operand : definedOp->getOpOperands())
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
- loc, operand.get(), extract.id(), extract.multiplicity()));
+ loc, extract.getResultType(), operand.get(), extract.ids()));
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
rewriter.replaceOp(extract, newOp->getResult(0));
return success();
}
-Optional<mlir::vector::DistributeOps>
-mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
- Value id, int64_t multiplicity) {
+Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
+ OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
+ ArrayRef<int64_t> multiplicity, const AffineMap &map) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(op);
Location loc = op->getLoc();
return {};
Value result = op->getResult(0);
VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
- // Currently only support distributing 1-D vectors of size multiple of the
- // given multiplicty. To handle more sizes we would need to support masking.
- if (!type || type.getRank() != 1 || type.getNumElements() % multiplicity != 0)
+ if (!type || map.getNumResults() != multiplicity.size())
return {};
+ // For each dimension being distributed check that the size is a multiple of
+ // the multiplicity. To handle more sizes we would need to support masking.
+ unsigned multiplictyCount = 0;
+ for (auto exp : map.getResults()) {
+ auto affinExp = exp.dyn_cast<AffineDimExpr>();
+ if (!affinExp || affinExp.getPosition() >= type.getRank() ||
+ type.getDimSize(affinExp.getPosition()) %
+ multiplicity[multiplictyCount++] !=
+ 0)
+ return {};
+ }
DistributeOps ops;
ops.extract =
- builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);
- ops.insert = builder.create<vector::InsertMapOp>(loc, ops.extract, result, id,
- multiplicity);
+ builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
+ ops.insert =
+ builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
return ops;
}
using mlir::edsc::op::operator*;
using namespace mlir::edsc::intrinsics;
SmallVector<Value, 4> indices(read.indices().begin(), read.indices().end());
- indices.back() =
- indices.back() +
- (extract.id() *
- std_constant_index(extract.getResultType().getDimSize(0)));
+ AffineMap map = extract.map();
+ unsigned idCount = 0;
+ for (auto expr : map.getResults()) {
+ unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ indices[pos] =
+ indices[pos] +
+ extract.ids()[idCount++] *
+ std_constant_index(extract.getResultType().getDimSize(pos));
+ }
Value newRead = vector_transfer_read(extract.getType(), read.memref(),
indices, read.permutation_map(),
- read.padding(), ArrayAttr());
+ read.padding(), read.maskedAttr());
Value dest = rewriter.create<ConstantOp>(
read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType()));
- newRead = rewriter.create<vector::InsertMapOp>(
- read.getLoc(), newRead, dest, extract.id(), extract.multiplicity());
+ newRead = rewriter.create<vector::InsertMapOp>(read.getLoc(), newRead, dest,
+ extract.ids());
rewriter.replaceOp(read, newRead);
return success();
}
using namespace mlir::edsc::intrinsics;
SmallVector<Value, 4> indices(write.indices().begin(),
write.indices().end());
- indices.back() =
- indices.back() +
- (insert.id() *
- std_constant_index(insert.getSourceVectorType().getDimSize(0)));
+ AffineMap map = insert.map();
+ unsigned idCount = 0;
+ for (auto expr : map.getResults()) {
+ unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+ indices[pos] =
+ indices[pos] +
+ insert.ids()[idCount++] *
+ std_constant_index(insert.getSourceVectorType().getDimSize(pos));
+ }
vector_transfer_write(insert.vector(), write.memref(), indices,
- write.permutation_map(), ArrayAttr());
+ write.permutation_map(), write.maskedAttr());
rewriter.eraseOp(write);
return success();
}
// -----
-func @extract_map_rank(%v: vector<2x32xf32>, %id : index) {
- // expected-error@+1 {{'vector.extract_map' op expects source and destination vectors of rank 1}}
- %0 = vector.extract_map %v[%id] : vector<2x32xf32> to vector<2x1xf32>
+func @extract_map_rank(%v: vector<32xf32>, %id : index) {
+ // expected-error@+1 {{'vector.extract_map' op expected source and destination vectors of same rank}}
+ %0 = vector.extract_map %v[%id] : vector<32xf32> to vector<2x1xf32>
}
// -----
func @extract_map_size(%v: vector<63xf32>, %id : index) {
- // expected-error@+1 {{'vector.extract_map' op source vector size must be a multiple of destination vector size}}
+ // expected-error@+1 {{'vector.extract_map' op source vector dimensions must be a multiple of destination vector dimensions}}
%0 = vector.extract_map %v[%id] : vector<63xf32> to vector<2xf32>
}
// -----
-func @insert_map_rank(%v: vector<2x1xf32>, %v1: vector<2x32xf32>, %id : index) {
- // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of rank 1}}
- %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<2x32xf32>
+func @extract_map_id(%v: vector<2x32xf32>, %id : index) {
+ // expected-error@+1 {{'vector.extract_map' op expected number of ids must match the number of dimensions distributed}}
+ %0 = vector.extract_map %v[%id] : vector<2x32xf32> to vector<1x1xf32>
+}
+
+// -----
+
+func @insert_map_rank(%v: vector<2x1xf32>, %v1: vector<32xf32>, %id : index) {
+ // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of same rank}}
+ %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<32xf32>
}
// -----
// expected-error@+1 {{'vector.insert_map' op destination vector size must be a multiple of source vector size}}
%0 = vector.insert_map %v, %v1[%id] : vector<3xf32> into vector<64xf32>
}
+
+// -----
+
+func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) {
+ // expected-error@+1 {{'vector.insert_map' op expected number of ids must match the number of dimensions distributed}}
+ %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32>
+}
}
// CHECK-LABEL: @extract_insert_map
-func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> {
+func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>,
+ %id0 : index, %id1 : index) -> (vector<32xf32>, vector<16x32xf32>) {
// CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}}] : vector<32xf32> to vector<2xf32>
- %vd = vector.extract_map %v[%id] : vector<32xf32> to vector<2xf32>
+ %vd = vector.extract_map %v[%id0] : vector<32xf32> to vector<2xf32>
+ // CHECK: %[[V1:.*]] = vector.extract_map %{{.*}}[%{{.*}}, %{{.*}}] : vector<16x32xf32> to vector<4x2xf32>
+ %vd2 = vector.extract_map %v2[%id0, %id1] : vector<16x32xf32> to vector<4x2xf32>
// CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}[%{{.*}}] : vector<2xf32> into vector<32xf32>
- %r = vector.insert_map %vd, %v[%id] : vector<2xf32> into vector<32xf32>
- // CHECK: return %[[R]] : vector<32xf32>
- return %r : vector<32xf32>
+ %r = vector.insert_map %vd, %v[%id0] : vector<2xf32> into vector<32xf32>
+ // CHECK: %[[R1:.*]] = vector.insert_map %[[V1]], %{{.*}}[%{{.*}}, %{{.*}}] : vector<4x2xf32> into vector<16x32xf32>
+ %r2 = vector.insert_map %vd2, %v2[%id0, %id1] : vector<4x2xf32> into vector<16x32xf32>
+ // CHECK: return %[[R]], %[[R1]] : vector<32xf32>, vector<16x32xf32>
+ return %r, %r2 : vector<32xf32>, vector<16x32xf32>
}
-// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
// CHECK-LABEL: func @distribute_vector_add
// CHECK-SAME: (%[[ID:.*]]: index
// CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
// CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
-// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] : vector<1xf32>, memref<32xf32>
+// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] {{.*}} : vector<1xf32>, memref<32xf32>
// CHECK-NEXT: return
func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
%c0 = constant 0 : index
// -----
-// CHECK-DAG: #[[MAP0:map[0-9]*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
// CHECK: func @vector_add_cycle
// CHECK-SAME: (%[[ID:.*]]: index
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]]], %{{.*}} : memref<64xf32>, vector<2xf32>
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32>
// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID]]]
-// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] : vector<2xf32>, memref<64xf32>
+// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] {{.*}} : vector<2xf32>, memref<64xf32>
// CHECK-NEXT: return
func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
%c0 = constant 0 : index
return
}
+// -----
+
+// CHECK-LABEL: func @distribute_vector_add_3d
+// CHECK-SAME: (%[[ID0:.*]]: index, %[[ID1:.*]]: index
+// CHECK-NEXT: %[[ADDV:.*]] = addf %{{.*}}, %{{.*}} : vector<64x4x32xf32>
+// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32>
+// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32>
+// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
+// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID0]], %[[ID1]]] : vector<2x4x1xf32> into vector<64x4x32xf32>
+// CHECK-NEXT: return %[[INS]] : vector<64x4x32xf32>
+func @distribute_vector_add_3d(%id0 : index, %id1 : index,
+ %A: vector<64x4x32xf32>, %B: vector<64x4x32xf32>) -> vector<64x4x32xf32> {
+ %0 = addf %A, %B : vector<64x4x32xf32>
+ return %0: vector<64x4x32xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK: func @vector_add_transfer_3d
+// CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32>
+// CHECK-NEXT: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32>
+// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32>
+// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]]
+// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]], %[[C0]], %[[ID_1]]] {{.*}} : vector<2x4x1xf32>, memref<64x64x64xf32>
+// CHECK-NEXT: return
+func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32>,
+ %B: memref<64x64x64xf32>, %C: memref<64x64x64xf32>) {
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32>
+ %b = vector.transfer_read %B[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32>
+ %acc = addf %a, %b: vector<64x4x32xf32>
+ vector.transfer_write %acc, %C[%c0, %c0, %c0]: vector<64x4x32xf32>, memref<64x64x64xf32>
+ return
+}
+
registry.insert<VectorDialect>();
registry.insert<AffineDialect>();
}
- Option<int32_t> multiplicity{
- *this, "distribution-multiplicity",
- llvm::cl::desc("Set the multiplicity used for distributing vector"),
- llvm::cl::init(32)};
+ ListOption<int32_t> multiplicity{
+ *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::desc("Set the multiplicity used for distributing vector")};
+
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
OpBuilder builder(op);
- Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
- builder, op.getOperation(), func.getArgument(0), multiplicity);
- if (ops.hasValue()) {
- SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
- op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
+ if (auto vecType = op.getType().dyn_cast<VectorType>()) {
+ SmallVector<int64_t, 2> mul;
+ SmallVector<AffineExpr, 2> perm;
+ SmallVector<Value, 2> ids;
+ unsigned count = 0;
+ // Remove the multiplicity of 1 and calculate the affine map based on
+ // the multiplicity.
+ SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
+ for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
+ if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
+ mul.push_back(m[i]);
+ ids.push_back(func.getArgument(count++));
+ perm.push_back(getAffineDimExpr(i, ctx));
+ }
+ }
+ auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
+ perm, ctx);
+ Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
+ builder, op.getOperation(), ids, mul, map);
+ if (ops.hasValue()) {
+ SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
+ op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
+ extractOp);
+ }
}
});
patterns.insert<PointwiseExtractPattern>(ctx);
for (Operation *it : dependentOps) {
it->moveBefore(forOp.getBody()->getTerminator());
}
+ auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
// break up the original op and let the patterns propagate.
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
- builder, op.getOperation(), forOp.getInductionVar(), multiplicity);
+ builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
+ map);
if (ops.hasValue()) {
SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);