Initial implementation to translate kernel fn in GPU Dialect to SPIR-V Dialect
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 30 Jul 2019 18:29:48 +0000 (11:29 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 30 Jul 2019 18:55:55 +0000 (11:55 -0700)
This CL adds an initial implementation for translation of kernel
function in GPU Dialect (used with a gpu.launch_kernel) op to a
spv.Module. The original function is translated into an entry
function.
Most of the heavy lifting is done by adding TypeConversion and other
utility functions/classes that provide most of the functionality to
translate from Standard Dialect to SPIR-V Dialect. These are intended
to be reusable in implementation of different dialect conversion
pipelines.
Note : Some of the files for have been renamed to be consistent with
the norm used by the other Conversion frameworks.
PiperOrigin-RevId: 260759165

20 files changed:
mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h [new file with mode: 0644]
mlir/include/mlir/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.h [deleted file]
mlir/include/mlir/Dialect/SPIRV/Passes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp [new file with mode: 0644]
mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp [new file with mode: 0644]
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp [new file with mode: 0644]
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td [moved from mlir/lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.td with 84% similarity]
mlir/lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.cpp [deleted file]
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Conversion/GPUToSPIRV/simple.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/ops.mlir
mlir/test/Dialect/SPIRV/standard_ops_to_spirv.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
new file mode 100644 (file)
index 0000000..21c2842
--- /dev/null
@@ -0,0 +1,103 @@
+//===- ConvertStandardToSPIRV.h - Convert to SPIR-V dialect -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides type converters and patterns to convert from standard types/ops to
+// SPIR-V types and operations. Also provides utilities and base classes to use
+// while targeting SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+namespace spirv {
+class SPIRVDialect;
+}
+
+/// Type conversion from Standard Types to SPIR-V Types.
+class SPIRVTypeConverter : public TypeConverter {
+public:
+  explicit SPIRVTypeConverter(MLIRContext *context);
+
+  /// Converts types to SPIR-V supported types.
+  Type convertType(Type t) override;
+
+protected:
+  spirv::SPIRVDialect *spirvDialect;
+};
+
+/// Converts a function type according to the requirements of a SPIR-V entry
+/// function. The arguments need to be converted to spv.Variables of spv.ptr
+/// types so that they could be bound by the runtime.
+class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter {
+public:
+  using SPIRVTypeConverter::SPIRVTypeConverter;
+
+  /// Method to convert argument of a function. The `type` is converted to
+  /// spv.ptr<type, Uniform>.
+  // TODO(ravishankarm) : Support other storage classes.
+  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
+                                    SignatureConversion &result) override;
+};
+
+/// Base class to define a conversion pattern to translate Ops into SPIR-V.
+template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
+public:
+  SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
+                  SPIRVEntryFnTypeConverter &entryFnConverter)
+      : ConversionPattern(OpTy::getOperationName(), 1, context),
+        typeConverter(typeConverter), entryFnConverter(entryFnConverter) {}
+
+protected:
+  // Type lowering class.
+  SPIRVTypeConverter &typeConverter;
+
+  // Entry function signature converter.
+  SPIRVEntryFnTypeConverter &entryFnConverter;
+};
+
+/// Base Class for legalize a FuncOp within a spv.module. This class can be
+/// extended to implement a ConversionPattern to lower a FuncOp. It provides
+/// hooks to legalize a FuncOp as a simple function, or as an entry function.
+class SPIRVFnLowering : public SPIRVOpLowering<FuncOp> {
+public:
+  using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+
+protected:
+  /// Method to legalize the function as a non-entry function.
+  LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                              ConversionPatternRewriter &rewriter,
+                              FuncOp &newFuncOp) const;
+
+  /// Method to legalize the function as an entry function.
+  LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                                     ConversionPatternRewriter &rewriter,
+                                     FuncOp &newFuncOp) const;
+};
+
+/// Appends to a pattern list additional patterns for translating StandardOps to
+/// SPIR-V ops.
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+                                     OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.h
deleted file mode 100644 (file)
index 7e75430..0000000
+++ /dev/null
@@ -1,35 +0,0 @@
-//===- StdOpsToSPIRVConversion.h - Convert StandardOps to SPIR-V *- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines utility function to import patterns to convert StandardOps
-// to SPIR-V ops
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef STANDARD_OPS_TO_SPIRV_H_
-#define STANDARD_OPS_TO_SPIRV_H_
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-/// Method to append to a pattern list additional patterns for translating
-/// StandardOps to SPIR-V ops.
-void populateStdOpsToSPIRVPatterns(MLIRContext *context,
-                                   OwningRewritePatternList &patterns);
-} // namespace mlir
-
-#endif // STANDARD_OPS_TO_SPIRV_H_
index 72eb866..e896da7 100644 (file)
@@ -27,7 +27,7 @@
 namespace mlir {
 namespace spirv {
 
-FunctionPassBase *createStdOpsToSPIRVConversionPass();
+ModulePassBase *createConvertStandardToSPIRVPass();
 
 } // namespace spirv
 } // namespace mlir
index abe3efb..494adc1 100644 (file)
@@ -38,6 +38,9 @@ public:
 
   /// Prints a type registered to this dialect.
   void printType(Type type, llvm::raw_ostream &os) const override;
+
+  /// Checks if a type is valid in SPIR-V dialect.
+  bool isValidSPIRVType(Type t) const;
 };
 
 } // end namespace spirv
index 273fd82..104a479 100644 (file)
@@ -31,6 +31,17 @@ namespace spirv {
 #define GET_OP_CLASSES
 #include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
 
+/// Following methods are auto-generated.
+///
+/// Get the name used in the Op to refer to an enum value of the given
+/// `EnumClass`.
+/// template <typename EnumClass> StringRef attributeName();
+///
+/// Get the function that can be used to symbolize an enum value.
+/// template <typename EnumClass>
+/// llvm::Optional<EnumClass> (*)(StringRef) symbolizeEnum();
+#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
+
 } // end namespace spirv
 } // end namespace mlir
 
index 509aa27..5cf8e13 100644 (file)
@@ -76,18 +76,24 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
   }];
 
   let arguments = (ins
+    SPV_AddressingModelAttr:$addressing_model,
+    SPV_MemoryModelAttr:$memory_model,
     OptionalAttr<StrArrayAttr>:$capabilities,
     OptionalAttr<StrArrayAttr>:$extensions,
-    OptionalAttr<StrArrayAttr>:$extended_instruction_sets,
-    SPV_AddressingModelAttr:$addressing_model,
-    SPV_MemoryModelAttr:$memory_model
+    OptionalAttr<StrArrayAttr>:$extended_instruction_sets
   );
 
   let results = (outs);
 
   let regions = (region SizedRegion<1>:$body);
 
-  let builders = [OpBuilder<"Builder *, OperationState *state">];
+  let builders = [OpBuilder<"Builder *, OperationState *state">,
+                  OpBuilder<[{Builder *, OperationState *state,
+                              IntegerAttr addressing_model,
+                              IntegerAttr memory_model,
+                              /*optional*/ArrayAttr capabilities = nullptr,
+                              /*optional*/ArrayAttr extensions = nullptr,
+                              /*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
 
   // We need to ensure the block inside the region is properly terminated;
   // the auto-generated builders do not guarantee that.
index 0238172..1ddd103 100644 (file)
@@ -2,5 +2,6 @@ add_subdirectory(LoopsToGPU)
 add_subdirectory(ControlFlowToCFG)
 add_subdirectory(GPUToCUDA)
 add_subdirectory(GPUToNVVM)
+add_subdirectory(GPUToSPIRV)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
new file mode 100644 (file)
index 0000000..8426420
--- /dev/null
@@ -0,0 +1,13 @@
+add_llvm_library(MLIRGPUtoSPIRVTransforms
+  GPUToSPIRV.cpp
+  )
+
+target_link_libraries(MLIRGPUtoSPIRVTransforms
+  MLIRGPU
+  MLIRIR
+  MLIRPass
+  MLIRSPIRV
+  MLIRStandardOps
+  MLIRSPIRVConversion
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
new file mode 100644 (file)
index 0000000..4eadb87
--- /dev/null
@@ -0,0 +1,125 @@
+//===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert a kernel function in the GPU Dialect
+// into a spv.module operation
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
+/// attribute gpu.kernel) within a spv.module.
+class KernelFnConversion final : public SPIRVFnLowering {
+public:
+  using SPIRVFnLowering::SPIRVFnLowering;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+PatternMatchResult
+KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+  auto funcOp = cast<FuncOp>(op);
+  FuncOp newFuncOp;
+  if (!gpu::GPUDialect::isKernel(funcOp)) {
+    return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp))
+               ? matchSuccess()
+               : matchFailure();
+  }
+
+  if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) {
+    return matchFailure();
+  }
+  newFuncOp.getOperation()->removeAttr(Identifier::get(
+      gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
+  return matchSuccess();
+}
+
+namespace {
+/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
+/// that have the "gpu.kernel" attribute, i.e. those functions that are
+/// referenced in gpu::LaunchKernelOp operations. For each such function
+///
+/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
+/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
+/// replace it).
+///
+/// 2) Lower the body of the spirv::ModuleOp.
+class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+  void runOnModule() override;
+};
+} // namespace
+
+void GPUToSPIRVPass::runOnModule() {
+  auto context = &getContext();
+  auto module = getModule();
+
+  SmallVector<Operation *, 4> spirvModules;
+  for (auto funcOp : module.getOps<FuncOp>()) {
+    if (gpu::GPUDialect::isKernel(funcOp)) {
+      OpBuilder builder(module.getBodyRegion());
+      // Create a new spirv::ModuleOp for this function, and clone the
+      // function into it.
+      // TODO : Generalize this to account for different extensions,
+      // capabilities, extended_instruction_sets, other addressing models
+      // and memory models.
+      auto spvModule = builder.create<spirv::ModuleOp>(
+          funcOp.getLoc(),
+          builder.getI32IntegerAttr(
+              static_cast<int32_t>(spirv::AddressingModel::Logical)),
+          builder.getI32IntegerAttr(
+              static_cast<int32_t>(spirv::MemoryModel::VulkanKHR)));
+      OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
+      moduleBuilder.clone(*funcOp.getOperation());
+      spirvModules.push_back(spvModule);
+    }
+  }
+
+  /// Dialect conversion to lower the functions with the spirv::ModuleOps.
+  SPIRVTypeConverter typeConverter(context);
+  SPIRVEntryFnTypeConverter entryFnConverter(context);
+  OwningRewritePatternList patterns;
+  RewriteListBuilder<KernelFnConversion>::build(
+      patterns, context, typeConverter, entryFnConverter);
+  populateStandardToSPIRVPatterns(context, patterns);
+
+  ConversionTarget target(*context);
+  target.addLegalDialect<spirv::SPIRVDialect>();
+  target.addDynamicallyLegalOp<FuncOp>(
+      [&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); });
+
+  if (failed(applyFullConversion(spirvModules, target, std::move(patterns),
+                                 &typeConverter))) {
+    return signalPassFailure();
+  }
+}
+
+ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
+
+static PassRegistration<GPUToSPIRVPass>
+    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
index ea04d56..be53112 100644 (file)
@@ -1,16 +1,18 @@
-set(LLVM_TARGET_DEFINITIONS StdOpsToSPIRVConversion.td)
-mlir_tablegen(StdOpsToSPIRVConversion.cpp.inc -gen-rewriters)
-add_public_tablegen_target(MLIRStdOpsToSPIRVConversionIncGen)
+set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td)
+mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
+add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
 
 add_llvm_library(MLIRSPIRVConversion
-  StdOpsToSPIRVConversion.cpp
+  ConvertStandardToSPIRV.cpp
+  ConvertStandardToSPIRVPass.cpp
 
   ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
   )
 
 add_dependencies(MLIRSPIRVConversion
-  MLIRStdOpsToSPIRVConversionIncGen)
+  MLIRStandardToSPIRVIncGen)
 
 target_link_libraries(MLIRSPIRVConversion
   MLIRIR
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
new file mode 100644 (file)
index 0000000..d32d866
--- /dev/null
@@ -0,0 +1,206 @@
+//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context)
+    : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
+
+Type SPIRVTypeConverter::convertType(Type t) {
+  // Check if the type is SPIR-V supported. If so return the type.
+  if (spirvDialect->isValidSPIRVType(t)) {
+    return t;
+  }
+
+  if (auto memRefType = t.dyn_cast<MemRefType>()) {
+    if (memRefType.hasStaticShape()) {
+      // Convert MemrefType to spv.array if size is known.
+      // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
+      // to support other Storage Classes.
+      return spirv::PointerType::get(
+          spirv::ArrayType::get(memRefType.getElementType(),
+                                memRefType.getNumElements()),
+          spirv::StorageClass::StorageBuffer);
+    }
+  }
+  return Type();
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Function signature Conversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+                                               SignatureConversion &result) {
+  // Try to convert the given input type.
+  auto convertedType = convertType(type);
+  // TODO(ravishankarm) : Vulkan spec requires these to be a
+  // spirv::StructType. This is not a SPIR-V requirement, so just making this a
+  // pointer type for now.
+  if (!convertedType)
+    return failure();
+  // For arguments to entry functions, convert the type into a pointer type if
+  // it is already not one.
+  if (!convertedType.isa<spirv::PointerType>()) {
+    // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
+    // to support other Storage classes.
+    convertedType = spirv::PointerType::get(convertedType,
+                                            spirv::StorageClass::StorageBuffer);
+  }
+
+  // Add the new inputs.
+  result.addInputs(inputNo, convertedType);
+  return success();
+}
+
+template <typename Converter>
+static LogicalResult
+lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter, Converter &typeConverter,
+                  TypeConverter::SignatureConversion &signatureConverter,
+                  FuncOp &newFuncOp) {
+  auto fnType = funcOp.getType();
+
+  if (fnType.getNumResults()) {
+    return funcOp.emitError("SPIR-V dialect only supports functions with no "
+                            "return values right now");
+  }
+
+  for (auto &argType : enumerate(fnType.getInputs())) {
+    // Get the type of the argument
+    if (failed(typeConverter.convertSignatureArg(
+            argType.index(), argType.value(), signatureConverter))) {
+      return funcOp.emitError("unable to convert argument type ")
+             << argType.value() << " to SPIR-V type";
+    }
+  }
+
+  // Create a new function with an updated signature.
+  newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+  newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
+                                      llvm::None, funcOp.getContext()));
+
+  // Tell the rewriter to convert the region signature.
+  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+  rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+  return success();
+}
+
+LogicalResult
+SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                               ConversionPatternRewriter &rewriter,
+                               FuncOp &newFuncOp) const {
+  auto fnType = funcOp.getType();
+  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+  return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
+                           signatureConverter, newFuncOp);
+}
+
+LogicalResult
+SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                                      ConversionPatternRewriter &rewriter,
+                                      FuncOp &newFuncOp) const {
+  auto fnType = funcOp.getType();
+  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+  if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter,
+                               signatureConverter, newFuncOp))) {
+    return failure();
+  }
+  // Create spv.Variable ops for each of the arguments. These need to be bound
+  // by the runtime. For now use descriptor_set 0, and arg number as the binding
+  // number.
+  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+  if (!module) {
+    return funcOp.emitError("expected op to be within a spv.module");
+  }
+  OpBuilder builder(module.getOperation()->getRegion(0));
+  SmallVector<Value *, 4> interface;
+  for (auto &convertedArgType :
+       llvm::enumerate(signatureConverter.getConvertedTypes())) {
+    auto variableOp = builder.create<spirv::VariableOp>(
+        funcOp.getLoc(), convertedArgType.value(),
+        builder.getI32IntegerAttr(
+            static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
+        llvm::None);
+    variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
+    variableOp.setAttr("binding",
+                       builder.getI32IntegerAttr(convertedArgType.index()));
+    interface.push_back(variableOp.getResult());
+  }
+  // Create an entry point instruction for this function.
+  // TODO(ravishankarm) : Add execution mode for the entry function
+  builder.setInsertionPoint(&(module.getBlock().back()));
+  builder.create<spirv::EntryPointOp>(
+      funcOp.getLoc(),
+      builder.getI32IntegerAttr(
+          static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
+      builder.getSymbolRefAttr(newFuncOp.getName()), interface);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Convert return -> spv.Return.
+class ReturnToSPIRVConversion : public ConversionPattern {
+public:
+  ReturnToSPIRVConversion(MLIRContext *context)
+      : ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op->getNumOperands()) {
+      return matchFailure();
+    }
+    rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
+    return matchSuccess();
+  }
+};
+
+} // namespace
+
+namespace {
+/// Import the Standard Ops to SPIR-V Patterns.
+#include "StandardToSPIRV.cpp.inc"
+} // namespace
+
+namespace mlir {
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+                                     OwningRewritePatternList &patterns) {
+  populateWithGenerated(context, &patterns);
+  // Add the return op conversion.
+  RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context);
+}
+} // namespace mlir
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
new file mode 100644 (file)
index 0000000..c2652be
--- /dev/null
@@ -0,0 +1,56 @@
+//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard ops into the SPIR-V
+// ops. It does not legalize FuncOps.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass converting MLIR Standard operations into the SPIR-V dialect.
+class ConvertStandardToSPIRVPass
+    : public ModulePass<ConvertStandardToSPIRVPass> {
+  void runOnModule() override;
+};
+} // namespace
+
+void ConvertStandardToSPIRVPass::runOnModule() {
+  OwningRewritePatternList patterns;
+  auto module = getModule();
+
+  populateStandardToSPIRVPatterns(module.getContext(), patterns);
+  ConversionTarget target(*(module.getContext()));
+  target.addLegalDialect<spirv::SPIRVDialect>();
+  target.addLegalOp<FuncOp>();
+
+  if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
+    return signalPassFailure();
+  }
+}
+
+ModulePassBase *mlir::spirv::createConvertStandardToSPIRVPass() {
+  return new ConvertStandardToSPIRVPass();
+}
+
+static PassRegistration<ConvertStandardToSPIRVPass>
+    pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
@@ -1,4 +1,4 @@
-//==- StdOpsToSPIRVConversion.td - Std Ops to SPIR-V Patterns *- tablegen -*==//
+//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==//
 
 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,13 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Defines Patterns to lower standard ops to SPIR-V
+// Defines Patterns to lower standard ops to SPIR-V.
 //
 //===----------------------------------------------------------------------===//
 
-#ifdef STANDARD_OPS_TO_SPIRV
+#ifdef MLIR_CONVERSION_STANDARDTOSPIRV_TD
 #else
-#define STANDARD_OPS_TO_SPIRV
+#define MLIR_CONVERSION_STANDARDTOSPIRV_TD
 
 #ifdef STANDARD_OPS
 #else
@@ -45,4 +45,4 @@ multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
 
 defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
 
-#endif // STANDARD_OPS_TO_SPIRV
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.cpp b/mlir/lib/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.cpp
deleted file mode 100644 (file)
index 45213bb..0000000
+++ /dev/null
@@ -1,62 +0,0 @@
-//===- StdOpsToSPIRVLowering.cpp - Std Ops to SPIR-V dialect conversion ---===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements a pass to convert MLIR standard ops into the SPIR-V
-// dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.h"
-#include "mlir/Dialect/SPIRV/Passes.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/StandardOps/Ops.h"
-
-using namespace mlir;
-
-namespace {
-/// A pass converting MLIR Standard operations into the SPIR-V dialect.
-class StdOpsToSPIRVConversionPass
-    : public FunctionPass<StdOpsToSPIRVConversionPass> {
-  void runOnFunction() override;
-};
-
-#include "StdOpsToSPIRVConversion.cpp.inc"
-} // namespace
-
-namespace mlir {
-void populateStdOpsToSPIRVPatterns(MLIRContext *context,
-                                   OwningRewritePatternList &patterns) {
-  populateWithGenerated(context, &patterns);
-}
-} // namespace mlir
-
-void StdOpsToSPIRVConversionPass::runOnFunction() {
-  OwningRewritePatternList patterns;
-  auto func = getFunction();
-
-  populateStdOpsToSPIRVPatterns(func.getContext(), patterns);
-  applyPatternsGreedily(func, std::move(patterns));
-}
-
-FunctionPassBase *mlir::spirv::createStdOpsToSPIRVConversionPass() {
-  return new StdOpsToSPIRVConversionPass();
-}
-
-static PassRegistration<StdOpsToSPIRVConversionPass>
-    pass("std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
index f9ddc47..622bb22 100644 (file)
@@ -72,6 +72,33 @@ static bool parseNumberX(StringRef &spec, int64_t &number) {
   return true;
 }
 
+static bool isValidSPIRVScalarType(Type type) {
+  if (type.isa<FloatType>()) {
+    return !type.isBF16();
+  }
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
+                              intType.getWidth());
+  }
+  return false;
+}
+
+bool SPIRVDialect::isValidSPIRVType(Type type) const {
+  // Allow SPIR-V dialect types
+  if (&type.getDialect() == this) {
+    return true;
+  }
+  if (isValidSPIRVScalarType(type)) {
+    return true;
+  }
+  if (auto vectorType = type.dyn_cast<VectorType>()) {
+    return (isValidSPIRVScalarType(vectorType.getElementType()) &&
+            vectorType.getNumElements() >= 2 &&
+            vectorType.getNumElements() <= 4);
+  }
+  return false;
+}
+
 static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
                                Location loc) {
   spec = spec.trim();
@@ -104,6 +131,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
       emitError(loc, "only 1-D vector allowed but found ") << t;
       return Type();
     }
+    if (t.getNumElements() > 4) {
+      emitError(loc,
+                "vector length has to be less than or equal to 4 but found ")
+          << t.getNumElements();
+      return Type();
+    }
   } else {
     emitError(loc, "cannot use ") << type << " to compose SPIR-V types";
     return Type();
index 76b26e0..ae5752a 100644 (file)
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
 
-namespace mlir {
-namespace spirv {
-#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
-} // namespace spirv
-} // namespace mlir
-
 using namespace mlir;
 
 // TODO(antiagainst): generate these strings using ODS.
@@ -550,20 +544,8 @@ static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
       return entryPointOp.emitOpError("interface operands to entry point must "
                                       "be generated from a variable op");
     }
-    // Before version 1.4 the variables can only have storage_class of Input or
-    // Output.
-    // TODO: Add versioning so that this can be avoided for 1.4
-    auto storageClass =
-        interface->getType().cast<spirv::PointerType>().getStorageClass();
-    switch (storageClass) {
-    case spirv::StorageClass::Input:
-    case spirv::StorageClass::Output:
-      break;
-    default:
-      return entryPointOp.emitOpError("invalid storage class '")
-             << stringifyStorageClass(storageClass)
-             << "' for interface variables";
-    }
+    // TODO:  Before version 1.4 the variables can only have storage_class of
+    // Input or Output. That needs to be verified.
   }
   return success();
 }
@@ -674,6 +656,22 @@ void spirv::ModuleOp::build(Builder *builder, OperationState *state) {
   ensureModuleEnd(state->addRegion(), *builder, state->location);
 }
 
+void spirv::ModuleOp::build(Builder *builder, OperationState *state,
+                            IntegerAttr addressing_model,
+                            IntegerAttr memory_model, ArrayAttr capabilities,
+                            ArrayAttr extensions,
+                            ArrayAttr extended_instruction_sets) {
+  state->addAttribute("addressing_model", addressing_model);
+  state->addAttribute("memory_model", memory_model);
+  if (capabilities)
+    state->addAttribute("capabilities", capabilities);
+  if (extensions)
+    state->addAttribute("extensions", extensions);
+  if (extended_instruction_sets)
+    state->addAttribute("extended_instruction_sets", extended_instruction_sets);
+  ensureModuleEnd(state->addRegion(), *builder, state->location);
+}
+
 static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) {
   Region *body = state->addRegion();
 
diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
new file mode 100644 (file)
index 0000000..671a38b
--- /dev/null
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+// CHECK:       spv.module "Logical" "VulkanKHR" {
+// CHECK-NEXT:    [[VAR1:%.*]] = spv.Variable bind(0, 0) : !spv.ptr<f32, StorageBuffer>
+// CHECK-NEXT:    [[VAR2:%.*]] = spv.Variable bind(0, 1) : !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
+// CHECK-NEXT:    func @kernel_1
+// CHECK-NEXT:      spv.Return
+// CHECK:       spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]]
+func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
+    attributes { gpu.kernel } {
+  return
+}
+
+func @foo() {
+  %0 = "op"() : () -> (f32)
+  %1 = "op"() : () -> (memref<12xf32, 1>)
+  %cst = constant 1 : index
+  "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel_1 }
+      : (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> ()
+  return
+}
\ No newline at end of file
index c8771ec..ebc5b6c 100644 (file)
@@ -345,17 +345,6 @@ spv.module "Logical" "VulkanKHR" {
 
 // -----
 
-spv.module "Logical" "VulkanKHR" {
-   %2 = spv.Variable : !spv.ptr<f32, Workgroup>
-   func @do_nothing() -> () {
-     spv.Return
-   }
-   // expected-error @+1 {{'spv.EntryPoint' op invalid storage class 'Workgroup'}}
-   spv.EntryPoint "GLCompute" @do_nothing, %2 : !spv.ptr<f32, Workgroup>
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // spv.ExecutionMode
 //===----------------------------------------------------------------------===//
index fc59d68..4b55551 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -std-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s
 
 // CHECK-LABEL: @fmul_scalar
 func @fmul_scalar(%arg: f32) -> f32 {
index 80b5499..0c17720 100644 (file)
@@ -420,6 +420,8 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Op Utilites", os);
 
   auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr");
+  os << "#ifndef SPIRV_OP_UTILS_H_\n";
+  os << "#define SPIRV_OP_UTILS_H_\n";
   emitEnumGetAttrNameFnDecl(os);
   emitEnumGetSymbolizeFnDecl(os);
   for (const auto *def : defs) {
@@ -427,6 +429,7 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
     emitEnumGetAttrNameFnDefn(enumAttr, os);
     emitEnumGetSymbolizeFnDefn(enumAttr, os);
   }
+  os << "#endif // SPIRV_OP_UTILS_H\n";
   return false;
 }