[spirv] Add a skeleton to translate standard ops into SPIR-V dialect
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 11 Jun 2019 17:47:06 +0000 (10:47 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 20 Jun 2019 05:58:26 +0000 (22:58 -0700)
PiperOrigin-RevId: 252651994

mlir/include/mlir/SPIRV/CMakeLists.txt
mlir/include/mlir/SPIRV/Passes.h [new file with mode: 0644]
mlir/include/mlir/SPIRV/Transforms/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.td [new file with mode: 0644]
mlir/lib/SPIRV/CMakeLists.txt
mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp [new file with mode: 0644]
mlir/test/SPIRV/standard_ops_to_spirv.mlir [new file with mode: 0644]

index 72fb6b9..b646aa5 100644 (file)
@@ -7,3 +7,5 @@ set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
 mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
 mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
+
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/SPIRV/Passes.h b/mlir/include/mlir/SPIRV/Passes.h
new file mode 100644 (file)
index 0000000..cfe5c91
--- /dev/null
@@ -0,0 +1,35 @@
+//===- Passes.h - SPIR-V pass entry points ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose pass constructors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SPIRV_PASSES_H_
+#define MLIR_SPIRV_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace spirv {
+
+FunctionPassBase *createStdOpsToSPIRVConversionPass();
+
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_SPIRV_PASSES_H_
diff --git a/mlir/include/mlir/SPIRV/Transforms/CMakeLists.txt b/mlir/include/mlir/SPIRV/Transforms/CMakeLists.txt
new file mode 100644 (file)
index 0000000..84adc39
--- /dev/null
@@ -0,0 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS StdOpsToSPIRVConversion.td)
+mlir_tablegen(StdOpsToSPIRVConversion.cpp.inc -gen-rewriters)
+add_public_tablegen_target(MLIRStdOpsToSPIRVConversionIncGen)
diff --git a/mlir/include/mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.td b/mlir/include/mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.td
new file mode 100644 (file)
index 0000000..7b94eb9
--- /dev/null
@@ -0,0 +1,48 @@
+//==- StdOpsToSPIRVConversion.td - Std Ops to SPIR-V Patterns *- tablegen -*==//
+
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines Patterns to lower standard ops to SPIR-V
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef STANDARD_OPS_TO_SPIRV
+#else
+#define STANDARD_OPS_TO_SPIRV
+
+#ifdef STANDARD_OPS
+#else
+include "mlir/StandardOps/Ops.td"
+#endif // STANDARD_OPS
+
+#ifdef SPIRV_OPS
+#else
+include "mlir/SPIRV/SPIRVOps.td"
+#endif // SPIRV_OPS
+
+def IsScalar : TypeConstraint<CPred<"!($_self.isa<ShapedType>())">, "scalar">;
+
+class IsVectorLengthPred<int vecLength> :
+      CPred<"($_self.cast<VectorType>().getShape().size() == 1 && " #
+            "$_self.cast<VectorType>().getShape()[0] == " # vecLength # ")">;
+
+class IsVectorOfLength<int vecLength>:
+    TypeConstraint<And<[IsVectorTypePred, IsVectorLengthPred<vecLength>]>,
+                   vecLength # "-element vector">;
+
+multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
+  def : Pat<(src IsScalar:$l, IsScalar:$r), (tgt $l, $r)>;
+  foreach vecLength = [2, 3, 4] in {
+    def : Pat<(src IsVectorOfLength<vecLength>:$l,
+                   IsVectorOfLength<vecLength>:$r),
+              (tgt $l, $r)>;
+  }
+}
+
+defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
+
+#endif // STANDARD_OPS_TO_SPIRV
\ No newline at end of file
index a597341..e19b5ae 100644 (file)
@@ -3,6 +3,7 @@ add_llvm_library(MLIRSPIRV
   SPIRVDialect.cpp
   SPIRVOps.cpp
   SPIRVTypes.cpp
+  Transforms/StdOpsToSPIRVConversion.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/SPIRV
@@ -10,6 +11,7 @@ add_llvm_library(MLIRSPIRV
 
 add_dependencies(MLIRSPIRV
   MLIRSPIRVOpsIncGen
-  MLIRSPIRVEnumsIncGen)
+  MLIRSPIRVEnumsIncGen
+  MLIRStdOpsToSPIRVConversionIncGen)
 
 target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport)
diff --git a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp
new file mode 100644 (file)
index 0000000..1a8d79c
--- /dev/null
@@ -0,0 +1,56 @@
+//===- StdOpsToSPIRVLowering.cpp - Std Ops to SPIR-V dialect conversion ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to convert MLIR standard ops into the SPIR-V
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/SPIRV/Passes.h"
+#include "mlir/SPIRV/SPIRVOps.h"
+
+namespace mlir {
+#include "mlir/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+/// A pass converting MLIR Standard operations into the SPIR-V dialect.
+class StdOpsToSPIRVConversionPass
+    : public FunctionPass<StdOpsToSPIRVConversionPass> {
+  void runOnFunction() override;
+};
+} // namespace
+
+void StdOpsToSPIRVConversionPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto &func = getFunction();
+
+  populateWithGenerated(func.getContext(), &patterns);
+  applyPatternsGreedily(func, std::move(patterns));
+}
+
+FunctionPassBase *mlir::spirv::createStdOpsToSPIRVConversionPass() {
+  return new StdOpsToSPIRVConversionPass();
+}
+
+static PassRegistration<StdOpsToSPIRVConversionPass>
+    pass("std-to-spirv", "Convert Standard Ops to SPIR-V dialect");
diff --git a/mlir/test/SPIRV/standard_ops_to_spirv.mlir b/mlir/test/SPIRV/standard_ops_to_spirv.mlir
new file mode 100644 (file)
index 0000000..fc59d68
--- /dev/null
@@ -0,0 +1,46 @@
+// RUN: mlir-opt -std-to-spirv %s -o - | FileCheck %s
+
+// CHECK-LABEL: @fmul_scalar
+func @fmul_scalar(%arg: f32) -> f32 {
+  // CHECK: spv.FMul
+  %0 = mulf %arg, %arg : f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @fmul_vector2
+func @fmul_vector2(%arg: vector<2xf32>) -> vector<2xf32> {
+  // CHECK: spv.FMul
+  %0 = mulf %arg, %arg : vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @fmul_vector3
+func @fmul_vector3(%arg: vector<3xf32>) -> vector<3xf32> {
+  // CHECK: spv.FMul
+  %0 = mulf %arg, %arg : vector<3xf32>
+  return %0 : vector<3xf32>
+}
+
+// CHECK-LABEL: @fmul_vector4
+func @fmul_vector4(%arg: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: spv.FMul
+  %0 = mulf %arg, %arg : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @fmul_vector5
+func @fmul_vector5(%arg: vector<5xf32>) -> vector<5xf32> {
+  // Vector length of only 2, 3, and 4 is valid for SPIR-V
+  // CHECK: mulf
+  %0 = mulf %arg, %arg : vector<5xf32>
+  return %0 : vector<5xf32>
+}
+
+// CHECK-LABEL: @fmul_tensor
+func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
+  // For tensors mulf cannot be lowered directly to spv.FMul
+  // CHECK: mulf
+  %0 = mulf %arg, %arg : tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+