[MLIR][SPIRVToLLVM] Additional conversions for spirv-runner
authorGeorge Mitenkov <georgemitenk0v@gmail.com>
Tue, 18 Aug 2020 15:42:23 +0000 (18:42 +0300)
committerGeorge Mitenkov <georgemitenk0v@gmail.com>
Tue, 18 Aug 2020 16:09:59 +0000 (19:09 +0300)
This patch adds more op/type conversion support
necessary for `spirv-runner`:
- EntryPoint/ExecutionMode: currently removed since we assume
having only one kernel function in the kernel module.
- StorageBuffer storage class is now supported. We are not
concerned with multithreading so this is fine for now.
- Type conversion enhanced, now regular offsets and strides
for structs and arrays are supported (based on
`VulkanLayoutUtils`).
- Support of `spc.AccessChain` that is modelled with GEP op
in LLVM dialect.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86109

mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir
mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir
mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir
mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir

index e7c5b3c..9c2ba26 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -179,6 +180,22 @@ static Value processCountOrOffset(Location loc, Value value, Type srcType,
   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
 }
 
+/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
+/// offset to LLVM struct. Otherwise, the conversion is not supported.
+static Optional<Type>
+convertStructTypeWithOffset(spirv::StructType type,
+                            LLVMTypeConverter &converter) {
+  if (type != VulkanLayoutUtils::decorateType(type))
+    return llvm::None;
+
+  auto elementsVector = llvm::to_vector<8>(
+      llvm::map_range(type.getElementTypes(), [&](Type elementType) {
+        return converter.convertType(elementType).cast<LLVM::LLVMType>();
+      }));
+  return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
+                                     /*isPacked=*/false);
+}
+
 /// Converts SPIR-V struct with no offset to packed LLVM struct.
 static Type convertStructTypePacked(spirv::StructType type,
                                     LLVMTypeConverter &converter) {
@@ -223,16 +240,22 @@ static LogicalResult replaceWithLoadOrStore(Operation *op,
 // Type conversion
 //===----------------------------------------------------------------------===//
 
-/// Converts SPIR-V array type to LLVM array. There is no modelling of array
-/// stride at the moment.
+/// Converts SPIR-V array type to LLVM array. Natural stride (according to
+/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
+/// when converting ops that manipulate array types.
 static Optional<Type> convertArrayType(spirv::ArrayType type,
                                        TypeConverter &converter) {
-  if (type.getArrayStride() != 0)
+  unsigned stride = type.getArrayStride();
+  Type elementType = type.getElementType();
+  auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
+  if (stride != 0 &&
+      !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
     return llvm::None;
-  auto elementType =
-      converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
+
+  auto llvmElementType =
+      converter.convertType(elementType).cast<LLVM::LLVMType>();
   unsigned numElements = type.getNumElements();
-  return LLVM::LLVMType::getArrayTy(elementType, numElements);
+  return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
 }
 
 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
@@ -257,13 +280,15 @@ static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
 }
 
 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
-/// member decorations or with offset.
+/// member decorations. Also, only natural offset is supported.
 static Optional<Type> convertStructType(spirv::StructType type,
                                         LLVMTypeConverter &converter) {
   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
   type.getMemberDecorations(memberDecorations);
-  if (type.hasOffset() || !memberDecorations.empty())
+  if (!memberDecorations.empty())
     return llvm::None;
+  if (type.hasOffset())
+    return convertStructTypeWithOffset(type, converter);
   return convertStructTypePacked(type, converter);
 }
 
@@ -273,6 +298,31 @@ static Optional<Type> convertStructType(spirv::StructType type,
 
 namespace {
 
+class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
+public:
+  using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = typeConverter.convertType(op.component_ptr().getType());
+    if (!dstType)
+      return failure();
+    // To use GEP we need to add a first 0 index to go through the pointer.
+    auto indices = llvm::to_vector<4>(op.indices());
+    Type indexType = op.indices().front().getType();
+    auto llvmIndexType = typeConverter.convertType(indexType);
+    if (!llvmIndexType)
+      return failure();
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
+    indices.insert(indices.begin(), zero);
+    rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
+                                             indices);
+    return success();
+  }
+};
+
 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
 public:
   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
@@ -545,11 +595,14 @@ public:
     if (!dstType)
       return failure();
 
-    // Limit conversion to the current invocation only for now.
+    // Limit conversion to the current invocation only or `StorageBuffer`
+    // required by SPIR-V runner.
+    // This is okay because multiple invocations are not supported yet.
     auto storageClass = srcType.getStorageClass();
     if (storageClass != spirv::StorageClass::Input &&
         storageClass != spirv::StorageClass::Private &&
-        storageClass != spirv::StorageClass::Output) {
+        storageClass != spirv::StorageClass::Output &&
+        storageClass != spirv::StorageClass::StorageBuffer) {
       return failure();
     }
 
@@ -757,6 +810,20 @@ public:
   }
 };
 
+/// A template pattern that erases the given `SPIRVOp`.
+template <typename SPIRVOp>
+class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
+public:
+  using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
+
+  LogicalResult
+  matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
 public:
   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
@@ -875,18 +942,6 @@ public:
   }
 };
 
-class MergePattern : public SPIRVToLLVMConversion<spirv::MergeOp> {
-public:
-  using SPIRVToLLVMConversion<spirv::MergeOp>::SPIRVToLLVMConversion;
-
-  LogicalResult
-  matchAndRewrite(spirv::MergeOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
 /// Converts `spv.selection` with `spv.BranchConditional` in its header block.
 /// All blocks within selection should be reachable for conversion to succeed.
 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
@@ -1266,11 +1321,18 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       ConstantScalarAndVectorPattern,
 
       // Control Flow ops
-      BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern,
-      SelectionPattern, MergePattern,
+      BranchConversionPattern, BranchConditionalConversionPattern,
+      FunctionCallPattern, LoopPattern, SelectionPattern,
+      ErasePattern<spirv::MergeOp>,
+
+      // Entry points and execution mode
+      // Module generated from SPIR-V could have other "internal" functions, so
+      // having entry point and execution mode metadat can be useful. For now,
+      // simply remove them.
+      // TODO: Support EntryPoint/ExecutionMode properly.
+      ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
 
       // Function Call op
-      FunctionCallPattern,
 
       // GLSL extended instruction set ops
       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
@@ -1295,8 +1357,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
       NotPattern<spirv::LogicalNotOp>,
 
       // Memory ops
-      AddressOfPattern, GlobalVariablePattern, LoadStorePattern<spirv::LoadOp>,
-      LoadStorePattern<spirv::StoreOp>, VariablePattern,
+      AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
+      LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
+      VariablePattern,
 
       // Miscellaneous ops
       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
index 51a734c..4402a51 100644 (file)
@@ -1,6 +1,31 @@
 // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
+// spv.AccessChain
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @access_chain
+func @access_chain() -> () {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
+  %0 = spv.constant 1: i32
+  %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+  // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (float, array<4 x float>)>>, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.ptr<float>
+  %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
+  return
+}
+
+// CHECK-LABEL: @access_chain_array
+func @access_chain_array(%arg0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
+  // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %{{.*}}] : (!llvm.ptr<array<4 x array<4 x float>>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<array<4 x float>>
+  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
+  %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
+  return
+}
+
+//===----------------------------------------------------------------------===//
 // spv.globalVariable and spv._address_of
 //===----------------------------------------------------------------------===//
 
index 2e74485..d54b916 100644 (file)
@@ -21,6 +21,23 @@ func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.EntryPoint and spv.ExecutionMode
+//===----------------------------------------------------------------------===//
+
+//      CHECK: module {
+// CHECK-NEXT:   llvm.func @empty
+// CHECK-NEXT:     llvm.return
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+spv.module Logical GLSL450 {
+  spv.func @empty() -> () "None" {
+    spv.Return
+  }
+  spv.EntryPoint "GLCompute" @empty
+  spv.ExecutionMode @empty "LocalSize", 1, 1, 1
+}
+
+//===----------------------------------------------------------------------===//
 // spv.Undef
 //===----------------------------------------------------------------------===//
 
index 96fb9f4..87f0bd8 100644 (file)
@@ -1,21 +1,14 @@
 // RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file
 
 // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
-spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" {
+spv.func @array_with_unnatural_stride(%arg: !spv.array<4 x f32, stride=8>) -> () "None" {
   spv.Return
 }
 
 // -----
 
 // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
-spv.func @struct_with_offset1(%arg: !spv.struct<i32[0], i32[4]>) -> () "None" {
-  spv.Return
-}
-
-// -----
-
-// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}}
-spv.func @struct_with_offset2(%arg: !spv.struct<i32[0], i32[8]>) -> () "None" {
+spv.func @struct_with_unnatural_offset(%arg: !spv.struct<i32[0], i32[8]>) -> () "None" {
   spv.Return
 }
 
index d6618a7..454b5b3 100644 (file)
@@ -5,7 +5,10 @@
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @array(!llvm.array<16 x float>, !llvm.array<32 x vec<4 x float>>)
-func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> ()
+func @array(!spv.array<16 x f32>, !spv.array< 32 x vector<4xf32> >) -> ()
+
+// CHECK-LABEL: @array_with_natural_stride(!llvm.array<16 x float>)
+func @array_with_natural_stride(!spv.array<16 x f32, stride=4>) -> ()
 
 //===----------------------------------------------------------------------===//
 // Pointer type
@@ -36,3 +39,6 @@ func @struct(!spv.struct<f64>) -> ()
 
 // CHECK-LABEL: @struct_nested(!llvm.struct<packed (i32, struct<packed (i64, i32)>)>)
 func @struct_nested(!spv.struct<i32, !spv.struct<i64, i32>>)
+
+// CHECK-LABEL: @struct_with_natural_offset(!llvm.struct<(i8, i32)>)
+func @struct_with_natural_offset(!spv.struct<i8[0], i32[4]>) -> ()