From d2a559ffc0dc61b9d7426064bd5076b66d2f96d6 Mon Sep 17 00:00:00 2001 From: Md Abdullah Shahneous Bari Date: Fri, 31 Mar 2023 14:02:25 -0700 Subject: [PATCH] [mlir][spirv] Add OpExtension "SPV_INTEL_bfloat16_conversion" Add Intel-specific "SPV_INTEL_bfloat16_conversion" extension and capability (Bfloat16ConversionINTEL), and two ops (OpConvertFToBF16INTEL, OpConvertBF16ToFINTEL) that are introduced by this extension. These ops allow BF16 to Float conversion and vice-versa. Reference Specification: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D147087 --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 19 +++- .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 126 +++++++++++++++++++++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td | 1 + mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 40 +++++++ mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 71 ++++++++++++ mlir/test/Target/SPIRV/intel-ext-ops.mlir | 31 +++++ 6 files changed, 285 insertions(+), 3 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td create mode 100644 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir create mode 100644 mlir/test/Target/SPIRV/intel-ext-ops.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 7ca32d9..43c6c3e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -399,6 +399,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>; def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>; +def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -457,7 +458,7 @@ def SPIRV_ExtensionAttr : SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone, SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix, - SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, + SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough, SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage, SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins, @@ -1413,6 +1414,12 @@ def SPIRV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMat ]; } +def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> { + list availability = [ + Extension<[SPV_INTEL_bfloat16_conversion]> + ]; +} + def SPIRV_CapabilityAttr : SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [ SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16, @@ -1504,7 +1511,7 @@ def SPIRV_CapabilityAttr : SPIRV_C_UniformTexelBufferArrayNonUniformIndexing, SPIRV_C_StorageTexelBufferArrayNonUniformIndexing, SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, - SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL + SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL, SPIRV_C_Bfloat16ConversionINTEL ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -4079,6 +4086,7 @@ def SPIRV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; def SPIRV_Void : TypeAlias; def SPIRV_Bool : TypeAlias; def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>; +def SPIRV_Int16 : TypeAlias; def SPIRV_Int32 : TypeAlias; def SPIRV_Float32 : TypeAlias; def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; @@ -4407,6 +4415,9 @@ def SPIRV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreI def SPIRV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>; def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>; +def SPIRV_OC_OpConvertFToBF16INTEL : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>; +def SPIRV_OC_OpConvertBF16ToFINTEL : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>; + def SPIRV_OpcodeAttr : SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ SPIRV_OC_OpNop, SPIRV_OC_OpUndef, SPIRV_OC_OpSourceContinued, @@ -4492,7 +4503,9 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL, SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL, - SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL + SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL, + + SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td new file mode 100644 index 0000000..55753b3 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td @@ -0,0 +1,126 @@ +//===- SPIRVIntelExtOps.td - Intel SPIR-V extensions ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of Intel-specific SPIR-V extensions +// These extensions are not part of Khronos specification but publicly available +// at (https://github.com/intel/llvm) +// Supported extensions +// * SPV_INTEL_bfloat16_conversion +//===----------------------------------------------------------------------===// + + +#ifndef MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS +#define MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS + +// ----- + +def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> { + let summary = "See extension SPV_INTEL_bfloat16_conversion"; + + let description = [{ + Convert value numerically from 32-bit floating point to bfloat16, + which is represented as a 16-bit unsigned integer. + + Result Type must be a scalar or vector of integer type. + The component width must be 16 bits. Bit pattern in the Result represents a bfloat16 value. + + Float Value must be a scalar or vector of floating-point type. + It must have the same number of components as Result Type. The component width must be 32 bits. + + Results are computed per component. + + ``` + convert-f-to-bf16-op ::= ssa-id `=` `spirv.INTEL.ConvertFToBF16` ssa-use + `:` operand-type `to` result-type + ``` + + #### Example: + + ```mlir + %1 = spirv.ConvertFToBF16 %0 : f32 to i16 + %3 = spirv.ConvertFToBF16 %2 : vector<3xf32> to vector<3xi16> + ``` + + }]; + + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_bfloat16_conversion]>, + Capability<[SPIRV_C_Bfloat16ConversionINTEL]> + ]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + + let hasVerifier = 1; +} + +// ----- + +def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { + let summary = "See extension SPV_INTEL_bfloat16_conversion"; + + let description = [{ + Interpret a 16-bit integer as bfloat16 and convert the value numerically to 32-bit floating point type. + + Result Type must be a scalar or vector of floating-point. The component width must be 32 bits. + + Bfloat16 Value must be a scalar or vector of integer type, which is interpreted as a bfloat16 type. + The type must have the same number of components as the Result Type. The component width must be 16 bits. + + Results are computed per component. + + ``` + convert-bf16-to-f-op ::= ssa-id `=` `spirv.INTEL.ConvertBF16ToF` ssa-use + `:` operand-type `to` result-type + ``` + + #### Example: + + ```mlir + %1 = spirv.ConvertBF16ToF %0 : i16 to f32 + %3 = spirv.ConvertBF16ToF %2 : vector<3xi16> to vector<3xf32> + ``` + + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_bfloat16_conversion]>, + Capability<[SPIRV_C_Bfloat16ConversionINTEL]> + ]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + let hasVerifier = 1; +} + + +// ----- + +#endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td index 767e939..13533d1 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -31,6 +31,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index bb3ad91..181c9e0 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2202,6 +2202,46 @@ LogicalResult spirv::ConvertUToFOp::verify() { } //===----------------------------------------------------------------------===// +// spirv.INTELConvertBF16ToFOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::INTELConvertBF16ToFOp::verify() { + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + // ODS checks that vector result type and vector operand type have the same + // shape. + if (auto vectorType = operandType.dyn_cast()) { + unsigned operandNumElements = vectorType.getNumElements(); + unsigned resultNumElements = resultType.cast().getNumElements(); + if (operandNumElements != resultNumElements) { + return emitOpError( + "operand and result must have same number of elements"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.INTELConvertFToBF16Op +//===----------------------------------------------------------------------===// + +LogicalResult spirv::INTELConvertFToBF16Op::verify() { + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + // ODS checks that vector result type and vector operand type have the same + // shape. + if (auto vectorType = operandType.dyn_cast()) { + unsigned operandNumElements = vectorType.getNumElements(); + unsigned resultNumElements = resultType.cast().getNumElements(); + if (operandNumElements != resultNumElements) { + return emitOpError( + "operand and result must have same number of elements"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// // spirv.EntryPoint //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir new file mode 100644 index 0000000..53a1015 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spirv.INTEL.ConvertFToBF16 +//===----------------------------------------------------------------------===// + +spirv.func @f32_to_bf16(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16 + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16 + spirv.Return +} + +// ----- + +spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16> + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16> + spirv.Return +} + +// ----- + +spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { + // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 + spirv.Return +} + +// ----- + +spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" { + // expected-error @+1 {{operand and result must have same number of elements}} + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16> + spirv.Return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.INTEL.ConvertBF16ToF +//===----------------------------------------------------------------------===// + +spirv.func @bf16_to_f32(%arg0 : i16) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32 + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32 + spirv.Return +} + +// ----- + +spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32> + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32> + spirv.Return +} + +// ----- + +spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { + // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16 + spirv.Return +} + +// ----- + +spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { + // expected-error @+1 {{operand and result must have same number of elements}} + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32> + spirv.Return +} diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir new file mode 100644 index 0000000..fe86fd2 --- /dev/null +++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK-LABEL: @f32_to_bf16 + spirv.func @f32_to_bf16(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16 + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16 + spirv.Return + } + + // CHECK-LABEL: @f32_to_bf16_vec + spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16> + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16> + spirv.Return + } + + // CHECK-LABEL: @bf16_to_f32 + spirv.func @bf16_to_f32(%arg0 : i16) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32 + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32 + spirv.Return + } + + // CHECK-LABEL: @bf16_to_f32_vec + spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32> + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32> + spirv.Return + } +} -- 2.7.4