NFC: Refactor Dialect Conversion targeting SPIR-V.
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 14 Nov 2019 20:31:32 +0000 (12:31 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 14 Nov 2019 20:34:54 +0000 (12:34 -0800)
Refactoring the conversion from StandardOps/GPU dialect to SPIR-V
dialect:
1) Move the SPIRVTypeConversion and SPIRVOpLowering class into SPIR-V
   dialect.
2) Add header files that expose functions to add patterns for the
   dialects to SPIR-V lowering, as well as a pass that does the
   dialect to SPIR-V lowering.
3) Make SPIRVOpLowering derive from OpLowering class.
PiperOrigin-RevId: 280486871

15 files changed:
mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h [new file with mode: 0644]
mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h [new file with mode: 0644]
mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SPIRV/Passes.h
mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h [new file with mode: 0644]
mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp [moved from mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp with 54% similarity]
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp [new file with mode: 0644]
mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt

diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h
new file mode 100644 (file)
index 0000000..f617986
--- /dev/null
@@ -0,0 +1,36 @@
+//===- ConvertGPUToSPIRV.h - GPU Ops to SPIR-V dialect patterns ----C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides patterns for lowering GPU Ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
+#define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class SPIRVTypeConverter;
+/// Appends to a pattern list additional patterns for translating GPU Ops to
+/// SPIR-V ops.
+void populateGPUToSPIRVPatterns(MLIRContext *context,
+                                SPIRVTypeConverter &typeConverter,
+                                OwningRewritePatternList &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h b/mlir/include/mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h
new file mode 100644 (file)
index 0000000..be8cad2
--- /dev/null
@@ -0,0 +1,36 @@
+//===- ConvertGPUToSPIRVPass.h - GPU to SPIR-V conversion pass --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides a pass to convert GPU ops to SPIRV ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+#define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+
+#include <memory>
+
+namespace mlir {
+
+class ModuleOp;
+template <typename T> class OpPassBase;
+
+/// Pass to convert GPU Ops to SPIR-V ops.
+std::unique_ptr<OpPassBase<ModuleOp>> createConvertGPUToSPIRVPass();
+
+} // namespace mlir
+#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
index 63e63cf..69db817 100644 (file)
 // limitations under the License.
 // =============================================================================
 //
-// Provides type converters and patterns to convert from standard types/ops to
-// SPIR-V types and operations. Also provides utilities and base classes to use
-// while targeting SPIR-V dialect.
+// Provides patterns to lower StandardOps to SPIR-V dialect.
 //
 //===----------------------------------------------------------------------===//
 
 #ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
 #define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
 
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Support/StringExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
-
-class LoadOp;
-class ReturnOp;
-class StoreOp;
-
-/// Type conversion from Standard Types to SPIR-V Types.
-class SPIRVBasicTypeConverter : public TypeConverter {
-public:
-  /// Converts types to SPIR-V supported types.
-  virtual Type convertType(Type t);
-};
-
-/// Converts a function type according to the requirements of a SPIR-V entry
-/// function. The arguments need to be converted to spv.Variables of spv.ptr
-/// types so that they could be bound by the runtime.
-class SPIRVTypeConverter final : public TypeConverter {
-public:
-  explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
-      : basicTypeConverter(basicTypeConverter) {}
-
-  /// Converts types to SPIR-V types using the basic type converter.
-  Type convertType(Type t) override;
-
-  /// Gets the basic type converter.
-  SPIRVBasicTypeConverter *getBasicTypeConverter() const {
-    return basicTypeConverter;
-  }
-
-private:
-  SPIRVBasicTypeConverter *basicTypeConverter;
-};
-
-/// Base class to define a conversion pattern to translate Ops into SPIR-V.
-template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
-public:
-  SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter)
-      : ConversionPattern(OpTy::getOperationName(), 1, context),
-        typeConverter(typeConverter) {}
-
-protected:
-  /// Gets the global variable associated with a builtin and add
-  /// it if it doesnt exist.
-  Value *loadFromBuiltinVariable(Operation *op, spirv::BuiltIn builtin,
-                                 ConversionPatternRewriter &rewriter) const {
-    auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
-    if (!moduleOp) {
-      op->emitError("expected operation to be within a SPIR-V module");
-      return nullptr;
-    }
-    auto varOp =
-        getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, rewriter);
-    auto ptr = rewriter
-                   .create<spirv::AddressOfOp>(op->getLoc(), varOp.type(),
-                                               rewriter.getSymbolRefAttr(varOp))
-                   .pointer();
-    return rewriter.create<spirv::LoadOp>(
-        op->getLoc(),
-        ptr->getType().template cast<spirv::PointerType>().getPointeeType(),
-        ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr);
-  }
-
-  /// Type lowering class.
-  SPIRVTypeConverter &typeConverter;
-
-private:
-  /// Look through all global variables in `moduleOp` and check if there is a
-  /// spv.globalVariable that has the same `builtin` attribute.
-  spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
-                                             spirv::BuiltIn builtin) const {
-    for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
-      if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
-              stringifyDecoration(spirv::Decoration::BuiltIn)))) {
-        auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
-        if (varBuiltIn && varBuiltIn.getValue() == builtin) {
-          return varOp;
-        }
-      }
-    }
-    return nullptr;
-  }
-
-  /// Gets name of global variable for a buitlin.
-  std::string getBuiltinVarName(spirv::BuiltIn builtin) const {
-    return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() +
-           "__";
-  }
-
-  /// Gets or inserts a global variable for a builtin within a module.
-  spirv::GlobalVariableOp
-  getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
-                             spirv::BuiltIn builtin,
-                             ConversionPatternRewriter &builder) const {
-    if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
-      return varOp;
-    }
-    auto ip = builder.saveInsertionPoint();
-    builder.setInsertionPointToStart(&moduleOp.getBlock());
-    auto name = getBuiltinVarName(builtin);
-    spirv::GlobalVariableOp newVarOp;
-    switch (builtin) {
-    case spirv::BuiltIn::NumWorkgroups:
-    case spirv::BuiltIn::WorkgroupSize:
-    case spirv::BuiltIn::WorkgroupId:
-    case spirv::BuiltIn::LocalInvocationId:
-    case spirv::BuiltIn::GlobalInvocationId: {
-      auto ptrType = spirv::PointerType::get(
-          VectorType::get({3}, builder.getIntegerType(32)),
-          spirv::StorageClass::Input);
-      newVarOp = builder.create<spirv::GlobalVariableOp>(
-          loc, TypeAttr::get(ptrType), builder.getStringAttr(name), nullptr);
-      newVarOp.setAttr(
-          convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
-          builder.getStringAttr(stringifyBuiltIn(builtin)));
-      break;
-    }
-    default:
-      emitError(loc, "unimplemented builtin variable generation for ")
-          << stringifyBuiltIn(builtin);
-    }
-    builder.restoreInsertionPoint(ip);
-    return newVarOp;
-  }
-};
-
-/// Legalizes a function as a non-entry function.
-LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
-                            ConversionPatternRewriter &rewriter,
-                            FuncOp &newFuncOp);
-
-/// Legalizes a function as an entry function.
-LogicalResult lowerAsEntryFunction(FuncOp funcOp,
-                                   SPIRVTypeConverter *typeConverter,
-                                   ConversionPatternRewriter &rewriter,
-                                   FuncOp &newFuncOp);
-
-/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
-/// spv.ExecutionMode ops.
-LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
-
+class SPIRVTypeConverter;
 /// Appends to a pattern list additional patterns for translating StandardOps to
 /// SPIR-V ops.
 void populateStandardToSPIRVPatterns(MLIRContext *context,
+                                     SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns);
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h b/mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h
new file mode 100644 (file)
index 0000000..1bf4977
--- /dev/null
@@ -0,0 +1,32 @@
+//===- ConvertStandardToSPIRVPass.h - StdOps to SPIR-V pass -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides a pass to lower from StandardOps to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
+#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+/// Pass to convert StandardOps to SPIR-V ops.
+std::unique_ptr<OpPassBase<ModuleOp>> createConvertStandardToSPIRVPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
index 4969935..d245ef0 100644 (file)
@@ -27,9 +27,6 @@
 namespace mlir {
 namespace spirv {
 
-// Creates a module pass that converts standard ops to SPIR-V ops.
-std::unique_ptr<OpPassBase<mlir::ModuleOp>> createConvertStandardToSPIRVPass();
-
 // Creates a module pass that converts composite types used by objects in the
 // StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
 // classes with layout information.
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
new file mode 100644 (file)
index 0000000..56182ae
--- /dev/null
@@ -0,0 +1,92 @@
+//===- SPIRVLowering.h - SPIR-V lowering utilities  -------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines, utilities and base classes to use while targeting SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
+#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Support/StringExtras.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir {
+
+/// Type conversion from Standard Types to SPIR-V Types.
+class SPIRVBasicTypeConverter : public TypeConverter {
+public:
+  /// Converts types to SPIR-V supported types.
+  virtual Type convertType(Type t);
+};
+
+/// Converts a function type according to the requirements of a SPIR-V entry
+/// function. The arguments need to be converted to spv.GlobalVariables of
+/// spv.ptr types so that they could be bound by the runtime.
+class SPIRVTypeConverter final : public TypeConverter {
+public:
+  explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
+      : basicTypeConverter(basicTypeConverter) {}
+
+  /// Converts types to SPIR-V types using the basic type converter.
+  Type convertType(Type t) override;
+
+  /// Gets the basic type converter.
+  Type convertBasicType(Type t) { return basicTypeConverter->convertType(t); }
+
+private:
+  SPIRVBasicTypeConverter *basicTypeConverter;
+};
+
+/// Base class to define a conversion pattern to translate Ops into SPIR-V.
+template <typename SourceOp>
+class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
+public:
+  SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
+                  PatternBenefit benefit = 1)
+      : OpConversionPattern<SourceOp>(context, benefit),
+        typeConverter(typeConverter) {}
+
+protected:
+  /// Type lowering class.
+  SPIRVTypeConverter &typeConverter;
+
+private:
+};
+
+namespace spirv {
+/// Returns a value that represents a builtin variable value within the SPIR-V
+/// module.
+Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
+                               OpBuilder &builder);
+
+/// Legalizes a function as an entry function.
+LogicalResult lowerAsEntryFunction(FuncOp funcOp,
+                                   SPIRVTypeConverter *typeConverter,
+                                   ConversionPatternRewriter &rewriter,
+                                   FuncOp &newFuncOp);
+
+/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
+/// spv.ExecutionMode ops.
+LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
+
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
index 8426420..a562439 100644 (file)
@@ -1,5 +1,6 @@
 add_llvm_library(MLIRGPUtoSPIRVTransforms
-  GPUToSPIRV.cpp
+  ConvertGPUToSPIRV.cpp
+  ConvertGPUToSPIRVPass.cpp
   )
 
 target_link_libraries(MLIRGPUtoSPIRVTransforms
@@ -8,6 +9,6 @@ target_link_libraries(MLIRGPUtoSPIRVTransforms
   MLIRPass
   MLIRSPIRV
   MLIRStandardOps
-  MLIRSPIRVConversion
+  MLIRStandardToSPIRVTransforms
   MLIRTransforms
   )
@@ -1,4 +1,4 @@
-//===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===//
+//===- ConvertGPUToSPIRV.cpp - Convert GPU ops to SPIR-V dialect ----------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
 // limitations under the License.
 // =============================================================================
 //
-// This file implements a pass to convert a kernel function in the GPU Dialect
-// into a spv.module operation
+// This file implements the conversion patterns from GPU ops to SPIR-V dialect.
 //
 //===----------------------------------------------------------------------===//
-#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Pass/Pass.h"
 
 using namespace mlir;
 
@@ -36,19 +34,19 @@ public:
   using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
 /// builin variables.
-template <typename OpTy, spirv::BuiltIn builtin>
-class LaunchConfigConversion : public SPIRVOpLowering<OpTy> {
+template <typename SourceOp, spirv::BuiltIn builtin>
+class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
 public:
-  using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
+  using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(SourceOp op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
@@ -59,23 +57,22 @@ public:
   using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
 
 } // namespace
 
 PatternMatchResult
-ForOpConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
                                  ConversionPatternRewriter &rewriter) const {
   // loop::ForOp can be lowered to the structured control flow represented by
   // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
   // latch and the merge block the exit block. The resulting spirv::LoopOp has a
   // single back edge from the continue to header block, and a single exit from
   // header to merge.
-  auto forOp = cast<loop::ForOp>(op);
   loop::ForOpOperandAdaptor forOperands(operands);
-  auto loc = op->getLoc();
+  auto loc = forOp.getLoc();
   auto loopControl = rewriter.getI32IntegerAttr(
       static_cast<uint32_t>(spirv::LoopControl::None));
   auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
@@ -135,11 +132,12 @@ ForOpConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   return matchSuccess();
 }
 
-template <typename OpTy, spirv::BuiltIn builtin>
-PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
-    Operation *op, ArrayRef<Value *> operands,
+template <typename SourceOp, spirv::BuiltIn builtin>
+PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
+    SourceOp op, ArrayRef<Value *> operands,
     ConversionPatternRewriter &rewriter) const {
-  auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
+  auto dimAttr =
+      op.getOperation()->template getAttrOfType<StringAttr>("dimension");
   if (!dimAttr) {
     return this->matchFailure();
   }
@@ -155,7 +153,7 @@ PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
   }
 
   // SPIR-V invocation builtin variables are a vector of type <3xi32>
-  auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter);
+  auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
   rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
       op, rewriter.getIntegerType(32), spirvBuiltin,
       rewriter.getI32ArrayAttr({index}));
@@ -163,72 +161,24 @@ PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
 }
 
 PatternMatchResult
-KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
                                     ConversionPatternRewriter &rewriter) const {
-  auto funcOp = cast<FuncOp>(op);
   FuncOp newFuncOp;
   if (!gpu::GPUDialect::isKernel(funcOp)) {
-    return succeeded(lowerFunction(funcOp, &typeConverter, rewriter, newFuncOp))
-               ? matchSuccess()
-               : matchFailure();
+    return matchFailure();
   }
 
-  if (failed(
-          lowerAsEntryFunction(funcOp, &typeConverter, rewriter, newFuncOp))) {
+  if (failed(spirv::lowerAsEntryFunction(funcOp, &typeConverter, rewriter,
+                                         newFuncOp))) {
     return matchFailure();
   }
   return matchSuccess();
 }
 
-namespace {
-/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
-/// that have the "gpu.kernel" attribute, i.e. those functions that are
-/// referenced in gpu::LaunchKernelOp operations. For each such function
-///
-/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
-/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
-/// replace it).
-///
-/// 2) Lower the body of the spirv::ModuleOp.
-class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
-  void runOnModule() override;
-};
-} // namespace
-
-void GPUToSPIRVPass::runOnModule() {
-  auto context = &getContext();
-  auto module = getModule();
-
-  SmallVector<Operation *, 4> spirvModules;
-  module.walk([&module, &spirvModules](FuncOp funcOp) {
-    if (gpu::GPUDialect::isKernel(funcOp)) {
-      OpBuilder builder(module.getBodyRegion());
-      // Create a new spirv::ModuleOp for this function, and clone the
-      // function into it.
-      // TODO : Generalize this to account for different extensions,
-      // capabilities, extended_instruction_sets, other addressing models
-      // and memory models.
-      auto spvModule = builder.create<spirv::ModuleOp>(
-          funcOp.getLoc(),
-          builder.getI32IntegerAttr(
-              static_cast<int32_t>(spirv::AddressingModel::Logical)),
-          builder.getI32IntegerAttr(
-              static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
-          builder.getStrArrayAttr(
-              spirv::stringifyCapability(spirv::Capability::Shader)),
-          builder.getStrArrayAttr(spirv::stringifyExtension(
-              spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
-      // Hardwire the capability to be Shader.
-      OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
-      moduleBuilder.clone(*funcOp.getOperation());
-      spirvModules.push_back(spvModule);
-    }
-  });
-
-  /// Dialect conversion to lower the functions with the spirv::ModuleOps.
-  SPIRVBasicTypeConverter basicTypeConverter;
-  SPIRVTypeConverter typeConverter(&basicTypeConverter);
-  OwningRewritePatternList patterns;
+namespace mlir {
+void populateGPUToSPIRVPatterns(MLIRContext *context,
+                                SPIRVTypeConverter &typeConverter,
+                                OwningRewritePatternList &patterns) {
   patterns.insert<
       ForOpConversion, KernelFnConversion,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
@@ -237,44 +187,5 @@ void GPUToSPIRVPass::runOnModule() {
       LaunchConfigConversion<gpu::ThreadIdOp,
                              spirv::BuiltIn::LocalInvocationId>>(context,
                                                                  typeConverter);
-  populateStandardToSPIRVPatterns(context, patterns);
-
-  ConversionTarget target(*context);
-  target.addLegalDialect<spirv::SPIRVDialect>();
-  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
-    // TODO(ravishankarm) : Currently lowering does not support handling
-    // function conversion of non-kernel functions. This is to be added.
-
-    // For kernel functions, verify that the signature is void(void).
-    return gpu::GPUDialect::isKernel(op) && op.getNumResults() == 0 &&
-           op.getNumArguments() == 0;
-  });
-
-  if (failed(applyFullConversion(spirvModules, target, patterns,
-                                 &typeConverter))) {
-    return signalPassFailure();
-  }
-
-  // After the SPIR-V modules have been generated, some finalization is needed
-  // for the entry functions. For example, adding spv.EntryPoint op,
-  // spv.ExecutionMode op, etc.
-  for (auto *spvModule : spirvModules) {
-    for (auto op :
-         cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
-      if (gpu::GPUDialect::isKernel(op)) {
-        OpBuilder builder(op.getContext());
-        builder.setInsertionPointAfter(op);
-        if (failed(finalizeEntryFunction(op, builder))) {
-          return signalPassFailure();
-        }
-        op.getOperation()->removeAttr(Identifier::get(
-            gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
-      }
-    }
-  }
 }
-
-OpPassBase<ModuleOp> *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
-
-static PassRegistration<GPUToSPIRVPass>
-    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
+} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
new file mode 100644 (file)
index 0000000..5fa5106
--- /dev/null
@@ -0,0 +1,126 @@
+//===- ConvertGPUToSPIRVPass.cpp - GPU to SPIR-V dialect lowering passes --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert a kernel function in the GPU Dialect
+// into a spv.module operation
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
+/// that have the "gpu.kernel" attribute, i.e. those functions that are
+/// referenced in gpu::LaunchKernelOp operations. For each such function
+///
+/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
+/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
+/// replace it).
+///
+/// 2) Lower the body of the spirv::ModuleOp.
+class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+  void runOnModule() override;
+};
+} // namespace
+
+void GPUToSPIRVPass::runOnModule() {
+  auto context = &getContext();
+  auto module = getModule();
+
+  SmallVector<Operation *, 4> spirvModules;
+  module.walk([&module, &spirvModules](FuncOp funcOp) {
+    if (!gpu::GPUDialect::isKernel(funcOp)) {
+      return;
+    }
+    OpBuilder builder(module.getBodyRegion());
+    // Create a new spirv::ModuleOp for this function, and clone the
+    // function into it.
+    // TODO : Generalize this to account for different extensions,
+    // capabilities, extended_instruction_sets, other addressing models
+    // and memory models.
+    auto spvModule = builder.create<spirv::ModuleOp>(
+        funcOp.getLoc(),
+        builder.getI32IntegerAttr(
+            static_cast<int32_t>(spirv::AddressingModel::Logical)),
+        builder.getI32IntegerAttr(
+            static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
+        builder.getStrArrayAttr(
+            spirv::stringifyCapability(spirv::Capability::Shader)),
+        builder.getStrArrayAttr(spirv::stringifyExtension(
+            spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
+    // Hardwire the capability to be Shader.
+    OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
+    moduleBuilder.clone(*funcOp.getOperation());
+    spirvModules.push_back(spvModule);
+  });
+
+  /// Dialect conversion to lower the functions with the spirv::ModuleOps.
+  SPIRVBasicTypeConverter basicTypeConverter;
+  SPIRVTypeConverter typeConverter(&basicTypeConverter);
+  OwningRewritePatternList patterns;
+  populateGPUToSPIRVPatterns(context, typeConverter, patterns);
+  populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+
+  ConversionTarget target(*context);
+  target.addLegalDialect<spirv::SPIRVDialect>();
+  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+    // TODO(ravishankarm) : Currently lowering does not support handling
+    // function conversion of non-kernel functions. This is to be added.
+
+    // For kernel functions, verify that the signature is void(void).
+    return gpu::GPUDialect::isKernel(op) && op.getNumResults() == 0 &&
+           op.getNumArguments() == 0;
+  });
+
+  if (failed(applyFullConversion(spirvModules, target, patterns,
+                                 &typeConverter))) {
+    return signalPassFailure();
+  }
+
+  // After the SPIR-V modules have been generated, some finalization is needed
+  // for the entry functions. For example, adding spv.EntryPoint op,
+  // spv.ExecutionMode op, etc.
+  for (auto *spvModule : spirvModules) {
+    for (auto op :
+         cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
+      if (gpu::GPUDialect::isKernel(op)) {
+        OpBuilder builder(op.getContext());
+        builder.setInsertionPointAfter(op);
+        if (failed(spirv::finalizeEntryFunction(op, builder))) {
+          return signalPassFailure();
+        }
+        op.getOperation()->removeAttr(Identifier::get(
+            gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
+      }
+    }
+  }
+}
+
+OpPassBase<ModuleOp> *createConvertGPUToSPIRVPass() {
+  return new GPUToSPIRVPass();
+}
+
+static PassRegistration<GPUToSPIRVPass>
+    pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
index be53112..3512162 100644 (file)
@@ -2,7 +2,7 @@ set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td)
 mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
 add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
 
-add_llvm_library(MLIRSPIRVConversion
+add_llvm_library(MLIRStandardToSPIRVTransforms
   ConvertStandardToSPIRV.cpp
   ConvertStandardToSPIRVPass.cpp
 
@@ -11,10 +11,10 @@ add_llvm_library(MLIRSPIRVConversion
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
   )
 
-add_dependencies(MLIRSPIRVConversion
+add_dependencies(MLIRStandardToSPIRVTransforms
   MLIRStandardToSPIRVIncGen)
 
-target_link_libraries(MLIRSPIRVConversion
+target_link_libraries(MLIRStandardToSPIRVTransforms
   MLIRIR
   MLIRPass
   MLIRSPIRV
index 6bb9dea..2fd6e75 100644 (file)
 // limitations under the License.
 // =============================================================================
 //
-// This file implements a pass to convert MLIR standard and builtin dialects
-// into the SPIR-V dialect.
+// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
 //
 //===----------------------------------------------------------------------===//
-#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "llvm/ADT/SetVector.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
-// Type Conversion
-//===----------------------------------------------------------------------===//
-
-static Type convertIndexType(MLIRContext *context) {
-  // Convert to 32-bit integers for now. Might need a way to control this in
-  // future.
-  // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
-  // this some support is needed in SPIR-V dialect for Conversion
-  // instructions. The Vulkan spec requires the builtins like
-  // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
-  // SExtended to 64-bit for index computations.
-  return IntegerType::get(32, context);
-}
-
-static Type convertIndexType(IndexType t) {
-  return convertIndexType(t.getContext());
-}
-
-static Type basicTypeConversion(Type t) {
-  // Check if the type is SPIR-V supported. If so return the type.
-  if (spirv::SPIRVDialect::isValidType(t)) {
-    return t;
-  }
-
-  if (auto indexType = t.dyn_cast<IndexType>()) {
-    return convertIndexType(indexType);
-  }
-
-  if (auto memRefType = t.dyn_cast<MemRefType>()) {
-    auto elementType = memRefType.getElementType();
-    if (memRefType.hasStaticShape()) {
-      // Convert to a multi-dimensional spv.array if size is known.
-      for (auto size : reverse(memRefType.getShape())) {
-        elementType = spirv::ArrayType::get(elementType, size);
-      }
-      return spirv::PointerType::get(elementType,
-                                     spirv::StorageClass::StorageBuffer);
-    } else {
-      // Vulkan SPIR-V validation rules require runtime array type to be the
-      // last member of a struct.
-      return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
-                                     spirv::StorageClass::StorageBuffer);
-    }
-  }
-  return Type();
-}
-
-Type SPIRVBasicTypeConverter::convertType(Type t) {
-  return basicTypeConversion(t);
-}
-
-//===----------------------------------------------------------------------===//
-// Entry Function signature Conversion
-//===----------------------------------------------------------------------===//
-
-Type getLayoutDecoratedType(spirv::StructType type) {
-  VulkanLayoutUtils::Size size = 0, alignment = 0;
-  return VulkanLayoutUtils::decorateType(type, size, alignment);
-}
-
-/// Generates the type of variable given the type of object.
-static Type getGlobalVarTypeForEntryFnArg(Type t) {
-  auto convertedType = basicTypeConversion(t);
-  if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
-    if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
-      return spirv::PointerType::get(
-          getLayoutDecoratedType(
-              spirv::StructType::get(ptrType.getPointeeType())),
-          ptrType.getStorageClass());
-    }
-  } else {
-    return spirv::PointerType::get(
-        getLayoutDecoratedType(spirv::StructType::get(convertedType)),
-        spirv::StorageClass::StorageBuffer);
-  }
-  return convertedType;
-}
-
-Type SPIRVTypeConverter::convertType(Type t) {
-  return getGlobalVarTypeForEntryFnArg(t);
-}
-
-/// Computes the replacement value for an argument of an entry function. It
-/// allocates a global variable for this argument and adds statements in the
-/// entry block to get a replacement value within function scope.
-static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
-                                                  size_t origArgNum,
-                                                  Value *origArg) {
-  // Create a global variable for this argument.
-  auto insertionOp = rewriter.getInsertionBlock()->getParent();
-  auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
-  if (!module) {
-    return nullptr;
-  }
-  auto funcOp = insertionOp->getParentOfType<FuncOp>();
-  spirv::GlobalVariableOp var;
-  {
-    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
-    rewriter.setInsertionPoint(funcOp.getOperation());
-    std::string varName =
-        funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
-    var = rewriter.create<spirv::GlobalVariableOp>(
-        funcOp.getLoc(),
-        TypeAttr::get(getGlobalVarTypeForEntryFnArg(origArg->getType())),
-        rewriter.getStringAttr(varName), nullptr);
-    var.setAttr(
-        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
-        rewriter.getI32IntegerAttr(0));
-    var.setAttr(
-        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
-        rewriter.getI32IntegerAttr(origArgNum));
-  }
-  // Insert the addressOf and load instructions, to get back the converted value
-  // type.
-  auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
-  auto indexType = convertIndexType(funcOp.getContext());
-  auto zero = rewriter.create<spirv::ConstantOp>(
-      funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
-  auto accessChain = rewriter.create<spirv::AccessChainOp>(
-      funcOp.getLoc(), addressOf.pointer(), zero.constant());
-  // If the original argument is a tensor/memref type, the value is not
-  // loaded. Instead the pointer value is returned to allow its use in access
-  // chain ops.
-  auto origArgType = origArg->getType();
-  if (origArgType.isa<MemRefType>()) {
-    return accessChain;
-  }
-  return rewriter.create<spirv::LoadOp>(
-      funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
-      /*alignment=*/nullptr);
-}
-
-static FuncOp applySignatureConversion(
-    FuncOp funcOp, ConversionPatternRewriter &rewriter,
-    TypeConverter::SignatureConversion &signatureConverter) {
-  // Create a new function with an updated signature.
-  auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
-  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
-                              newFuncOp.end());
-  newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
-                                      llvm::None, funcOp.getContext()));
-
-  // Tell the rewriter to convert the region signature.
-  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
-  rewriter.replaceOp(funcOp.getOperation(), llvm::None);
-  return newFuncOp;
-}
-
-/// Gets the global variables that need to be specified as interface variable
-/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
-LogicalResult getInterfaceVariables(FuncOp funcOp,
-                                    SmallVectorImpl<Attribute> &interfaceVars) {
-  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
-  if (!module) {
-    return failure();
-  }
-  llvm::SetVector<Operation *> interfaceVarSet;
-  for (auto &block : funcOp) {
-    // TODO(ravishankarm) : This should in reality traverse the entry function
-    // call graph and collect all the interfaces. For now, just traverse the
-    // instructions in this function.
-    auto callOps = block.getOps<CallOp>();
-    if (std::distance(callOps.begin(), callOps.end())) {
-      return funcOp.emitError("Collecting interface variables through function "
-                              "calls unimplemented");
-    }
-    for (auto op : block.getOps<spirv::AddressOfOp>()) {
-      auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
-      if (var.type().cast<spirv::PointerType>().getStorageClass() ==
-          spirv::StorageClass::StorageBuffer) {
-        continue;
-      }
-      interfaceVarSet.insert(var.getOperation());
-    }
-  }
-  for (auto &var : interfaceVarSet) {
-    interfaceVars.push_back(SymbolRefAttr::get(
-        cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
-  }
-  return success();
-}
-
-namespace mlir {
-LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
-                            ConversionPatternRewriter &rewriter,
-                            FuncOp &newFuncOp) {
-  auto fnType = funcOp.getType();
-  if (fnType.getNumResults()) {
-    return funcOp.emitError("SPIR-V lowering only supports functions with no "
-                            "return values right now");
-  }
-  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  auto basicTypeConverter = typeConverter->getBasicTypeConverter();
-  for (auto origArgType : enumerate(fnType.getInputs())) {
-    auto convertedType = basicTypeConverter->convertType(origArgType.value());
-    if (!convertedType) {
-      return funcOp.emitError("unable to convert argument of type '")
-             << convertedType << "'";
-    }
-    signatureConverter.addInputs(origArgType.index(), convertedType);
-  }
-  newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
-  return success();
-}
-
-LogicalResult lowerAsEntryFunction(FuncOp funcOp,
-                                   SPIRVTypeConverter *typeConverter,
-                                   ConversionPatternRewriter &rewriter,
-                                   FuncOp &newFuncOp) {
-  auto fnType = funcOp.getType();
-  if (fnType.getNumResults()) {
-    return funcOp.emitError("SPIR-V lowering only supports functions with no "
-                            "return values right now");
-  }
-  // For entry functions need to make the signature void(void). Compute the
-  // replacement value for all arguments and replace all uses.
-  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  {
-    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
-    rewriter.setInsertionPointToStart(&funcOp.front());
-    for (auto origArg : enumerate(funcOp.getArguments())) {
-      auto replacement = createAndLoadGlobalVarForEntryFnArg(
-          rewriter, origArg.index(), origArg.value());
-      signatureConverter.remapInput(origArg.index(), replacement);
-    }
-  }
-  newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
-  return success();
-}
-
-LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
-  // Add the spv.EntryPointOp after collecting all the interface variables
-  // needed.
-  SmallVector<Attribute, 1> interfaceVars;
-  if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
-    return failure();
-  }
-  builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
-                                      spirv::ExecutionModel::GLCompute,
-                                      newFuncOp, interfaceVars);
-  // Specify the spv.ExecutionModeOp.
-
-  /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
-  /// LocalSize execution mode or an object decorated with the WorkgroupSize
-  /// decoration must be specified." Better approach is to use the
-  /// WorkgroupSize GlobalVariable with initializer being a specialization
-  /// constant. But current support for specialization constant does not allow
-  /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
-  /// 1} for now. To be fixed ASAP.
-  builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
-                                         spirv::ExecutionMode::LocalSize,
-                                         ArrayRef<int32_t>{1, 1, 1});
-  return success();
-}
-} // namespace mlir
-
-//===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
 
@@ -295,20 +37,18 @@ namespace {
 /// operation. Since IndexType is not used within SPIR-V dialect, this needs
 /// special handling to make sure the result type and the type of the value
 /// attribute are consistent.
-class ConstantIndexOpConversion final : public ConversionPattern {
+class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
 public:
-  ConstantIndexOpConversion(MLIRContext *context)
-      : ConversionPattern(ConstantOp::getOperationName(), 1, context) {}
+  using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto constIndexOp = cast<ConstantOp>(op);
     if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
       return matchFailure();
     }
-    // The attribute has index type. Get the integer value and create a new
-    // IntegerAttr.
+    // The attribute has index type which is not directly supported in
+    // SPIR-V. Get the integer value and create a new IntegerAttr.
     auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
     if (!constAttr) {
       return matchFailure();
@@ -322,34 +62,32 @@ public:
     if (!constValType) {
       return matchFailure();
     }
-    auto spirvConstType = convertIndexType(constValType);
+    auto spirvConstType =
+        typeConverter.convertBasicType(constIndexOp.getResult()->getType());
     auto spirvConstVal =
         rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
-    auto spirvConstantOp = rewriter.create<spirv::ConstantOp>(
-        op->getLoc(), spirvConstType, spirvConstVal);
-    rewriter.replaceOp(op, spirvConstantOp.constant(), {});
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
+                                                   spirvConstVal);
     return matchSuccess();
   }
 };
 
 /// Convert compare operation to SPIR-V dialect.
-class CmpIOpConversion final : public ConversionPattern {
+class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
 public:
-  CmpIOpConversion(MLIRContext *context)
-      : ConversionPattern(CmpIOp::getOperationName(), 1, context) {}
+  using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto cmpIOp = cast<CmpIOp>(op);
     CmpIOpOperandAdaptor cmpIOpOperands(operands);
 
     switch (cmpIOp.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
-    rewriter.replaceOpWithNewOp<spirvOp>(op, op->getResult(0)->getType(),      \
-                                         cmpIOpOperands.lhs(),                 \
-                                         cmpIOpOperands.rhs());                \
+    rewriter.replaceOpWithNewOp<spirvOp>(                                      \
+        cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(),           \
+        cmpIOpOperands.rhs());                                                 \
     return matchSuccess();
 
       DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
@@ -374,16 +112,17 @@ public:
 /// that of the replaced operation. This is not handled in tablegen-based
 /// pattern specification.
 template <typename StdOp, typename SPIRVOp>
-class IntegerOpConversion final : public ConversionPattern {
+class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
 public:
-  IntegerOpConversion(MLIRContext *context)
-      : ConversionPattern(StdOp::getOperationName(), 1, context) {}
+  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(StdOp operation, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    auto resultType =
+        this->typeConverter.convertBasicType(operation.getResult()->getType());
     rewriter.template replaceOpWithNewOp<SPIRVOp>(
-        op, operands[0]->getType(), operands, ArrayRef<NamedAttribute>());
+        operation, resultType, operands, ArrayRef<NamedAttribute>());
     return this->matchSuccess();
   }
 };
@@ -393,13 +132,12 @@ public:
 /// not supported in tablegen based pattern specification.
 // TODO(ravishankarm) : These could potentially be templated on the operation
 // being converted, since the same logic should work for linalg.load.
-class LoadOpConversion final : public ConversionPattern {
+class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
 public:
-  LoadOpConversion(MLIRContext *context)
-      : ConversionPattern(LoadOp::getOperationName(), 1, context) {}
+  using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     LoadOpOperandAdaptor loadOperands(operands);
     auto basePtr = loadOperands.memref();
@@ -408,38 +146,38 @@ public:
       return matchFailure();
     }
     auto loadPtr = rewriter.create<spirv::AccessChainOp>(
-        op->getLoc(), basePtr, loadOperands.indices());
+        loadOp.getLoc(), basePtr, loadOperands.indices());
     auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>();
     rewriter.replaceOpWithNewOp<spirv::LoadOp>(
-        op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr,
+        loadOp, loadPtrType.getPointeeType(), loadPtr,
+        /*memory_access =*/nullptr,
         /*alignment =*/nullptr);
     return matchSuccess();
   }
 };
 
 /// Convert return -> spv.Return.
-class ReturnToSPIRVConversion : public ConversionPattern {
+class ReturnToSPIRVConversion final : public SPIRVOpLowering<ReturnOp> {
 public:
-  ReturnToSPIRVConversion(MLIRContext *context)
-      : ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
+  using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
+
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(ReturnOp returnOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (op->getNumOperands()) {
+    if (returnOp.getNumOperands()) {
       return matchFailure();
     }
-    rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
+    rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
     return matchSuccess();
   }
 };
 
 /// Convert select -> spv.Select
-class SelectOpConversion : public ConversionPattern {
+class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
 public:
-  SelectOpConversion(MLIRContext *context)
-      : ConversionPattern(SelectOp::getOperationName(), 1, context) {}
+  using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(SelectOp op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     SelectOpOperandAdaptor selectOperands(operands);
     rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
@@ -454,13 +192,12 @@ public:
 /// is not supported in tablegen based pattern specification.
 // TODO(ravishankarm) : These could potentially be templated on the operation
 // being converted, since the same logic should work for linalg.store.
-class StoreOpConversion final : public ConversionPattern {
+class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
 public:
-  StoreOpConversion(MLIRContext *context)
-      : ConversionPattern(StoreOp::getOperationName(), 1, context) {}
+  using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
 
   PatternMatchResult
-  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+  matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
     StoreOpOperandAdaptor storeOperands(operands);
     auto value = storeOperands.value();
@@ -470,8 +207,8 @@ public:
       return matchFailure();
     }
     auto storePtr = rewriter.create<spirv::AccessChainOp>(
-        op->getLoc(), basePtr, storeOperands.indices());
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(op, storePtr, value,
+        storeOp.getLoc(), basePtr, storeOperands.indices());
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, value,
                                                 /*memory_access =*/nullptr,
                                                 /*alignment =*/nullptr);
     return matchSuccess();
@@ -487,6 +224,7 @@ namespace {
 
 namespace mlir {
 void populateStandardToSPIRVPatterns(MLIRContext *context,
+                                     SPIRVTypeConverter &typeConverter,
                                      OwningRewritePatternList &patterns) {
   populateWithGenerated(context, &patterns);
   // Add the return op conversion.
@@ -498,6 +236,6 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
               IntegerOpConversion<RemISOp, spirv::SModOp>,
               IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
               ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
-          context);
+          context, typeConverter);
 }
 } // namespace mlir
index dcecb84..9b2b2d3 100644 (file)
 // =============================================================================
 //
 // This file implements a pass to convert MLIR standard ops into the SPIR-V
-// ops. It does not legalize FuncOps.
+// ops.
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
-#include "mlir/Dialect/SPIRV/Passes.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Pass/Pass.h"
 
 using namespace mlir;
 
@@ -38,7 +40,9 @@ void ConvertStandardToSPIRVPass::runOnModule() {
   OwningRewritePatternList patterns;
   auto module = getModule();
 
-  populateStandardToSPIRVPatterns(module.getContext(), patterns);
+  SPIRVBasicTypeConverter basicTypeConverter;
+  SPIRVTypeConverter typeConverter(&basicTypeConverter);
+  populateStandardToSPIRVPatterns(module.getContext(), typeConverter, patterns);
   ConversionTarget target(*(module.getContext()));
   target.addLegalDialect<spirv::SPIRVDialect>();
   target.addLegalOp<FuncOp>();
@@ -48,8 +52,7 @@ void ConvertStandardToSPIRVPass::runOnModule() {
   }
 }
 
-std::unique_ptr<OpPassBase<ModuleOp>>
-mlir::spirv::createConvertStandardToSPIRVPass() {
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertStandardToSPIRVPass() {
   return std::make_unique<ConvertStandardToSPIRVPass>();
 }
 
index dd51f0a..fa37543 100644 (file)
@@ -1,9 +1,10 @@
 add_llvm_library(MLIRSPIRV
   DialectRegistration.cpp
+  LayoutUtils.cpp
   SPIRVDialect.cpp
   SPIRVOps.cpp
+  SPIRVLowering.cpp
   SPIRVTypes.cpp
-  LayoutUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
@@ -17,7 +18,8 @@ add_dependencies(MLIRSPIRV
 target_link_libraries(MLIRSPIRV
   MLIRIR
   MLIRParser
-  MLIRSupport)
+  MLIRSupport
+  MLIRTransforms)
 
 add_subdirectory(Serialization)
 add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
new file mode 100644 (file)
index 0000000..01a4733
--- /dev/null
@@ -0,0 +1,340 @@
+//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements utilities used to lower to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+Type convertIndexType(MLIRContext *context) {
+  // Convert to 32-bit integers for now. Might need a way to control this in
+  // future.
+  // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
+  // this some support is needed in SPIR-V dialect for Conversion
+  // instructions. The Vulkan spec requires the builtins like
+  // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
+  // SExtended to 64-bit for index computations.
+  return IntegerType::get(32, context);
+}
+
+Type convertIndexType(IndexType t) { return convertIndexType(t.getContext()); }
+
+Type basicTypeConversion(Type t) {
+  // Check if the type is SPIR-V supported. If so return the type.
+  if (spirv::SPIRVDialect::isValidType(t)) {
+    return t;
+  }
+
+  if (auto indexType = t.dyn_cast<IndexType>()) {
+    return convertIndexType(indexType);
+  }
+
+  if (auto memRefType = t.dyn_cast<MemRefType>()) {
+    auto elementType = memRefType.getElementType();
+    if (memRefType.hasStaticShape()) {
+      // Convert to a multi-dimensional spv.array if size is known.
+      for (auto size : reverse(memRefType.getShape())) {
+        elementType = spirv::ArrayType::get(elementType, size);
+      }
+      return spirv::PointerType::get(elementType,
+                                     spirv::StorageClass::StorageBuffer);
+    } else {
+      // Vulkan SPIR-V validation rules require runtime array type to be the
+      // last member of a struct.
+      return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
+                                     spirv::StorageClass::StorageBuffer);
+    }
+  }
+  return Type();
+}
+
+Type getLayoutDecoratedType(spirv::StructType type) {
+  VulkanLayoutUtils::Size size = 0, alignment = 0;
+  return VulkanLayoutUtils::decorateType(type, size, alignment);
+}
+
+/// Generates the type of variable given the type of object.
+static Type getGlobalVarTypeForEntryFnArg(Type t) {
+  auto convertedType = basicTypeConversion(t);
+  if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
+    if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
+      return spirv::PointerType::get(
+          getLayoutDecoratedType(
+              spirv::StructType::get(ptrType.getPointeeType())),
+          ptrType.getStorageClass());
+    }
+  } else {
+    return spirv::PointerType::get(
+        getLayoutDecoratedType(spirv::StructType::get(convertedType)),
+        spirv::StorageClass::StorageBuffer);
+  }
+  return convertedType;
+}
+} // namespace
+
+Type SPIRVBasicTypeConverter::convertType(Type t) {
+  return basicTypeConversion(t);
+}
+
+Type SPIRVTypeConverter::convertType(Type t) {
+  return getGlobalVarTypeForEntryFnArg(t);
+}
+
+//===----------------------------------------------------------------------===//
+// Builtin Variables
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Look through all global variables in `moduleOp` and check if there is a
+/// spv.globalVariable that has the same `builtin` attribute.
+spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+                                           spirv::BuiltIn builtin) {
+  for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
+    if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
+            stringifyDecoration(spirv::Decoration::BuiltIn)))) {
+      auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+      if (varBuiltIn && varBuiltIn.getValue() == builtin) {
+        return varOp;
+      }
+    }
+  }
+  return nullptr;
+}
+
+/// Gets name of global variable for a buitlin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin) {
+  return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
+}
+
+/// Gets or inserts a global variable for a builtin within a module.
+spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
+                                                   Location loc,
+                                                   spirv::BuiltIn builtin,
+                                                   OpBuilder &builder) {
+  if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
+    return varOp;
+  }
+  auto ip = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(&moduleOp.getBlock());
+  auto name = getBuiltinVarName(builtin);
+  spirv::GlobalVariableOp newVarOp;
+  switch (builtin) {
+  case spirv::BuiltIn::NumWorkgroups:
+  case spirv::BuiltIn::WorkgroupSize:
+  case spirv::BuiltIn::WorkgroupId:
+  case spirv::BuiltIn::LocalInvocationId:
+  case spirv::BuiltIn::GlobalInvocationId: {
+    auto ptrType = spirv::PointerType::get(
+        VectorType::get({3}, builder.getIntegerType(32)),
+        spirv::StorageClass::Input);
+    newVarOp = builder.create<spirv::GlobalVariableOp>(
+        loc, TypeAttr::get(ptrType), builder.getStringAttr(name), nullptr);
+    newVarOp.setAttr(
+        convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
+        builder.getStringAttr(stringifyBuiltIn(builtin)));
+    break;
+  }
+  default:
+    emitError(loc, "unimplemented builtin variable generation for ")
+        << stringifyBuiltIn(builtin);
+  }
+  builder.restoreInsertionPoint(ip);
+  return newVarOp;
+}
+} // namespace
+
+/// Gets the global variable associated with a builtin and add
+/// it if it doesnt exist.
+Value *mlir::spirv::getBuiltinVariableValue(Operation *op,
+                                            spirv::BuiltIn builtin,
+                                            OpBuilder &builder) {
+  auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
+  if (!moduleOp) {
+    op->emitError("expected operation to be within a SPIR-V module");
+    return nullptr;
+  }
+  auto varOp =
+      getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder);
+  auto ptr = builder
+                 .create<spirv::AddressOfOp>(op->getLoc(), varOp.type(),
+                                             builder.getSymbolRefAttr(varOp))
+                 .pointer();
+  return builder.create<spirv::LoadOp>(
+      op->getLoc(),
+      ptr->getType().template cast<spirv::PointerType>().getPointeeType(), ptr,
+      /*memory_access =*/nullptr, /*alignment =*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Function signature Conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Computes the replacement value for an argument of an entry function. It
+/// allocates a global variable for this argument and adds statements in the
+/// entry block to get a replacement value within function scope.
+Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
+                                           size_t origArgNum, Value *origArg) {
+  // Create a global variable for this argument.
+  auto insertionOp = rewriter.getInsertionBlock()->getParent();
+  auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
+  if (!module) {
+    return nullptr;
+  }
+  auto funcOp = insertionOp->getParentOfType<FuncOp>();
+  spirv::GlobalVariableOp var;
+  {
+    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
+    rewriter.setInsertionPoint(funcOp.getOperation());
+    std::string varName =
+        funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
+    var = rewriter.create<spirv::GlobalVariableOp>(
+        funcOp.getLoc(),
+        TypeAttr::get(getGlobalVarTypeForEntryFnArg(origArg->getType())),
+        rewriter.getStringAttr(varName), nullptr);
+    var.setAttr(
+        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
+        rewriter.getI32IntegerAttr(0));
+    var.setAttr(
+        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
+        rewriter.getI32IntegerAttr(origArgNum));
+  }
+  // Insert the addressOf and load instructions, to get back the converted value
+  // type.
+  auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+  auto indexType = convertIndexType(funcOp.getContext());
+  auto zero = rewriter.create<spirv::ConstantOp>(
+      funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+  auto accessChain = rewriter.create<spirv::AccessChainOp>(
+      funcOp.getLoc(), addressOf.pointer(), zero.constant());
+  // If the original argument is a tensor/memref type, the value is not
+  // loaded. Instead the pointer value is returned to allow its use in access
+  // chain ops.
+  auto origArgType = origArg->getType();
+  if (origArgType.isa<MemRefType>()) {
+    return accessChain;
+  }
+  return rewriter.create<spirv::LoadOp>(
+      funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
+      /*alignment=*/nullptr);
+}
+
+FuncOp applySignatureConversion(
+    FuncOp funcOp, ConversionPatternRewriter &rewriter,
+    TypeConverter::SignatureConversion &signatureConverter) {
+  // Create a new function with an updated signature.
+  auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+                              newFuncOp.end());
+  newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
+                                      llvm::None, funcOp.getContext()));
+
+  // Tell the rewriter to convert the region signature.
+  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+  rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+  return newFuncOp;
+}
+
+/// Gets the global variables that need to be specified as interface variable
+/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
+LogicalResult getInterfaceVariables(FuncOp funcOp,
+                                    SmallVectorImpl<Attribute> &interfaceVars) {
+  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+  if (!module) {
+    return failure();
+  }
+  llvm::SetVector<Operation *> interfaceVarSet;
+  for (auto &block : funcOp) {
+    // TODO(ravishankarm) : This should in reality traverse the entry function
+    // call graph and collect all the interfaces. For now, just traverse the
+    // instructions in this function.
+    for (auto op : block.getOps<spirv::AddressOfOp>()) {
+      auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
+      if (var.type().cast<spirv::PointerType>().getStorageClass() ==
+          spirv::StorageClass::StorageBuffer) {
+        continue;
+      }
+      interfaceVarSet.insert(var.getOperation());
+    }
+  }
+  for (auto &var : interfaceVarSet) {
+    interfaceVars.push_back(SymbolRefAttr::get(
+        cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+  }
+  return success();
+}
+} // namespace
+
+LogicalResult mlir::spirv::lowerAsEntryFunction(
+    FuncOp funcOp, SPIRVTypeConverter *typeConverter,
+    ConversionPatternRewriter &rewriter, FuncOp &newFuncOp) {
+  auto fnType = funcOp.getType();
+  if (fnType.getNumResults()) {
+    return funcOp.emitError("SPIR-V lowering only supports functions with no "
+                            "return values right now");
+  }
+  // For entry functions need to make the signature void(void). Compute the
+  // replacement value for all arguments and replace all uses.
+  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+  {
+    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
+    rewriter.setInsertionPointToStart(&funcOp.front());
+    for (auto origArg : enumerate(funcOp.getArguments())) {
+      auto replacement = createAndLoadGlobalVarForEntryFnArg(
+          rewriter, origArg.index(), origArg.value());
+      signatureConverter.remapInput(origArg.index(), replacement);
+    }
+  }
+  newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
+  return success();
+}
+
+LogicalResult mlir::spirv::finalizeEntryFunction(FuncOp newFuncOp,
+                                                 OpBuilder &builder) {
+  // Add the spv.EntryPointOp after collecting all the interface variables
+  // needed.
+  SmallVector<Attribute, 1> interfaceVars;
+  if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
+    return failure();
+  }
+  builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
+                                      spirv::ExecutionModel::GLCompute,
+                                      newFuncOp, interfaceVars);
+  // Specify the spv.ExecutionModeOp.
+
+  /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
+  /// LocalSize execution mode or an object decorated with the WorkgroupSize
+  /// decoration must be specified." Better approach is to use the
+  /// WorkgroupSize GlobalVariable with initializer being a specialization
+  /// constant. But current support for specialization constant does not allow
+  /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
+  /// 1} for now. To be fixed ASAP.
+  builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
+                                         spirv::ExecutionMode::LocalSize,
+                                         ArrayRef<int32_t>{1, 1, 1});
+  return success();
+}
index 23515b4..9eb3f53 100644 (file)
@@ -38,7 +38,7 @@ set(LIBS
   MLIRQuantOps
   MLIRROCDLIR
   MLIRSPIRV
-  MLIRSPIRVConversion
+  MLIRStandardToSPIRVTransforms
   MLIRSPIRVTransforms
   MLIRStandardOps
   MLIRStandardToLLVM