AffineMap map = op.permutation_map();
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
return failure();
-
AffineMap permutationMap =
map.getPermutationMap(permutation, op.getContext());
if (permutationMap.isIdentity())
return failure();
- if (op.mask())
- return failure();
+
// Caluclate the map of the new read by applying the inverse permutation.
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
for (auto pos : llvm::enumerate(permutation)) {
newVectorShape[pos.value()] = originalShape[pos.index()];
}
+
+ Value newMask;
+ if (op.mask()) {
+ // Remove unused dims from the permutation map. E.g.:
+ // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
+ // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+ auto comp = compressUnusedDims(map);
+ // Get positions of remaining result dims.
+ // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
+ // maskTransposeIndices = [ 2, 1, 0]
+ SmallVector<int64_t> maskTransposeIndices;
+ for (unsigned i = 0; i < comp.getNumResults(); ++i) {
+ if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
+ maskTransposeIndices.push_back(expr.getPosition());
+ }
+
+ newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
+ maskTransposeIndices);
+ }
+
VectorType newReadType =
VectorType::get(newVectorShape, op.getVectorType().getElementType());
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.source(), op.indices(), newMap,
- op.padding(), op.in_bounds() ? *op.in_bounds() : ArrayAttr());
+ op.padding(), newMask, op.in_bounds() ? *op.in_bounds() : ArrayAttr());
+
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
transposePerm);
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
- if (op.mask())
- return failure();
AffineMap map = op.permutation_map();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
: ArrayAttr();
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.source(), op.indices(), newMap,
- op.padding(), newInBounds);
+ op.padding(), op.mask(), newInBounds);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
+ %m = constant 1 : i1
- %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
-// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
+ %mask0 = splat %m : vector<7x14xi1>
+ %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
+// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32>
// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32>
- %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
-// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
+ %mask1 = splat %m : vector<14x16xi1>
+ %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1>
+// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
- %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, false, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
-// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
+ %mask2 = splat %m : vector<7x14xi1>
+ %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, true, false, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32>
+// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1>
+// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32>
// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32>
// CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>