OpFoldResult warpIdOfr = makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(kWarpSize), {linearId});
Value warpId = getValueOrCreateConstantIndexOp(rewriter, loc, warpIdOfr);
+ // Sizes in [x, y, z] -> [z, y x] order to properly compute strides in
+ // "row-major" order.
SmallVector<int64_t> reverseBasisSizes(
llvm::reverse(this->availableMappingSizes));
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
SmallVector<Value> ids;
- for (AffineExpr e : delinearizingExprs)
+ // Reverse back to be in [x, y, z] order.
+ for (AffineExpr e : llvm::reverse(delinearizingExprs))
ids.push_back(makeComposedAffineApply(rewriter, loc, e, warpId));
// clang-format off
// Build the linear thread id and decompose it in the basis of
// `forallMappingSizes`.
Value linearId = buildLinearThreadId(rewriter, loc, this->blockDimsOfr);
+ // Sizes in [x, y, z] -> [z, y x] order to properly compute strides in
+ // "row-major" order.
SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
AffineExpr d0;
bindDims(rewriter.getContext(), d0);
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
SmallVector<Value> ids;
- for (AffineExpr e : delinearizingExprs)
+ // Reverse back to be in [x, y, z] order.
+ for (AffineExpr e : llvm::reverse(delinearizingExprs))
ids.push_back(makeComposedAffineApply(rewriter, loc, e, linearId));
// clang-format off
!type = memref<2 x 32 x f32>
!type1d = memref<32 x f32>
-// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) floordiv 32) floordiv 4)>
-// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 12) floordiv 32) mod 4) floordiv 2)>
+// CHECK-DAG: #[[$MAPWX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) floordiv 32) mod 3)>
+// CHECK-DAG: #[[$MAPWY:.*]] = affine_map<(d0, d1) -> ((((d0 + d1 * 12) floordiv 32) mod 6) floordiv 3)>
// CHECK-DAG: #[[$MAPLIN:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 12)>
-// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 12) floordiv 20)>
-// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) mod 20) floordiv 10)>
+// CHECK-DAG: #[[$MAPLX:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 12) mod 10)>
+// CHECK-DAG: #[[$MAPLY:.*]] = affine_map<(d0, d1) -> (((d0 + d1 * 12) mod 20) floordiv 10)>
// CHECK-LABEL: func.func @map_multi_level(
func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
memref.store %6, %y[%i, %j] : !type
} { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
- // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
// CHECK-DAG: %[[WIDX:.*]] = affine.apply #[[$MAPWX]](%[[TIDX]], %[[TIDY]])
+ // CHECK-DAG: %[[WIDY:.*]] = affine.apply #[[$MAPWY]](%[[TIDX]], %[[TIDY]])
// CHECK-DAG: %[[CMPX:.*]] = arith.cmpi ult, %[[WIDX]], %[[C1]] : index
// CHECK-DAG: %[[CMPY:.*]] = arith.cmpi ult, %[[WIDY]], %[[C1]] : index
- // CHECK: %[[COND:.*]] = arith.andi %[[CMPY]], %[[CMPX]] : i1
+ // CHECK: %[[COND:.*]] = arith.andi %[[CMPX]], %[[CMPY]] : i1
// CHECK: scf.if %[[COND]]
scf.forall (%i) in (%c1) {
%7 = memref.load %t[%i] : !type1d
} {mapping = [#gpu.warp<x>] }
// CHECK-DAG: %[[LIN:.*]] = affine.apply #[[$MAPLIN]](%[[TIDX]], %[[TIDY]])
+ // CHECK-DAG: %[[LIDX:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
// CHECK-DAG: %[[LIDY:.*]] = affine.apply #[[$MAPLY]](%[[TIDX]], %[[TIDY]])
- // CHECK-DAG: %[[LIDZ:.*]] = affine.apply #[[$MAPLX]](%[[TIDX]], %[[TIDY]])
// CHECK-DAG: %[[COND:.*]] = arith.cmpi ult, %[[LIN]], %[[C20]] : index
// CHECK: scf.if %[[COND]]
+ // CHECK: memref.load %{{.*}}[%[[LIDX]]] : memref<32xf32>
+ // CHECK: memref.store %{{.*}}[%[[LIDY]]] : memref<32xf32>
scf.forall (%i, %j) in (%c10, %c2) {
%7 = memref.load %t[%i] : !type1d
%8 = arith.addf %alpha, %7 : f32
^bb1(%arg0: !pdl.operation):
%funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
transform.gpu.map_nested_forall_to_threads %funcop
- block_dims = [12, 11, 1] warp_dims = [2, 2, 1]
+ block_dims = [12, 11, 1] warp_dims = [3, 2, 1]
: (!pdl.operation) -> ()
}