[mlir][spirv] Convert math.ctlz to spv.GLSL.FindUMsb
authorLei Zhang <antiagainst@google.com>
Mon, 13 Jun 2022 17:01:53 +0000 (13:01 -0400)
committerLei Zhang <antiagainst@google.com>
Mon, 13 Jun 2022 17:02:37 +0000 (13:02 -0400)
Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D127582

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir
mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
mlir/test/Target/SPIRV/glsl-ops.mlir

index 5c7e22fb33d7e050642bffed7b04ba0ffdd447ca..1c66d13e0ae103d6c8a3f7404a202a5b468e0ba9 100644 (file)
@@ -1221,4 +1221,20 @@ def SPV_GLSLFMixOp :
   let hasVerifier = 0;
 }
 
+def SPV_GLSLFindUMsbOp : SPV_GLSLUnaryArithmeticOp<"FindUMsb", 75, SPV_Int32> {
+  let summary = "Unsigned-integer most-significant bit";
+
+  let description = [{
+    Results in the bit number of the most-significant 1-bit in the binary
+    representation of Value. If Value is 0, the result is -1.
+
+    Result Type and the type of Value must both be integer scalar or
+    integer vector types. Result Type and operand types must have the
+    same number of components with the same component width. Results are
+    computed per component.
+
+    This instruction is currently limited to 32-bit width components.
+  }];
+}
+
 #endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS
index b65042dd46c7735d2e413375eb2202fdf0397758..8fd07bfbd3c560369313c519875fa5aa47ba8a23 100644 (file)
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "math-to-spirv-pattern"
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
+/// given type is not a 32-bit scalar/vector type.
+static Value getScalarOrVectorI32Constant(Type type, int value,
+                                          OpBuilder &builder, Location loc) {
+  if (auto vectorType = type.dyn_cast<VectorType>()) {
+    if (!vectorType.getElementType().isInteger(32))
+      return nullptr;
+    SmallVector<int> values(vectorType.getNumElements(), value);
+    return builder.create<spirv::ConstantOp>(loc, type,
+                                             builder.getI32VectorAttr(values));
+  }
+  if (type.isInteger(32))
+    return builder.create<spirv::ConstantOp>(loc, type,
+                                             builder.getI32IntegerAttr(value));
+
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -92,6 +115,42 @@ class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
   }
 };
 
+/// Converts math.ctlz to SPIR-V ops.
+///
+/// SPIR-V does not have a direct operations for counting leading zeros. If
+/// Shader capability is supported, we can leverage GLSL FindUMsb to calculate
+/// it.
+class CountLeadingZerosPattern final
+    : public OpConversionPattern<math::CountLeadingZerosOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto type = getTypeConverter()->convertType(countOp.getType());
+    if (!type)
+      return failure();
+
+    // We can only support 32-bit integer types for now.
+    unsigned bitwidth = 0;
+    if (type.isa<IntegerType>())
+      bitwidth = type.getIntOrFloatBitWidth();
+    if (auto vectorType = type.dyn_cast<VectorType>())
+      bitwidth = vectorType.getElementTypeBitWidth();
+    if (bitwidth != 32)
+      return failure();
+
+    Location loc = countOp.getLoc();
+    Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
+    Value msb =
+        rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
+    // We need to subtract from 31 given that the index is from the least
+    // significant bit.
+    rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb);
+    return success();
+  }
+};
+
 /// Converts math.expm1 to SPIR-V ops.
 ///
 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
@@ -148,7 +207,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
 
   // GLSL patterns
   patterns
-      .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
+      .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
+           ExpM1OpPattern<spirv::GLSLExpOp>,
            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
index 18fbd2eefa260d5337d502835305f9cee5376b77..0817ba7560d952f850cca80c975ac26e6d2273e8 100644 (file)
@@ -36,6 +36,17 @@ void ConvertMathToSPIRVPass::runOnOperation() {
 
   SPIRVTypeConverter typeConverter(targetAttr);
 
+  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+  // in patterns for other dialects.
+  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                              Location loc) {
+    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+    return Optional<Value>(cast.getResult(0));
+  };
+  typeConverter.addSourceMaterialization(addUnrealizedCast);
+  typeConverter.addTargetMaterialization(addUnrealizedCast);
+  target->addLegalOp<UnrealizedConversionCastOp>();
+
   RewritePatternSet patterns(context);
   populateMathToSPIRVPatterns(typeConverter, patterns);
 
index d481886bd6ac226160b82f5aabea0c2ec094e0c1..8b179b22a7bd711f9e0311b936bcb920deca9425 100644 (file)
@@ -1,6 +1,8 @@
 // RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
-module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>> } {
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>>
+} {
 
 // CHECK-LABEL: @float32_unary_scalar
 func.func @float32_unary_scalar(%arg0: f32) {
@@ -91,4 +93,56 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
   return
 }
 
+// CHECK-LABEL: @ctlz_scalar
+//  CHECK-SAME: (%[[VAL:.+]]: i32)
+func.func @ctlz_scalar(%val: i32) -> i32 {
+  // CHECK: %[[V31:.+]] = spv.Constant 31 : i32
+  // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
+  // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
+  // CHECK: return %[[SUB]]
+  %0 = math.ctlz %val : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @ctlz_vector1
+func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
+  // CHECK: spv.GLSL.FindUMsb
+  // CHECK: spv.ISub
+  %0 = math.ctlz %val : vector<1xi32>
+  return %0 : vector<1xi32>
+}
+
+// CHECK-LABEL: @ctlz_vector2
+//  CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
+func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
+  // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
+  // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
+  // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
+  // CHECK: return %[[SUB]]
+  %0 = math.ctlz %val : vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+} // end module
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader, Int64, Int16], []>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @ctlz_scalar
+func.func @ctlz_scalar(%val: i64) -> i64 {
+  // CHECK: math.ctlz
+  %0 = math.ctlz %val : i64
+  return %0 : i64
+}
+
+// CHECK-LABEL: @ctlz_vector2
+func.func @ctlz_vector2(%val: vector<2xi16>) -> vector<2xi16> {
+  // CHECK: math.ctlz
+  %0 = math.ctlz %val : vector<2xi16>
+  return %0 : vector<2xi16>
+}
+
 } // end module
index e253a1a74a7054dc1c09e355d72969d289b68f48..3f5db8ee95e84af4d88d8382b300ae03da6389d9 100644 (file)
@@ -494,10 +494,34 @@ func.func @fmix(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> () {
   return
 }
 
-// -----
 func.func @fmix_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) -> () {
   // CHECK: {{%.*}} = spv.GLSL.FMix {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32> -> vector<3xf32>
   %0 = spv.GLSL.FMix %arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32> -> vector<3xf32>
   return
 }
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Exp
+//===----------------------------------------------------------------------===//
+
+func.func @findumsb(%arg0 : i32) -> () {
+  // CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
+  %2 = spv.GLSL.FindUMsb %arg0 : i32
+  return
+}
+
+func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () {
+  // CHECK: spv.GLSL.FindUMsb {{%.*}} : vector<3xi32>
+  %2 = spv.GLSL.FindUMsb %arg0 : vector<3xi32>
+  return
+}
+
+// -----
+
+func.func @findumsb(%arg0 : i64) -> () {
+  // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}}
+  %2 = spv.GLSL.FindUMsb %arg0 : i64
+  return
+}
index 914aa1ffc40e8627d37d72b78290b43338306dcc..1e3c3bbeb1d837874f441a35587fa91d8411eb8f 100644 (file)
@@ -75,4 +75,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32
     spv.Return
   }
+
+  spv.func @findumsb(%arg0 : i32) "None" {
+    // CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
+    %2 = spv.GLSL.FindUMsb %arg0 : i32
+    spv.Return
+  }
 }