From 96267b6b88405c9222e69aadb669461533bc1352 Mon Sep 17 00:00:00 2001 From: Jake Hall Date: Mon, 13 Feb 2023 14:10:20 +0000 Subject: [PATCH] [mlir] Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR Float8E5M2FNUZ and Float8E4M3FNUZ have been added to APFloat in D141863. This change adds these types as MLIR builtin types alongside Float8E5M2 and Float8E4M3FN (added in D133823 and D138075). Reviewed By: krzysz00 Differential Revision: https://reviews.llvm.org/D143744 --- mlir/include/mlir-c/BuiltinTypes.h | 14 +++++++++ mlir/include/mlir/IR/Builders.h | 2 ++ mlir/include/mlir/IR/BuiltinTypes.h | 15 +++++++-- mlir/include/mlir/IR/BuiltinTypes.td | 44 +++++++++++++++++++++++++++ mlir/include/mlir/IR/OpBase.td | 4 +++ mlir/include/mlir/IR/Types.h | 2 ++ mlir/lib/AsmParser/TokenKinds.def | 2 ++ mlir/lib/AsmParser/TypeParser.cpp | 8 +++++ mlir/lib/Bindings/Python/IRTypes.cpp | 38 +++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 16 ++++++++++ mlir/lib/IR/AsmPrinter.cpp | 2 ++ mlir/lib/IR/Builders.cpp | 8 +++++ mlir/lib/IR/BuiltinTypes.cpp | 7 ++++- mlir/lib/IR/MLIRContext.cpp | 10 ++++++ mlir/lib/IR/Types.cpp | 2 ++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 16 ++++++++++ mlir/test/IR/attribute.mlir | 8 +++++ mlir/test/python/ir/builtin_types.py | 4 +++ mlir/utils/lldb-scripts/mlirDataFormatters.py | 2 ++ 19 files changed, 201 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 0003975..8b855d8 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -81,6 +81,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E5M2FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); + +/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); + +/// Checks whether the given type is an f8E4M3FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); + +/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index e0d33dd..14df7b0 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -62,6 +62,8 @@ public: // Types. FloatType getFloat8E5M2Type(); FloatType getFloat8E4M3FNType(); + FloatType getFloat8E5M2FNUZType(); + FloatType getFloat8E4M3FNUZType(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getF32Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 135fa9b..33995f3 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -47,6 +47,8 @@ public: static FloatType getF128(MLIRContext *ctx); static FloatType getFloat8E5M2(MLIRContext *ctx); static FloatType getFloat8E4M3FN(MLIRContext *ctx); + static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx); + static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); @@ -374,8 +376,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) { } inline bool FloatType::classof(Type type) { - return type.isa(); + return type.isa(); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { @@ -386,6 +389,14 @@ inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) { return Float8E4M3FNType::get(ctx); } +inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) { + return Float8E5M2FNUZType::get(ctx); +} + +inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) { + return Float8E4M3FNUZType::get(ctx); +} + inline FloatType FloatType::getBF16(MLIRContext *ctx) { return BFloat16Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 5f9141d..a8d1fae 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -119,6 +119,50 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> { } //===----------------------------------------------------------------------===// +// Float8E5M2FNUZType + +def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> { + let summary = "8-bit floating point with 2 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values, + no negative zero, and only one NaN representation. This type has the + following characteristics: + + * bit encoding: S1E5M2 + * exponent bias: 16 + * infinities: Not supported + * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2206.02915 + }]; +} + +//===----------------------------------------------------------------------===// +// Float8E4M3FNUZType + +def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> { + let summary = "8-bit floating point with 3 bit mantissa"; + let description = [{ + An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits + mantissa. This is not a standard type as defined by IEEE-754, but it follows + similar conventions, with the exception that there are no infinity values, + no negative zero, and only one NaN representation. This type has the + following characteristics: + + * bit encoding: S1E4M3 + * exponent bias: 8 + * infinities: Not supported + * NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s + * denormals when exponent is 0 + + Described in: https://arxiv.org/abs/2209.05433 + }]; +} + +//===----------------------------------------------------------------------===// // BFloat16Type def Builtin_BFloat16 : Builtin_FloatType<"BFloat16"> { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index cd888ac..527ccc0 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -488,6 +488,10 @@ def F8E4M3FN : Type, "f8E4M3FN type">, BuildableType<"$_builder.getFloat8E4M3FNType()">; def F8E5M2 : Type, "f8E5M2 type">, BuildableType<"$_builder.getFloat8E5M2Type()">; +def F8E4M3FNUZ : Type, "f8E4M3FNUZ type">, + BuildableType<"$_builder.getFloat8E4M3FNUZType()">; +def F8E5M2FNUZ : Type, "f8E5M2FNUZ type">, + BuildableType<"$_builder.getFloat8E5M2FNUZType()">; def AnyComplex : Type()">, "complex-type", "::mlir::ComplexType">; diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 2a0586c..9f30ce1 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -122,6 +122,8 @@ public: bool isIndex() const; bool isFloat8E5M2() const; bool isFloat8E4M3FN() const; + bool isFloat8E5M2FNUZ() const; + bool isFloat8E4M3FNUZ() const; bool isBF16() const; bool isF16() const; bool isF32() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def index 9bd7b60..0e666c7 100644 --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -95,6 +95,8 @@ TOK_KEYWORD(f64) TOK_KEYWORD(f80) TOK_KEYWORD(f8E5M2) TOK_KEYWORD(f8E4M3FN) +TOK_KEYWORD(f8E5M2FNUZ) +TOK_KEYWORD(f8E4M3FNUZ) TOK_KEYWORD(f128) TOK_KEYWORD(false) TOK_KEYWORD(floordiv) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index fab7244..47078c1 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -33,6 +33,8 @@ OptionalParseResult Parser::parseOptionalType(Type &type) { case Token::inttype: case Token::kw_f8E5M2: case Token::kw_f8E4M3FN: + case Token::kw_f8E5M2FNUZ: + case Token::kw_f8E4M3FNUZ: case Token::kw_bf16: case Token::kw_f16: case Token::kw_f32: @@ -295,6 +297,12 @@ Type Parser::parseNonFunctionType() { case Token::kw_f8E4M3FN: consumeToken(Token::kw_f8E4M3FN); return builder.getFloat8E4M3FNType(); + case Token::kw_f8E5M2FNUZ: + consumeToken(Token::kw_f8E5M2FNUZ); + return builder.getFloat8E5M2FNUZType(); + case Token::kw_f8E4M3FNUZ: + consumeToken(Token::kw_f8E4M3FNUZ); + return builder.getFloat8E4M3FNUZType(); case Token::kw_bf16: consumeToken(Token::kw_bf16); return builder.getBF16Type(); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 3cc226d..87ffe59 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -139,6 +139,42 @@ public: } }; +/// Floating Point Type subclass - Float8E4M3FNUZ. +class PyFloat8E4M3FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); + } +}; + +/// Floating Point Type subclass - Float8E5M2FNUZ. +class PyFloat8E5M2FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -700,6 +736,8 @@ void mlir::python::populateIRTypes(py::module &m) { PyIndexType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); + PyFloat8E4M3FNUZType::bind(m); + PyFloat8E5M2FNUZType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyF32Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 73a3ec4..aea1221 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -84,6 +84,22 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { + return unwrap(type).isFloat8E5M2FNUZ(); +} + +MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); +} + +bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { + return unwrap(type).isFloat8E4M3FNUZ(); +} + +MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 1ce617c..8c5bb30 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2410,6 +2410,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "f8E5M2"; }) .Case([&](Type) { os << "f8E4M3FN"; }) + .Case([&](Type) { os << "f8E5M2FNUZ"; }) + .Case([&](Type) { os << "f8E4M3FNUZ"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) .Case([&](Type) { os << "f32"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 38f0501..d36791f 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -41,6 +41,14 @@ FloatType Builder::getFloat8E4M3FNType() { return FloatType::getFloat8E4M3FN(context); } +FloatType Builder::getFloat8E5M2FNUZType() { + return FloatType::getFloat8E5M2FNUZ(context); +} + +FloatType Builder::getFloat8E4M3FNUZType() { + return FloatType::getFloat8E4M3FNUZ(context); +} + FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } FloatType Builder::getF16Type() { return FloatType::getF16(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 238b5bb..6e6c6b9 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -88,7 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - if (isa()) + if (isa()) return 8; if (isa()) return 16; @@ -109,6 +110,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() { return APFloat::Float8E5M2(); if (isa()) return APFloat::Float8E4M3FN(); + if (isa()) + return APFloat::Float8E5M2FNUZ(); + if (isa()) + return APFloat::Float8E4M3FNUZ(); if (isa()) return APFloat::BFloat(); if (isa()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 5bbedda..176a8ab 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -209,6 +209,8 @@ public: /// Cached Type Instances. Float8E5M2Type f8E5M2Ty; Float8E4M3FNType f8E4M3FNTy; + Float8E5M2FNUZType f8E5M2FNUZTy; + Float8E4M3FNUZType f8E4M3FNUZTy; BFloat16Type bf16Ty; Float16Type f16Ty; Float32Type f32Ty; @@ -281,6 +283,8 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) /// Floating-point Types. impl->f8E5M2Ty = TypeUniquer::get(this); impl->f8E4M3FNTy = TypeUniquer::get(this); + impl->f8E5M2FNUZTy = TypeUniquer::get(this); + impl->f8E4M3FNUZTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); @@ -870,6 +874,12 @@ Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) { Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) { return context->getImpl().f8E4M3FNTy; } +Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) { + return context->getImpl().f8E5M2FNUZTy; +} +Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) { + return context->getImpl().f8E4M3FNUZTy; +} BFloat16Type BFloat16Type::get(MLIRContext *context) { return context->getImpl().bf16Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 070ed4b..e739786 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -36,6 +36,8 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); } bool Type::isFloat8E5M2() const { return isa(); } bool Type::isFloat8E4M3FN() const { return isa(); } +bool Type::isFloat8E5M2FNUZ() const { return isa(); } +bool Type::isFloat8E4M3FNUZ() const { return isa(); } bool Type::isBF16() const { return isa(); } bool Type::isF16() const { return isa(); } bool Type::isF32() const { return isa(); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 63a3125..7d5ff23 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -52,6 +52,8 @@ __all__ = [ "DictAttr", "Float8E4M3FNType", "Float8E5M2Type", + "Float8E4M3FNUZType", + "Float8E5M2FNUZType", "F16Type", "F32Type", "F64Type", @@ -593,6 +595,20 @@ class Float8E5M2Type(Type): @staticmethod def isinstance(arg: Any) -> bool: ... +class Float8E4M3FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + +class Float8E5M2FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E5M2FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + # TODO: Auto-generated. Audit and fix. class F16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index d494824e..de840f95 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -45,6 +45,14 @@ func.func @float_attrs_pass() { float_attr = 2. : f8E4M3FN } : () -> () "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ + float_attr = 2. : f8E5M2FNUZ + } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ + float_attr = 2. : f8E4M3FNUZ + } : () -> () + "test.float_attrs"() { // CHECK: float_attr = 2.000000e+00 : f16 float_attr = 2. : f16 } : () -> () diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index e160216..7af8185 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -197,6 +197,10 @@ def testFloatType(): print("float:", Float8E4M3FNType.get()) # CHECK: float: f8E5M2 print("float:", Float8E5M2Type.get()) + # CHECK: float: f8E5M2FNUZ + print("float:", Float8E5M2FNUZType.get()) + # CHECK: float: f8E4M3FNUZ + print("float:", Float8E4M3FNUZType.get()) # CHECK: float: bf16 print("float:", BF16Type.get()) # CHECK: float: f16 diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py index f04516b..908a734 100644 --- a/mlir/utils/lldb-scripts/mlirDataFormatters.py +++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py @@ -52,6 +52,8 @@ builtin_attr_type_mnemonics = { "mlir::UnknownLoc": '"loc(unknown)"', "mlir::Float8E5M2Type": '"f8E5M2"', "mlir::Float8E4M3FNType": '"f8E4M3FN"', + "mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"', + "mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"', "mlir::BFloat16Type": '"bf16"', "mlir::Float16Type": '"f16"', "mlir::Float32Type": '"f32"', -- 2.7.4