--- /dev/null
+//===- ConvertGPUToSPIRV.h - GPU Ops to SPIR-V dialect patterns ----C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides patterns for lowering GPU Ops to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
+#define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class SPIRVTypeConverter;
+/// Appends to a pattern list additional patterns for translating GPU Ops to
+/// SPIR-V ops.
+void populateGPUToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRV_H
--- /dev/null
+//===- ConvertGPUToSPIRVPass.h - GPU to SPIR-V conversion pass --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides a pass to convert GPU ops to SPIRV ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+#define MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
+
+#include <memory>
+
+namespace mlir {
+
+class ModuleOp;
+template <typename T> class OpPassBase;
+
+/// Pass to convert GPU Ops to SPIR-V ops.
+std::unique_ptr<OpPassBase<ModuleOp>> createConvertGPUToSPIRVPass();
+
+} // namespace mlir
+#endif // MLIR_CONVERSION_GPUTOSPIRV_CONVERTGPUTOSPIRVPASS_H
// limitations under the License.
// =============================================================================
//
-// Provides type converters and patterns to convert from standard types/ops to
-// SPIR-V types and operations. Also provides utilities and base classes to use
-// while targeting SPIR-V dialect.
+// Provides patterns to lower StandardOps to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Support/StringExtras.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
-
-class LoadOp;
-class ReturnOp;
-class StoreOp;
-
-/// Type conversion from Standard Types to SPIR-V Types.
-class SPIRVBasicTypeConverter : public TypeConverter {
-public:
- /// Converts types to SPIR-V supported types.
- virtual Type convertType(Type t);
-};
-
-/// Converts a function type according to the requirements of a SPIR-V entry
-/// function. The arguments need to be converted to spv.Variables of spv.ptr
-/// types so that they could be bound by the runtime.
-class SPIRVTypeConverter final : public TypeConverter {
-public:
- explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
- : basicTypeConverter(basicTypeConverter) {}
-
- /// Converts types to SPIR-V types using the basic type converter.
- Type convertType(Type t) override;
-
- /// Gets the basic type converter.
- SPIRVBasicTypeConverter *getBasicTypeConverter() const {
- return basicTypeConverter;
- }
-
-private:
- SPIRVBasicTypeConverter *basicTypeConverter;
-};
-
-/// Base class to define a conversion pattern to translate Ops into SPIR-V.
-template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
-public:
- SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter)
- : ConversionPattern(OpTy::getOperationName(), 1, context),
- typeConverter(typeConverter) {}
-
-protected:
- /// Gets the global variable associated with a builtin and add
- /// it if it doesnt exist.
- Value *loadFromBuiltinVariable(Operation *op, spirv::BuiltIn builtin,
- ConversionPatternRewriter &rewriter) const {
- auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
- if (!moduleOp) {
- op->emitError("expected operation to be within a SPIR-V module");
- return nullptr;
- }
- auto varOp =
- getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, rewriter);
- auto ptr = rewriter
- .create<spirv::AddressOfOp>(op->getLoc(), varOp.type(),
- rewriter.getSymbolRefAttr(varOp))
- .pointer();
- return rewriter.create<spirv::LoadOp>(
- op->getLoc(),
- ptr->getType().template cast<spirv::PointerType>().getPointeeType(),
- ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr);
- }
-
- /// Type lowering class.
- SPIRVTypeConverter &typeConverter;
-
-private:
- /// Look through all global variables in `moduleOp` and check if there is a
- /// spv.globalVariable that has the same `builtin` attribute.
- spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
- spirv::BuiltIn builtin) const {
- for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
- if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
- stringifyDecoration(spirv::Decoration::BuiltIn)))) {
- auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
- if (varBuiltIn && varBuiltIn.getValue() == builtin) {
- return varOp;
- }
- }
- }
- return nullptr;
- }
-
- /// Gets name of global variable for a buitlin.
- std::string getBuiltinVarName(spirv::BuiltIn builtin) const {
- return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() +
- "__";
- }
-
- /// Gets or inserts a global variable for a builtin within a module.
- spirv::GlobalVariableOp
- getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
- spirv::BuiltIn builtin,
- ConversionPatternRewriter &builder) const {
- if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
- return varOp;
- }
- auto ip = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(&moduleOp.getBlock());
- auto name = getBuiltinVarName(builtin);
- spirv::GlobalVariableOp newVarOp;
- switch (builtin) {
- case spirv::BuiltIn::NumWorkgroups:
- case spirv::BuiltIn::WorkgroupSize:
- case spirv::BuiltIn::WorkgroupId:
- case spirv::BuiltIn::LocalInvocationId:
- case spirv::BuiltIn::GlobalInvocationId: {
- auto ptrType = spirv::PointerType::get(
- VectorType::get({3}, builder.getIntegerType(32)),
- spirv::StorageClass::Input);
- newVarOp = builder.create<spirv::GlobalVariableOp>(
- loc, TypeAttr::get(ptrType), builder.getStringAttr(name), nullptr);
- newVarOp.setAttr(
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
- builder.getStringAttr(stringifyBuiltIn(builtin)));
- break;
- }
- default:
- emitError(loc, "unimplemented builtin variable generation for ")
- << stringifyBuiltIn(builtin);
- }
- builder.restoreInsertionPoint(ip);
- return newVarOp;
- }
-};
-
-/// Legalizes a function as a non-entry function.
-LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp);
-
-/// Legalizes a function as an entry function.
-LogicalResult lowerAsEntryFunction(FuncOp funcOp,
- SPIRVTypeConverter *typeConverter,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp);
-
-/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
-/// spv.ExecutionMode ops.
-LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
-
+class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating StandardOps to
/// SPIR-V ops.
void populateStandardToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
} // namespace mlir
--- /dev/null
+//===- ConvertStandardToSPIRVPass.h - StdOps to SPIR-V pass -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides a pass to lower from StandardOps to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
+#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+/// Pass to convert StandardOps to SPIR-V ops.
+std::unique_ptr<OpPassBase<ModuleOp>> createConvertStandardToSPIRVPass();
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRVPASS_H
namespace mlir {
namespace spirv {
-// Creates a module pass that converts standard ops to SPIR-V ops.
-std::unique_ptr<OpPassBase<mlir::ModuleOp>> createConvertStandardToSPIRVPass();
-
// Creates a module pass that converts composite types used by objects in the
// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
// classes with layout information.
--- /dev/null
+//===- SPIRVLowering.h - SPIR-V lowering utilities -------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines, utilities and base classes to use while targeting SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
+#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Support/StringExtras.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir {
+
+/// Type conversion from Standard Types to SPIR-V Types.
+class SPIRVBasicTypeConverter : public TypeConverter {
+public:
+ /// Converts types to SPIR-V supported types.
+ virtual Type convertType(Type t);
+};
+
+/// Converts a function type according to the requirements of a SPIR-V entry
+/// function. The arguments need to be converted to spv.GlobalVariables of
+/// spv.ptr types so that they could be bound by the runtime.
+class SPIRVTypeConverter final : public TypeConverter {
+public:
+ explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
+ : basicTypeConverter(basicTypeConverter) {}
+
+ /// Converts types to SPIR-V types using the basic type converter.
+ Type convertType(Type t) override;
+
+ /// Gets the basic type converter.
+ Type convertBasicType(Type t) { return basicTypeConverter->convertType(t); }
+
+private:
+ SPIRVBasicTypeConverter *basicTypeConverter;
+};
+
+/// Base class to define a conversion pattern to translate Ops into SPIR-V.
+template <typename SourceOp>
+class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
+public:
+ SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<SourceOp>(context, benefit),
+ typeConverter(typeConverter) {}
+
+protected:
+ /// Type lowering class.
+ SPIRVTypeConverter &typeConverter;
+
+private:
+};
+
+namespace spirv {
+/// Returns a value that represents a builtin variable value within the SPIR-V
+/// module.
+Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
+ OpBuilder &builder);
+
+/// Legalizes a function as an entry function.
+LogicalResult lowerAsEntryFunction(FuncOp funcOp,
+ SPIRVTypeConverter *typeConverter,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp);
+
+/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
+/// spv.ExecutionMode ops.
+LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
+
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
add_llvm_library(MLIRGPUtoSPIRVTransforms
- GPUToSPIRV.cpp
+ ConvertGPUToSPIRV.cpp
+ ConvertGPUToSPIRVPass.cpp
)
target_link_libraries(MLIRGPUtoSPIRVTransforms
MLIRPass
MLIRSPIRV
MLIRStandardOps
- MLIRSPIRVConversion
+ MLIRStandardToSPIRVTransforms
MLIRTransforms
)
-//===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===//
+//===- ConvertGPUToSPIRV.cpp - Convert GPU ops to SPIR-V dialect ----------===//
//
// Copyright 2019 The MLIR Authors.
//
// limitations under the License.
// =============================================================================
//
-// This file implements a pass to convert a kernel function in the GPU Dialect
-// into a spv.module operation
+// This file implements the conversion patterns from GPU ops to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Pass/Pass.h"
using namespace mlir;
using SPIRVOpLowering<loop::ForOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
/// builin variables.
-template <typename OpTy, spirv::BuiltIn builtin>
-class LaunchConfigConversion : public SPIRVOpLowering<OpTy> {
+template <typename SourceOp, spirv::BuiltIn builtin>
+class LaunchConfigConversion : public SPIRVOpLowering<SourceOp> {
public:
- using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
+ using SPIRVOpLowering<SourceOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(SourceOp op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override;
};
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
PatternMatchResult
-ForOpConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
// loop::ForOp can be lowered to the structured control flow represented by
// spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
// latch and the merge block the exit block. The resulting spirv::LoopOp has a
// single back edge from the continue to header block, and a single exit from
// header to merge.
- auto forOp = cast<loop::ForOp>(op);
loop::ForOpOperandAdaptor forOperands(operands);
- auto loc = op->getLoc();
+ auto loc = forOp.getLoc();
auto loopControl = rewriter.getI32IntegerAttr(
static_cast<uint32_t>(spirv::LoopControl::None));
auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
return matchSuccess();
}
-template <typename OpTy, spirv::BuiltIn builtin>
-PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
- Operation *op, ArrayRef<Value *> operands,
+template <typename SourceOp, spirv::BuiltIn builtin>
+PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
+ SourceOp op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
- auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
+ auto dimAttr =
+ op.getOperation()->template getAttrOfType<StringAttr>("dimension");
if (!dimAttr) {
return this->matchFailure();
}
}
// SPIR-V invocation builtin variables are a vector of type <3xi32>
- auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter);
+ auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter);
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
op, rewriter.getIntegerType(32), spirvBuiltin,
rewriter.getI32ArrayAttr({index}));
}
PatternMatchResult
-KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
- auto funcOp = cast<FuncOp>(op);
FuncOp newFuncOp;
if (!gpu::GPUDialect::isKernel(funcOp)) {
- return succeeded(lowerFunction(funcOp, &typeConverter, rewriter, newFuncOp))
- ? matchSuccess()
- : matchFailure();
+ return matchFailure();
}
- if (failed(
- lowerAsEntryFunction(funcOp, &typeConverter, rewriter, newFuncOp))) {
+ if (failed(spirv::lowerAsEntryFunction(funcOp, &typeConverter, rewriter,
+ newFuncOp))) {
return matchFailure();
}
return matchSuccess();
}
-namespace {
-/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
-/// that have the "gpu.kernel" attribute, i.e. those functions that are
-/// referenced in gpu::LaunchKernelOp operations. For each such function
-///
-/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
-/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
-/// replace it).
-///
-/// 2) Lower the body of the spirv::ModuleOp.
-class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
- void runOnModule() override;
-};
-} // namespace
-
-void GPUToSPIRVPass::runOnModule() {
- auto context = &getContext();
- auto module = getModule();
-
- SmallVector<Operation *, 4> spirvModules;
- module.walk([&module, &spirvModules](FuncOp funcOp) {
- if (gpu::GPUDialect::isKernel(funcOp)) {
- OpBuilder builder(module.getBodyRegion());
- // Create a new spirv::ModuleOp for this function, and clone the
- // function into it.
- // TODO : Generalize this to account for different extensions,
- // capabilities, extended_instruction_sets, other addressing models
- // and memory models.
- auto spvModule = builder.create<spirv::ModuleOp>(
- funcOp.getLoc(),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::AddressingModel::Logical)),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
- builder.getStrArrayAttr(
- spirv::stringifyCapability(spirv::Capability::Shader)),
- builder.getStrArrayAttr(spirv::stringifyExtension(
- spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
- // Hardwire the capability to be Shader.
- OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
- moduleBuilder.clone(*funcOp.getOperation());
- spirvModules.push_back(spvModule);
- }
- });
-
- /// Dialect conversion to lower the functions with the spirv::ModuleOps.
- SPIRVBasicTypeConverter basicTypeConverter;
- SPIRVTypeConverter typeConverter(&basicTypeConverter);
- OwningRewritePatternList patterns;
+namespace mlir {
+void populateGPUToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
+ OwningRewritePatternList &patterns) {
patterns.insert<
ForOpConversion, KernelFnConversion,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
LaunchConfigConversion<gpu::ThreadIdOp,
spirv::BuiltIn::LocalInvocationId>>(context,
typeConverter);
- populateStandardToSPIRVPatterns(context, patterns);
-
- ConversionTarget target(*context);
- target.addLegalDialect<spirv::SPIRVDialect>();
- target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- // TODO(ravishankarm) : Currently lowering does not support handling
- // function conversion of non-kernel functions. This is to be added.
-
- // For kernel functions, verify that the signature is void(void).
- return gpu::GPUDialect::isKernel(op) && op.getNumResults() == 0 &&
- op.getNumArguments() == 0;
- });
-
- if (failed(applyFullConversion(spirvModules, target, patterns,
- &typeConverter))) {
- return signalPassFailure();
- }
-
- // After the SPIR-V modules have been generated, some finalization is needed
- // for the entry functions. For example, adding spv.EntryPoint op,
- // spv.ExecutionMode op, etc.
- for (auto *spvModule : spirvModules) {
- for (auto op :
- cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
- if (gpu::GPUDialect::isKernel(op)) {
- OpBuilder builder(op.getContext());
- builder.setInsertionPointAfter(op);
- if (failed(finalizeEntryFunction(op, builder))) {
- return signalPassFailure();
- }
- op.getOperation()->removeAttr(Identifier::get(
- gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
- }
- }
- }
}
-
-OpPassBase<ModuleOp> *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
-
-static PassRegistration<GPUToSPIRVPass>
- pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
+} // namespace mlir
--- /dev/null
+//===- ConvertGPUToSPIRVPass.cpp - GPU to SPIR-V dialect lowering passes --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert a kernel function in the GPU Dialect
+// into a spv.module operation
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
+/// that have the "gpu.kernel" attribute, i.e. those functions that are
+/// referenced in gpu::LaunchKernelOp operations. For each such function
+///
+/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
+/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
+/// replace it).
+///
+/// 2) Lower the body of the spirv::ModuleOp.
+class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+void GPUToSPIRVPass::runOnModule() {
+ auto context = &getContext();
+ auto module = getModule();
+
+ SmallVector<Operation *, 4> spirvModules;
+ module.walk([&module, &spirvModules](FuncOp funcOp) {
+ if (!gpu::GPUDialect::isKernel(funcOp)) {
+ return;
+ }
+ OpBuilder builder(module.getBodyRegion());
+ // Create a new spirv::ModuleOp for this function, and clone the
+ // function into it.
+ // TODO : Generalize this to account for different extensions,
+ // capabilities, extended_instruction_sets, other addressing models
+ // and memory models.
+ auto spvModule = builder.create<spirv::ModuleOp>(
+ funcOp.getLoc(),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::AddressingModel::Logical)),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
+ builder.getStrArrayAttr(
+ spirv::stringifyCapability(spirv::Capability::Shader)),
+ builder.getStrArrayAttr(spirv::stringifyExtension(
+ spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
+ // Hardwire the capability to be Shader.
+ OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
+ moduleBuilder.clone(*funcOp.getOperation());
+ spirvModules.push_back(spvModule);
+ });
+
+ /// Dialect conversion to lower the functions with the spirv::ModuleOps.
+ SPIRVBasicTypeConverter basicTypeConverter;
+ SPIRVTypeConverter typeConverter(&basicTypeConverter);
+ OwningRewritePatternList patterns;
+ populateGPUToSPIRVPatterns(context, typeConverter, patterns);
+ populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+
+ ConversionTarget target(*context);
+ target.addLegalDialect<spirv::SPIRVDialect>();
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ // TODO(ravishankarm) : Currently lowering does not support handling
+ // function conversion of non-kernel functions. This is to be added.
+
+ // For kernel functions, verify that the signature is void(void).
+ return gpu::GPUDialect::isKernel(op) && op.getNumResults() == 0 &&
+ op.getNumArguments() == 0;
+ });
+
+ if (failed(applyFullConversion(spirvModules, target, patterns,
+ &typeConverter))) {
+ return signalPassFailure();
+ }
+
+ // After the SPIR-V modules have been generated, some finalization is needed
+ // for the entry functions. For example, adding spv.EntryPoint op,
+ // spv.ExecutionMode op, etc.
+ for (auto *spvModule : spirvModules) {
+ for (auto op :
+ cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
+ if (gpu::GPUDialect::isKernel(op)) {
+ OpBuilder builder(op.getContext());
+ builder.setInsertionPointAfter(op);
+ if (failed(spirv::finalizeEntryFunction(op, builder))) {
+ return signalPassFailure();
+ }
+ op.getOperation()->removeAttr(Identifier::get(
+ gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
+ }
+ }
+ }
+}
+
+OpPassBase<ModuleOp> *createConvertGPUToSPIRVPass() {
+ return new GPUToSPIRVPass();
+}
+
+static PassRegistration<GPUToSPIRVPass>
+ pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
-add_llvm_library(MLIRSPIRVConversion
+add_llvm_library(MLIRStandardToSPIRVTransforms
ConvertStandardToSPIRV.cpp
ConvertStandardToSPIRVPass.cpp
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
)
-add_dependencies(MLIRSPIRVConversion
+add_dependencies(MLIRStandardToSPIRVTransforms
MLIRStandardToSPIRVIncGen)
-target_link_libraries(MLIRSPIRVConversion
+target_link_libraries(MLIRStandardToSPIRVTransforms
MLIRIR
MLIRPass
MLIRSPIRV
// limitations under the License.
// =============================================================================
//
-// This file implements a pass to convert MLIR standard and builtin dialects
-// into the SPIR-V dialect.
+// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
-// Type Conversion
-//===----------------------------------------------------------------------===//
-
-static Type convertIndexType(MLIRContext *context) {
- // Convert to 32-bit integers for now. Might need a way to control this in
- // future.
- // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
- // this some support is needed in SPIR-V dialect for Conversion
- // instructions. The Vulkan spec requires the builtins like
- // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
- // SExtended to 64-bit for index computations.
- return IntegerType::get(32, context);
-}
-
-static Type convertIndexType(IndexType t) {
- return convertIndexType(t.getContext());
-}
-
-static Type basicTypeConversion(Type t) {
- // Check if the type is SPIR-V supported. If so return the type.
- if (spirv::SPIRVDialect::isValidType(t)) {
- return t;
- }
-
- if (auto indexType = t.dyn_cast<IndexType>()) {
- return convertIndexType(indexType);
- }
-
- if (auto memRefType = t.dyn_cast<MemRefType>()) {
- auto elementType = memRefType.getElementType();
- if (memRefType.hasStaticShape()) {
- // Convert to a multi-dimensional spv.array if size is known.
- for (auto size : reverse(memRefType.getShape())) {
- elementType = spirv::ArrayType::get(elementType, size);
- }
- return spirv::PointerType::get(elementType,
- spirv::StorageClass::StorageBuffer);
- } else {
- // Vulkan SPIR-V validation rules require runtime array type to be the
- // last member of a struct.
- return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
- spirv::StorageClass::StorageBuffer);
- }
- }
- return Type();
-}
-
-Type SPIRVBasicTypeConverter::convertType(Type t) {
- return basicTypeConversion(t);
-}
-
-//===----------------------------------------------------------------------===//
-// Entry Function signature Conversion
-//===----------------------------------------------------------------------===//
-
-Type getLayoutDecoratedType(spirv::StructType type) {
- VulkanLayoutUtils::Size size = 0, alignment = 0;
- return VulkanLayoutUtils::decorateType(type, size, alignment);
-}
-
-/// Generates the type of variable given the type of object.
-static Type getGlobalVarTypeForEntryFnArg(Type t) {
- auto convertedType = basicTypeConversion(t);
- if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
- if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
- return spirv::PointerType::get(
- getLayoutDecoratedType(
- spirv::StructType::get(ptrType.getPointeeType())),
- ptrType.getStorageClass());
- }
- } else {
- return spirv::PointerType::get(
- getLayoutDecoratedType(spirv::StructType::get(convertedType)),
- spirv::StorageClass::StorageBuffer);
- }
- return convertedType;
-}
-
-Type SPIRVTypeConverter::convertType(Type t) {
- return getGlobalVarTypeForEntryFnArg(t);
-}
-
-/// Computes the replacement value for an argument of an entry function. It
-/// allocates a global variable for this argument and adds statements in the
-/// entry block to get a replacement value within function scope.
-static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
- size_t origArgNum,
- Value *origArg) {
- // Create a global variable for this argument.
- auto insertionOp = rewriter.getInsertionBlock()->getParent();
- auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
- if (!module) {
- return nullptr;
- }
- auto funcOp = insertionOp->getParentOfType<FuncOp>();
- spirv::GlobalVariableOp var;
- {
- OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
- rewriter.setInsertionPoint(funcOp.getOperation());
- std::string varName =
- funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
- var = rewriter.create<spirv::GlobalVariableOp>(
- funcOp.getLoc(),
- TypeAttr::get(getGlobalVarTypeForEntryFnArg(origArg->getType())),
- rewriter.getStringAttr(varName), nullptr);
- var.setAttr(
- spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
- rewriter.getI32IntegerAttr(0));
- var.setAttr(
- spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
- rewriter.getI32IntegerAttr(origArgNum));
- }
- // Insert the addressOf and load instructions, to get back the converted value
- // type.
- auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
- auto indexType = convertIndexType(funcOp.getContext());
- auto zero = rewriter.create<spirv::ConstantOp>(
- funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
- auto accessChain = rewriter.create<spirv::AccessChainOp>(
- funcOp.getLoc(), addressOf.pointer(), zero.constant());
- // If the original argument is a tensor/memref type, the value is not
- // loaded. Instead the pointer value is returned to allow its use in access
- // chain ops.
- auto origArgType = origArg->getType();
- if (origArgType.isa<MemRefType>()) {
- return accessChain;
- }
- return rewriter.create<spirv::LoadOp>(
- funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
- /*alignment=*/nullptr);
-}
-
-static FuncOp applySignatureConversion(
- FuncOp funcOp, ConversionPatternRewriter &rewriter,
- TypeConverter::SignatureConversion &signatureConverter) {
- // Create a new function with an updated signature.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
- llvm::None, funcOp.getContext()));
-
- // Tell the rewriter to convert the region signature.
- rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.replaceOp(funcOp.getOperation(), llvm::None);
- return newFuncOp;
-}
-
-/// Gets the global variables that need to be specified as interface variable
-/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
-LogicalResult getInterfaceVariables(FuncOp funcOp,
- SmallVectorImpl<Attribute> &interfaceVars) {
- auto module = funcOp.getParentOfType<spirv::ModuleOp>();
- if (!module) {
- return failure();
- }
- llvm::SetVector<Operation *> interfaceVarSet;
- for (auto &block : funcOp) {
- // TODO(ravishankarm) : This should in reality traverse the entry function
- // call graph and collect all the interfaces. For now, just traverse the
- // instructions in this function.
- auto callOps = block.getOps<CallOp>();
- if (std::distance(callOps.begin(), callOps.end())) {
- return funcOp.emitError("Collecting interface variables through function "
- "calls unimplemented");
- }
- for (auto op : block.getOps<spirv::AddressOfOp>()) {
- auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
- if (var.type().cast<spirv::PointerType>().getStorageClass() ==
- spirv::StorageClass::StorageBuffer) {
- continue;
- }
- interfaceVarSet.insert(var.getOperation());
- }
- }
- for (auto &var : interfaceVarSet) {
- interfaceVars.push_back(SymbolRefAttr::get(
- cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
- }
- return success();
-}
-
-namespace mlir {
-LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp) {
- auto fnType = funcOp.getType();
- if (fnType.getNumResults()) {
- return funcOp.emitError("SPIR-V lowering only supports functions with no "
- "return values right now");
- }
- TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
- auto basicTypeConverter = typeConverter->getBasicTypeConverter();
- for (auto origArgType : enumerate(fnType.getInputs())) {
- auto convertedType = basicTypeConverter->convertType(origArgType.value());
- if (!convertedType) {
- return funcOp.emitError("unable to convert argument of type '")
- << convertedType << "'";
- }
- signatureConverter.addInputs(origArgType.index(), convertedType);
- }
- newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
- return success();
-}
-
-LogicalResult lowerAsEntryFunction(FuncOp funcOp,
- SPIRVTypeConverter *typeConverter,
- ConversionPatternRewriter &rewriter,
- FuncOp &newFuncOp) {
- auto fnType = funcOp.getType();
- if (fnType.getNumResults()) {
- return funcOp.emitError("SPIR-V lowering only supports functions with no "
- "return values right now");
- }
- // For entry functions need to make the signature void(void). Compute the
- // replacement value for all arguments and replace all uses.
- TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
- {
- OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
- rewriter.setInsertionPointToStart(&funcOp.front());
- for (auto origArg : enumerate(funcOp.getArguments())) {
- auto replacement = createAndLoadGlobalVarForEntryFnArg(
- rewriter, origArg.index(), origArg.value());
- signatureConverter.remapInput(origArg.index(), replacement);
- }
- }
- newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
- return success();
-}
-
-LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
- // Add the spv.EntryPointOp after collecting all the interface variables
- // needed.
- SmallVector<Attribute, 1> interfaceVars;
- if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
- return failure();
- }
- builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
- spirv::ExecutionModel::GLCompute,
- newFuncOp, interfaceVars);
- // Specify the spv.ExecutionModeOp.
-
- /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
- /// LocalSize execution mode or an object decorated with the WorkgroupSize
- /// decoration must be specified." Better approach is to use the
- /// WorkgroupSize GlobalVariable with initializer being a specialization
- /// constant. But current support for specialization constant does not allow
- /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
- /// 1} for now. To be fixed ASAP.
- builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
- spirv::ExecutionMode::LocalSize,
- ArrayRef<int32_t>{1, 1, 1});
- return success();
-}
-} // namespace mlir
-
-//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
/// special handling to make sure the result type and the type of the value
/// attribute are consistent.
-class ConstantIndexOpConversion final : public ConversionPattern {
+class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
public:
- ConstantIndexOpConversion(MLIRContext *context)
- : ConversionPattern(ConstantOp::getOperationName(), 1, context) {}
+ using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
- auto constIndexOp = cast<ConstantOp>(op);
if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
return matchFailure();
}
- // The attribute has index type. Get the integer value and create a new
- // IntegerAttr.
+ // The attribute has index type which is not directly supported in
+ // SPIR-V. Get the integer value and create a new IntegerAttr.
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
if (!constAttr) {
return matchFailure();
if (!constValType) {
return matchFailure();
}
- auto spirvConstType = convertIndexType(constValType);
+ auto spirvConstType =
+ typeConverter.convertBasicType(constIndexOp.getResult()->getType());
auto spirvConstVal =
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
- auto spirvConstantOp = rewriter.create<spirv::ConstantOp>(
- op->getLoc(), spirvConstType, spirvConstVal);
- rewriter.replaceOp(op, spirvConstantOp.constant(), {});
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
+ spirvConstVal);
return matchSuccess();
}
};
/// Convert compare operation to SPIR-V dialect.
-class CmpIOpConversion final : public ConversionPattern {
+class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
public:
- CmpIOpConversion(MLIRContext *context)
- : ConversionPattern(CmpIOp::getOperationName(), 1, context) {}
+ using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
- auto cmpIOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor cmpIOpOperands(operands);
switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
- rewriter.replaceOpWithNewOp<spirvOp>(op, op->getResult(0)->getType(), \
- cmpIOpOperands.lhs(), \
- cmpIOpOperands.rhs()); \
+ rewriter.replaceOpWithNewOp<spirvOp>( \
+ cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \
+ cmpIOpOperands.rhs()); \
return matchSuccess();
DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
/// that of the replaced operation. This is not handled in tablegen-based
/// pattern specification.
template <typename StdOp, typename SPIRVOp>
-class IntegerOpConversion final : public ConversionPattern {
+class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
public:
- IntegerOpConversion(MLIRContext *context)
- : ConversionPattern(StdOp::getOperationName(), 1, context) {}
+ using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(StdOp operation, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
+ auto resultType =
+ this->typeConverter.convertBasicType(operation.getResult()->getType());
rewriter.template replaceOpWithNewOp<SPIRVOp>(
- op, operands[0]->getType(), operands, ArrayRef<NamedAttribute>());
+ operation, resultType, operands, ArrayRef<NamedAttribute>());
return this->matchSuccess();
}
};
/// not supported in tablegen based pattern specification.
// TODO(ravishankarm) : These could potentially be templated on the operation
// being converted, since the same logic should work for linalg.load.
-class LoadOpConversion final : public ConversionPattern {
+class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
public:
- LoadOpConversion(MLIRContext *context)
- : ConversionPattern(LoadOp::getOperationName(), 1, context) {}
+ using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
LoadOpOperandAdaptor loadOperands(operands);
auto basePtr = loadOperands.memref();
return matchFailure();
}
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
- op->getLoc(), basePtr, loadOperands.indices());
+ loadOp.getLoc(), basePtr, loadOperands.indices());
auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>();
rewriter.replaceOpWithNewOp<spirv::LoadOp>(
- op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr,
+ loadOp, loadPtrType.getPointeeType(), loadPtr,
+ /*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
}
};
/// Convert return -> spv.Return.
-class ReturnToSPIRVConversion : public ConversionPattern {
+class ReturnToSPIRVConversion final : public SPIRVOpLowering<ReturnOp> {
public:
- ReturnToSPIRVConversion(MLIRContext *context)
- : ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
+ using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
+
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(ReturnOp returnOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
- if (op->getNumOperands()) {
+ if (returnOp.getNumOperands()) {
return matchFailure();
}
- rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
+ rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return matchSuccess();
}
};
/// Convert select -> spv.Select
-class SelectOpConversion : public ConversionPattern {
+class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
public:
- SelectOpConversion(MLIRContext *context)
- : ConversionPattern(SelectOp::getOperationName(), 1, context) {}
+ using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(SelectOp op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
/// is not supported in tablegen based pattern specification.
// TODO(ravishankarm) : These could potentially be templated on the operation
// being converted, since the same logic should work for linalg.store.
-class StoreOpConversion final : public ConversionPattern {
+class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
public:
- StoreOpConversion(MLIRContext *context)
- : ConversionPattern(StoreOp::getOperationName(), 1, context) {}
+ using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
StoreOpOperandAdaptor storeOperands(operands);
auto value = storeOperands.value();
return matchFailure();
}
auto storePtr = rewriter.create<spirv::AccessChainOp>(
- op->getLoc(), basePtr, storeOperands.indices());
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(op, storePtr, value,
+ storeOp.getLoc(), basePtr, storeOperands.indices());
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, value,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
namespace mlir {
void populateStandardToSPIRVPatterns(MLIRContext *context,
+ SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
IntegerOpConversion<RemISOp, spirv::SModOp>,
IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
- context);
+ context, typeConverter);
}
} // namespace mlir
// =============================================================================
//
// This file implements a pass to convert MLIR standard ops into the SPIR-V
-// ops. It does not legalize FuncOps.
+// ops.
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
-#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+#include "mlir/Pass/Pass.h"
using namespace mlir;
OwningRewritePatternList patterns;
auto module = getModule();
- populateStandardToSPIRVPatterns(module.getContext(), patterns);
+ SPIRVBasicTypeConverter basicTypeConverter;
+ SPIRVTypeConverter typeConverter(&basicTypeConverter);
+ populateStandardToSPIRVPatterns(module.getContext(), typeConverter, patterns);
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
target.addLegalOp<FuncOp>();
}
}
-std::unique_ptr<OpPassBase<ModuleOp>>
-mlir::spirv::createConvertStandardToSPIRVPass() {
+std::unique_ptr<OpPassBase<ModuleOp>> mlir::createConvertStandardToSPIRVPass() {
return std::make_unique<ConvertStandardToSPIRVPass>();
}
add_llvm_library(MLIRSPIRV
DialectRegistration.cpp
+ LayoutUtils.cpp
SPIRVDialect.cpp
SPIRVOps.cpp
+ SPIRVLowering.cpp
SPIRVTypes.cpp
- LayoutUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
target_link_libraries(MLIRSPIRV
MLIRIR
MLIRParser
- MLIRSupport)
+ MLIRSupport
+ MLIRTransforms)
add_subdirectory(Serialization)
add_subdirectory(Transforms)
--- /dev/null
+//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements utilities used to lower to SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
+
+#include "mlir/Dialect/SPIRV/LayoutUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+Type convertIndexType(MLIRContext *context) {
+ // Convert to 32-bit integers for now. Might need a way to control this in
+ // future.
+ // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
+ // this some support is needed in SPIR-V dialect for Conversion
+ // instructions. The Vulkan spec requires the builtins like
+ // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
+ // SExtended to 64-bit for index computations.
+ return IntegerType::get(32, context);
+}
+
+Type convertIndexType(IndexType t) { return convertIndexType(t.getContext()); }
+
+Type basicTypeConversion(Type t) {
+ // Check if the type is SPIR-V supported. If so return the type.
+ if (spirv::SPIRVDialect::isValidType(t)) {
+ return t;
+ }
+
+ if (auto indexType = t.dyn_cast<IndexType>()) {
+ return convertIndexType(indexType);
+ }
+
+ if (auto memRefType = t.dyn_cast<MemRefType>()) {
+ auto elementType = memRefType.getElementType();
+ if (memRefType.hasStaticShape()) {
+ // Convert to a multi-dimensional spv.array if size is known.
+ for (auto size : reverse(memRefType.getShape())) {
+ elementType = spirv::ArrayType::get(elementType, size);
+ }
+ return spirv::PointerType::get(elementType,
+ spirv::StorageClass::StorageBuffer);
+ } else {
+ // Vulkan SPIR-V validation rules require runtime array type to be the
+ // last member of a struct.
+ return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
+ spirv::StorageClass::StorageBuffer);
+ }
+ }
+ return Type();
+}
+
+Type getLayoutDecoratedType(spirv::StructType type) {
+ VulkanLayoutUtils::Size size = 0, alignment = 0;
+ return VulkanLayoutUtils::decorateType(type, size, alignment);
+}
+
+/// Generates the type of variable given the type of object.
+static Type getGlobalVarTypeForEntryFnArg(Type t) {
+ auto convertedType = basicTypeConversion(t);
+ if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
+ if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
+ return spirv::PointerType::get(
+ getLayoutDecoratedType(
+ spirv::StructType::get(ptrType.getPointeeType())),
+ ptrType.getStorageClass());
+ }
+ } else {
+ return spirv::PointerType::get(
+ getLayoutDecoratedType(spirv::StructType::get(convertedType)),
+ spirv::StorageClass::StorageBuffer);
+ }
+ return convertedType;
+}
+} // namespace
+
+Type SPIRVBasicTypeConverter::convertType(Type t) {
+ return basicTypeConversion(t);
+}
+
+Type SPIRVTypeConverter::convertType(Type t) {
+ return getGlobalVarTypeForEntryFnArg(t);
+}
+
+//===----------------------------------------------------------------------===//
+// Builtin Variables
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Look through all global variables in `moduleOp` and check if there is a
+/// spv.globalVariable that has the same `builtin` attribute.
+spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
+ spirv::BuiltIn builtin) {
+ for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
+ if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
+ stringifyDecoration(spirv::Decoration::BuiltIn)))) {
+ auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
+ if (varBuiltIn && varBuiltIn.getValue() == builtin) {
+ return varOp;
+ }
+ }
+ }
+ return nullptr;
+}
+
+/// Gets name of global variable for a buitlin.
+std::string getBuiltinVarName(spirv::BuiltIn builtin) {
+ return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
+}
+
+/// Gets or inserts a global variable for a builtin within a module.
+spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
+ Location loc,
+ spirv::BuiltIn builtin,
+ OpBuilder &builder) {
+ if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
+ return varOp;
+ }
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(&moduleOp.getBlock());
+ auto name = getBuiltinVarName(builtin);
+ spirv::GlobalVariableOp newVarOp;
+ switch (builtin) {
+ case spirv::BuiltIn::NumWorkgroups:
+ case spirv::BuiltIn::WorkgroupSize:
+ case spirv::BuiltIn::WorkgroupId:
+ case spirv::BuiltIn::LocalInvocationId:
+ case spirv::BuiltIn::GlobalInvocationId: {
+ auto ptrType = spirv::PointerType::get(
+ VectorType::get({3}, builder.getIntegerType(32)),
+ spirv::StorageClass::Input);
+ newVarOp = builder.create<spirv::GlobalVariableOp>(
+ loc, TypeAttr::get(ptrType), builder.getStringAttr(name), nullptr);
+ newVarOp.setAttr(
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
+ builder.getStringAttr(stringifyBuiltIn(builtin)));
+ break;
+ }
+ default:
+ emitError(loc, "unimplemented builtin variable generation for ")
+ << stringifyBuiltIn(builtin);
+ }
+ builder.restoreInsertionPoint(ip);
+ return newVarOp;
+}
+} // namespace
+
+/// Gets the global variable associated with a builtin and add
+/// it if it doesnt exist.
+Value *mlir::spirv::getBuiltinVariableValue(Operation *op,
+ spirv::BuiltIn builtin,
+ OpBuilder &builder) {
+ auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
+ if (!moduleOp) {
+ op->emitError("expected operation to be within a SPIR-V module");
+ return nullptr;
+ }
+ auto varOp =
+ getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder);
+ auto ptr = builder
+ .create<spirv::AddressOfOp>(op->getLoc(), varOp.type(),
+ builder.getSymbolRefAttr(varOp))
+ .pointer();
+ return builder.create<spirv::LoadOp>(
+ op->getLoc(),
+ ptr->getType().template cast<spirv::PointerType>().getPointeeType(), ptr,
+ /*memory_access =*/nullptr, /*alignment =*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Function signature Conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Computes the replacement value for an argument of an entry function. It
+/// allocates a global variable for this argument and adds statements in the
+/// entry block to get a replacement value within function scope.
+Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
+ size_t origArgNum, Value *origArg) {
+ // Create a global variable for this argument.
+ auto insertionOp = rewriter.getInsertionBlock()->getParent();
+ auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
+ if (!module) {
+ return nullptr;
+ }
+ auto funcOp = insertionOp->getParentOfType<FuncOp>();
+ spirv::GlobalVariableOp var;
+ {
+ OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
+ rewriter.setInsertionPoint(funcOp.getOperation());
+ std::string varName =
+ funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
+ var = rewriter.create<spirv::GlobalVariableOp>(
+ funcOp.getLoc(),
+ TypeAttr::get(getGlobalVarTypeForEntryFnArg(origArg->getType())),
+ rewriter.getStringAttr(varName), nullptr);
+ var.setAttr(
+ spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
+ rewriter.getI32IntegerAttr(0));
+ var.setAttr(
+ spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
+ rewriter.getI32IntegerAttr(origArgNum));
+ }
+ // Insert the addressOf and load instructions, to get back the converted value
+ // type.
+ auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+ auto indexType = convertIndexType(funcOp.getContext());
+ auto zero = rewriter.create<spirv::ConstantOp>(
+ funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+ auto accessChain = rewriter.create<spirv::AccessChainOp>(
+ funcOp.getLoc(), addressOf.pointer(), zero.constant());
+ // If the original argument is a tensor/memref type, the value is not
+ // loaded. Instead the pointer value is returned to allow its use in access
+ // chain ops.
+ auto origArgType = origArg->getType();
+ if (origArgType.isa<MemRefType>()) {
+ return accessChain;
+ }
+ return rewriter.create<spirv::LoadOp>(
+ funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
+ /*alignment=*/nullptr);
+}
+
+FuncOp applySignatureConversion(
+ FuncOp funcOp, ConversionPatternRewriter &rewriter,
+ TypeConverter::SignatureConversion &signatureConverter) {
+ // Create a new function with an updated signature.
+ auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
+ llvm::None, funcOp.getContext()));
+
+ // Tell the rewriter to convert the region signature.
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+ rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+ return newFuncOp;
+}
+
+/// Gets the global variables that need to be specified as interface variable
+/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
+LogicalResult getInterfaceVariables(FuncOp funcOp,
+ SmallVectorImpl<Attribute> &interfaceVars) {
+ auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+ if (!module) {
+ return failure();
+ }
+ llvm::SetVector<Operation *> interfaceVarSet;
+ for (auto &block : funcOp) {
+ // TODO(ravishankarm) : This should in reality traverse the entry function
+ // call graph and collect all the interfaces. For now, just traverse the
+ // instructions in this function.
+ for (auto op : block.getOps<spirv::AddressOfOp>()) {
+ auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
+ if (var.type().cast<spirv::PointerType>().getStorageClass() ==
+ spirv::StorageClass::StorageBuffer) {
+ continue;
+ }
+ interfaceVarSet.insert(var.getOperation());
+ }
+ }
+ for (auto &var : interfaceVarSet) {
+ interfaceVars.push_back(SymbolRefAttr::get(
+ cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+ }
+ return success();
+}
+} // namespace
+
+LogicalResult mlir::spirv::lowerAsEntryFunction(
+ FuncOp funcOp, SPIRVTypeConverter *typeConverter,
+ ConversionPatternRewriter &rewriter, FuncOp &newFuncOp) {
+ auto fnType = funcOp.getType();
+ if (fnType.getNumResults()) {
+ return funcOp.emitError("SPIR-V lowering only supports functions with no "
+ "return values right now");
+ }
+ // For entry functions need to make the signature void(void). Compute the
+ // replacement value for all arguments and replace all uses.
+ TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+ {
+ OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
+ rewriter.setInsertionPointToStart(&funcOp.front());
+ for (auto origArg : enumerate(funcOp.getArguments())) {
+ auto replacement = createAndLoadGlobalVarForEntryFnArg(
+ rewriter, origArg.index(), origArg.value());
+ signatureConverter.remapInput(origArg.index(), replacement);
+ }
+ }
+ newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
+ return success();
+}
+
+LogicalResult mlir::spirv::finalizeEntryFunction(FuncOp newFuncOp,
+ OpBuilder &builder) {
+ // Add the spv.EntryPointOp after collecting all the interface variables
+ // needed.
+ SmallVector<Attribute, 1> interfaceVars;
+ if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
+ return failure();
+ }
+ builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
+ spirv::ExecutionModel::GLCompute,
+ newFuncOp, interfaceVars);
+ // Specify the spv.ExecutionModeOp.
+
+ /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
+ /// LocalSize execution mode or an object decorated with the WorkgroupSize
+ /// decoration must be specified." Better approach is to use the
+ /// WorkgroupSize GlobalVariable with initializer being a specialization
+ /// constant. But current support for specialization constant does not allow
+ /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
+ /// 1} for now. To be fixed ASAP.
+ builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
+ spirv::ExecutionMode::LocalSize,
+ ArrayRef<int32_t>{1, 1, 1});
+ return success();
+}
MLIRQuantOps
MLIRROCDLIR
MLIRSPIRV
- MLIRSPIRVConversion
+ MLIRStandardToSPIRVTransforms
MLIRSPIRVTransforms
MLIRStandardOps
MLIRStandardToLLVM