let hasVerifier = 0;
}
+def SPV_GLSLFindUMsbOp : SPV_GLSLUnaryArithmeticOp<"FindUMsb", 75, SPV_Int32> {
+ let summary = "Unsigned-integer most-significant bit";
+
+ let description = [{
+ Results in the bit number of the most-significant 1-bit in the binary
+ representation of Value. If Value is 0, the result is -1.
+
+ Result Type and the type of Value must both be integer scalar or
+ integer vector types. Result Type and operand types must have the
+ same number of components with the same component width. Results are
+ computed per component.
+
+ This instruction is currently limited to 32-bit width components.
+ }];
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "math-to-spirv-pattern"
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
+/// given type is not a 32-bit scalar/vector type.
+static Value getScalarOrVectorI32Constant(Type type, int value,
+ OpBuilder &builder, Location loc) {
+ if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (!vectorType.getElementType().isInteger(32))
+ return nullptr;
+ SmallVector<int> values(vectorType.getNumElements(), value);
+ return builder.create<spirv::ConstantOp>(loc, type,
+ builder.getI32VectorAttr(values));
+ }
+ if (type.isInteger(32))
+ return builder.create<spirv::ConstantOp>(loc, type,
+ builder.getI32IntegerAttr(value));
+
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
}
};
+/// Converts math.ctlz to SPIR-V ops.
+///
+/// SPIR-V does not have a direct operations for counting leading zeros. If
+/// Shader capability is supported, we can leverage GLSL FindUMsb to calculate
+/// it.
+class CountLeadingZerosPattern final
+ : public OpConversionPattern<math::CountLeadingZerosOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto type = getTypeConverter()->convertType(countOp.getType());
+ if (!type)
+ return failure();
+
+ // We can only support 32-bit integer types for now.
+ unsigned bitwidth = 0;
+ if (type.isa<IntegerType>())
+ bitwidth = type.getIntOrFloatBitWidth();
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ bitwidth = vectorType.getElementTypeBitWidth();
+ if (bitwidth != 32)
+ return failure();
+
+ Location loc = countOp.getLoc();
+ Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
+ Value msb =
+ rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
+ // We need to subtract from 31 given that the index is from the least
+ // significant bit.
+ rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb);
+ return success();
+ }
+};
+
/// Converts math.expm1 to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
// GLSL patterns
patterns
- .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
+ .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
+ ExpM1OpPattern<spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
SPIRVTypeConverter typeConverter(targetAttr);
+ // Use UnrealizedConversionCast as the bridge so that we don't need to pull
+ // in patterns for other dialects.
+ auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) {
+ auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
+ return Optional<Value>(cast.getResult(0));
+ };
+ typeConverter.addSourceMaterialization(addUnrealizedCast);
+ typeConverter.addTargetMaterialization(addUnrealizedCast);
+ target->addLegalOp<UnrealizedConversionCastOp>();
+
RewritePatternSet patterns(context);
populateMathToSPIRVPatterns(typeConverter, patterns);
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
-module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>> } {
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>>
+} {
// CHECK-LABEL: @float32_unary_scalar
func.func @float32_unary_scalar(%arg0: f32) {
return
}
+// CHECK-LABEL: @ctlz_scalar
+// CHECK-SAME: (%[[VAL:.+]]: i32)
+func.func @ctlz_scalar(%val: i32) -> i32 {
+ // CHECK: %[[V31:.+]] = spv.Constant 31 : i32
+ // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
+ // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
+ // CHECK: return %[[SUB]]
+ %0 = math.ctlz %val : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: @ctlz_vector1
+func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
+ // CHECK: spv.GLSL.FindUMsb
+ // CHECK: spv.ISub
+ %0 = math.ctlz %val : vector<1xi32>
+ return %0 : vector<1xi32>
+}
+
+// CHECK-LABEL: @ctlz_vector2
+// CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
+func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
+ // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
+ // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
+ // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
+ // CHECK: return %[[SUB]]
+ %0 = math.ctlz %val : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+} // end module
+
+// -----
+
+module attributes {
+ spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader, Int64, Int16], []>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @ctlz_scalar
+func.func @ctlz_scalar(%val: i64) -> i64 {
+ // CHECK: math.ctlz
+ %0 = math.ctlz %val : i64
+ return %0 : i64
+}
+
+// CHECK-LABEL: @ctlz_vector2
+func.func @ctlz_vector2(%val: vector<2xi16>) -> vector<2xi16> {
+ // CHECK: math.ctlz
+ %0 = math.ctlz %val : vector<2xi16>
+ return %0 : vector<2xi16>
+}
+
} // end module
return
}
-// -----
func.func @fmix_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) -> () {
// CHECK: {{%.*}} = spv.GLSL.FMix {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32> -> vector<3xf32>
%0 = spv.GLSL.FMix %arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32> -> vector<3xf32>
return
}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Exp
+//===----------------------------------------------------------------------===//
+
+func.func @findumsb(%arg0 : i32) -> () {
+ // CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
+ %2 = spv.GLSL.FindUMsb %arg0 : i32
+ return
+}
+
+func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () {
+ // CHECK: spv.GLSL.FindUMsb {{%.*}} : vector<3xi32>
+ %2 = spv.GLSL.FindUMsb %arg0 : vector<3xi32>
+ return
+}
+
+// -----
+
+func.func @findumsb(%arg0 : i64) -> () {
+ // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}}
+ %2 = spv.GLSL.FindUMsb %arg0 : i64
+ return
+}
%13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32
spv.Return
}
+
+ spv.func @findumsb(%arg0 : i32) "None" {
+ // CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
+ %2 = spv.GLSL.FindUMsb %arg0 : i32
+ spv.Return
+ }
}