[mlir][LinAlg][Transform][GPU] Add GPU memory hierarchy to the transform.promote op
authorAmir Mohammad Tavakkoli <tavakkoli.amirmohammad@gmail.com>
Mon, 27 Feb 2023 15:28:54 +0000 (16:28 +0100)
committerAlex Zinenko <zinenko@google.com>
Mon, 27 Feb 2023 15:33:58 +0000 (16:33 +0100)
In this patch we are adding the support of copying a a `memref.subview` to the shared or private memory in GPU. The global to shared memory copy is adopted from codes implemented in IREE (https://github.com/iree-org/iree), but the private memory copy part has not been implemented in IREE. This patch enables transferring a subview from `global->shared`, `global->private`, and `shared->private`.

Our final aim is to provide a copy layout as an affine map to the `transform.promote` op to support transpose memory copy. This map is a permutation of the original affine index map. Although this has been implemented and user can copy data to arbitrary layout , this attempt is not included in this patch since we have still problem with `linalg.generic` operations to change their index map to the transformed index map. You can find more in following links ([[ https://github.com/tavakkoliamirmohammad/iree-llvm-fork/commit/4fd5f93355951ad0fb338858393ff409bd9c62f8 | Initial attempt to support layout map in promote op in transform dialect ]]) ([[ https://github.com/tavakkoliamirmohammad/iree-llvm-fork/commit/9062b5849f91d4defb84996392b71087dadf7a8c | Fix data transpose in shared memory ]])

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D144666

mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/test/Dialect/Linalg/promote.mlir

index c03bdb6..3b261ac 100644 (file)
@@ -85,4 +85,23 @@ def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [
   }];
 }
 
+
+def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
+  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
+  let parameters = (ins
+    EnumParameter<GPU_AddressSpaceEnum>:$address_space
+  );
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    An attribute that allows defining memory hierarchy for GPU devices.
+
+    GPU Memory has three memory space, global, workgroup, and private. The global memory
+    is visible to all workitems and workgroups, the workgroup memory is only available for workitems
+    within a workgroup, and private memory is only visible to a single workitem. This attribute indicates
+    that using memory hiearchy is desired. It can be consumed by lowering to
+    move data to a specific address space in GPU code.
+  }];
+}
+
+
 #endif // GPU_DEVICE_MAPPING_ATTR
index c534978..41c5daf 100644 (file)
@@ -765,6 +765,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
 // PromoteOp
 //===----------------------------------------------------------------------===//
 
+
 def PromoteOp : Op<Transform_Dialect, "structured.promote",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
     TransformOpInterface, TransformEachOpTrait]> {
@@ -791,6 +792,7 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
                        DefaultValuedAttr<BoolArrayAttr, "{}">:$use_full_tile_buffers,
                        UnitAttr:$use_full_tiles_by_default,
                        UnitAttr:$use_alloca,
+                       OptionalAttr<DeviceMappingArrayAttr>:$mapping,
                        OptionalAttr<I64Attr>:$alignment);
   let results = (outs PDL_Operation:$transformed);
 
index 8e77b54..ea645e8 100644 (file)
@@ -393,6 +393,32 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
 FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
                                     const LinalgPromotionOptions &options);
 
+/// Allocate the subview in the GPU workgroup memory.
+Optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
+                                        memref::SubViewOp subview,
+                                        ArrayRef<Value> sizeBounds,
+                                        DataLayout &);
+
+/// In case of GPU group memory there is no need to deallocate.
+LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/);
+
+/// Create Memref copy operations and add gpu barrier guards before and after
+/// the copy operation to ensure data integrity.
+LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst);
+
+/// Allocate the subview in the GPU private memory.
+Optional<Value> allocateGPUPrivateMemory(OpBuilder &builder,
+                                         memref::SubViewOp subview,
+                                         ArrayRef<Value> sizeBounds,
+                                         DataLayout &);
+
+/// Normal copy to between src and dst.
+LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst);
+
+/// In case of GPU private memory there is no need to deallocate since the
+/// memory is freed when going outside of the scope.
+LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
+
 /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
 /// are used to vectorize this operation. `inputVectorSizes` must match the rank
 /// of the iteration space of the operation and the sizes must be smaller or
index 159892e..0ec5877 100644 (file)
@@ -50,6 +50,10 @@ int64_t GPUThreadMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getThread());
 }
 
+int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getAddressSpace());
+}
+
 //===----------------------------------------------------------------------===//
 // MMAMatrixType
 //===----------------------------------------------------------------------===//
index 2d10238..5a8f981 100644 (file)
@@ -1802,6 +1802,35 @@ transform::PromoteOp::applyToOne(LinalgOp target,
   if (getAlignment().has_value())
     promotionOptions = promotionOptions.setAlignment(*getAlignment());
 
+  if (getMapping().has_value()) {
+    // The mapping should only contain an element
+    auto mapping = *getMapping();
+    if (mapping.size() > 1)
+      return emitDefaultDefiniteFailure(target);
+
+    auto addressSpace = mapping[0].cast<gpu::GPUMemorySpaceMappingAttr>();
+
+    if (addressSpace.getAddressSpace() ==
+        gpu::GPUDialect::getWorkgroupAddressSpace()) {
+      promotionOptions =
+          promotionOptions
+              .setAllocationDeallocationFns(allocateWorkgroupMemory,
+                                            deallocateWorkgroupMemory)
+              .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
+              .setUseFullTileBuffers({false, false});
+    } else if (addressSpace.getAddressSpace() ==
+               gpu::GPUDialect::getPrivateAddressSpace()) {
+      promotionOptions =
+          promotionOptions
+              .setAllocationDeallocationFns(allocateGPUPrivateMemory,
+                                            deallocateGPUPrivateMemory)
+              .setCopyInOutFns(copyToGPUPrivateMemory, copyToGPUPrivateMemory)
+              .setUseFullTileBuffers({false, false});
+    } else {
+      return emitDefaultDefiniteFailure(target);
+    }
+  }
+
   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
     return emitDefaultDefiniteFailure(target);
 
index fd41ed3..abc5e00 100644 (file)
@@ -13,6 +13,8 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -397,3 +399,87 @@ mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp,
     return failure();
   return res;
 }
+
+/// Allocate the given subview to a memory address space in GPU by creating a
+/// allocation operation and setting the memref type address space to desired
+/// address space.
+static Optional<Value> allocateSubviewGPUMemoryInAddressSpace(
+    OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
+    gpu::AddressSpace addressSpace) {
+  OpBuilder::InsertionGuard guard(builder);
+
+  func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
+  if (!funcOp)
+    return std::nullopt;
+
+  // The subview size bounds are expected to be constant; they specify the shape
+  // of the allocation.
+  SmallVector<int64_t> shape;
+  for (Value bound : sizeBounds) {
+    APInt value;
+    if (!matchPattern(bound, m_ConstantInt(&value)))
+      return std::nullopt;
+    shape.push_back(value.getSExtValue());
+  }
+
+  builder.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
+  auto type = MemRefType::get(
+      shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{},
+      gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace));
+  Value buffer;
+  if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) {
+    buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type);
+  } else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) {
+    buffer = builder.create<memref::AllocaOp>(funcOp.getLoc(), type);
+  } else {
+    return std::nullopt;
+  }
+  return buffer;
+}
+
+/// Allocate the subview in the GPU workgroup memory.
+Optional<Value> mlir::linalg::allocateWorkgroupMemory(
+    OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
+    DataLayout &) {
+  return allocateSubviewGPUMemoryInAddressSpace(
+      builder, subview, sizeBounds,
+      gpu::GPUDialect::getWorkgroupAddressSpace());
+}
+
+/// In case of GPU group memory there is no need to deallocate.
+LogicalResult mlir::linalg::deallocateWorkgroupMemory(OpBuilder &,
+                                                      Value /*buffer*/) {
+  return success();
+}
+
+/// Create Memref copy operations and add gpu barrier guards before and after
+/// the copy operation to ensure data integrity.
+LogicalResult mlir::linalg::copyToWorkgroupMemory(OpBuilder &b, Value src,
+                                                  Value dst) {
+  b.create<gpu::BarrierOp>(src.getLoc());
+  Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst);
+  b.create<gpu::BarrierOp>(copyOp->getLoc());
+  return success();
+}
+
+/// Allocate the subview in the GPU private memory.
+Optional<Value> mlir::linalg::allocateGPUPrivateMemory(
+    OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
+    DataLayout &) {
+  return allocateSubviewGPUMemoryInAddressSpace(
+      builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace());
+}
+
+/// Normal copy to between src and dst.
+LogicalResult mlir::linalg::copyToGPUPrivateMemory(OpBuilder &b, Value src,
+                                                   Value dst) {
+  Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst);
+  return success();
+}
+
+/// In case of GPU private memory there is no need to deallocate since the
+/// memory is freed when going outside of the scope.
+LogicalResult mlir::linalg::deallocateGPUPrivateMemory(OpBuilder &,
+                                                       Value /*buffer*/) {
+  return success();
+}
\ No newline at end of file
index 085c1b7..b34a86e 100644 (file)
@@ -143,6 +143,94 @@ transform.sequence failures(propagate) {
 }
 
 // -----
+func.func @gemm_shared(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
+{
+   linalg.matmul ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%c: memref<?x?xf32>)
+   return
+}
+
+// CHECK: func @gemm_shared
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: %[[alloc_A:.*]] = memref.alloc() : memref<16x16xf32, #gpu.address_space<workgroup>>
+// CHECK: %[[alloc_B:.*]] = memref.alloc() : memref<16x16xf32, #gpu.address_space<workgroup>>
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1
+// CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:         %[[subview_A:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         %[[subview_B:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         %[[subview_C:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+
+// CHECK:         %[[shared_A:.*]] = memref.subview %[[alloc_B]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+// CHECK:         %[[shared_B:.*]] = memref.subview %[[alloc_A]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+
+// CHECK-NEXT:    gpu.barrier
+// CHECK-NEXT:    memref.copy %[[subview_A]], %[[shared_A]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+// CHECK-NEXT:    gpu.barrier
+
+// CHECK-NEXT:    gpu.barrier
+// CHECK-NEXT:    memref.copy %[[subview_B]], %[[shared_B]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+// CHECK-NEXT:    gpu.barrier
+
+// CHECK:         linalg.matmul ins(%[[shared_A]], %[[shared_B]]{{.*}} outs(%[[subview_C]]
+
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+  %2 = transform.structured.promote %1 { operands_to_promote = [0, 1], mapping = [#gpu.memory_space<workgroup>] }
+}
+
+
+// -----
+
+func.func @gemm_private(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
+{
+   linalg.matmul ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+               outs(%c: memref<?x?xf32>)
+   return
+}
+
+// CHECK: func @gemm_private
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: %[[alloc_A:.*]] = memref.alloca() : memref<16x16xf32, #gpu.address_space<private>>
+// CHECK: %[[alloc_B:.*]] = memref.alloca() : memref<16x16xf32, #gpu.address_space<private>>
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1
+// CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:     scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:         %[[subview_A:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         %[[subview_B:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK:         %[[subview_C:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+
+// CHECK:         %[[private_A:.*]] = memref.subview %[[alloc_B]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<private>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+// CHECK:         %[[private_B:.*]] = memref.subview %[[alloc_A]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<private>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+
+// CHECK-NEXT:    memref.copy %[[subview_A]], %[[private_A]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+// CHECK-NEXT:    memref.copy %[[subview_B]], %[[private_B]] :  memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+
+// CHECK:         linalg.matmul ins(%[[private_A]], %[[private_B]]{{.*}} outs(%[[subview_C]]
+
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+  %2 = transform.structured.promote %1 { operands_to_promote = [0, 1], mapping = [#gpu.memory_space<private>] }
+}
+
+
+// -----
 
 #map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map7 = affine_map<(d0, d1, d2) -> (d1, d2)>