From 13caf8b3109028d0be6da1c5b13c20c451920610 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 2 May 2019 12:41:34 -0700 Subject: [PATCH] Add FxpMathOps real_matmul and real_matmul_bias. Also: - cleans up some operand names for consistency - remove the broadcast_dims attribute as it isn't used - adds an IsNullAttr predicate which is needed to match optional clamp attributes on these kind of ops (needed to simplify some out of tree transforms on the new matmul op) -- PiperOrigin-RevId: 246370576 --- mlir/include/mlir/FxpMathOps/FxpMathOps.td | 66 ++++++++++++++-------- mlir/include/mlir/IR/OpBase.td | 3 + .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 8 +-- 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/FxpMathOps/FxpMathOps.td index 708b17c..fc4062c 100644 --- a/mlir/include/mlir/FxpMathOps/FxpMathOps.td +++ b/mlir/include/mlir/FxpMathOps/FxpMathOps.td @@ -104,7 +104,7 @@ def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameValueType]> { Element-wise equivalent to: r = std::min(clamp_max, std::max(e, clamp_min)) }]; - let arguments = (ins IntegerLike:$arg, + let arguments = (ins IntegerLike:$operand, APIntAttr:$clamp_min, APIntAttr:$clamp_max); let results = (outs IntegerLike); @@ -119,7 +119,7 @@ def fxpmath_ConvertISOp : Similar to an element-wise static_cast in C++, from a one signed integer element type to another. }]; - let arguments = (ins IntegerLike:$arg); + let arguments = (ins IntegerLike:$operand); let results = (outs IntegerLike); } @@ -133,7 +133,7 @@ def fxpmath_ConvertISToFOp : element type to a floating point element type, rounding to the nearest floating point value. }]; - let arguments = (ins IntegerLike:$arg); + let arguments = (ins IntegerLike:$operand); let results = (outs FloatLike); } @@ -161,8 +161,8 @@ def fxpmath_RoundingDivideByPotISOp : Also known as a rounding arithmetic right shift. See gemmlowp::RoundingDivideByPOT for a reference implementation. }]; - let arguments = (ins IntegerLike:$x, APIntAttr:$exponent); - let results = (outs IntegerLike:$y); + let arguments = (ins IntegerLike:$operand, APIntAttr:$exponent); + let results = (outs IntegerLike:$res); let verifier = [{ auto verifyExponent = exponent().getSExtValue(); if (verifyExponent < 0 || verifyExponent > 31) { @@ -215,25 +215,17 @@ class fxpmath_RealMathOp traits = [], dag args> : // Element wise binary real math ops. //===----------------------------------------------------------------------===// -// The broadcasting dimensions correspond to a tuple that describes how a -// smaller rank shape is broadcast into a larger rank shape. For example, -// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means -// matching the matrix to dimensions 1 and 2 of the cuboid. -def fxpmath_BroadcastDimAttr : OptionalAttr; - class fxpmath_RealBinaryOp traits = []> : fxpmath_RealMathOp, - Results<(outs quant_RealValueType:$r)>; + (ins quant_RealValueType:$lhs, + quant_RealValueType:$rhs)>, + Results<(outs quant_RealValueType:$res)>; class fxpmath_RealBinaryBiasOp traits = []> : fxpmath_RealMathOp, - Results<(outs quant_RealValueType:$r)>; + Results<(outs quant_RealValueType:$res)>; def fxpmath_RealAddEwOp : fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>; @@ -253,16 +245,46 @@ def fxpmath_RealDivEwOp : def fxpmath_RealUnaryEwOp : fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect], - (ins quant_RealValueType:$x, fxpmath_EwUnaryFnAttr:$fn)>, - Results<(outs quant_RealValueType:$r)>; + (ins quant_RealValueType:$operand, fxpmath_EwUnaryFnAttr:$fn)>, + Results<(outs quant_RealValueType:$res)>; def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>, - Arguments<(ins quant_RealValueType:$x, fxpmath_CompareFnAttr:$fn)>, - Results<(outs I1Tensor:$r)> { + Arguments<(ins quant_RealValueType:$operand, fxpmath_CompareFnAttr:$fn)>, + Results<(outs I1Tensor:$res)> { let description = [{ Compares a real value to zero, returning an I1 (boolean) tensor with the result of applying the comparison function. }]; } +//===----------------------------------------------------------------------===// +// Dot op with fused bias addition. +//===----------------------------------------------------------------------===// + +def fxpmath_RealMatMulOp : + fxpmath_RealBinaryOp<"real_matmul", [NoSideEffect]> { + let summary = "Matmul"; + let description = [{ + A matrix multiply of [m, k] and [k, n] -> [m, n] where the bias vector is + of shape [n]. Also accepts rank 3 or more input tensors, in which case + the leading dimensions are batch dims. + + Many real systems have specific library calls optimized for this precise + operation, which is why it is handled explicitly versus purely as a + generalized tensor contraction. + }]; +} + +def fxpmath_RealMatMulBiasOp : + fxpmath_RealBinaryBiasOp<"real_matmul_bias", [NoSideEffect]> { + let summary = "Matmul with bias"; + let description = [{ + A specialization of a RealMatMulOp that also accepts an [n] dimension + bias vector. + + In addition, there is often special support for a fused bias and clamp, + which is why they are included. + }]; +} + #endif // FXPMATH_OPS diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index fd84701..81eeb13 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -735,6 +735,9 @@ class IntArrayNthElemMinValue : AttrConstraint< ]>, "whose " # index # "-th element must be at least " # min>; +def IsNullAttr : AttrConstraint< + CPred<"!$_self">, "empty attribute (for optional attributes)">; + //===----------------------------------------------------------------------===// // OpTrait definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 2ee39c9..3e8a47a 100644 --- a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -331,8 +331,8 @@ struct UniformRealAddEwPattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto addOp = op->cast(); - const UniformBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(), - addOp.clamp_max()); + const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(), + addOp.clamp_min(), addOp.clamp_max()); if (!info.isValid()) { return matchFailure(); } @@ -353,8 +353,8 @@ struct UniformRealMulEwPattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto mulOp = op->cast(); - const UniformBinaryOpInfo info(op, mulOp.x(), mulOp.y(), mulOp.clamp_min(), - mulOp.clamp_max()); + const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(), + mulOp.clamp_min(), mulOp.clamp_max()); if (!info.isValid()) { return matchFailure(); } -- 2.7.4