From: Thomas Raoux Date: Tue, 6 Oct 2020 18:35:14 +0000 (-0700) Subject: [mlir][spirv] Add Vector to SPIR-V conversion pass X-Git-Tag: llvmorg-13-init~9967 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6e557bc40507cbc5e331179b26f7ae5fe9624294;p=platform%2Fupstream%2Fllvm.git [mlir][spirv] Add Vector to SPIR-V conversion pass Add conversion pass for Vector dialect to SPIR-V dialect and add some simple conversion pattern for vector.broadcast, vector.insert, vector.extract. Differential Revision: https://reviews.llvm.org/D88761 --- diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index b044985..b4418bb 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 547b952..3661838 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -381,4 +381,15 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> { let dependentDialects = ["ROCDL::ROCDLDialect"]; } +//===----------------------------------------------------------------------===// +// VectorToSPIRV +//===----------------------------------------------------------------------===// + +def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> { + let summary = "Lower the operations from the vector dialect into the SPIR-V " + "dialect"; + let constructor = "mlir::createConvertVectorToSPIRVPass()"; + let dependentDialects = ["spirv::SPIRVDialect"]; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h new file mode 100644 index 0000000..de664df --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h @@ -0,0 +1,29 @@ +//=- ConvertVectorToSPIRV.h - Vector Ops to SPIR-V dialect patterns - C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns for lowering Vector Ops to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_ +#define MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; + +/// Appends to a pattern list additional patterns for translating Vector Ops to +/// SPIR-V ops. +void populateVectorToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSPIRV_CONVERTVECTORTOSPIRV_H_ diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h new file mode 100644 index 0000000..7d4c7c1 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h @@ -0,0 +1,25 @@ +//=- ConvertVectorToSPIRVPass.h - Pass converting Vector to SPIRV -*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a pass to convert Vector ops to SPIR-V ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H +#define MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Pass to convert Vector Ops to SPIR-V ops. +std::unique_ptr> createConvertVectorToSPIRVPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOSPIRV_CONVERTGPUTOSPIRVPASS_H diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td index d6e66a6..c3a8679 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -161,6 +161,11 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> { let results = (outs SPV_Composite:$result ); + + let builders = [ + OpBuilder<[{OpBuilder &builder, OperationState &state, Value object, + Value composite, ArrayRef indices}]> + ]; } #endif // SPIRV_COMPOSITE_OPS diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index fe2af07..dbb9ed6 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -19,3 +19,4 @@ add_subdirectory(StandardToSPIRV) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) +add_subdirectory(VectorToSPIRV) diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt new file mode 100644 index 0000000..a6e7300 --- /dev/null +++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIRVectorToSPIRV + VectorToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToSPIRV + + DEPENDS + MLIRConversionPassIncGen + intrinsics_gen + + LINK_LIBS PUBLIC + MLIRSPIRV + MLIRVector + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp new file mode 100644 index 0000000..05949fb --- /dev/null +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -0,0 +1,119 @@ +//===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to generate SPIRV operations for Vector +// operations. +// +//===----------------------------------------------------------------------===// + +#include "../PassDetail.h" +#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h" +#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { +struct VectorBroadcastConvert final + : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (broadcastOp.source().getType().isa() || + !spirv::CompositeType::isValid(broadcastOp.getVectorType())) + return failure(); + vector::BroadcastOp::Adaptor adaptor(operands); + SmallVector source(broadcastOp.getVectorType().getNumElements(), + adaptor.source()); + Value construct = rewriter.create( + broadcastOp.getLoc(), broadcastOp.getVectorType(), source); + rewriter.replaceOp(broadcastOp, construct); + return success(); + } +}; + +struct VectorExtractOpConvert final + : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (extractOp.getType().isa() || + !spirv::CompositeType::isValid(extractOp.getVectorType())) + return failure(); + vector::ExtractOp::Adaptor adaptor(operands); + int32_t id = extractOp.position().begin()->cast().getInt(); + Value newExtract = rewriter.create( + extractOp.getLoc(), adaptor.vector(), id); + rewriter.replaceOp(extractOp, newExtract); + return success(); + } +}; + +struct VectorInsertOpConvert final : public SPIRVOpLowering { + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (insertOp.getSourceType().isa() || + !spirv::CompositeType::isValid(insertOp.getDestVectorType())) + return failure(); + vector::InsertOp::Adaptor adaptor(operands); + int32_t id = insertOp.position().begin()->cast().getInt(); + Value newInsert = rewriter.create( + insertOp.getLoc(), adaptor.source(), adaptor.dest(), id); + rewriter.replaceOp(insertOp, newInsert); + return success(); + } +}; +} // namespace + +void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert(context, typeConverter); +} + +namespace { +struct LowerVectorToSPIRVPass + : public ConvertVectorToSPIRVBase { + void runOnOperation() override; +}; +} // namespace + +void LowerVectorToSPIRVPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + spirv::SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); + OwningRewritePatternList patterns; + populateVectorToSPIRVPatterns(context, typeConverter, patterns); + + target->addLegalOp(); + target->addLegalOp(); + + if (failed(applyFullConversion(module, *target, patterns))) + return signalPassFailure(); +} + +std::unique_ptr> +mlir::createConvertVectorToSPIRVPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index ad25ecb..c17490c 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1410,6 +1410,13 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) { // spv.CompositeInsert //===----------------------------------------------------------------------===// +void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, + Value object, Value composite, + ArrayRef indices) { + auto indexAttr = builder.getI32ArrayAttr(indices); + build(builder, state, composite.getType(), object, composite, indexAttr); +} + static ParseResult parseCompositeInsertOp(OpAsmParser &parser, OperationState &state) { SmallVector operands; diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir new file mode 100644 index 0000000..34f1ef5 --- /dev/null +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -split-input-file -convert-vector-to-spirv %s -o - | FileCheck %s + +// CHECK-LABEL: broadcast +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> +// CHECK: spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32> +func @broadcast(%arg0 : f32) { + %0 = vector.broadcast %arg0 : f32 to vector<4xf32> + %1 = vector.broadcast %arg0 : f32 to vector<2xf32> + spv.Return +} + +// ----- + +// CHECK-LABEL: extract_insert +// CHECK-SAME: %[[V:.*]]: vector<4xf32> +// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> +// CHECK: spv.CompositeInsert %[[S]], %[[V]][0 : i32] : f32 into vector<4xf32> +func @extract_insert(%arg0 : vector<4xf32>) { + %0 = vector.extract %arg0[1] : vector<4xf32> + %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32> + spv.Return +}