--- /dev/null
+//===- ConvertStandardToSPIRV.h - Convert to SPIR-V dialect -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Provides type converters and patterns to convert from standard types/ops to
+// SPIR-V types and operations. Also provides utilities and base classes to use
+// while targeting SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+namespace spirv {
+class SPIRVDialect;
+}
+
+/// Type conversion from Standard Types to SPIR-V Types.
+class SPIRVTypeConverter : public TypeConverter {
+public:
+ explicit SPIRVTypeConverter(MLIRContext *context);
+
+ /// Converts types to SPIR-V supported types.
+ Type convertType(Type t) override;
+
+protected:
+ spirv::SPIRVDialect *spirvDialect;
+};
+
+/// Converts a function type according to the requirements of a SPIR-V entry
+/// function. The arguments need to be converted to spv.Variables of spv.ptr
+/// types so that they could be bound by the runtime.
+class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter {
+public:
+ using SPIRVTypeConverter::SPIRVTypeConverter;
+
+ /// Method to convert argument of a function. The `type` is converted to
+ /// spv.ptr<type, Uniform>.
+ // TODO(ravishankarm) : Support other storage classes.
+ LogicalResult convertSignatureArg(unsigned inputNo, Type type,
+ SignatureConversion &result) override;
+};
+
+/// Base class to define a conversion pattern to translate Ops into SPIR-V.
+template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
+public:
+ SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
+ SPIRVEntryFnTypeConverter &entryFnConverter)
+ : ConversionPattern(OpTy::getOperationName(), 1, context),
+ typeConverter(typeConverter), entryFnConverter(entryFnConverter) {}
+
+protected:
+ // Type lowering class.
+ SPIRVTypeConverter &typeConverter;
+
+ // Entry function signature converter.
+ SPIRVEntryFnTypeConverter &entryFnConverter;
+};
+
+/// Base Class for legalize a FuncOp within a spv.module. This class can be
+/// extended to implement a ConversionPattern to lower a FuncOp. It provides
+/// hooks to legalize a FuncOp as a simple function, or as an entry function.
+class SPIRVFnLowering : public SPIRVOpLowering<FuncOp> {
+public:
+ using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
+
+protected:
+ /// Method to legalize the function as a non-entry function.
+ LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp) const;
+
+ /// Method to legalize the function as an entry function.
+ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp) const;
+};
+
+/// Appends to a pattern list additional patterns for translating StandardOps to
+/// SPIR-V ops.
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
+++ /dev/null
-//===- StdOpsToSPIRVConversion.h - Convert StandardOps to SPIR-V *- C++ -*-===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file defines utility function to import patterns to convert StandardOps
-// to SPIR-V ops
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef STANDARD_OPS_TO_SPIRV_H_
-#define STANDARD_OPS_TO_SPIRV_H_
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-/// Method to append to a pattern list additional patterns for translating
-/// StandardOps to SPIR-V ops.
-void populateStdOpsToSPIRVPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
-} // namespace mlir
-
-#endif // STANDARD_OPS_TO_SPIRV_H_
namespace mlir {
namespace spirv {
-FunctionPassBase *createStdOpsToSPIRVConversionPass();
+ModulePassBase *createConvertStandardToSPIRVPass();
} // namespace spirv
} // namespace mlir
/// Prints a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
+
+ /// Checks if a type is valid in SPIR-V dialect.
+ bool isValidSPIRVType(Type t) const;
};
} // end namespace spirv
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
+/// Following methods are auto-generated.
+///
+/// Get the name used in the Op to refer to an enum value of the given
+/// `EnumClass`.
+/// template <typename EnumClass> StringRef attributeName();
+///
+/// Get the function that can be used to symbolize an enum value.
+/// template <typename EnumClass>
+/// llvm::Optional<EnumClass> (*)(StringRef) symbolizeEnum();
+#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
+
} // end namespace spirv
} // end namespace mlir
}];
let arguments = (ins
+ SPV_AddressingModelAttr:$addressing_model,
+ SPV_MemoryModelAttr:$memory_model,
OptionalAttr<StrArrayAttr>:$capabilities,
OptionalAttr<StrArrayAttr>:$extensions,
- OptionalAttr<StrArrayAttr>:$extended_instruction_sets,
- SPV_AddressingModelAttr:$addressing_model,
- SPV_MemoryModelAttr:$memory_model
+ OptionalAttr<StrArrayAttr>:$extended_instruction_sets
);
let results = (outs);
let regions = (region SizedRegion<1>:$body);
- let builders = [OpBuilder<"Builder *, OperationState *state">];
+ let builders = [OpBuilder<"Builder *, OperationState *state">,
+ OpBuilder<[{Builder *, OperationState *state,
+ IntegerAttr addressing_model,
+ IntegerAttr memory_model,
+ /*optional*/ArrayAttr capabilities = nullptr,
+ /*optional*/ArrayAttr extensions = nullptr,
+ /*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
// We need to ensure the block inside the region is properly terminated;
// the auto-generated builders do not guarantee that.
add_subdirectory(ControlFlowToCFG)
add_subdirectory(GPUToCUDA)
add_subdirectory(GPUToNVVM)
+add_subdirectory(GPUToSPIRV)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
--- /dev/null
+add_llvm_library(MLIRGPUtoSPIRVTransforms
+ GPUToSPIRV.cpp
+ )
+
+target_link_libraries(MLIRGPUtoSPIRVTransforms
+ MLIRGPU
+ MLIRIR
+ MLIRPass
+ MLIRSPIRV
+ MLIRStandardOps
+ MLIRSPIRVConversion
+ MLIRTransforms
+ )
--- /dev/null
+//===- GPUToSPIRV.cp - MLIR SPIR-V lowering passes ------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert a kernel function in the GPU Dialect
+// into a spv.module operation
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
+/// attribute gpu.kernel) within a spv.module.
+class KernelFnConversion final : public SPIRVFnLowering {
+public:
+ using SPIRVFnLowering::SPIRVFnLowering;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+PatternMatchResult
+KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto funcOp = cast<FuncOp>(op);
+ FuncOp newFuncOp;
+ if (!gpu::GPUDialect::isKernel(funcOp)) {
+ return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp))
+ ? matchSuccess()
+ : matchFailure();
+ }
+
+ if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) {
+ return matchFailure();
+ }
+ newFuncOp.getOperation()->removeAttr(Identifier::get(
+ gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
+ return matchSuccess();
+}
+
+namespace {
+/// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
+/// that have the "gpu.kernel" attribute, i.e. those functions that are
+/// referenced in gpu::LaunchKernelOp operations. For each such function
+///
+/// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
+/// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
+/// replace it).
+///
+/// 2) Lower the body of the spirv::ModuleOp.
+class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+void GPUToSPIRVPass::runOnModule() {
+ auto context = &getContext();
+ auto module = getModule();
+
+ SmallVector<Operation *, 4> spirvModules;
+ for (auto funcOp : module.getOps<FuncOp>()) {
+ if (gpu::GPUDialect::isKernel(funcOp)) {
+ OpBuilder builder(module.getBodyRegion());
+ // Create a new spirv::ModuleOp for this function, and clone the
+ // function into it.
+ // TODO : Generalize this to account for different extensions,
+ // capabilities, extended_instruction_sets, other addressing models
+ // and memory models.
+ auto spvModule = builder.create<spirv::ModuleOp>(
+ funcOp.getLoc(),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::AddressingModel::Logical)),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::MemoryModel::VulkanKHR)));
+ OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
+ moduleBuilder.clone(*funcOp.getOperation());
+ spirvModules.push_back(spvModule);
+ }
+ }
+
+ /// Dialect conversion to lower the functions with the spirv::ModuleOps.
+ SPIRVTypeConverter typeConverter(context);
+ SPIRVEntryFnTypeConverter entryFnConverter(context);
+ OwningRewritePatternList patterns;
+ RewriteListBuilder<KernelFnConversion>::build(
+ patterns, context, typeConverter, entryFnConverter);
+ populateStandardToSPIRVPatterns(context, patterns);
+
+ ConversionTarget target(*context);
+ target.addLegalDialect<spirv::SPIRVDialect>();
+ target.addDynamicallyLegalOp<FuncOp>(
+ [&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); });
+
+ if (failed(applyFullConversion(spirvModules, target, std::move(patterns),
+ &typeConverter))) {
+ return signalPassFailure();
+ }
+}
+
+ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
+
+static PassRegistration<GPUToSPIRVPass>
+ pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");
-set(LLVM_TARGET_DEFINITIONS StdOpsToSPIRVConversion.td)
-mlir_tablegen(StdOpsToSPIRVConversion.cpp.inc -gen-rewriters)
-add_public_tablegen_target(MLIRStdOpsToSPIRVConversionIncGen)
+set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td)
+mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters)
+add_public_tablegen_target(MLIRStandardToSPIRVIncGen)
add_llvm_library(MLIRSPIRVConversion
- StdOpsToSPIRVConversion.cpp
+ ConvertStandardToSPIRV.cpp
+ ConvertStandardToSPIRVPass.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
)
add_dependencies(MLIRSPIRVConversion
- MLIRStdOpsToSPIRVConversionIncGen)
+ MLIRStandardToSPIRVIncGen)
target_link_libraries(MLIRSPIRVConversion
MLIRIR
--- /dev/null
+//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Type Conversion
+//===----------------------------------------------------------------------===//
+
+SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context)
+ : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
+
+Type SPIRVTypeConverter::convertType(Type t) {
+ // Check if the type is SPIR-V supported. If so return the type.
+ if (spirvDialect->isValidSPIRVType(t)) {
+ return t;
+ }
+
+ if (auto memRefType = t.dyn_cast<MemRefType>()) {
+ if (memRefType.hasStaticShape()) {
+ // Convert MemrefType to spv.array if size is known.
+ // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
+ // to support other Storage Classes.
+ return spirv::PointerType::get(
+ spirv::ArrayType::get(memRefType.getElementType(),
+ memRefType.getNumElements()),
+ spirv::StorageClass::StorageBuffer);
+ }
+ }
+ return Type();
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Function signature Conversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+ SignatureConversion &result) {
+ // Try to convert the given input type.
+ auto convertedType = convertType(type);
+ // TODO(ravishankarm) : Vulkan spec requires these to be a
+ // spirv::StructType. This is not a SPIR-V requirement, so just making this a
+ // pointer type for now.
+ if (!convertedType)
+ return failure();
+ // For arguments to entry functions, convert the type into a pointer type if
+ // it is already not one.
+ if (!convertedType.isa<spirv::PointerType>()) {
+ // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
+ // to support other Storage classes.
+ convertedType = spirv::PointerType::get(convertedType,
+ spirv::StorageClass::StorageBuffer);
+ }
+
+ // Add the new inputs.
+ result.addInputs(inputNo, convertedType);
+ return success();
+}
+
+template <typename Converter>
+static LogicalResult
+lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter, Converter &typeConverter,
+ TypeConverter::SignatureConversion &signatureConverter,
+ FuncOp &newFuncOp) {
+ auto fnType = funcOp.getType();
+
+ if (fnType.getNumResults()) {
+ return funcOp.emitError("SPIR-V dialect only supports functions with no "
+ "return values right now");
+ }
+
+ for (auto &argType : enumerate(fnType.getInputs())) {
+ // Get the type of the argument
+ if (failed(typeConverter.convertSignatureArg(
+ argType.index(), argType.value(), signatureConverter))) {
+ return funcOp.emitError("unable to convert argument type ")
+ << argType.value() << " to SPIR-V type";
+ }
+ }
+
+ // Create a new function with an updated signature.
+ newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
+ llvm::None, funcOp.getContext()));
+
+ // Tell the rewriter to convert the region signature.
+ rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+ rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+ return success();
+}
+
+LogicalResult
+SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp) const {
+ auto fnType = funcOp.getType();
+ TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+ return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
+ signatureConverter, newFuncOp);
+}
+
+LogicalResult
+SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter,
+ FuncOp &newFuncOp) const {
+ auto fnType = funcOp.getType();
+ TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
+ if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter,
+ signatureConverter, newFuncOp))) {
+ return failure();
+ }
+ // Create spv.Variable ops for each of the arguments. These need to be bound
+ // by the runtime. For now use descriptor_set 0, and arg number as the binding
+ // number.
+ auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+ if (!module) {
+ return funcOp.emitError("expected op to be within a spv.module");
+ }
+ OpBuilder builder(module.getOperation()->getRegion(0));
+ SmallVector<Value *, 4> interface;
+ for (auto &convertedArgType :
+ llvm::enumerate(signatureConverter.getConvertedTypes())) {
+ auto variableOp = builder.create<spirv::VariableOp>(
+ funcOp.getLoc(), convertedArgType.value(),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
+ llvm::None);
+ variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
+ variableOp.setAttr("binding",
+ builder.getI32IntegerAttr(convertedArgType.index()));
+ interface.push_back(variableOp.getResult());
+ }
+ // Create an entry point instruction for this function.
+ // TODO(ravishankarm) : Add execution mode for the entry function
+ builder.setInsertionPoint(&(module.getBlock().back()));
+ builder.create<spirv::EntryPointOp>(
+ funcOp.getLoc(),
+ builder.getI32IntegerAttr(
+ static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
+ builder.getSymbolRefAttr(newFuncOp.getName()), interface);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Operation conversion
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Convert return -> spv.Return.
+class ReturnToSPIRVConversion : public ConversionPattern {
+public:
+ ReturnToSPIRVConversion(MLIRContext *context)
+ : ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
+ virtual PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (op->getNumOperands()) {
+ return matchFailure();
+ }
+ rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
+ return matchSuccess();
+ }
+};
+
+} // namespace
+
+namespace {
+/// Import the Standard Ops to SPIR-V Patterns.
+#include "StandardToSPIRV.cpp.inc"
+} // namespace
+
+namespace mlir {
+void populateStandardToSPIRVPatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns) {
+ populateWithGenerated(context, &patterns);
+ // Add the return op conversion.
+ RewriteListBuilder<ReturnToSPIRVConversion>::build(patterns, context);
+}
+} // namespace mlir
--- /dev/null
+//===- ConvertStandardToSPIRVPass.cpp - Convert Std Ops to SPIR-V Ops -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard ops into the SPIR-V
+// ops. It does not legalize FuncOps.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass converting MLIR Standard operations into the SPIR-V dialect.
+class ConvertStandardToSPIRVPass
+ : public ModulePass<ConvertStandardToSPIRVPass> {
+ void runOnModule() override;
+};
+} // namespace
+
+void ConvertStandardToSPIRVPass::runOnModule() {
+ OwningRewritePatternList patterns;
+ auto module = getModule();
+
+ populateStandardToSPIRVPatterns(module.getContext(), patterns);
+ ConversionTarget target(*(module.getContext()));
+ target.addLegalDialect<spirv::SPIRVDialect>();
+ target.addLegalOp<FuncOp>();
+
+ if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+ModulePassBase *mlir::spirv::createConvertStandardToSPIRVPass() {
+ return new ConvertStandardToSPIRVPass();
+}
+
+static PassRegistration<ConvertStandardToSPIRVPass>
+ pass("convert-std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
-//==- StdOpsToSPIRVConversion.td - Std Ops to SPIR-V Patterns *- tablegen -*==//
+//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
//
//===----------------------------------------------------------------------===//
//
-// Defines Patterns to lower standard ops to SPIR-V
+// Defines Patterns to lower standard ops to SPIR-V.
//
//===----------------------------------------------------------------------===//
-#ifdef STANDARD_OPS_TO_SPIRV
+#ifdef MLIR_CONVERSION_STANDARDTOSPIRV_TD
#else
-#define STANDARD_OPS_TO_SPIRV
+#define MLIR_CONVERSION_STANDARDTOSPIRV_TD
#ifdef STANDARD_OPS
#else
defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
-#endif // STANDARD_OPS_TO_SPIRV
+#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
+++ /dev/null
-//===- StdOpsToSPIRVLowering.cpp - Std Ops to SPIR-V dialect conversion ---===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// This file implements a pass to convert MLIR standard ops into the SPIR-V
-// dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/StandardToSPIRV/StdOpsToSPIRVConversion.h"
-#include "mlir/Dialect/SPIRV/Passes.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/StandardOps/Ops.h"
-
-using namespace mlir;
-
-namespace {
-/// A pass converting MLIR Standard operations into the SPIR-V dialect.
-class StdOpsToSPIRVConversionPass
- : public FunctionPass<StdOpsToSPIRVConversionPass> {
- void runOnFunction() override;
-};
-
-#include "StdOpsToSPIRVConversion.cpp.inc"
-} // namespace
-
-namespace mlir {
-void populateStdOpsToSPIRVPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- populateWithGenerated(context, &patterns);
-}
-} // namespace mlir
-
-void StdOpsToSPIRVConversionPass::runOnFunction() {
- OwningRewritePatternList patterns;
- auto func = getFunction();
-
- populateStdOpsToSPIRVPatterns(func.getContext(), patterns);
- applyPatternsGreedily(func, std::move(patterns));
-}
-
-FunctionPassBase *mlir::spirv::createStdOpsToSPIRVConversionPass() {
- return new StdOpsToSPIRVConversionPass();
-}
-
-static PassRegistration<StdOpsToSPIRVConversionPass>
- pass("std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
return true;
}
+static bool isValidSPIRVScalarType(Type type) {
+ if (type.isa<FloatType>()) {
+ return !type.isBF16();
+ }
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
+ intType.getWidth());
+ }
+ return false;
+}
+
+bool SPIRVDialect::isValidSPIRVType(Type type) const {
+ // Allow SPIR-V dialect types
+ if (&type.getDialect() == this) {
+ return true;
+ }
+ if (isValidSPIRVScalarType(type)) {
+ return true;
+ }
+ if (auto vectorType = type.dyn_cast<VectorType>()) {
+ return (isValidSPIRVScalarType(vectorType.getElementType()) &&
+ vectorType.getNumElements() >= 2 &&
+ vectorType.getNumElements() <= 4);
+ }
+ return false;
+}
+
static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
spec = spec.trim();
emitError(loc, "only 1-D vector allowed but found ") << t;
return Type();
}
+ if (t.getNumElements() > 4) {
+ emitError(loc,
+ "vector length has to be less than or equal to 4 but found ")
+ << t.getNumElements();
+ return Type();
+ }
} else {
emitError(loc, "cannot use ") << type << " to compose SPIR-V types";
return Type();
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
-namespace mlir {
-namespace spirv {
-#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
-} // namespace spirv
-} // namespace mlir
-
using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
return entryPointOp.emitOpError("interface operands to entry point must "
"be generated from a variable op");
}
- // Before version 1.4 the variables can only have storage_class of Input or
- // Output.
- // TODO: Add versioning so that this can be avoided for 1.4
- auto storageClass =
- interface->getType().cast<spirv::PointerType>().getStorageClass();
- switch (storageClass) {
- case spirv::StorageClass::Input:
- case spirv::StorageClass::Output:
- break;
- default:
- return entryPointOp.emitOpError("invalid storage class '")
- << stringifyStorageClass(storageClass)
- << "' for interface variables";
- }
+ // TODO: Before version 1.4 the variables can only have storage_class of
+ // Input or Output. That needs to be verified.
}
return success();
}
ensureModuleEnd(state->addRegion(), *builder, state->location);
}
+void spirv::ModuleOp::build(Builder *builder, OperationState *state,
+ IntegerAttr addressing_model,
+ IntegerAttr memory_model, ArrayAttr capabilities,
+ ArrayAttr extensions,
+ ArrayAttr extended_instruction_sets) {
+ state->addAttribute("addressing_model", addressing_model);
+ state->addAttribute("memory_model", memory_model);
+ if (capabilities)
+ state->addAttribute("capabilities", capabilities);
+ if (extensions)
+ state->addAttribute("extensions", extensions);
+ if (extended_instruction_sets)
+ state->addAttribute("extended_instruction_sets", extended_instruction_sets);
+ ensureModuleEnd(state->addRegion(), *builder, state->location);
+}
+
static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) {
Region *body = state->addRegion();
--- /dev/null
+// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
+
+// CHECK: spv.module "Logical" "VulkanKHR" {
+// CHECK-NEXT: [[VAR1:%.*]] = spv.Variable bind(0, 0) : !spv.ptr<f32, StorageBuffer>
+// CHECK-NEXT: [[VAR2:%.*]] = spv.Variable bind(0, 1) : !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
+// CHECK-NEXT: func @kernel_1
+// CHECK-NEXT: spv.Return
+// CHECK: spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]]
+func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
+ attributes { gpu.kernel } {
+ return
+}
+
+func @foo() {
+ %0 = "op"() : () -> (f32)
+ %1 = "op"() : () -> (memref<12xf32, 1>)
+ %cst = constant 1 : index
+ "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel_1 }
+ : (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> ()
+ return
+}
\ No newline at end of file
// -----
-spv.module "Logical" "VulkanKHR" {
- %2 = spv.Variable : !spv.ptr<f32, Workgroup>
- func @do_nothing() -> () {
- spv.Return
- }
- // expected-error @+1 {{'spv.EntryPoint' op invalid storage class 'Workgroup'}}
- spv.EntryPoint "GLCompute" @do_nothing, %2 : !spv.ptr<f32, Workgroup>
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// spv.ExecutionMode
//===----------------------------------------------------------------------===//
-// RUN: mlir-opt -std-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s
// CHECK-LABEL: @fmul_scalar
func @fmul_scalar(%arg: f32) -> f32 {
llvm::emitSourceFileHeader("SPIR-V Op Utilites", os);
auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr");
+ os << "#ifndef SPIRV_OP_UTILS_H_\n";
+ os << "#define SPIRV_OP_UTILS_H_\n";
emitEnumGetAttrNameFnDecl(os);
emitEnumGetSymbolizeFnDecl(os);
for (const auto *def : defs) {
emitEnumGetAttrNameFnDefn(enumAttr, os);
emitEnumGetSymbolizeFnDefn(enumAttr, os);
}
+ os << "#endif // SPIRV_OP_UTILS_H\n";
return false;
}