Enhance GPU To SPIR-V conversion to support builtins and load/store ops.
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 27 Aug 2019 17:49:53 +0000 (10:49 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Aug 2019 17:50:23 +0000 (10:50 -0700)
To support a conversion of a simple load-compute-store kernel from GPU
dialect to SPIR-V dialect, the conversion of operations like
"gpu.block_dim", "gpu.thread_id" which allow threads to get the launch
conversion is needed. In SPIR-V these are specified as global
variables with builin attributes. This CL adds support to specify
builtin variables in SPIR-V conversion framework. This is used to
convert the relevant operations from GPU dialect to SPIR-V dialect.
Also add support for conversion of load/store operation in Standard
dialect to SPIR-V dialect.
To simplify the conversion add a method to build a spv.AccessChain
operation that automatically determines the return type based on the
base pointer type and the indices provided.

PiperOrigin-RevId: 265718525

mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Conversion/GPUToSPIRV/builtins.mlir [new file with mode: 0644]
mlir/test/Conversion/GPUToSPIRV/load_store.mlir [new file with mode: 0644]

index adfd83b..25a710f 100644 (file)
 #ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
 #define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
 
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Support/StringExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 
+class LoadOp;
+class ReturnOp;
+class StoreOp;
 namespace spirv {
 class SPIRVDialect;
 }
@@ -63,7 +68,7 @@ public:
   LogicalResult convertSignatureArg(unsigned inputNo, Type type,
                                     SignatureConversion &result) override;
 
-  /// Get the basic type converter.
+  /// Gets the basic type converter.
   SPIRVBasicTypeConverter *getBasicTypeConverter() const {
     return basicTypeConverter;
   }
@@ -80,17 +85,98 @@ public:
         typeConverter(typeConverter) {}
 
 protected:
-  // Type lowering class.
+  /// Gets the global variable associated with a builtin and add
+  /// it if it doesnt exist.
+  Value *loadFromBuiltinVariable(Operation *op, spirv::BuiltIn builtin,
+                                 ConversionPatternRewriter &rewriter) const {
+    auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
+    if (!moduleOp) {
+      op->emitError("expected operation to be within a SPIR-V module");
+      return nullptr;
+    }
+    auto varOp =
+        getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, rewriter);
+    auto ptr = rewriter
+                   .create<spirv::AddressOfOp>(op->getLoc(), varOp.type(),
+                                               rewriter.getSymbolRefAttr(varOp))
+                   .pointer();
+    return rewriter.create<spirv::LoadOp>(
+        op->getLoc(),
+        ptr->getType().template cast<spirv::PointerType>().getPointeeType(),
+        ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr);
+  }
+
+  /// Type lowering class.
   SPIRVTypeConverter &typeConverter;
+
+private:
+  /// Look through all global variables in `moduleOp` and check if there is a
+  /// spv.globalVariable that has the same `builtin` attribute.
+  spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+                                             spirv::BuiltIn builtin) const {
+    for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
+      if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
+              stringifyDecoration(spirv::Decoration::BuiltIn)))) {
+        auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+        if (varBuiltIn && varBuiltIn.getValue() == builtin) {
+          return varOp;
+        }
+      }
+    }
+    return nullptr;
+  }
+
+  /// Gets name of global variable for a buitlin.
+  std::string getBuiltinVarName(spirv::BuiltIn builtin) const {
+    return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() +
+           "__";
+  }
+
+  /// Gets or inserts a global variable for a builtin within a module.
+  spirv::GlobalVariableOp
+  getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
+                             spirv::BuiltIn builtin,
+                             ConversionPatternRewriter &builder) const {
+    if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
+      return varOp;
+    }
+    auto ip = builder.saveInsertionPoint();
+    builder.setInsertionPointToStart(&moduleOp.getBlock());
+    auto name = getBuiltinVarName(builtin);
+    spirv::GlobalVariableOp newVarOp;
+    switch (builtin) {
+    case spirv::BuiltIn::NumWorkgroups:
+    case spirv::BuiltIn::WorkgroupSize:
+    case spirv::BuiltIn::WorkgroupId:
+    case spirv::BuiltIn::LocalInvocationId:
+    case spirv::BuiltIn::GlobalInvocationId: {
+      auto ptrType = spirv::PointerType::get(
+          builder.getVectorType({3}, builder.getIntegerType(32)),
+          spirv::StorageClass::Input);
+      newVarOp = builder.create<spirv::GlobalVariableOp>(
+          loc, builder.getTypeAttr(ptrType), builder.getStringAttr(name),
+          nullptr);
+      newVarOp.setAttr(
+          convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
+          builder.getStringAttr(stringifyBuiltIn(builtin)));
+      break;
+    }
+    default:
+      emitError(loc, "unimplemented builtin variable generation for ")
+          << stringifyBuiltIn(builtin);
+    }
+    builder.restoreInsertionPoint(ip);
+    return newVarOp;
+  }
 };
 
-/// Method to legalize a function as a non-entry function.
+/// Legalizes a function as a non-entry function.
 LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
                             SPIRVTypeConverter *typeConverter,
                             ConversionPatternRewriter &rewriter,
                             FuncOp &newFuncOp);
 
-/// Method to legalize a function as an entry function.
+/// Legalizes a function as an entry function.
 LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
                                    SPIRVTypeConverter *typeConverter,
                                    ConversionPatternRewriter &rewriter,
index 5fccf1b..6aad600 100644 (file)
@@ -113,6 +113,9 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
   let results = (outs
     SPV_AnyPtr:$component_ptr
   );
+
+  let builders = [OpBuilder<[{Builder *builder, OperationState *state,
+                              Value *basePtr, ArrayRef<Value *> indices}]>];
 }
 
 // -----
index ff6af83..06b2498 100644 (file)
@@ -1,4 +1,4 @@
-//===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===//
+//===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -29,6 +29,18 @@ using namespace mlir;
 
 namespace {
 
+/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
+/// builin variables.
+template <typename OpTy, spirv::BuiltIn builtin>
+class LaunchConfigConversion : public SPIRVOpLowering<OpTy> {
+public:
+  using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
 /// attribute gpu.kernel) within a spv.module.
 class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
@@ -41,6 +53,33 @@ public:
 };
 } // namespace
 
+template <typename OpTy, spirv::BuiltIn builtin>
+PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
+    Operation *op, ArrayRef<Value *> operands,
+    ConversionPatternRewriter &rewriter) const {
+  auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
+  if (!dimAttr) {
+    return this->matchFailure();
+  }
+  int32_t index = 0;
+  if (dimAttr.getValue() == "x") {
+    index = 0;
+  } else if (dimAttr.getValue() == "y") {
+    index = 1;
+  } else if (dimAttr.getValue() == "z") {
+    index = 2;
+  } else {
+    return this->matchFailure();
+  }
+
+  // SPIR-V invocation builtin variables are a vector of type <3xi32>
+  auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter);
+  rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+      op, rewriter.getIntegerType(32), spirvBuiltin,
+      rewriter.getI32ArrayAttr({index}));
+  return this->matchSuccess();
+}
+
 PatternMatchResult
 KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                                     ConversionPatternRewriter &rewriter) const {
@@ -106,7 +145,13 @@ void GPUToSPIRVPass::runOnModule() {
   SPIRVBasicTypeConverter basicTypeConverter(context);
   SPIRVTypeConverter typeConverter(&basicTypeConverter);
   OwningRewritePatternList patterns;
-  patterns.insert<KernelFnConversion>(context, typeConverter);
+  patterns.insert<
+      KernelFnConversion,
+      LaunchConfigConversion<gpu::BlockDim, spirv::BuiltIn::WorkgroupSize>,
+      LaunchConfigConversion<gpu::BlockId, spirv::BuiltIn::WorkgroupId>,
+      LaunchConfigConversion<gpu::GridDim, spirv::BuiltIn::NumWorkgroups>,
+      LaunchConfigConversion<gpu::ThreadId, spirv::BuiltIn::LocalInvocationId>>(
+      context, typeConverter);
   populateStandardToSPIRVPatterns(context, patterns);
 
   ConversionTarget target(*context);
index b7dfff4..e3bcc04 100644 (file)
@@ -39,15 +39,22 @@ Type SPIRVBasicTypeConverter::convertType(Type t) {
     return t;
   }
 
+  if (auto indexType = t.dyn_cast<IndexType>()) {
+    // Return I32 for index types.
+    return IntegerType::get(32, t.getContext());
+  }
+
   if (auto memRefType = t.dyn_cast<MemRefType>()) {
     if (memRefType.hasStaticShape()) {
-      // Convert MemrefType to spv.array if size is known.
+      // Convert MemrefType to a multi-dimensional spv.array if size is known.
+      auto elementType = memRefType.getElementType();
+      for (auto size : reverse(memRefType.getShape())) {
+        elementType = spirv::ArrayType::get(elementType, size);
+      }
       // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
       // to support other Storage Classes.
-      return spirv::PointerType::get(
-          spirv::ArrayType::get(memRefType.getElementType(),
-                                memRefType.getNumElements()),
-          spirv::StorageClass::StorageBuffer);
+      return spirv::PointerType::get(elementType,
+                                     spirv::StorageClass::StorageBuffer);
     }
   }
   return Type();
@@ -68,8 +75,12 @@ SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
   if (!convertedType)
     return failure();
   // For arguments to entry functions, convert the type into a pointer type if
-  // it is already not one.
-  if (!convertedType.isa<spirv::PointerType>()) {
+  // it is already not one, unless the original type was an index type.
+  // TODO(ravishankarm): For arguments that are of index type, keep the
+  // arguments as the scalar converted type, i.e. i32. These are still not
+  // handled effectively. These are potentially best handled as specialization
+  // constants.
+  if (!convertedType.isa<spirv::PointerType>() && !type.isa<IndexType>()) {
     // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
     // to support other Storage classes.
     convertedType = spirv::PointerType::get(convertedType,
@@ -143,29 +154,40 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
   if (!module) {
     return funcOp.emitError("expected op to be within a spv.module");
   }
-  OpBuilder builder(module.getOperation()->getRegion(0));
+  auto ip = rewriter.saveInsertionPoint();
+  rewriter.setInsertionPointToStart(&module.getBlock());
   SmallVector<Attribute, 4> interface;
   for (auto &convertedArgType :
        llvm::enumerate(signatureConverter.getConvertedTypes())) {
+    // TODO(ravishankarm) : The arguments to the converted function are either
+    // spirv::PointerType or i32 type, the latter due to conversion of index
+    // type to i32. Eventually entry function should be of signature
+    // void(void). Arguments converted to spirv::PointerType, will be made
+    // variables and those converted to i32 will be made specialization
+    // constants. Latter is not implemented.
+    if (!convertedArgType.value().isa<spirv::PointerType>()) {
+      continue;
+    }
     std::string varName = funcOp.getName().str() + "_arg_" +
                           std::to_string(convertedArgType.index());
-    auto variableOp = builder.create<spirv::GlobalVariableOp>(
-        funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
-        builder.getStringAttr(varName), nullptr);
-    variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
+    auto variableOp = rewriter.create<spirv::GlobalVariableOp>(
+        funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()),
+        rewriter.getStringAttr(varName), nullptr);
+    variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0));
     variableOp.setAttr("binding",
-                       builder.getI32IntegerAttr(convertedArgType.index()));
-    interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
+                       rewriter.getI32IntegerAttr(convertedArgType.index()));
+    interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name()));
   }
   // Create an entry point instruction for this function.
   // TODO(ravishankarm) : Add execution mode for the entry function
-  builder.setInsertionPoint(&(module.getBlock().back()));
-  builder.create<spirv::EntryPointOp>(
+  rewriter.setInsertionPoint(&(module.getBlock().back()));
+  rewriter.create<spirv::EntryPointOp>(
       funcOp.getLoc(),
-      builder.getI32IntegerAttr(
+      rewriter.getI32IntegerAttr(
           static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
-      builder.getSymbolRefAttr(newFuncOp.getName()),
-      builder.getArrayAttr(interface));
+      rewriter.getSymbolRefAttr(newFuncOp.getName()),
+      rewriter.getArrayAttr(interface));
+  rewriter.restoreInsertionPoint(ip);
   return success();
 }
 } // namespace mlir
@@ -175,6 +197,56 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
 //===----------------------------------------------------------------------===//
 
 namespace {
+
+/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
+/// for this. If the integer operation is on variables of IndexType, the type of
+/// the return value of the replacement operation differs from that of the
+/// replaced operation. This is not handled in tablegen-based pattern
+/// specification.
+template <typename StdOp, typename SPIRVOp>
+class IntegerOpConversion final : public ConversionPattern {
+public:
+  IntegerOpConversion(MLIRContext *context)
+      : ConversionPattern(StdOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.template replaceOpWithNewOp<SPIRVOp>(
+        op, operands[0]->getType(), operands, ArrayRef<NamedAttribute>());
+    return this->matchSuccess();
+  }
+};
+
+/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
+/// IndexType while that of the replacement operation are of type i32. This is
+/// not suppored in tablegen based pattern specification.
+// TODO(ravishankarm) : These could potentially be templated on the operation
+// being converted, since the same logic should work for linalg.load.
+class LoadOpConversion final : public ConversionPattern {
+public:
+  LoadOpConversion(MLIRContext *context)
+      : ConversionPattern(LoadOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    LoadOpOperandAdaptor loadOperands(operands);
+    auto basePtr = loadOperands.memref();
+    auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
+    if (!ptrType) {
+      return matchFailure();
+    }
+    auto loadPtr = rewriter.create<spirv::AccessChainOp>(
+        op->getLoc(), basePtr, loadOperands.indices());
+    auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>();
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(
+        op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr,
+        /*alignment =*/nullptr);
+    return matchSuccess();
+  }
+};
+
 /// Convert return -> spv.Return.
 class ReturnToSPIRVConversion : public ConversionPattern {
 public:
@@ -191,6 +263,35 @@ public:
   }
 };
 
+/// Convert store -> spv.StoreOp. The operands of the replaced operation are of
+/// IndexType while that of the replacement operation are of type i32. This is
+/// not suppored in tablegen based pattern specification.
+// TODO(ravishankarm) : These could potentially be templated on the operation
+// being converted, since the same logic should work for linalg.store.
+class StoreOpConversion final : public ConversionPattern {
+public:
+  StoreOpConversion(MLIRContext *context)
+      : ConversionPattern(StoreOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    StoreOpOperandAdaptor storeOperands(operands);
+    auto value = storeOperands.value();
+    auto basePtr = storeOperands.memref();
+    auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
+    if (!ptrType) {
+      return matchFailure();
+    }
+    auto storePtr = rewriter.create<spirv::AccessChainOp>(
+        op->getLoc(), basePtr, storeOperands.indices());
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(op, storePtr, value,
+                                                /*memory_access =*/nullptr,
+                                                /*alignment =*/nullptr);
+    return matchSuccess();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -203,6 +304,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns) {
   populateWithGenerated(context, &patterns);
   // Add the return op conversion.
-  patterns.insert<ReturnToSPIRVConversion>(context);
+  patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>,
+                  IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
+                  ReturnToSPIRVConversion, StoreOpConversion>(context);
 }
 } // namespace mlir
index 4cfd559..b37eee8 100644 (file)
@@ -43,6 +43,7 @@ multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
   }
 }
 
+defm : BinaryOpPattern<AddFOp, SPV_FAddOp>;
 defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
 
 #endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
index fef9c0b..aaa7ed5 100644 (file)
@@ -316,7 +316,7 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
 
 static Type getElementPtrType(Type type, ArrayRef<Value *> indices,
                               Location baseLoc) {
-  if (!indices.size()) {
+  if (indices.empty()) {
     emitError(baseLoc, "'spv.AccessChain' op expected at least "
                        "one index ");
     return nullptr;
@@ -372,6 +372,13 @@ static Type getElementPtrType(Type type, ArrayRef<Value *> indices,
   return spirv::PointerType::get(resultType, resultStorageClass);
 }
 
+void spirv::AccessChainOp::build(Builder *builder, OperationState *state,
+                                 Value *basePtr, ArrayRef<Value *> indices) {
+  auto type = getElementPtrType(basePtr->getType(), indices, state->location);
+  assert(type && "Unable to deduce return type based on basePtr and indices");
+  build(builder, state, type, basePtr, indices);
+}
+
 static ParseResult parseAccessChainOp(OpAsmParser *parser,
                                       OperationState *state) {
   OpAsmParser::OperandType ptrInfo;
diff --git a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir
new file mode 100644 (file)
index 0000000..ce9421e
--- /dev/null
@@ -0,0 +1,113 @@
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+func @builtin() {
+  %c0 = constant 1 : index
+  "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_x} : (index, index, index, index, index, index) -> ()
+  return
+}
+
+// CHECK-LABEL:  spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+func @builtin_workgroup_id_x()
+  attributes {gpu.kernel} {
+  // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]]
+  // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+  // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+  %0 = "gpu.block_id"() {dimension = "x"} : () -> index
+  return
+}
+
+// -----
+
+func @builtin() {
+  %c0 = constant 1 : index
+  "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_y} : (index, index, index, index, index, index) -> ()
+  return
+}
+
+// CHECK-LABEL:  spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+func @builtin_workgroup_id_y()
+  attributes {gpu.kernel} {
+  // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]]
+  // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+  // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}1 : i32{{\]}}
+  %0 = "gpu.block_id"() {dimension = "y"} : () -> index
+  return
+}
+
+// -----
+
+func @builtin() {
+  %c0 = constant 1 : index
+  "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_id_z} : (index, index, index, index, index, index) -> ()
+  return
+}
+
+// CHECK-LABEL:  spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+func @builtin_workgroup_id_z()
+  attributes {gpu.kernel} {
+  // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPID]]
+  // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+  // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}2 : i32{{\]}}
+  %0 = "gpu.block_id"() {dimension = "z"} : () -> index
+  return
+}
+
+// -----
+
+func @builtin() {
+  %c0 = constant 1 : index
+  "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_workgroup_size_x} : (index, index, index, index, index, index) -> ()
+  return
+}
+
+// CHECK-LABEL:  spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize")
+func @builtin_workgroup_size_x()
+  attributes {gpu.kernel} {
+  // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPSIZE]]
+  // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+  // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+  %0 = "gpu.block_dim"() {dimension = "x"} : () -> index
+  return
+}
+
+// -----
+
+func @builtin() {
+  %c0 = constant 1 : index
+  "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_local_id_x} : (index, index, index, index, index, index) -> ()
+  return
+}
+
+// CHECK-LABEL:  spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[LOCALINVOCATIONID:@.*]] built_in("LocalInvocationId")
+func @builtin_local_id_x()
+  attributes {gpu.kernel} {
+  // CHECK: [[ADDRESS:%.*]] = spv._address_of [[LOCALINVOCATIONID]]
+  // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+  // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+  %0 = "gpu.thread_id"() {dimension = "x"} : () -> index
+  return
+}
+
+// -----
+
+func @builtin() {
+  %c0 = constant 1 : index
+  "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = @builtin_num_workgroups_x} : (index, index, index, index, index, index) -> ()
+  return
+}
+
+// CHECK-LABEL:  spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable [[NUMWORKGROUPS:@.*]] built_in("NumWorkgroups")
+func @builtin_num_workgroups_x()
+  attributes {gpu.kernel} {
+  // CHECK: [[ADDRESS:%.*]] = spv._address_of [[NUMWORKGROUPS]]
+  // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]]
+  // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}}
+  %0 = "gpu.grid_dim"() {dimension = "x"} : () -> index
+  return
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/load_store.mlir b/mlir/test/Conversion/GPUToSPIRV/load_store.mlir
new file mode 100644 (file)
index 0000000..cc8ed07
--- /dev/null
@@ -0,0 +1,52 @@
+// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
+  %c0 = constant 0 : index
+  %c12 = constant 12 : index
+  %0 = subi %c12, %c0 : index
+  %c1 = constant 1 : index
+  %c0_0 = constant 0 : index
+  %c4 = constant 4 : index
+  %1 = subi %c4, %c0_0 : index
+  %c1_1 = constant 1 : index
+  %c1_2 = constant 1 : index
+  "gpu.launch_func"(%0, %c1_2, %c1_2, %1, %c1_2, %c1_2, %arg0, %arg1, %arg2, %c0, %c0_0, %c1, %c1_1) {kernel = @load_store_kernel} : (index, index, index, index, index, index, memref<12x4xf32>, memref<12x4xf32>, memref<12x4xf32>, index, index, index, index) -> ()
+  return
+}
+
+// CHECK-LABEL: spv.module "Logical" "VulkanKHR"
+// CHECK: spv.globalVariable {{@.*}} bind(0, 0) : [[TYPE1:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
+// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 1) : [[TYPE2:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
+// CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 2) : [[TYPE3:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
+// CHECK: func @load_store_kernel([[ARG0:%.*]]: [[TYPE1]], [[ARG1:%.*]]: [[TYPE2]], [[ARG2:%.*]]: [[TYPE3]], [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: i32, [[ARG6:%.*]]: i32)
+func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
+  attributes  {gpu.kernel} {
+  %0 = "gpu.block_id"() {dimension = "x"} : () -> index
+  %1 = "gpu.block_id"() {dimension = "y"} : () -> index
+  %2 = "gpu.block_id"() {dimension = "z"} : () -> index
+  %3 = "gpu.thread_id"() {dimension = "x"} : () -> index
+  %4 = "gpu.thread_id"() {dimension = "y"} : () -> index
+  %5 = "gpu.thread_id"() {dimension = "z"} : () -> index
+  %6 = "gpu.grid_dim"() {dimension = "x"} : () -> index
+  %7 = "gpu.grid_dim"() {dimension = "y"} : () -> index
+  %8 = "gpu.grid_dim"() {dimension = "z"} : () -> index
+  %9 = "gpu.block_dim"() {dimension = "x"} : () -> index
+  %10 = "gpu.block_dim"() {dimension = "y"} : () -> index
+  %11 = "gpu.block_dim"() {dimension = "z"} : () -> index
+  // CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], {{%.*}}
+  %12 = addi %arg3, %0 : index
+  // CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], {{%.*}}
+  %13 = addi %arg4, %3 : index
+  // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+  // CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
+  %14 = load %arg0[%12, %13] : memref<12x4xf32>
+  // CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+  // CHECK-NEXT: [[VAL2:%.*]] = spv.Load "StorageBuffer" [[PTR2]]
+  %15 = load %arg1[%12, %13] : memref<12x4xf32>
+  // CHECK: [[VAL3:%.*]] = spv.FAdd [[VAL1]], [[VAL2]]
+  %16 = addf %14, %15 : f32
+  // CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
+  // CHECK-NEXT: spv.Store "StorageBuffer" [[PTR3]], [[VAL3]]
+  store %16, %arg2[%12, %13] : memref<12x4xf32>
+  return
+}
\ No newline at end of file