From d3a601ce331a1bfa5dc882bc34f2e7aebf029f9c Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Tue, 11 Jun 2019 10:47:06 -0700 Subject: [PATCH] [spirv] Add a skeleton to translate standard ops into SPIR-V dialect PiperOrigin-RevId: 252651994 --- mlir/include/mlir/SPIRV/CMakeLists.txt | 2 + mlir/include/mlir/SPIRV/Passes.h | 35 ++++++++++++ .../mlir/SPIRV/Transforms/CMakeLists.txt | 3 + .../Transforms/StdOpsToSPIRVConversion.td | 48 ++++++++++++++++ mlir/lib/SPIRV/CMakeLists.txt | 4 +- .../Transforms/StdOpsToSPIRVConversion.cpp | 56 +++++++++++++++++++ mlir/test/SPIRV/standard_ops_to_spirv.mlir | 46 +++++++++++++++ 7 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/SPIRV/Passes.h create mode 100644 mlir/include/mlir/SPIRV/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.td create mode 100644 mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp create mode 100644 mlir/test/SPIRV/standard_ops_to_spirv.mlir diff --git a/mlir/include/mlir/SPIRV/CMakeLists.txt b/mlir/include/mlir/SPIRV/CMakeLists.txt index 72fb6b93b31a..b646aa58c82b 100644 --- a/mlir/include/mlir/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/SPIRV/CMakeLists.txt @@ -7,3 +7,5 @@ set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls) mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRSPIRVEnumsIncGen) + +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/SPIRV/Passes.h b/mlir/include/mlir/SPIRV/Passes.h new file mode 100644 index 000000000000..cfe5c919f175 --- /dev/null +++ b/mlir/include/mlir/SPIRV/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SPIRV_PASSES_H_ +#define MLIR_SPIRV_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace spirv { + +FunctionPassBase *createStdOpsToSPIRVConversionPass(); + +} // namespace spirv +} // namespace mlir + +#endif // MLIR_SPIRV_PASSES_H_ diff --git a/mlir/include/mlir/SPIRV/Transforms/CMakeLists.txt b/mlir/include/mlir/SPIRV/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..84adc3907f18 --- /dev/null +++ b/mlir/include/mlir/SPIRV/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS StdOpsToSPIRVConversion.td) +mlir_tablegen(StdOpsToSPIRVConversion.cpp.inc -gen-rewriters) +add_public_tablegen_target(MLIRStdOpsToSPIRVConversionIncGen) diff --git a/mlir/include/mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.td b/mlir/include/mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.td new file mode 100644 index 000000000000..7b94eb9f8699 --- /dev/null +++ b/mlir/include/mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.td @@ -0,0 +1,48 @@ +//==- StdOpsToSPIRVConversion.td - Std Ops to SPIR-V Patterns *- tablegen -*==// + +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines Patterns to lower standard ops to SPIR-V +// +//===----------------------------------------------------------------------===// + +#ifdef STANDARD_OPS_TO_SPIRV +#else +#define STANDARD_OPS_TO_SPIRV + +#ifdef STANDARD_OPS +#else +include "mlir/StandardOps/Ops.td" +#endif // STANDARD_OPS + +#ifdef SPIRV_OPS +#else +include "mlir/SPIRV/SPIRVOps.td" +#endif // SPIRV_OPS + +def IsScalar : TypeConstraint())">, "scalar">; + +class IsVectorLengthPred : + CPred<"($_self.cast().getShape().size() == 1 && " # + "$_self.cast().getShape()[0] == " # vecLength # ")">; + +class IsVectorOfLength: + TypeConstraint]>, + vecLength # "-element vector">; + +multiclass BinaryOpPattern { + def : Pat<(src IsScalar:$l, IsScalar:$r), (tgt $l, $r)>; + foreach vecLength = [2, 3, 4] in { + def : Pat<(src IsVectorOfLength:$l, + IsVectorOfLength:$r), + (tgt $l, $r)>; + } +} + +defm : BinaryOpPattern; + +#endif // STANDARD_OPS_TO_SPIRV \ No newline at end of file diff --git a/mlir/lib/SPIRV/CMakeLists.txt b/mlir/lib/SPIRV/CMakeLists.txt index a5973413e15e..e19b5ae887eb 100644 --- a/mlir/lib/SPIRV/CMakeLists.txt +++ b/mlir/lib/SPIRV/CMakeLists.txt @@ -3,6 +3,7 @@ add_llvm_library(MLIRSPIRV SPIRVDialect.cpp SPIRVOps.cpp SPIRVTypes.cpp + Transforms/StdOpsToSPIRVConversion.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV @@ -10,6 +11,7 @@ add_llvm_library(MLIRSPIRV add_dependencies(MLIRSPIRV MLIRSPIRVOpsIncGen - MLIRSPIRVEnumsIncGen) + MLIRSPIRVEnumsIncGen + MLIRStdOpsToSPIRVConversionIncGen) target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport) diff --git a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp new file mode 100644 index 000000000000..1a8d79c17909 --- /dev/null +++ b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp @@ -0,0 +1,56 @@ +//===- StdOpsToSPIRVLowering.cpp - Std Ops to SPIR-V dialect conversion ---===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a pass to convert MLIR standard ops into the SPIR-V +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/SPIRV/Passes.h" +#include "mlir/SPIRV/SPIRVOps.h" + +namespace mlir { +#include "mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +/// A pass converting MLIR Standard operations into the SPIR-V dialect. +class StdOpsToSPIRVConversionPass + : public FunctionPass { + void runOnFunction() override; +}; +} // namespace + +void StdOpsToSPIRVConversionPass::runOnFunction() { + OwningRewritePatternList patterns; + auto &func = getFunction(); + + populateWithGenerated(func.getContext(), &patterns); + applyPatternsGreedily(func, std::move(patterns)); +} + +FunctionPassBase *mlir::spirv::createStdOpsToSPIRVConversionPass() { + return new StdOpsToSPIRVConversionPass(); +} + +static PassRegistration + pass("std-to-spirv", "Convert Standard Ops to SPIR-V dialect"); diff --git a/mlir/test/SPIRV/standard_ops_to_spirv.mlir b/mlir/test/SPIRV/standard_ops_to_spirv.mlir new file mode 100644 index 000000000000..fc59d6863a5a --- /dev/null +++ b/mlir/test/SPIRV/standard_ops_to_spirv.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt -std-to-spirv %s -o - | FileCheck %s + +// CHECK-LABEL: @fmul_scalar +func @fmul_scalar(%arg: f32) -> f32 { + // CHECK: spv.FMul + %0 = mulf %arg, %arg : f32 + return %0 : f32 +} + +// CHECK-LABEL: @fmul_vector2 +func @fmul_vector2(%arg: vector<2xf32>) -> vector<2xf32> { + // CHECK: spv.FMul + %0 = mulf %arg, %arg : vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: @fmul_vector3 +func @fmul_vector3(%arg: vector<3xf32>) -> vector<3xf32> { + // CHECK: spv.FMul + %0 = mulf %arg, %arg : vector<3xf32> + return %0 : vector<3xf32> +} + +// CHECK-LABEL: @fmul_vector4 +func @fmul_vector4(%arg: vector<4xf32>) -> vector<4xf32> { + // CHECK: spv.FMul + %0 = mulf %arg, %arg : vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @fmul_vector5 +func @fmul_vector5(%arg: vector<5xf32>) -> vector<5xf32> { + // Vector length of only 2, 3, and 4 is valid for SPIR-V + // CHECK: mulf + %0 = mulf %arg, %arg : vector<5xf32> + return %0 : vector<5xf32> +} + +// CHECK-LABEL: @fmul_tensor +func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { + // For tensors mulf cannot be lowered directly to spv.FMul + // CHECK: mulf + %0 = mulf %arg, %arg : tensor<4xf32> + return %0 : tensor<4xf32> +} + -- 2.34.1