Convert MemRefType to a linearized array in SPIR-V lowering.
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 3 Dec 2019 18:20:37 +0000 (10:20 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Dec 2019 18:21:16 +0000 (10:21 -0800)
The SPIR-V lowering used nested !spv.arrays to represented
multi-dimensional arrays, with the hope that in-conjunction with the
layout annotations, the shape and layout of memref can be represented
directly. It is unclear though how portable this representation will
end up being. It will rely on driver compilers implementing complex
index computations faithfully. A more portable approach is to use
linearized arrays to represent memrefs and explicitly instantiate all
the index computation in SPIR-V. This gives added benefit that we can
further optimize the generated code in MLIR before generating the
SPIR-V binary.

PiperOrigin-RevId: 283571167

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/GPUToSPIRV/loop.mlir

index 4a3d25f..ee2dfed 100644 (file)
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
+// Utility functions for operation conversion
+//===----------------------------------------------------------------------===//
+
+/// Performs the index computation to get to the element pointed to by
+/// `indices` using the layout map of `baseType`.
+
+// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
+// MemRefType with AffineMap that has static strides. Handle dynamic strides
+spirv::AccessChainOp getElementPtr(OpBuilder &builder,
+                                   SPIRVTypeConverter &typeConverter,
+                                   Location loc, MemRefType origBaseType,
+                                   Value *basePtr, ArrayRef<Value *> indices) {
+  // Get base and offset of the MemRefType and verify they are static.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
+      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+    return nullptr;
+  }
+
+  auto indexType = typeConverter.getIndexType(builder.getContext());
+
+  Value *ptrLoc = nullptr;
+  assert(indices.size() == strides.size());
+  for (auto index : enumerate(indices)) {
+    Value *strideVal = builder.create<spirv::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+    Value *update =
+        builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+    ptrLoc =
+        (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
+                : update);
+  }
+  SmallVector<Value *, 2> linearizedIndices;
+  // Add a '0' at the start to index into the struct.
+  linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
+      loc, indexType, IntegerAttr::get(indexType, 0)));
+  linearizedIndices.push_back(ptrLoc);
+  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+}
+
+//===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
 
@@ -38,6 +80,7 @@ namespace {
 /// operation. Since IndexType is not used within SPIR-V dialect, this needs
 /// special handling to make sure the result type and the type of the value
 /// attribute are consistent.
+// TODO(ravishankarm) : This should be moved into DRR.
 class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
 public:
   using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
@@ -112,6 +155,7 @@ public:
 /// the type of the return value of the replacement operation differs from
 /// that of the replaced operation. This is not handled in tablegen-based
 /// pattern specification.
+// TODO(ravishankarm) : This should be moved into DRR.
 template <typename StdOp, typename SPIRVOp>
 class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
 public:
@@ -128,36 +172,10 @@ public:
   }
 };
 
-// If 'basePtr' is the result of lowering a value of MemRefType, and 'indices'
-// are the indices used to index into the original value (for load/store),
-// perform the equivalent address calculation in SPIR-V.
-spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc,
-                                   Value *basePtr, ArrayRef<Value *> indices,
-                                   SPIRVTypeConverter &typeConverter) {
-  // MemRefType is converted to a
-  // spirv::StructType<spirv::ArrayType<spirv:ArrayType...>>>
-  auto ptrType = basePtr->getType().cast<spirv::PointerType>();
-  (void)ptrType;
-  auto structType = ptrType.getPointeeType().cast<spirv::StructType>();
-  (void)structType;
-  assert(structType.getNumElements() == 1);
-  auto indexType = typeConverter.getIndexType(builder.getContext());
-
-  // Need to add a '0' at the beginning of the index list for accessing into the
-  // struct that wraps the nested array types.
-  Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
-  SmallVector<Value *, 4> accessIndices;
-  accessIndices.reserve(1 + indices.size());
-  accessIndices.push_back(zero);
-  accessIndices.append(indices.begin(), indices.end());
-  return builder.create<spirv::AccessChainOp>(loc, basePtr, accessIndices);
-}
-
 /// Convert load -> spv.LoadOp. The operands of the replaced operation are of
 /// IndexType while that of the replacement operation are of type i32. This is
 /// not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : These could potentially be templated on the operation
-// being converted, since the same logic should work for linalg.load.
+// TODO(ravishankarm) : This should be moved into DRR.
 class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
 public:
   using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
@@ -166,9 +184,9 @@ public:
   matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     LoadOpOperandAdaptor loadOperands(operands);
-    auto basePtr = loadOperands.memref();
-    auto loadPtr = getElementPtr(rewriter, loadOp.getLoc(), basePtr,
-                                 loadOperands.indices(), typeConverter);
+    auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
+                                 loadOp.memref()->getType().cast<MemRefType>(),
+                                 loadOperands.memref(), loadOperands.indices());
     rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
                                                /*memory_access =*/nullptr,
                                                /*alignment =*/nullptr);
@@ -177,6 +195,7 @@ public:
 };
 
 /// Convert return -> spv.Return.
+// TODO(ravishankarm) : This should be moved into DRR.
 class ReturnToSPIRVConversion final : public SPIRVOpLowering<ReturnOp> {
 public:
   using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
@@ -193,6 +212,7 @@ public:
 };
 
 /// Convert select -> spv.Select
+// TODO(ravishankarm) : This should be moved into DRR.
 class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
 public:
   using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
@@ -210,8 +230,7 @@ public:
 /// Convert store -> spv.StoreOp. The operands of the replaced operation are
 /// of IndexType while that of the replacement operation are of type i32. This
 /// is not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : These could potentially be templated on the operation
-// being converted, since the same logic should work for linalg.store.
+// TODO(ravishankarm) : This should be moved into DRR.
 class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
 public:
   using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
@@ -220,11 +239,12 @@ public:
   matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     StoreOpOperandAdaptor storeOperands(operands);
-    auto value = storeOperands.value();
-    auto basePtr = storeOperands.memref();
-    auto storePtr = getElementPtr(rewriter, storeOp.getLoc(), basePtr,
-                                  storeOperands.indices(), typeConverter);
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, value,
+    auto storePtr =
+        getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
+                      storeOp.memref()->getType().cast<MemRefType>(),
+                      storeOperands.memref(), storeOperands.indices());
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
+                                                storeOperands.value(),
                                                 /*memory_access =*/nullptr,
                                                 /*alignment =*/nullptr);
     return matchSuccess();
index baa9ed3..e3b5502 100644 (file)
@@ -86,6 +86,33 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
     return integerType.getWidth() / 8;
   } else if (auto floatType = t.dyn_cast<FloatType>()) {
     return floatType.getWidth() / 8;
+  } else if (auto memRefType = t.dyn_cast<MemRefType>()) {
+    // TODO: Layout should also be controlled by the ABI attributes. For now
+    // using the layout from MemRef.
+    int64_t offset;
+    SmallVector<int64_t, 4> strides;
+    if (!memRefType.hasStaticShape() ||
+        failed(getStridesAndOffset(memRefType, strides, offset))) {
+      return llvm::None;
+    }
+    // To get the size of the memref object in memory, the total size is the
+    // max(stride * dimension-size) computed for all dimensions times the size
+    // of the element.
+    auto elementSize = getTypeNumBytes(memRefType.getElementType());
+    if (!elementSize) {
+      return llvm::None;
+    }
+    auto dims = memRefType.getShape();
+    if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
+        offset == MemRefType::getDynamicStrideOrOffset() ||
+        llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+      return llvm::None;
+    }
+    int64_t memrefSize = -1;
+    for (auto shape : enumerate(dims)) {
+      memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
+    }
+    return (offset + memrefSize) * elementSize.getValue();
   }
   // TODO: Add size computation for other types.
   return llvm::None;
@@ -120,40 +147,21 @@ static Type convertStdType(Type type) {
     if (!elementSize) {
       return Type();
     }
-
-    if (!memRefType.hasStaticShape()) {
-      // TODO(ravishankarm) : Handle dynamic shapes.
-      return Type();
-    }
-
-    // Get the strides and offset.
-    int64_t offset;
-    SmallVector<int64_t, 4> strides;
-    if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
-        offset == MemRefType::getDynamicStrideOrOffset() ||
-        llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
-      // TODO(ravishankarm) : Handle dynamic strides and offsets.
-      return Type();
-    }
-
-    // Convert to a multi-dimensional spv.array if size is known.
-    auto shape = memRefType.getShape();
-    assert(shape.size() == strides.size());
-    Type arrayType = elementType;
-    // TODO(antiagainst): Introduce layout as part of the shader ABI to have
-    // better separate of concerns.
-    for (int i = shape.size(); i > 0; --i) {
-      arrayType = spirv::ArrayType::get(
-          arrayType, shape[i - 1], strides[i - 1] * elementSize.getValue());
+    // TODO(ravishankarm) : Handle dynamic shapes.
+    if (memRefType.hasStaticShape()) {
+      auto arraySize = getTypeNumBytes(memRefType);
+      if (!arraySize) {
+        return Type();
+      }
+      auto arrayType = spirv::ArrayType::get(
+          elementType, arraySize.getValue() / elementSize.getValue(),
+          elementSize.getValue());
+      auto structType = spirv::StructType::get(arrayType, 0);
+      // For now initialize the storage class to StorageBuffer. This will be
+      // updated later based on whats passed in w.r.t to the ABI attributes.
+      return spirv::PointerType::get(structType,
+                                     spirv::StorageClass::StorageBuffer);
     }
-
-    // For the offset, need to wrap the array in a struct.
-    auto structType =
-        spirv::StructType::get(arrayType, offset * elementSize.getValue());
-    // For now initialize the storage class to StorageBuffer. This will be
-    // updated later based on whats passed in w.r.t to the ABI attributes.
-    return spirv::PointerType::get(structType,
-                                   spirv::StorageClass::StorageBuffer);
   }
 
   return Type();
index 2f76c5b..786a16b 100644 (file)
@@ -22,9 +22,9 @@ module attributes {gpu.container_module} {
     // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-LABEL:    func @load_store_kernel
-    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     // CHECK-SAME: [[ARG3:%.*]]: i32 {spirv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     // CHECK-SAME: [[ARG4:%.*]]: i32 {spirv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     // CHECK-SAME: [[ARG5:%.*]]: i32 {spirv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
@@ -53,18 +53,21 @@ module attributes {gpu.container_module} {
       %12 = addi %arg3, %0 : index
       // CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], [[LOCALINVOCATIONIDX]]
       %13 = addi %arg4, %3 : index
+      // CHECK: [[STRIDE1_1:%.*]] = spv.constant 4 : i32
+      // CHECK: [[OFFSET1_1:%.*]] = spv.IMul [[STRIDE1_1]], [[INDEX1]] : i32
+      // CHECK: [[STRIDE1_2:%.*]] = spv.constant 1 : i32
+      // CHECK: [[UPDATE1_2:%.*]] = spv.IMul [[STRIDE1_2]], [[INDEX2]] : i32
+      // CHECK: [[OFFSET1_2:%.*]] = spv.IAdd [[OFFSET1_1]], [[UPDATE1_2]] : i32
       // CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
-      // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[ZERO1]], [[INDEX1]], [[INDEX2]]{{\]}}
+      // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[ZERO1]], [[OFFSET1_2]]{{\]}}
       // CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
       %14 = load %arg0[%12, %13] : memref<12x4xf32>
-      // CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
-      // CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[ZERO2]], [[INDEX1]], [[INDEX2]]{{\]}}
+      // CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
       // CHECK-NEXT: [[VAL2:%.*]] = spv.Load "StorageBuffer" [[PTR2]]
       %15 = load %arg1[%12, %13] : memref<12x4xf32>
       // CHECK: [[VAL3:%.*]] = spv.FAdd [[VAL1]], [[VAL2]]
       %16 = addf %14, %15 : f32
-      // CHECK: [[ZERO3:%.*]] = spv.constant 0 : i32
-      // CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}[[ZERO3]], [[INDEX1]], [[INDEX2]]{{\]}}
+      // CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
       // CHECK-NEXT: spv.Store "StorageBuffer" [[PTR3]], [[VAL3]]
       store %16, %arg2[%12, %13] : memref<12x4xf32>
       return
index ba6ee07..43a6b3e 100644 (file)
@@ -22,8 +22,12 @@ module attributes {gpu.container_module} {
       // CHECK:        [[CMP:%.*]] = spv.SLessThan [[INDVAR]], [[UB]] : i32
       // CHECK:        spv.BranchConditional [[CMP]], [[BODY:\^.*]], [[MERGE:\^.*]]
       // CHECK:      [[BODY]]:
-      // CHECK:        spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[INDVAR]]{{\]}} : {{.*}}
-      // CHECK:        spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[INDVAR]]{{\]}} : {{.*}}
+      // CHECK:        [[STRIDE1:%.*]] = spv.constant 1 : i32
+      // CHECK:        [[OFFSET1:%.*]] = spv.IMul [[STRIDE1]], [[INDVAR]] : i32
+      // CHECK:        spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[OFFSET1]]{{\]}} : {{.*}}
+      // CHECK:        [[STRIDE2:%.*]] = spv.constant 1 : i32
+      // CHECK:        [[OFFSET2:%.*]] = spv.IMul [[STRIDE2]], [[INDVAR]] : i32
+      // CHECK:        spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[OFFSET2]]{{\]}} : {{.*}}
       // CHECK:        [[INCREMENT:%.*]] = spv.IAdd [[INDVAR]], [[STEP]] : i32
       // CHECK:        spv.Branch [[HEADER]]([[INCREMENT]] : i32)
       // CHECK:      [[MERGE]]