Introduce attributes that specify the final ABI for a spirv::ModuleOp.
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 25 Nov 2019 18:38:31 +0000 (10:38 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 25 Nov 2019 19:19:56 +0000 (11:19 -0800)
To simplify the lowering into SPIR-V, while still respecting the ABI
requirements of SPIR-V/Vulkan, split the process into two
1) While lowering a function to SPIR-V (when the function is an entry
   point function), allow specifying attributes on arguments and
   function itself that describe the ABI of the function.
2) Add a pass that materializes the ABI described in the function.

Two attributes are needed.
1) Attribute on arguments of the entry point function that describe
   the descriptor_set, binding, storage class, etc, of the
   spv.globalVariable this argument will be replaced by
2) Attribute on function that specifies workgroup size, etc. (for now
   only workgroup size).

Add the pass -spirv-lower-abi-attrs to materialize the ABI described
by the attributes.

This change makes the SPIRVBasicTypeConverter class unnecessary and is
removed, further simplifying the SPIR-V lowering path.

PiperOrigin-RevId: 282387587

21 files changed:
mlir/include/mlir/Dialect/SPIRV/LayoutUtils.h
mlir/include/mlir/Dialect/SPIRV/Passes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td [new file with mode: 0644]
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp [new file with mode: 0644]
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/GPUToSPIRV/simple.mlir
mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir [new file with mode: 0644]

index e21aa43..a65ee41 100644 (file)
@@ -56,10 +56,10 @@ class VulkanLayoutUtils {
 public:
   using Size = uint64_t;
 
-  /// Returns a new type with layout info. Assigns the type size in bytes to the
-  /// `size`. Assigns the type alignment in bytes to the `alignment`.
-  static Type decorateType(spirv::StructType structType, Size &size,
-                           Size &alignment);
+  /// Returns a new StructType with layout info. Assigns the type size in bytes
+  /// to the `size`. Assigns the type alignment in bytes to the `alignment`.
+  static spirv::StructType decorateType(spirv::StructType structType,
+                                        Size &size, Size &alignment);
   /// Checks whether a type is legal in terms of Vulkan layout info
   /// decoration. A type is dynamically illegal if it's a composite type in the
   /// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage
index d245ef0..fe029ff 100644 (file)
 namespace mlir {
 namespace spirv {
 
-// Creates a module pass that converts composite types used by objects in the
-// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
-// classes with layout information.
-//
-// Right now this pass only supports Vulkan layout rules.
+class ModuleOp;
+/// Creates a module pass that converts composite types used by objects in the
+/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
+/// classes with layout information.
+/// Right now this pass only supports Vulkan layout rules.
 std::unique_ptr<OpPassBase<mlir::ModuleOp>>
 createDecorateSPIRVCompositeTypeLayoutPass();
 
+/// Creates a module pass that lowers the ABI attributes specified during SPIR-V
+/// Lowering. Specifically,
+/// 1) Creates the global variables for arguments of entry point function using
+/// the specification in the ABI attributes for each argument.
+/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point
+/// functions using the specification in the EntryPointAttr.
+std::unique_ptr<OpPassBase<spirv::ModuleOp>> createLowerABIAttributesPass();
+
 } // namespace spirv
 } // namespace mlir
 
index 417d0e1..93bcc4a 100644 (file)
 #define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
 
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/Support/StringExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SetVector.h"
 
 namespace mlir {
 
-/// Type conversion from Standard Types to SPIR-V Types.
-class SPIRVBasicTypeConverter : public TypeConverter {
-public:
-  /// Converts types to SPIR-V supported types.
-  virtual Type convertType(Type t);
-};
-
 /// Converts a function type according to the requirements of a SPIR-V entry
 /// function. The arguments need to be converted to spv.GlobalVariables of
 /// spv.ptr types so that they could be bound by the runtime.
 class SPIRVTypeConverter final : public TypeConverter {
 public:
-  explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
-      : basicTypeConverter(basicTypeConverter) {}
+  using TypeConverter::TypeConverter;
 
   /// Converts types to SPIR-V types using the basic type converter.
   Type convertType(Type t) override;
-
-  /// Gets the basic type converter.
-  Type convertBasicType(Type t) { return basicTypeConverter->convertType(t); }
-
-private:
-  SPIRVBasicTypeConverter *basicTypeConverter;
 };
 
 /// Base class to define a conversion pattern to translate Ops into SPIR-V.
@@ -70,6 +57,8 @@ protected:
 private:
 };
 
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc"
+
 namespace spirv {
 /// Returns a value that represents a builtin variable value within the SPIR-V
 /// module.
@@ -77,14 +66,26 @@ Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
                                OpBuilder &builder);
 
 /// Legalizes a function as an entry function.
-LogicalResult lowerAsEntryFunction(FuncOp funcOp,
-                                   SPIRVTypeConverter *typeConverter,
-                                   ConversionPatternRewriter &rewriter,
-                                   FuncOp &newFuncOp);
-
-/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
-/// spv.ExecutionMode ops.
-LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
+FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter,
+                            ConversionPatternRewriter &rewriter,
+                            ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
+                            spirv::EntryPointABIAttr entryPointInfo);
+
+/// Attribute name for specifying argument ABI information.
+StringRef getInterfaceVarABIAttrName();
+
+/// Get the InterfaceVarABIAttr given its fields.
+InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet,
+                                           unsigned binding,
+                                           spirv::StorageClass storageClass,
+                                           MLIRContext *context);
+
+/// Attribute name for specifying entry point information.
+StringRef getEntryPointABIAttrName();
+
+/// Get the EntryPointABIAttr given its fields.
+EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
+                                       MLIRContext *context);
 
 } // namespace spirv
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.td
new file mode 100644 (file)
index 0000000..d9cf0a7
--- /dev/null
@@ -0,0 +1,55 @@
+//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the base file for supporting lowering to SPIR-V dialect. This
+// file defines SPIR-V attributes used for specifying the shader
+// interface or ABI. This is because SPIR-V module is expected to work in
+// an execution environment as specified by a client API. A SPIR-V module
+// needs to "link" correctly with the execution environment regarding the
+// resources that are used in the SPIR-V module and get populated with
+// data via the client API. The shader interface (or ABI) is passed into
+// SPIR-V lowering path via attributes defined in this file. A
+// compilation flow targeting SPIR-V is expected to attach such
+// attributes to resources and other suitable places.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_LOWERING
+#define SPIRV_LOWERING
+
+include "mlir/Dialect/SPIRV/SPIRVBase.td"
+
+// For arguments that eventually map to spv.globalVariable for the
+// shader interface, this attribute specifies the information regarding
+// the global variable :
+// 1) Descriptor Set.
+// 2) Binding number.
+// 3) Storage class.
+def SPV_InterfaceVarABIAttr:
+    StructAttr<"InterfaceVarABIAttr", SPV_Dialect,
+               [StructFieldAttr<"descriptor_set", I32Attr>,
+                StructFieldAttr<"binding", I32Attr>,
+                StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>;
+
+// For entry functions, this attribute specifies information related to entry
+// points in the generated SPIR-V module:
+// 1) WorkGroup Size.
+def SPV_EntryPointABIAttr:
+    StructAttr<"EntryPointABIAttr", SPV_Dialect,
+               [StructFieldAttr<"local_size", I32ElementsAttr>]>;
+
+#endif // SPIRV_LOWERING
index e6b4dac..1ec825a 100644 (file)
@@ -243,7 +243,10 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
       "TypeAttr type, ArrayRef<NamedAttribute> namedAttrs", [{
       state.addAttribute("type", type);
       state.addAttributes(namedAttrs);
-    }]>
+    }]>,
+    OpBuilder<[{Builder *builder, OperationState &state,
+                Type type, StringRef name, unsigned descriptorSet,
+                unsigned binding}]>
   ];
 
   let results = (outs);
index 136c836..bfe80e0 100644 (file)
@@ -1230,6 +1230,10 @@ class ArrayMinCount<int n> : AttrConstraint<
     CPred<"$_self.cast<ArrayAttr>().size() >= " # n>,
     "with at least " # n # " elements">;
 
+class ArrayCount<int n> : AttrConstraint<
+    CPred<"$_self.cast<ArrayAttr>().size() == " #n>,
+    "with exactly " # n # " elements">;
+
 class IntArrayNthElemEq<int index, int value> : AttrConstraint<
     And<[
       CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
index bf3fda4..23e7b91 100644 (file)
@@ -163,15 +163,26 @@ PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
 PatternMatchResult
 KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
                                     ConversionPatternRewriter &rewriter) const {
-  FuncOp newFuncOp;
   if (!gpu::GPUDialect::isKernel(funcOp)) {
     return matchFailure();
   }
 
-  if (failed(spirv::lowerAsEntryFunction(funcOp, &typeConverter, rewriter,
-                                         newFuncOp))) {
+  SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
+  for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+    argABI.push_back(spirv::getInterfaceVarABIAttr(
+        0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext()));
+  }
+  // TODO(ravishankarm) : For now set this to {32, 1, 1}. This is incorrect. The
+  // actual workgroup size needs to be plumbed through.
+  auto context = rewriter.getContext();
+  auto entryPointAttr = spirv::getEntryPointABIAttr({32, 1, 1}, context);
+  FuncOp newFuncOp = spirv::lowerAsEntryFunction(
+      funcOp, typeConverter, rewriter, argABI, entryPointAttr);
+  if (!newFuncOp) {
     return matchFailure();
   }
+  newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
+                                       rewriter.getContext()));
   return matchSuccess();
 }
 
index 5fa5106..df38068 100644 (file)
@@ -54,7 +54,7 @@ void GPUToSPIRVPass::runOnModule() {
     if (!gpu::GPUDialect::isKernel(funcOp)) {
       return;
     }
-    OpBuilder builder(module.getBodyRegion());
+    OpBuilder builder(funcOp.getOperation());
     // Create a new spirv::ModuleOp for this function, and clone the
     // function into it.
     // TODO : Generalize this to account for different extensions,
@@ -77,45 +77,20 @@ void GPUToSPIRVPass::runOnModule() {
   });
 
   /// Dialect conversion to lower the functions with the spirv::ModuleOps.
-  SPIRVBasicTypeConverter basicTypeConverter;
-  SPIRVTypeConverter typeConverter(&basicTypeConverter);
+  SPIRVTypeConverter typeConverter;
   OwningRewritePatternList patterns;
   populateGPUToSPIRVPatterns(context, typeConverter, patterns);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
   ConversionTarget target(*context);
   target.addLegalDialect<spirv::SPIRVDialect>();
-  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
-    // TODO(ravishankarm) : Currently lowering does not support handling
-    // function conversion of non-kernel functions. This is to be added.
-
-    // For kernel functions, verify that the signature is void(void).
-    return gpu::GPUDialect::isKernel(op) && op.getNumResults() == 0 &&
-           op.getNumArguments() == 0;
-  });
+  target.addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
 
   if (failed(applyFullConversion(spirvModules, target, patterns,
                                  &typeConverter))) {
     return signalPassFailure();
   }
-
-  // After the SPIR-V modules have been generated, some finalization is needed
-  // for the entry functions. For example, adding spv.EntryPoint op,
-  // spv.ExecutionMode op, etc.
-  for (auto *spvModule : spirvModules) {
-    for (auto op :
-         cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
-      if (gpu::GPUDialect::isKernel(op)) {
-        OpBuilder builder(op.getContext());
-        builder.setInsertionPointAfter(op);
-        if (failed(spirv::finalizeEntryFunction(op, builder))) {
-          return signalPassFailure();
-        }
-        op.getOperation()->removeAttr(Identifier::get(
-            gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
-      }
-    }
-  }
 }
 
 OpPassBase<ModuleOp> *createConvertGPUToSPIRVPass() {
index 74d1352..2157f5a 100644 (file)
@@ -63,7 +63,7 @@ public:
       return matchFailure();
     }
     auto spirvConstType =
-        typeConverter.convertBasicType(constIndexOp.getResult()->getType());
+        typeConverter.convertType(constIndexOp.getResult()->getType());
     auto spirvConstVal =
         rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
@@ -120,7 +120,7 @@ public:
   matchAndRewrite(StdOp operation, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     auto resultType =
-        this->typeConverter.convertBasicType(operation.getResult()->getType());
+        this->typeConverter.convertType(operation.getResult()->getType());
     rewriter.template replaceOpWithNewOp<SPIRVOp>(
         operation, resultType, operands, ArrayRef<NamedAttribute>());
     return this->matchSuccess();
index 9b2b2d3..7133a33 100644 (file)
@@ -40,8 +40,7 @@ void ConvertStandardToSPIRVPass::runOnModule() {
   OwningRewritePatternList patterns;
   auto module = getModule();
 
-  SPIRVBasicTypeConverter basicTypeConverter;
-  SPIRVTypeConverter typeConverter(&basicTypeConverter);
+  SPIRVTypeConverter typeConverter;
   populateStandardToSPIRVPatterns(module.getContext(), typeConverter, patterns);
   ConversionTarget target(*(module.getContext()));
   target.addLegalDialect<spirv::SPIRVDialect>();
index eee01f1..e2d5332 100644 (file)
 
 using namespace mlir;
 
-Type VulkanLayoutUtils::decorateType(spirv::StructType structType,
-                                     VulkanLayoutUtils::Size &size,
-                                     VulkanLayoutUtils::Size &alignment) {
+spirv::StructType
+VulkanLayoutUtils::decorateType(spirv::StructType structType,
+                                VulkanLayoutUtils::Size &size,
+                                VulkanLayoutUtils::Size &alignment) {
   if (structType.getNumElements() == 0) {
     return structType;
   }
index 01a4733..1f736fe 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
-
 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "llvm/ADT/Sequence.h"
 
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
+// Attributes for ABI
+//===----------------------------------------------------------------------===//
+
+// Pull in the attributes needed for lowering.
+namespace mlir {
+#include "mlir/Dialect/SPIRV/SPIRVLowering.cpp.inc"
+}
+
+StringRef mlir::spirv::getInterfaceVarABIAttrName() {
+  return "spirv.interface_var_abi";
+}
+
+mlir::spirv::InterfaceVarABIAttr
+mlir::spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
+                                    spirv::StorageClass storageClass,
+                                    MLIRContext *context) {
+  Type i32Type = IntegerType::get(32, context);
+  return mlir::spirv::InterfaceVarABIAttr::get(
+      IntegerAttr::get(i32Type, descriptorSet),
+      IntegerAttr::get(i32Type, binding),
+      IntegerAttr::get(i32Type, static_cast<int64_t>(storageClass)), context);
+}
+
+StringRef mlir::spirv::getEntryPointABIAttrName() {
+  return "spirv.entry_point_abi";
+}
+
+mlir::spirv::EntryPointABIAttr
+mlir::spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize,
+                                  MLIRContext *context) {
+  assert(localSize.size() == 3);
+  return mlir::spirv::EntryPointABIAttr::get(
+      DenseElementsAttr::get<int32_t>(
+          VectorType::get(3, IntegerType::get(32, context)), localSize)
+          .cast<DenseIntElementsAttr>(),
+      context);
+}
+
+//===----------------------------------------------------------------------===//
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
@@ -41,68 +80,35 @@ Type convertIndexType(MLIRContext *context) {
   return IntegerType::get(32, context);
 }
 
-Type convertIndexType(IndexType t) { return convertIndexType(t.getContext()); }
-
-Type basicTypeConversion(Type t) {
+Type typeConversionImpl(Type t) {
   // Check if the type is SPIR-V supported. If so return the type.
   if (spirv::SPIRVDialect::isValidType(t)) {
     return t;
   }
 
   if (auto indexType = t.dyn_cast<IndexType>()) {
-    return convertIndexType(indexType);
+    return convertIndexType(t.getContext());
   }
 
   if (auto memRefType = t.dyn_cast<MemRefType>()) {
     auto elementType = memRefType.getElementType();
-    if (memRefType.hasStaticShape()) {
+    // TODO(ravishankarm) : Handle dynamic shapes and memref with strides.
+    if (memRefType.hasStaticShape() && memRefType.getAffineMaps().empty()) {
       // Convert to a multi-dimensional spv.array if size is known.
       for (auto size : reverse(memRefType.getShape())) {
         elementType = spirv::ArrayType::get(elementType, size);
       }
+      // For now initialize the storage class to StorageBuffer. This will be
+      // updated later based on whats passed in w.r.t to the ABI attributes.
       return spirv::PointerType::get(elementType,
                                      spirv::StorageClass::StorageBuffer);
-    } else {
-      // Vulkan SPIR-V validation rules require runtime array type to be the
-      // last member of a struct.
-      return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
-                                     spirv::StorageClass::StorageBuffer);
     }
   }
   return Type();
 }
-
-Type getLayoutDecoratedType(spirv::StructType type) {
-  VulkanLayoutUtils::Size size = 0, alignment = 0;
-  return VulkanLayoutUtils::decorateType(type, size, alignment);
-}
-
-/// Generates the type of variable given the type of object.
-static Type getGlobalVarTypeForEntryFnArg(Type t) {
-  auto convertedType = basicTypeConversion(t);
-  if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
-    if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
-      return spirv::PointerType::get(
-          getLayoutDecoratedType(
-              spirv::StructType::get(ptrType.getPointeeType())),
-          ptrType.getStorageClass());
-    }
-  } else {
-    return spirv::PointerType::get(
-        getLayoutDecoratedType(spirv::StructType::get(convertedType)),
-        spirv::StorageClass::StorageBuffer);
-  }
-  return convertedType;
-}
 } // namespace
 
-Type SPIRVBasicTypeConverter::convertType(Type t) {
-  return basicTypeConversion(t);
-}
-
-Type SPIRVTypeConverter::convertType(Type t) {
-  return getGlobalVarTypeForEntryFnArg(t);
-}
+Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); }
 
 //===----------------------------------------------------------------------===//
 // Builtin Variables
@@ -193,148 +199,44 @@ Value *mlir::spirv::getBuiltinVariableValue(Operation *op,
 // Entry Function signature Conversion
 //===----------------------------------------------------------------------===//
 
-namespace {
-/// Computes the replacement value for an argument of an entry function. It
-/// allocates a global variable for this argument and adds statements in the
-/// entry block to get a replacement value within function scope.
-Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
-                                           size_t origArgNum, Value *origArg) {
-  // Create a global variable for this argument.
-  auto insertionOp = rewriter.getInsertionBlock()->getParent();
-  auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
-  if (!module) {
-    return nullptr;
-  }
-  auto funcOp = insertionOp->getParentOfType<FuncOp>();
-  spirv::GlobalVariableOp var;
-  {
-    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
-    rewriter.setInsertionPoint(funcOp.getOperation());
-    std::string varName =
-        funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
-    var = rewriter.create<spirv::GlobalVariableOp>(
-        funcOp.getLoc(),
-        TypeAttr::get(getGlobalVarTypeForEntryFnArg(origArg->getType())),
-        rewriter.getStringAttr(varName), nullptr);
-    var.setAttr(
-        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
-        rewriter.getI32IntegerAttr(0));
-    var.setAttr(
-        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
-        rewriter.getI32IntegerAttr(origArgNum));
-  }
-  // Insert the addressOf and load instructions, to get back the converted value
-  // type.
-  auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
-  auto indexType = convertIndexType(funcOp.getContext());
-  auto zero = rewriter.create<spirv::ConstantOp>(
-      funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
-  auto accessChain = rewriter.create<spirv::AccessChainOp>(
-      funcOp.getLoc(), addressOf.pointer(), zero.constant());
-  // If the original argument is a tensor/memref type, the value is not
-  // loaded. Instead the pointer value is returned to allow its use in access
-  // chain ops.
-  auto origArgType = origArg->getType();
-  if (origArgType.isa<MemRefType>()) {
-    return accessChain;
-  }
-  return rewriter.create<spirv::LoadOp>(
-      funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
-      /*alignment=*/nullptr);
-}
-
-FuncOp applySignatureConversion(
-    FuncOp funcOp, ConversionPatternRewriter &rewriter,
-    TypeConverter::SignatureConversion &signatureConverter) {
-  // Create a new function with an updated signature.
-  auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
-  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
-                              newFuncOp.end());
-  newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
-                                      llvm::None, funcOp.getContext()));
-
-  // Tell the rewriter to convert the region signature.
-  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
-  rewriter.replaceOp(funcOp.getOperation(), llvm::None);
-  return newFuncOp;
-}
-
-/// Gets the global variables that need to be specified as interface variable
-/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
-LogicalResult getInterfaceVariables(FuncOp funcOp,
-                                    SmallVectorImpl<Attribute> &interfaceVars) {
-  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
-  if (!module) {
-    return failure();
-  }
-  llvm::SetVector<Operation *> interfaceVarSet;
-  for (auto &block : funcOp) {
-    // TODO(ravishankarm) : This should in reality traverse the entry function
-    // call graph and collect all the interfaces. For now, just traverse the
-    // instructions in this function.
-    for (auto op : block.getOps<spirv::AddressOfOp>()) {
-      auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
-      if (var.type().cast<spirv::PointerType>().getStorageClass() ==
-          spirv::StorageClass::StorageBuffer) {
-        continue;
-      }
-      interfaceVarSet.insert(var.getOperation());
-    }
-  }
-  for (auto &var : interfaceVarSet) {
-    interfaceVars.push_back(SymbolRefAttr::get(
-        cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
-  }
-  return success();
-}
-} // namespace
-
-LogicalResult mlir::spirv::lowerAsEntryFunction(
-    FuncOp funcOp, SPIRVTypeConverter *typeConverter,
-    ConversionPatternRewriter &rewriter, FuncOp &newFuncOp) {
+FuncOp mlir::spirv::lowerAsEntryFunction(
+    FuncOp funcOp, SPIRVTypeConverter &typeConverter,
+    ConversionPatternRewriter &rewriter,
+    ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
+    spirv::EntryPointABIAttr entryPointInfo) {
   auto fnType = funcOp.getType();
   if (fnType.getNumResults()) {
-    return funcOp.emitError("SPIR-V lowering only supports functions with no "
-                            "return values right now");
+    funcOp.emitError("SPIR-V lowering only supports entry functions"
+                     "with no return values right now");
+    return nullptr;
+  }
+  if (fnType.getNumInputs() != argABIInfo.size()) {
+    funcOp.emitError(
+        "lowering as entry functions requires ABI info for all arguments");
+    return nullptr;
   }
   // For entry functions need to make the signature void(void). Compute the
   // replacement value for all arguments and replace all uses.
   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
   {
-    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
-    rewriter.setInsertionPointToStart(&funcOp.front());
-    for (auto origArg : enumerate(funcOp.getArguments())) {
-      auto replacement = createAndLoadGlobalVarForEntryFnArg(
-          rewriter, origArg.index(), origArg.value());
-      signatureConverter.remapInput(origArg.index(), replacement);
+    for (auto argType : enumerate(funcOp.getType().getInputs())) {
+      auto convertedType = typeConverter.convertType(argType.value());
+      signatureConverter.addInputs(argType.index(), convertedType);
     }
   }
-  newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
-  return success();
-}
+  auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+  newFuncOp.setType(rewriter.getFunctionType(
+      signatureConverter.getConvertedTypes(), llvm::None));
+  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+  rewriter.replaceOp(funcOp.getOperation(), llvm::None);
 
-LogicalResult mlir::spirv::finalizeEntryFunction(FuncOp newFuncOp,
-                                                 OpBuilder &builder) {
-  // Add the spv.EntryPointOp after collecting all the interface variables
-  // needed.
-  SmallVector<Attribute, 1> interfaceVars;
-  if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
-    return failure();
+  // Set the attributes for argument and the function.
+  StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
+  for (auto argIndex : llvm::seq<unsigned>(0, newFuncOp.getNumArguments())) {
+    newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
   }
-  builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
-                                      spirv::ExecutionModel::GLCompute,
-                                      newFuncOp, interfaceVars);
-  // Specify the spv.ExecutionModeOp.
-
-  /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
-  /// LocalSize execution mode or an object decorated with the WorkgroupSize
-  /// decoration must be specified." Better approach is to use the
-  /// WorkgroupSize GlobalVariable with initializer being a specialization
-  /// constant. But current support for specialization constant does not allow
-  /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
-  /// 1} for now. To be fixed ASAP.
-  builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
-                                         spirv::ExecutionMode::LocalSize,
-                                         ArrayRef<int32_t>{1, 1, 1});
-  return success();
+  newFuncOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
+  return newFuncOp;
 }
index 55199d3..6bb052d 100644 (file)
@@ -1421,6 +1421,19 @@ Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
 // spv.globalVariable
 //===----------------------------------------------------------------------===//
 
+void spirv::GlobalVariableOp::build(Builder *builder, OperationState &state,
+                                    Type type, StringRef name,
+                                    unsigned descriptorSet, unsigned binding) {
+  build(builder, state, TypeAttr::get(type), builder->getStringAttr(name),
+        nullptr);
+  state.addAttribute(
+      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
+      builder->getI32IntegerAttr(descriptorSet));
+  state.addAttribute(
+      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
+      builder->getI32IntegerAttr(binding));
+}
+
 static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
                                          OperationState &state) {
   // Parse variable name.
index b4f49b7..15621aa 100644 (file)
@@ -20,6 +20,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
 #include "mlir/IR/StandardTypes.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
index 42be4f8..8316d54 100644 (file)
@@ -1,5 +1,6 @@
 add_llvm_library(MLIRSPIRVTransforms
   DecorateSPIRVCompositeTypeLayoutPass.cpp
+  LowerABIAttributesPass.cpp
   )
 
 target_link_libraries(MLIRSPIRVTransforms
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
new file mode 100644 (file)
index 0000000..203d793
--- /dev/null
@@ -0,0 +1,264 @@
+//===- LowerABIAttributesPass.cpp - Decorate composite type ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to lower attributes that specify the shader ABI
+// for the functions in the generated SPIR-V module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+/// Checks if the `type` is a scalar or vector type. It is assumed that they are
+/// valid for SPIR-V dialect already.
+static bool isScalarOrVectorType(Type type) {
+  return spirv::SPIRVDialect::isValidScalarType(type) || type.isa<VectorType>();
+}
+
+/// Creates a global variable for an argument based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVariableForArg(FuncOp funcOp, OpBuilder &builder, unsigned argNum,
+                           spirv::InterfaceVarABIAttr abiInfo) {
+  auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
+  if (!spirvModule) {
+    return nullptr;
+  }
+  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+  builder.setInsertionPoint(funcOp.getOperation());
+  std::string varName =
+      funcOp.getName().str() + "_arg_" + std::to_string(argNum);
+
+  // Get the type of variable. If this is a scalar/vector type and has an ABI
+  // info create a variable of type !spv.ptr<!spv.struct<elementTYpe>>. If not
+  // it must already be a !spv.ptr<!spv.struct<...>>.
+  auto varType = funcOp.getType().getInput(argNum);
+  auto storageClass =
+      static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
+  if (isScalarOrVectorType(varType)) {
+    varType =
+        spirv::PointerType::get(spirv::StructType::get(varType), storageClass);
+  } else {
+    auto varPtrType = varType.cast<spirv::PointerType>();
+    varType = spirv::PointerType::get(
+        spirv::StructType::get(varPtrType.getPointeeType()), storageClass);
+  }
+  auto varPtrType = varType.cast<spirv::PointerType>();
+  auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
+
+  // Set the offset information.
+  VulkanLayoutUtils::Size size = 0, alignment = 0;
+  varPointeeType =
+      VulkanLayoutUtils::decorateType(varPointeeType, size, alignment)
+          .cast<spirv::StructType>();
+  varType =
+      spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
+
+  return builder.create<spirv::GlobalVariableOp>(
+      funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(),
+      abiInfo.binding().getInt());
+}
+
+/// Gets the global variables that need to be specified as interface variable
+/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
+static LogicalResult
+getInterfaceVariables(FuncOp funcOp,
+                      SmallVectorImpl<Attribute> &interfaceVars) {
+  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+  if (!module) {
+    return failure();
+  }
+  llvm::SetVector<Operation *> interfaceVarSet;
+
+  // TODO(ravishankarm) : This should in reality traverse the entry function
+  // call graph and collect all the interfaces. For now, just traverse the
+  // instructions in this function.
+  funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
+    auto var =
+        module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
+    if (var.type().cast<spirv::PointerType>().getStorageClass() !=
+        spirv::StorageClass::StorageBuffer) {
+      interfaceVarSet.insert(var.getOperation());
+    }
+  });
+  for (auto &var : interfaceVarSet) {
+    interfaceVars.push_back(SymbolRefAttr::get(
+        cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+  }
+  return success();
+}
+
+/// Lowers the entry point attribute.
+static LogicalResult lowerEntryPointABIAttr(FuncOp funcOp, OpBuilder &builder) {
+  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
+  auto entryPointAttr =
+      funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName);
+  if (!entryPointAttr) {
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+  auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
+  builder.setInsertionPoint(spirvModule.body().front().getTerminator());
+
+  // Adds the spv.EntryPointOp after collecting all the interface variables
+  // needed.
+  SmallVector<Attribute, 1> interfaceVars;
+  if (failed(getInterfaceVariables(funcOp, interfaceVars))) {
+    return failure();
+  }
+  builder.create<spirv::EntryPointOp>(
+      funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars);
+  // Specifies the spv.ExecutionModeOp.
+  auto localSizeAttr = entryPointAttr.local_size();
+  SmallVector<int32_t, 3> localSize(localSizeAttr.getValues<int32_t>());
+  builder.create<spirv::ExecutionModeOp>(
+      funcOp.getLoc(), funcOp, spirv::ExecutionMode::LocalSize, localSize);
+  funcOp.removeAttr(entryPointAttrName);
+  return success();
+}
+
+namespace {
+/// Pattern rewriter for changing function signature to match the ABI specified
+/// in attributes.
+class FuncOpLowering final : public SPIRVOpLowering<FuncOp> {
+public:
+  using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+  PatternMatchResult
+  matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// Pass to implement the ABI information specified as attributes.
+class LowerABIAttributesPass final
+    : public OperationPass<LowerABIAttributesPass, spirv::ModuleOp> {
+private:
+  void runOnOperation() override;
+};
+} // namespace
+
+PatternMatchResult
+FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
+                                ConversionPatternRewriter &rewriter) const {
+  if (!funcOp.getAttrOfType<spirv::EntryPointABIAttr>(
+          spirv::getEntryPointABIAttrName())) {
+    // TODO(ravishankarm) : Non-entry point functions are not handled.
+    return matchFailure();
+  }
+  TypeConverter::SignatureConversion signatureConverter(
+      funcOp.getType().getNumInputs());
+
+  auto attrName = spirv::getInterfaceVarABIAttrName();
+  for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) {
+    auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+        argType.index(), attrName);
+    if (!abiInfo) {
+      // TODO(ravishankarm) : For non-entry point functions, it should be legal
+      // to pass around scalar/vector values and return a scalar/vector. For now
+      // non-entry point functions are not handled in this ABI lowering and will
+      // produce an error.
+      return matchFailure();
+    }
+    auto var =
+        createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
+
+    OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
+    rewriter.setInsertionPointToStart(&funcOp.front());
+    // Inserts spirv::AddressOf and spirv::AccessChain operations.
+    auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+    auto indexType =
+        typeConverter.convertType(IndexType::get(funcOp.getContext()));
+    auto zero = rewriter.create<spirv::ConstantOp>(
+        funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+    Value *replacement = rewriter.create<spirv::AccessChainOp>(
+        funcOp.getLoc(), addressOf.pointer(), zero.constant());
+    // Check if the arg is a scalar or vector type. In that case, the value
+    // needs to be loaded into registers.
+    // TODO(ravishankarm) : This is loading value of the scalar into registers
+    // at the start of the function. It is probably better to do the load just
+    // before the use. There might be multiple loads and currently there is no
+    // easy way to replace all uses with a sequence of operations.
+    if (isScalarOrVectorType(argType.value())) {
+      replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), replacement,
+                                                   /*memory_access=*/nullptr,
+                                                   /*alignment=*/nullptr);
+    }
+    signatureConverter.remapInput(argType.index(), replacement);
+  }
+
+  // Creates a new function with the update signature.
+  auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+  newFuncOp.setType(rewriter.getFunctionType(
+      signatureConverter.getConvertedTypes(), llvm::None));
+  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+  rewriter.eraseOp(funcOp.getOperation());
+  return matchSuccess();
+}
+
+void LowerABIAttributesPass::runOnOperation() {
+  // Uses the signature conversion methodology of the dialect conversion
+  // framework to implement the conversion.
+  spirv::ModuleOp module = getOperation();
+  MLIRContext *context = &getContext();
+
+  SPIRVTypeConverter typeConverter;
+  OwningRewritePatternList patterns;
+  patterns.insert<FuncOpLowering>(context, typeConverter);
+
+  ConversionTarget target(*context);
+  target.addLegalDialect<spirv::SPIRVDialect>();
+  auto entryPointAttrName = spirv::getEntryPointABIAttrName();
+  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+    return op.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName) &&
+           op.getNumResults() == 0 && op.getNumArguments() == 0;
+  });
+  target.addLegalOp<ReturnOp>();
+  if (failed(
+          applyPartialConversion(module, target, patterns, &typeConverter))) {
+    return signalPassFailure();
+  }
+
+  // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
+  // attributes.
+  OpBuilder builder(context);
+  SmallVector<FuncOp, 1> entryPointFns;
+  module.walk([&](FuncOp funcOp) {
+    if (funcOp.getAttrOfType<spirv::EntryPointABIAttr>(entryPointAttrName)) {
+      entryPointFns.push_back(funcOp);
+    }
+  });
+  for (auto fn : entryPointFns) {
+    if (failed(lowerEntryPointABIAttr(fn, builder))) {
+      return signalPassFailure();
+    }
+  }
+}
+
+std::unique_ptr<OpPassBase<spirv::ModuleOp>>
+mlir::spirv::createLowerABIAttributesPass() {
+  return std::make_unique<LowerABIAttributesPass>();
+}
+
+static PassRegistration<LowerABIAttributesPass>
+    pass("spirv-lower-abi-attrs", "Lower SPIR-V ABI Attributes");
index 7931932..b1feea6 100644 (file)
@@ -259,6 +259,13 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
 
       // Handle the case of a 1->0 value mapping.
       if (!argInfo) {
+        // If a replacement value was given for this argument, use that to
+        // replace all uses.
+        auto argReplacementValue = mapping.lookupOrDefault(origArg);
+        if (argReplacementValue != origArg) {
+          origArg->replaceAllUsesWith(argReplacementValue);
+          continue;
+        }
         // If there are any dangling uses then replace the argument with one
         // generated by the type converter. This is necessary as the cast must
         // persist in the IR after conversion.
index daa975b..6657f77 100644 (file)
@@ -21,41 +21,16 @@ module attributes {gpu.container_module} {
     // CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
-    // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
-    // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
-    // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
-    // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
-    // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
-    // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
-    // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
-    // CHECK: func [[FN:@.*]]()
+    // CHECK-LABEL:    func @load_store_kernel
+    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG3:%.*]]: i32 {spirv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG4:%.*]]: i32 {spirv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG5:%.*]]: i32 {spirv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG6:%.*]]: i32 {spirv.interface_var_abi = {binding = 6 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
       attributes  {gpu.kernel} {
-      // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
-      // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG0:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
-      // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
-      // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
-      // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
-      // CHECK: [[CONST2:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG2:%.*]] = spv.AccessChain [[ADDRESSARG2]]{{\[}}[[CONST2]]
-      // CHECK: [[ADDRESSARG3:%.*]] = spv._address_of [[VAR3]]
-      // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
-      // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
-      // CHECK: [[ADDRESSARG4:%.*]] = spv._address_of [[VAR4]]
-      // CHECK: [[CONST4:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG4PTR:%.*]] = spv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
-      // CHECK: [[ARG4:%.*]] = spv.Load "StorageBuffer" [[ARG4PTR]]
-      // CHECK: [[ADDRESSARG5:%.*]] = spv._address_of [[VAR5]]
-      // CHECK: [[CONST5:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG5PTR:%.*]] = spv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
-      // CHECK: [[ARG5:%.*]] = spv.Load "StorageBuffer" [[ARG5PTR]]
-      // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
-      // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
-      // CHECK: [[ARG6:%.*]] = spv.Load "StorageBuffer" [[ARG6PTR]]
       // CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
       // CHECK: [[WORKGROUPID:%.*]] = spv.Load "Input" [[ADDRESSWORKGROUPID]]
       // CHECK: [[WORKGROUPIDX:%.*]] = spv.CompositeExtract [[WORKGROUPID]]{{\[}}0 : i32{{\]}}
index 61bc6ea..771680b 100644 (file)
@@ -4,21 +4,12 @@ module attributes {gpu.container_module} {
 
   module @kernels attributes {gpu.kernel_module} {
     // CHECK:       spv.module "Logical" "GLSL450" {
-    // CHECK-DAG:    spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
-    // CHECK-DAG:    spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
-    // CHECK:    func [[FN:@.*]]()
+    // CHECK-LABEL: func @kernel_1
+    // CHECK-SAME: {{%.*}}: f32 {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.array<12 x f32>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
         attributes { gpu.kernel } {
-      // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
-      // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG0PTR:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
-      // CHECK: [[ARG0:%.*]] = spv.Load "StorageBuffer" [[ARG0PTR]]
-      // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
-      // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
-      // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
-      // CHECK-NEXT: spv.Return
-      // CHECK: spv.EntryPoint "GLCompute" [[FN]]
-      // CHECK: spv.ExecutionMode [[FN]] "LocalSize"
+      // CHECK: spv.Return
       return
     }
   }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
new file mode 100644 (file)
index 0000000..84a3db0
--- /dev/null
@@ -0,0 +1,129 @@
+// RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
+
+// CHECK-LABEL: spv.module
+spv.module "Logical" "GLSL450" {
+  // CHECK-DAG: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize")
+  spv.globalVariable @__builtin_var_WorkgroupSize__ built_in("WorkgroupSize") : !spv.ptr<vector<3xi32>, Input>
+  // CHECK-DAG: spv.globalVariable [[NUMWORKGROUPS:@.*]] built_in("NumWorkgroups")
+  spv.globalVariable @__builtin_var_NumWorkgroups__ built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
+  // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONID:@.*]] built_in("LocalInvocationId")
+  spv.globalVariable @__builtin_var_LocalInvocationId__ built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+  // CHECK-DAG: spv.globalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId")
+  spv.globalVariable @__builtin_var_WorkgroupId__ built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
+  // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
+  // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
+  // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer>
+  // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+  // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+  // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+  // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+  // CHECK: func [[FN:@.*]]()
+  func @load_store_kernel(%arg0: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+                          {spirv.interface_var_abi = {binding = 0 : i32,
+                                                      descriptor_set = 0 : i32,
+                                                      storage_class = 12 : i32}},
+                          %arg1: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+                          {spirv.interface_var_abi = {binding = 1 : i32,
+                                                      descriptor_set = 0 : i32,
+                                                      storage_class = 12 : i32}},
+                          %arg2: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+                          {spirv.interface_var_abi = {binding = 2 : i32,
+                                                      descriptor_set = 0 : i32,
+                                                      storage_class = 12 : i32}},
+                          %arg3: i32
+                          {spirv.interface_var_abi = {binding = 3 : i32,
+                                                      descriptor_set = 0 : i32,
+                                                      storage_class = 12 : i32}},
+                          %arg4: i32
+                          {spirv.interface_var_abi = {binding = 4 : i32,
+                                                      descriptor_set = 0 : i32,
+                                                      storage_class = 12 : i32}},
+                          %arg5: i32
+                          {spirv.interface_var_abi = {binding = 5 : i32,
+                                                      descriptor_set = 0 : i32,
+                                                      storage_class = 12 : i32}},
+                          %arg6: i32
+                          {spirv.interface_var_abi = {binding = 6 : i32,
+                                                      descriptor_set = 0 : i32,
+                                                      storage_class = 12 : i32}})
+  attributes  {spirv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+    // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
+    // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
+    // CHECK: {{%.*}} = spv.Load "StorageBuffer" [[ARG6PTR]]
+    // CHECK: [[ADDRESSARG5:%.*]] = spv._address_of [[VAR5]]
+    // CHECK: [[CONST5:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG5PTR:%.*]] = spv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
+    // CHECK: {{%.*}} = spv.Load "StorageBuffer" [[ARG5PTR]]
+    // CHECK: [[ADDRESSARG4:%.*]] = spv._address_of [[VAR4]]
+    // CHECK: [[CONST4:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG4PTR:%.*]] = spv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
+    // CHECK: [[ARG4:%.*]] = spv.Load "StorageBuffer" [[ARG4PTR]]
+    // CHECK: [[ADDRESSARG3:%.*]] = spv._address_of [[VAR3]]
+    // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
+    // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
+    // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
+    // CHECK: [[CONST2:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG2:%.*]] = spv.AccessChain [[ADDRESSARG2]]{{\[}}[[CONST2]]
+    // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
+    // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
+    // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
+    // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG0:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+    %0 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi32>, Input>
+    %1 = spv.Load "Input" %0 : vector<3xi32>
+    %2 = spv.CompositeExtract %1[0 : i32] : vector<3xi32>
+    %3 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi32>, Input>
+    %4 = spv.Load "Input" %3 : vector<3xi32>
+    %5 = spv.CompositeExtract %4[1 : i32] : vector<3xi32>
+    %6 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi32>, Input>
+    %7 = spv.Load "Input" %6 : vector<3xi32>
+    %8 = spv.CompositeExtract %7[2 : i32] : vector<3xi32>
+    %9 = spv._address_of @__builtin_var_LocalInvocationId__ : !spv.ptr<vector<3xi32>, Input>
+    %10 = spv.Load "Input" %9 : vector<3xi32>
+    %11 = spv.CompositeExtract %10[0 : i32] : vector<3xi32>
+    %12 = spv._address_of @__builtin_var_LocalInvocationId__ : !spv.ptr<vector<3xi32>, Input>
+    %13 = spv.Load "Input" %12 : vector<3xi32>
+    %14 = spv.CompositeExtract %13[1 : i32] : vector<3xi32>
+    %15 = spv._address_of @__builtin_var_LocalInvocationId__ : !spv.ptr<vector<3xi32>, Input>
+    %16 = spv.Load "Input" %15 : vector<3xi32>
+    %17 = spv.CompositeExtract %16[2 : i32] : vector<3xi32>
+    %18 = spv._address_of @__builtin_var_NumWorkgroups__ : !spv.ptr<vector<3xi32>, Input>
+    %19 = spv.Load "Input" %18 : vector<3xi32>
+    %20 = spv.CompositeExtract %19[0 : i32] : vector<3xi32>
+    %21 = spv._address_of @__builtin_var_NumWorkgroups__ : !spv.ptr<vector<3xi32>, Input>
+    %22 = spv.Load "Input" %21 : vector<3xi32>
+    %23 = spv.CompositeExtract %22[1 : i32] : vector<3xi32>
+    %24 = spv._address_of @__builtin_var_NumWorkgroups__ : !spv.ptr<vector<3xi32>, Input>
+    %25 = spv.Load "Input" %24 : vector<3xi32>
+    %26 = spv.CompositeExtract %25[2 : i32] : vector<3xi32>
+    %27 = spv._address_of @__builtin_var_WorkgroupSize__ : !spv.ptr<vector<3xi32>, Input>
+    %28 = spv.Load "Input" %27 : vector<3xi32>
+    %29 = spv.CompositeExtract %28[0 : i32] : vector<3xi32>
+    %30 = spv._address_of @__builtin_var_WorkgroupSize__ : !spv.ptr<vector<3xi32>, Input>
+    %31 = spv.Load "Input" %30 : vector<3xi32>
+    %32 = spv.CompositeExtract %31[1 : i32] : vector<3xi32>
+    %33 = spv._address_of @__builtin_var_WorkgroupSize__ : !spv.ptr<vector<3xi32>, Input>
+    %34 = spv.Load "Input" %33 : vector<3xi32>
+    %35 = spv.CompositeExtract %34[2 : i32] : vector<3xi32>
+    // CHECK: spv.IAdd [[ARG3]]
+    %36 = spv.IAdd %arg3, %2 : i32
+    // CHECK: spv.IAdd [[ARG4]]
+    %37 = spv.IAdd %arg4, %11 : i32
+    // CHECK: spv.AccessChain [[ARG0]]
+    %38 = spv.AccessChain %arg0[%36, %37] : !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+    %39 = spv.Load "StorageBuffer" %38 : f32
+    // CHECK: spv.AccessChain [[ARG1]]
+    %40 = spv.AccessChain %arg1[%36, %37] : !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+    %41 = spv.Load "StorageBuffer" %40 : f32
+    %42 = spv.FAdd %39, %41 : f32
+    // CHECK: spv.AccessChain [[ARG2]]
+    %43 = spv.AccessChain %arg2[%36, %37] : !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>
+    spv.Store "StorageBuffer" %43, %42 : f32
+    spv.Return
+  }
+  // CHECK: spv.EntryPoint "GLCompute" [[FN]], [[WORKGROUPID]], [[LOCALINVOCATIONID]], [[NUMWORKGROUPS]], [[WORKGROUPSIZE]]
+  // CHECK-NEXT: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
+} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir
new file mode 100644 (file)
index 0000000..51dfbdb
--- /dev/null
@@ -0,0 +1,29 @@
+// RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
+
+// CHECK-LABEL: spv.module
+spv.module "Logical" "GLSL450" {
+  // CHECK-DAG:    spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
+  // CHECK-DAG:    spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+  // CHECK:    func [[FN:@.*]]()
+  func @kernel_1(%arg0: f32
+                {spirv.interface_var_abi = {binding = 0 : i32,
+                                            descriptor_set = 0 : i32,
+                                            storage_class = 12 : i32}},
+                 %arg1: !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
+                 {spirv.interface_var_abi = {binding = 1 : i32,
+                                             descriptor_set = 0 : i32,
+                                             storage_class = 12 : i32}})
+  attributes  {spirv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+    // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
+    // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
+    // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
+    // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
+    // CHECK: [[ARG0PTR:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+    // CHECK: [[ARG0:%.*]] = spv.Load "StorageBuffer" [[ARG0PTR]]
+    // CHECK: spv.Return
+    spv.Return
+  }
+  // CHECK: spv.EntryPoint "GLCompute" [[FN]]
+  // CHECK: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
+} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}