[mlir][spirv][math] Fix crash on unsupported types in math-to-spirv
authorJakub Kuderski <kubak@google.com>
Thu, 17 Nov 2022 18:45:01 +0000 (13:45 -0500)
committerJakub Kuderski <kubak@google.com>
Thu, 17 Nov 2022 18:45:36 +0000 (13:45 -0500)
Fail to match conversion patterns when source op has unsupported types.

Fixes: https://github.com/llvm/llvm-project/issues/58749

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D138178

mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/lib/Conversion/SPIRVCommon/Pattern.h
mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir
mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

index 2420004..5bd06c9 100644 (file)
@@ -18,7 +18,9 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
 
 #define DEBUG_TYPE "math-to-spirv-pattern"
 
@@ -46,6 +48,48 @@ static Value getScalarOrVectorI32Constant(Type type, int value,
   return nullptr;
 }
 
+/// Check if the type is supported by math-to-spirv conversion. We expect to
+/// only see scalars and vectors at this point, with higher-level types already
+/// lowered.
+static bool isSupportedSourceType(Type originalType) {
+  if (originalType.isIntOrIndexOrFloat())
+    return true;
+
+  if (auto vecTy = originalType.dyn_cast<VectorType>()) {
+    if (!vecTy.getElementType().isIntOrIndexOrFloat())
+      return false;
+    if (vecTy.isScalable())
+      return false;
+    if (vecTy.getRank() > 1)
+      return false;
+
+    return true;
+  }
+
+  return false;
+}
+
+/// Check if all `sourceOp` types are supported by math-to-spirv conversion.
+/// Notify of a match failure othwerise and return a `failure` result.
+/// This is intended to simplify type checks in `OpConversionPattern`s.
+static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter,
+                                        Operation *sourceOp) {
+  auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
+  llvm::append_range(allTypes, sourceOp->getResultTypes());
+
+  for (Type ty : allTypes) {
+    if (!isSupportedSourceType(ty)) {
+      return rewriter.notifyMatchFailure(
+          sourceOp,
+          llvm::formatv(
+              "unsupported source type for Math to SPIR-V conversion: {0}",
+              ty));
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -55,14 +99,36 @@ static Value getScalarOrVectorI32Constant(Type type, int value,
 // normal RewritePattern.
 
 namespace {
+/// Converts elementwise unary, binary, and ternary standard operations to
+/// SPIR-V operations. Checks that source `Op` types are supported.
+template <typename Op, typename SPIRVOp>
+struct CheckedElementwiseOpPattern final
+    : public spirv::ElementwiseOpPattern<Op, SPIRVOp> {
+  using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
+  using BasePattern::BasePattern;
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
+      return res;
+
+    return BasePattern::matchAndRewrite(op, adaptor, rewriter);
+  }
+};
+
 /// Converts math.copysign to SPIR-V ops.
-class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
+struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto type = getTypeConverter()->convertType(copySignOp.getType());
+    if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp);
+        failed(res))
+      return res;
+
+    Type type = getTypeConverter()->convertType(copySignOp.getType());
     if (!type)
       return failure();
 
@@ -121,14 +187,17 @@ class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
 /// SPIR-V does not have a direct operations for counting leading zeros. If
 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
 /// it.
-class CountLeadingZerosPattern final
+struct CountLeadingZerosPattern final
     : public OpConversionPattern<math::CountLeadingZerosOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto type = getTypeConverter()->convertType(countOp.getType());
+    if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res))
+      return res;
+
+    Type type = getTypeConverter()->convertType(countOp.getType());
     if (!type)
       return failure();
 
@@ -177,9 +246,16 @@ struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() == 1);
+    if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
+        failed(res))
+      return res;
+
     Location loc = operation.getLoc();
-    auto type = this->getTypeConverter()->convertType(operation.getType());
-    auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
+    Type type = this->getTypeConverter()->convertType(operation.getType());
+    if (!type)
+      return failure();
+
+    Value exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
     return success();
@@ -198,10 +274,17 @@ struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() == 1);
+    if (LogicalResult res = checkSourceOpTypes(rewriter, operation);
+        failed(res))
+      return res;
+
     Location loc = operation.getLoc();
-    auto type = this->getTypeConverter()->convertType(operation.getType());
+    Type type = this->getTypeConverter()->convertType(operation.getType());
+    if (!type)
+      return failure();
+
     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
-    auto onePlus =
+    Value onePlus =
         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
     return success();
@@ -215,7 +298,10 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
   LogicalResult
   matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto dstType = getTypeConverter()->convertType(powfOp.getType());
+    if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res))
+      return res;
+
+    Type dstType = getTypeConverter()->convertType(powfOp.getType());
     if (!dstType)
       return failure();
 
@@ -241,10 +327,13 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
   LogicalResult
   matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res))
+      return res;
+
     Location loc = roundOp.getLoc();
-    auto operand = roundOp.getOperand();
-    auto ty = operand.getType();
-    auto ety = getElementTypeOrSelf(ty);
+    Value operand = roundOp.getOperand();
+    Type ty = operand.getType();
+    Type ety = getElementTypeOrSelf(ty);
 
     auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
     auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
@@ -287,38 +376,38 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns
       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
            ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
-           spirv::ElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
-           spirv::ElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
-           spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
-           spirv::ElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
-           spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
-           spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
-           spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
-           spirv::ElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
-           spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
-           spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
-           spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
-           spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
-           spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
+           CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
+           CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
+           CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
+           CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
+           CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
+           CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
+           CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
+           CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
+           CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
+           CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
+           CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
+           CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
+           CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
           typeConverter, patterns.getContext());
 
   // OpenCL patterns
   patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
-               spirv::ElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
-               spirv::ElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
-               spirv::ElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
-               spirv::ElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
-               spirv::ElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
-               spirv::ElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
-               spirv::ElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
-               spirv::ElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
-               spirv::ElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
-               spirv::ElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
-               spirv::ElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
-               spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
-               spirv::ElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
-               spirv::ElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
-               spirv::ElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
+               CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
+               CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
+               CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
+               CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
+               CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
+               CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
+               CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
+               CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
+               CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
+               CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
+               CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
+               CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
+               CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
+               CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
+               CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
       typeConverter, patterns.getContext());
 }
 
index ed859a8..4da3e19 100644 (file)
@@ -19,8 +19,7 @@ namespace spirv {
 /// Converts elementwise unary, binary and ternary standard operations to SPIR-V
 /// operations.
 template <typename Op, typename SPIRVOp>
-class ElementwiseOpPattern final : public OpConversionPattern<Op> {
-public:
+struct ElementwiseOpPattern : public OpConversionPattern<Op> {
   using OpConversionPattern<Op>::OpConversionPattern;
 
   LogicalResult
index a9ea026..e84b9b0 100644 (file)
@@ -41,3 +41,27 @@ func.func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vect
 //       CHECK:   %[[OR:.+]] = spirv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16>
 //       CHECK:   %[[RESULT:.+]] = spirv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16>
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+// 2-D vectors are not supported.
+func.func @copy_sign_2d_vector(%value: vector<3x3xf32>, %sign: vector<3x3xf32>) -> vector<3x3xf32> {
+  %0 = math.copysign %value, %sign : vector<3x3xf32>
+  return %0: vector<3x3xf32>
+}
+
+// CHECK-LABEL: func @copy_sign_2d_vector
+// CHECK-NEXT:    math.copysign {{%.+}}, {{%.+}} : vector<3x3xf32>
+// CHECK-NEXT:    return
+
+// -----
+
+// Tensors are not supported.
+func.func @copy_sign_tensor(%value: tensor<3x3xf32>, %sign: tensor<3x3xf32>) -> tensor<3x3xf32> {
+  %0 = math.copysign %value, %sign : tensor<3x3xf32>
+  return %0: tensor<3x3xf32>
+}
+
+// CHECK-LABEL: func @copy_sign_tensor
+// CHECK-NEXT:    math.copysign {{%.+}}, {{%.+}} : tensor<3x3xf32>
+// CHECK-NEXT:    return
index a29b18b..125478e 100644 (file)
@@ -211,3 +211,51 @@ func.func @ctlz_vector2(%val: vector<2xi16>) -> vector<2xi16> {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+
+// 2-D vectors are not supported.
+
+// CHECK-LABEL: @vector_2d
+func.func @vector_2d(%arg0: vector<2x2xf32>) {
+  // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
+  %0 = math.cos %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
+  %1 = math.exp %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
+  %2 = math.absf %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
+  %3 = math.ceil %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
+  %4 = math.floor %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
+  %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: return
+  return
+}
+
+// Tensors are not supported.
+
+// CHECK-LABEL: @tensor_1d
+func.func @tensor_1d(%arg0: tensor<2xf32>) {
+  // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
+  %0 = math.cos %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
+  %1 = math.exp %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
+  %2 = math.absf %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
+  %3 = math.ceil %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
+  %4 = math.floor %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
+  %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+  // CHECK-NEXT: return
+  return
+}
+
+} // end module
index 6897cfd..da02e62 100644 (file)
@@ -100,3 +100,51 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+
+// 2-D vectors are not supported.
+
+// CHECK-LABEL: @vector_2d
+func.func @vector_2d(%arg0: vector<2x2xf32>) {
+  // CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
+  %0 = math.cos %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
+  %1 = math.exp %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
+  %2 = math.absf %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
+  %3 = math.ceil %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
+  %4 = math.floor %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
+  %5 = math.powf %arg0, %arg0 : vector<2x2xf32>
+  // CHECK-NEXT: return
+  return
+}
+
+// Tensors are not supported.
+
+// CHECK-LABEL: @tensor_1d
+func.func @tensor_1d(%arg0: tensor<2xf32>) {
+  // CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
+  %0 = math.cos %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
+  %1 = math.exp %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
+  %2 = math.absf %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
+  %3 = math.ceil %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
+  %4 = math.floor %arg0 : tensor<2xf32>
+  // CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
+  %5 = math.powf %arg0, %arg0 : tensor<2xf32>
+  // CHECK-NEXT: return
+  return
+}
+
+} // end module