From 03e6bf5f564c440ffbbac3a7a30015b6ca779afe Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 6 Dec 2022 20:17:40 -0500 Subject: [PATCH] [mlir][spirv] Define `spirv.*Dot` integer dot product ops This covers `SDot`, `SUDot`, and `UDot`. The `*AccSat` version will be added in a follow-up revision. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D139242 --- .../mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 142 +++++++-------- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 35 +++- .../Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td | 190 +++++++++++++++++++++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td | 1 + mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 131 +++++++++++++- mlir/test/Dialect/SPIRV/IR/availability.mlir | 94 ++++++++++ .../Dialect/SPIRV/IR/integer-dot-product-ops.mlir | 144 ++++++++++++++++ mlir/test/Dialect/SPIRV/IR/target-env.mlir | 115 +++++++++++++ mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp | 22 ++- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 3 +- mlir/utils/spirv/define_enum.sh | 2 +- 11 files changed, 799 insertions(+), 80 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td create mode 100644 mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 1d6c98d..f18796b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -76,7 +76,7 @@ class SPIRV_ArithmeticExtendedBinaryOp ]; - // These op require a custom verifier. + // These ops require a custom verifier. let hasVerifier = 1; } @@ -423,75 +423,6 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul", // ----- -def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended", - [Pure, Commutative]> { - let summary = [{ - Result is the full value of the signed integer multiplication of Operand - 1 and Operand 2. - }]; - - let description = [{ - Result Type must be from OpTypeStruct. The struct must have two - members, and the two members must be the same type. The member type - must be a scalar or vector of integer type. - - Operand 1 and Operand 2 must have the same type as the members of Result - Type. These are consumed as signed integers. - - Results are computed per component. - - Member 0 of the result gets the low-order bits of the multiplication. - - Member 1 of the result gets the high-order bits of the multiplication. - - - - #### Example: - - ```mlir - %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)> - %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> - ``` - }]; -} - -// ----- - -def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended", - [Pure, Commutative]> { - let summary = [{ - Result is the full value of the unsigned integer multiplication of - Operand 1 and Operand 2. - }]; - - let description = [{ - Result Type must be from OpTypeStruct. The struct must have two - members, and the two members must be the same type. The member type - must be a scalar or vector of integer type, whose Signedness operand is - 0. - - Operand 1 and Operand 2 must have the same type as the members of Result - Type. These are consumed as unsigned integers. - - Results are computed per component. - - Member 0 of the result gets the low-order bits of the multiplication. - - Member 1 of the result gets the high-order bits of the multiplication. - - - - #### Example: - - ```mlir - %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)> - %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> - ``` - }]; -} - -// ----- - def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub", SPIRV_Integer, [UsableInSpecConstantOp]> { @@ -646,6 +577,40 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod", // ----- +def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended", + [Pure, Commutative]> { + let summary = [{ + Result is the full value of the signed integer multiplication of Operand + 1 and Operand 2. + }]; + + let description = [{ + Result Type must be from OpTypeStruct. The struct must have two + members, and the two members must be the same type. The member type + must be a scalar or vector of integer type. + + Operand 1 and Operand 2 must have the same type as the members of Result + Type. These are consumed as signed integers. + + Results are computed per component. + + Member 0 of the result gets the low-order bits of the multiplication. + + Member 1 of the result gets the high-order bits of the multiplication. + + + + #### Example: + + ```mlir + %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)> + %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> + ``` + }]; +} + +// ----- + def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate", SPIRV_Integer, [UsableInSpecConstantOp]> { @@ -654,7 +619,7 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate", let description = [{ Result Type must be a scalar or vector of integer type. - Operand’s type must be a scalar or vector of integer type. It must + Operand's type must be a scalar or vector of integer type. It must have the same number of components as Result Type. The component width must equal the component width in Result Type. @@ -746,6 +711,41 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv", // ----- +def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended", + [Pure, Commutative]> { + let summary = [{ + Result is the full value of the unsigned integer multiplication of + Operand 1 and Operand 2. + }]; + + let description = [{ + Result Type must be from OpTypeStruct. The struct must have two + members, and the two members must be the same type. The member type + must be a scalar or vector of integer type, whose Signedness operand is + 0. + + Operand 1 and Operand 2 must have the same type as the members of Result + Type. These are consumed as unsigned integers. + + Results are computed per component. + + Member 0 of the result gets the low-order bits of the multiplication. + + Member 1 of the result gets the high-order bits of the multiplication. + + + + #### Example: + + ```mlir + %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)> + %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)> + ``` + }]; +} + +// ----- + def SPIRV_VectorTimesScalarOp : SPIRV_Op<"VectorTimesScalar", [Pure]> { let summary = "Scale a floating-point vector."; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 3c993cb..4be10e6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3969,6 +3969,18 @@ def SPIRV_StorageClassAttr : SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL ]>; +def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> { + list availability = [ + MinVersion, + Extension<[SPV_KHR_integer_dot_product]> + ]; +} + +def SPIRV_PackedVectorFormatAttr : + SPIRV_I32EnumAttr<"PackedVectorFormat", "valid SPIR-V PackedVectorFormat", "packed_vector_format", [ + SPIRV_PVF_PackedVectorFormat4x8Bit + ]>; + // End enum section. Generated from SPIR-V spec; DO NOT MODIFY! // Enums added manually that are not part of SPIR-V spec @@ -4365,6 +4377,12 @@ def SPIRV_OC_OpGroupNonUniformSMax : I32EnumAttrCase<"OpGroupNonUniformSM def SPIRV_OC_OpGroupNonUniformUMax : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>; def SPIRV_OC_OpGroupNonUniformFMax : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>; def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; +def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>; +def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>; +def SPIRV_OC_OpSUDot : I32EnumAttrCase<"OpSUDot", 4452>; +def SPIRV_OC_OpSDotAccSat : I32EnumAttrCase<"OpSDotAccSat", 4453>; +def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454>; +def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>; def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>; def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>; def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>; @@ -4457,7 +4475,9 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupNonUniformSMin, SPIRV_OC_OpGroupNonUniformUMin, SPIRV_OC_OpGroupNonUniformFMin, SPIRV_OC_OpGroupNonUniformSMax, SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax, - SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpTypeCooperativeMatrixNV, + SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, + SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, + SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV, SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL, @@ -4494,6 +4514,19 @@ class SPIRV_Op traits = []> : Capability<[]> ]; + // Controls whether to auto-generate this op's availability specification. + // If set, generates the following methods: + // + // ```c++ + // SmallVector, 1> OpTy::getCapabilities(); + // SmallVector, 1> OpTy::getExtensions(); + // Optional OpTy::getMinVersion(); + // Optional OpTy::getMaxVersion(); + // ``` + // + // When not set, manual implementation of these methods is required. + bit autogenAvailability = 1; + // For each SPIR-V op, the following static functions need to be defined // in SPIRVOps.cpp: // diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td new file mode 100644 index 0000000..451aeb2 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td @@ -0,0 +1,190 @@ +//===-- SPIRVIntegerDotProductOps.td - MLIR SPIR-V IDP Ops -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains arithmetic ops for the SPIR-V dialect. It corresponds +// to instructions defined by the "SPV_KHR_integer_dot_product" SPIR-V +// extension. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS +#define MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS + +include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class SPIRV_IntegerDotProductOp traits = []> : + SPIRV_Op { + let results = (outs + SPIRV_Integer:$result + ); + + let assemblyFormat = [{ + operands attr-dict `:` `(` type(operands) `)` `->` type($result) + }]; + + // These ops require dynamic availability specification based on operand and + // result types. + bit autogenAvailability = 0; + + // These ops require a custom verifier. + let hasVerifier = 1; +} + +class SPIRV_IntegerDotProductBinaryOp traits = []> : + SPIRV_IntegerDotProductOp { + let arguments = (ins + SPIRV_ScalarOrVectorOf:$vector1, + SPIRV_ScalarOrVectorOf:$vector2, + OptionalAttr:$format + ); +} + +class SPIRV_IntegerDotProductTernaryOp traits = []> : + SPIRV_IntegerDotProductOp { + let arguments = (ins + SPIRV_ScalarOrVectorOf:$vector1, + SPIRV_ScalarOrVectorOf:$vector2, + SPIRV_Integer:$accumulator, + OptionalAttr:$format + ); +} + +// ----- + +def SPIRV_SDotOp : SPIRV_IntegerDotProductBinaryOp<"SDot", + [SignedOp, Commutative]> { + let summary = "Signed integer dot product of Vector 1 and Vector 2."; + + let description = [{ + Result Type must be an integer type whose Width must be greater than or + equal to that of the components of Vector 1 and Vector 2. + + Vector 1 and Vector 2 must have the same type. + + Vector 1 and Vector 2 must be either 32-bit integers (enabled by the + DotProductInput4x8BitPacked capability) or vectors of integer type + (enabled by the DotProductInput4x8Bit or DotProductInputAll capability). + + When Vector 1 and Vector 2 are scalar integer types, Packed Vector + Format must be specified to select how the integers are to be + interpreted as vectors. + + All components of the input vectors are sign-extended to the bit width + of the result's type. The sign-extended input vectors are then + multiplied component-wise and all components of the vector resulting + from the component-wise multiplication are added together. The resulting + value will equal the low-order N bits of the correct result R, where N + is the result width and R is computed with enough precision to avoid + overflow and underflow. + + + + #### Example: + + ```mlir + %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i64 + %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32 + ``` + }]; +} + +// ----- + +def SPIRV_SUDotOp : SPIRV_IntegerDotProductBinaryOp<"SUDot", + [SignedOp, UnsignedOp]> { + let summary = [{ + Mixed-signedness integer dot product of Vector 1 and Vector 2. + Components of Vector 1 are treated as signed, components of Vector 2 are + treated as unsigned. + }]; + + let description = [{ + Result Type must be an integer type whose Width must be greater than or + equal to that of the components of Vector 1 and Vector 2. + + Vector 1 and Vector 2 must be either 32-bit integers (enabled by the + DotProductInput4x8BitPacked capability) or vectors of integer type with + the same number of components and same component Width (enabled by the + DotProductInput4x8Bit or DotProductInputAll capability). When Vector 1 + and Vector 2 are vectors, the components of Vector 2 must have a + Signedness of 0. + + When Vector 1 and Vector 2 are scalar integer types, Packed Vector + Format must be specified to select how the integers are to be + interpreted as vectors. + + All components of Vector 1 are sign-extended to the bit width of the + result's type. All components of Vector 2 are zero-extended to the bit + width of the result's type. The sign- or zero-extended input vectors are + then multiplied component-wise and all components of the vector + resulting from the component-wise multiplication are added together. The + resulting value will equal the low-order N bits of the correct result R, + where N is the result width and R is computed with enough precision to + avoid overflow and underflow. + + + + #### Example: + + ```mlir + %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i64 + %r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32 + ``` + }]; +} + +// ----- + +def SPIRV_UDotOp : SPIRV_IntegerDotProductBinaryOp<"UDot", + [UnsignedOp, Commutative]> { + let summary = "Unsigned integer dot product of Vector 1 and Vector 2."; + + let description = [{ + Result Type must be an integer type with Signedness of 0 whose Width + must be greater than or equal to that of the components of Vector 1 and + Vector 2. + + Vector 1 and Vector 2 must have the same type. + + Vector 1 and Vector 2 must be either 32-bit integers (enabled by the + DotProductInput4x8BitPacked capability) or vectors of integer type with + Signedness of 0 (enabled by the DotProductInput4x8Bit or + DotProductInputAll capability). + + When Vector 1 and Vector 2 are scalar integer types, Packed Vector + Format must be specified to select how the integers are to be + interpreted as vectors. + + All components of the input vectors are zero-extended to the bit width + of the result's type. The zero-extended input vectors are then + multiplied component-wise and all components of the vector resulting + from the component-wise multiplication are added together. The resulting + value will equal the low-order N bits of the correct result R, where N + is the result width and R is computed with enough precision to avoid + overflow and underflow. + + + + #### Example: + + ```mlir + %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i64 + %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32 + ``` + }]; +} + +#endif // MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td index 5e8e5e4..767e939 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -34,6 +34,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 1a93882..888a756 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -29,33 +29,35 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/bit.h" +#include "llvm/Support/FormatVariadic.h" #include #include using namespace mlir; // TODO: generate these strings using ODS. -constexpr char kMemoryAccessAttrName[] = "memory_access"; -constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; constexpr char kAlignmentAttrName[] = "alignment"; -constexpr char kSourceAlignmentAttrName[] = "source_alignment"; constexpr char kBranchWeightAttrName[] = "branch_weights"; constexpr char kCallee[] = "callee"; constexpr char kClusterSize[] = "cluster_size"; constexpr char kControl[] = "control"; constexpr char kDefaultValueAttrName[] = "default_value"; -constexpr char kExecutionScopeAttrName[] = "execution_scope"; constexpr char kEqualSemanticsAttrName[] = "equal_semantics"; +constexpr char kExecutionScopeAttrName[] = "execution_scope"; constexpr char kFnNameAttrName[] = "fn"; constexpr char kGroupOperationAttrName[] = "group_operation"; constexpr char kIndicesAttrName[] = "indices"; constexpr char kInitializerAttrName[] = "initializer"; constexpr char kInterfaceAttrName[] = "interface"; +constexpr char kMemoryAccessAttrName[] = "memory_access"; constexpr char kMemoryScopeAttrName[] = "memory_scope"; +constexpr char kPackedVectorFormatAttrName[] = "format"; constexpr char kSemanticsAttrName[] = "semantics"; +constexpr char kSourceAlignmentAttrName[] = "source_alignment"; +constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; constexpr char kSpecIdAttrName[] = "spec_id"; constexpr char kTypeAttrName[] = "type"; constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics"; @@ -4791,6 +4793,125 @@ LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); } LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); } +//===----------------------------------------------------------------------===// +// Integer Dot Product ops +//===----------------------------------------------------------------------===// + +static LogicalResult verifyIntegerDotProduct(Operation *op) { + assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) && + "Not an integer dot product op?"); + assert(op->getNumResults() == 1 && "Expected a single result"); + + Type factorTy = op->getOperand(0).getType(); + if (op->getOperand(1).getType() != factorTy) + return op->emitOpError("requires the same type for both vector operands"); + + if (auto intTy = factorTy.dyn_cast()) { + auto packedVectorFormat = + op->getAttr(kPackedVectorFormatAttrName) + .dyn_cast_or_null(); + if (!packedVectorFormat) + return op->emitOpError("requires Packed Vector Format attribute for " + "integer vector operands"); + + assert(packedVectorFormat.getValue() == + spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && + "unknown Packed Vector format"); + if (intTy.getWidth() != 32) + return op->emitOpError( + llvm::formatv("with specified Packed Vector Format ({0}) requires " + "integer vector operands to be 32-bits wide", + packedVectorFormat.getValue())); + } + + if (op->getAttrs().size() > 1) + return op->emitError( + "op only supports the 'format' #spirv.packed_vector_format attribute"); + + Type resultTy = op->getResultTypes().front(); + unsigned factorBitWidth = getBitWidth(factorTy); + unsigned resultBitWidth = getBitWidth(resultTy); + if (factorBitWidth > resultBitWidth) + return op->emitOpError( + llvm::formatv("result type has insufficient bit-width ({0} bits) " + "for the specified vector operand type ({1} bits)", + resultBitWidth, factorBitWidth)); + + return success(); +} + +static Optional getIntegerDotProductMinVersion() { + return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. +} + +static Optional getIntegerDotProductMaxVersion() { + return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. +} + +static SmallVector, 1> +getIntegerDotProductExtensions() { + // Requires the SPV_KHR_integer_dot_product extension, specified either + // explicitly or implied by target env's SPIR-V version >= 1.6. + static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product; + return {extension}; +} + +static SmallVector, 1> +getIntegerDotProductCapabilities(Operation *op) { + // Requires the the DotProduct capability and capabilities that depend on + // exact op types. + static const auto dotProductCap = spirv::Capability::DotProduct; + static const auto dotProductInput4x8BitPackedCap = + spirv::Capability::DotProductInput4x8BitPacked; + static const auto dotProductInput4x8BitCap = + spirv::Capability::DotProductInput4x8Bit; + static const auto dotProductInputAllCap = + spirv::Capability::DotProductInputAll; + + SmallVector, 1> capabilities = {dotProductCap}; + + Type factorTy = op->getOperand(0).getType(); + if (auto intTy = factorTy.dyn_cast()) { + auto formatAttr = op->getAttr(kPackedVectorFormatAttrName) + .cast(); + if (formatAttr.getValue() == + spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) + capabilities.push_back(dotProductInput4x8BitPackedCap); + + return capabilities; + } + + auto vecTy = factorTy.cast(); + if (vecTy.getElementTypeBitWidth() == 8) { + capabilities.push_back(dotProductInput4x8BitCap); + return capabilities; + } + + capabilities.push_back(dotProductInputAllCap); + return capabilities; +} + +#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \ + LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \ + SmallVector, 1> OpName::getExtensions() { \ + return getIntegerDotProductExtensions(); \ + } \ + SmallVector, 1> OpName::getCapabilities() { \ + return getIntegerDotProductCapabilities(*this); \ + } \ + Optional OpName::getMinVersion() { \ + return getIntegerDotProductMinVersion(); \ + } \ + Optional OpName::getMaxVersion() { \ + return getIntegerDotProductMaxVersion(); \ + } + +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp) +SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp) + +#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP + // TableGen'erated operation interfaces for querying versions, extensions, and // capabilities. #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index 290e07d..5cd7253 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -49,3 +49,97 @@ func.func @module_physical_storage_buffer64_vulkan() { spirv.module PhysicalStorageBuffer64 Vulkan { } return } + +//===----------------------------------------------------------------------===// +// Integer Dot Product ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: sdot_scalar_i32_i32 +func.func @sdot_scalar_i32_i32(%a: i32) -> i32 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] + %r = spirv.SDot %a, %a {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + return %r: i32 +} + +// CHECK-LABEL: sdot_vector_4xi8_i64 +func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] + %r = spirv.SDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64 + return %r: i64 +} + +// CHECK-LABEL: sdot_vector_4xi16_i64 +func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] + %r = spirv.SDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64 + return %r: i64 +} + +// CHECK-LABEL: sudot_scalar_i32_i32 +func.func @sudot_scalar_i32_i32(%a: i32) -> i32 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] + %r = spirv.SUDot %a, %a {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + return %r: i32 +} + +// CHECK-LABEL: sudot_vector_4xi8_i64 +func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] + %r = spirv.SUDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64 + return %r: i64 +} + +// CHECK-LABEL: sudot_vector_4xi16_i64 +func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] + %r = spirv.SUDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64 + return %r: i64 +} + +// CHECK-LABEL: udot_scalar_i32_i32 +func.func @udot_scalar_i32_i32(%a: i32) -> i32 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ] + %r = spirv.UDot %a, %a {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + return %r: i32 +} + +// CHECK-LABEL: udot_vector_4xi8_i64 +func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ] + %r = spirv.UDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64 + return %r: i64 +} + +// CHECK-LABEL: udot_vector_4xi16_i64 +func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ] + // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ] + %r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64 + return %r: i64 +} diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir new file mode 100644 index 0000000..c0c5cf3 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir @@ -0,0 +1,144 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// This test covers the Integer Dot Product ops defined in the +// SPV_KHR_integer_dot_product extension. + +//===----------------------------------------------------------------------===// +// spirv.SDot +//===----------------------------------------------------------------------===// + +// CHECK: @sdot_scalar_i32 +func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 { + // CHECK-NEXT: spirv.SDot + %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + return %r : i32 +} + +// CHECK: @sdot_scalar_i64 +func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 { + // CHECK-NEXT: spirv.SDot + %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i64 + return %r : i64 +} + +// CHECK: @sdot_vector_4xi8 +func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { + // CHECK-NEXT: spirv.SDot + %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32 + return %r : i32 +} + +// CHECK: @sdot_vector_4xi16 +func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 { + // CHECK-NEXT: spirv.SDot + %r = spirv.SDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64 + return %r : i64 +} + +// CHECK: @sdot_vector_8xi8 +func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 { + // CHECK-NEXT: spirv.SDot + %r = spirv.SDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64 + return %r : i64 +} + +// ----- + +func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 { + // expected-error @+1 {{op requires the same type for both vector operands}} + %r = spirv.SDot %a, %b : (i32, i64) -> i32 + return %r : i32 +} + +// ----- + +func.func @sdot_scalar_i32_bad_attr(%a: i32, %b: i32) -> i32 { + // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}} + %r = spirv.SDot %a, %b {volatile = #spirv.decoration, + format = #spirv.packed_vector_format}: (i32, i32) -> i32 + return %r : i32 +} + +// ----- + +func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 { + // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}} + %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i16 + return %r : i16 +} + +// ----- + +func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 { + // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}} + %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format}: (i64, i64) -> i64 + return %r : i64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.SUDot +//===----------------------------------------------------------------------===// + +// CHECK: @sudot_scalar_i32 +func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 { + // CHECK-NEXT: spirv.SUDot + %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + return %r : i32 +} + +// CHECK: @sudot_scalar_i64 +func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 { + // CHECK-NEXT: spirv.SUDot + %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i64 + return %r : i64 +} + +// CHECK: @sudot_vector_4xi8 +func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { + // CHECK-NEXT: spirv.SUDot + %r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32 + return %r : i32 +} + +// CHECK: @sudot_vector_4xi16 +func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 { + // CHECK-NEXT: spirv.SUDot + %r = spirv.SUDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64 + return %r : i64 +} + +// CHECK: @sudot_vector_8xi8 +func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 { + // CHECK-NEXT: spirv.SUDot + %r = spirv.SUDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64 + return %r : i64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.UDot +//===----------------------------------------------------------------------===// + +// CHECK: @udot_scalar_i32 +func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 { + // CHECK-NEXT: spirv.UDot + %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i32 + return %r : i32 +} + +// CHECK: @udot_scalar_i64 +func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 { + // CHECK-NEXT: spirv.UDot + %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format}: (i32, i32) -> i64 + return %r : i64 +} + +// CHECK: @udot_vector_4xi8 +func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 { + // CHECK-NEXT: spirv.UDot + %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32 + return %r : i32 +} diff --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir index ecf8767..91ffdf2 100644 --- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir +++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir @@ -19,6 +19,9 @@ // spirv.KHR.SubgroupBallot is available under in all SPIR-V versions under // SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension. +// Integer Dot Product ops (spirv.*Dot*) require the +// SPV_KHR_integer_dot_product extension and a number of related capabilities. + // The GeometryPointSize capability implies the Geometry capability, which // implies the Shader capability. @@ -122,6 +125,96 @@ func.func @bit_reverse_recursively_implied_capability(%operand: i32) -> i32 attr return %0: i32 } +// CHECK-LABEL: @sdot_scalar_i32_i32_capabilities +func.func @sdot_scalar_i32_i32_capabilities(%operand: i32) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.SDot + %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format}: (i32, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @sdot_scalar_i32_i32_missing_capability1 +func.func @sdot_scalar_i32_i32_missing_capability1(%operand: i32) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_sdot_op + %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format}: (i32, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @sdot_scalar_i32_i32_missing_capability2 +func.func @sdot_scalar_i32_i32_missing_capability2(%operand: i32) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_sdot_op + %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format}: (i32, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @sudot_vector_4xi8_i32_capabilities +func.func @sudot_vector_4xi8_i32_capabilities(%operand: vector<4xi8>) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.SUDot + %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @sudot_vector_4xi8_i32_missing_capability1 +func.func @sudot_vector_4xi8_i32_missing_capability1(%operand: vector<4xi8>) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_sudot_op + %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @sudot_vector_4xi8_i32_missing_capability2 +func.func @sudot_vector_4xi8_i32_missing_capability2(%operand: vector<4xi8>) -> i32 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_sudot_op + %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @udot_vector_4xi16_i64_capabilities +func.func @udot_vector_4xi16_i64_capabilities(%operand: vector<4xi16>) -> i64 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.UDot + %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64) + return %0: i64 +} + +// CHECK-LABEL: @udot_vector_4xi16_i64_missing_capability1 +func.func @udot_vector_4xi16_i64_missing_capability1(%operand: vector<4xi16>) -> i64 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_udot_op + %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64) + return %0: i64 +} + +// CHECK-LABEL: @udot_vector_4xi16_i64_missing_capability2 +func.func @udot_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>) -> i64 attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_udot_op + %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64) + return %0: i64 +} + //===----------------------------------------------------------------------===// // Extension //===----------------------------------------------------------------------===// @@ -189,3 +282,25 @@ func.func @module_implied_extension() attributes { "test.convert_to_module_op"() : () -> () return } + +// CHECK-LABEL: @udot_vector_4xi16_i64_implied_extension +func.func @udot_vector_4xi16_i64_implied_extension(%operand: vector<4xi16>) -> i64 attributes { + // Version 1.6 implies SPV_KHR_integer_to_product. + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: spirv.UDot + %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64) + return %0: i64 +} + +// CHECK-LABEL: @udot_vector_4xi16_i64_missing_extension +func.func @udot_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>) -> i64 attributes { + // Version 1.5 does not imply SPV_KHR_integer_to_product. + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + // CHECK: test.convert_to_udot_op + %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64) + return %0: i64 +} diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp index e29d167..13c35ca 100644 --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -191,6 +191,19 @@ struct ConvertToSubgroupBallot : RewritePattern { return success(); } }; + +template +struct ConvertToIntegerDotProd : RewritePattern { + ConvertToIntegerDotProd(MLIRContext *context) + : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + op->getOperands(), op->getAttrs()); + return success(); + } +}; } // namespace void ConvertToTargetEnv::runOnOperation() { @@ -207,10 +220,17 @@ void ConvertToTargetEnv::runOnOperation() { auto target = SPIRVConversionTarget::get(targetEnv); + static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op"; + static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op"; + static constexpr char uDotTestOpName[] = "test.convert_to_udot_op"; + RewritePatternSet patterns(context); patterns.add(context); + ConvertToSubgroupBallot, + ConvertToIntegerDotProd, + ConvertToIntegerDotProd, + ConvertToIntegerDotProd>(context); if (failed(applyPartialConversion(fn, *target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 014c947..dad9056 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -1395,7 +1395,8 @@ static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper, auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_Op"); for (const auto *def : defs) { Operator op(def); - emitAvailabilityImpl(op, os); + if (def->getValueAsBit("autogenAvailability")) + emitAvailabilityImpl(op, os); } return false; } diff --git a/mlir/utils/spirv/define_enum.sh b/mlir/utils/spirv/define_enum.sh index 496f90c..ca9d864 100755 --- a/mlir/utils/spirv/define_enum.sh +++ b/mlir/utils/spirv/define_enum.sh @@ -12,7 +12,7 @@ # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported # SPIR-V enum classes. # -# If is missing, this script updates existing ones. +# If is missing, this script updates existing ones. set -e -- 2.7.4