Add FxpMathOps real_matmul and real_matmul_bias.
authorStella Laurenzo <laurenzo@google.com>
Thu, 2 May 2019 19:41:34 +0000 (12:41 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:25:27 +0000 (08:25 -0700)
    Also:
      - cleans up some operand names for consistency
      - remove the broadcast_dims attribute as it isn't used
      - adds an IsNullAttr predicate which is needed to match optional clamp attributes on these kind of ops (needed to simplify some out of tree transforms on the new matmul op)

--

PiperOrigin-RevId: 246370576

mlir/include/mlir/FxpMathOps/FxpMathOps.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp

index 708b17c..fc4062c 100644 (file)
@@ -104,7 +104,7 @@ def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameValueType]> {
     Element-wise equivalent to:
       r = std::min(clamp_max, std::max(e, clamp_min))
   }];
-  let arguments = (ins IntegerLike:$arg,
+  let arguments = (ins IntegerLike:$operand,
                        APIntAttr:$clamp_min,
                        APIntAttr:$clamp_max);
   let results = (outs IntegerLike);
@@ -119,7 +119,7 @@ def fxpmath_ConvertISOp :
     Similar to an element-wise static_cast in C++, from a one signed integer
     element type to another.
   }];
-  let arguments = (ins IntegerLike:$arg);
+  let arguments = (ins IntegerLike:$operand);
   let results = (outs IntegerLike);
 }
 
@@ -133,7 +133,7 @@ def fxpmath_ConvertISToFOp :
     element type to a floating point element type, rounding to the nearest
     floating point value.
   }];
-  let arguments = (ins IntegerLike:$arg);
+  let arguments = (ins IntegerLike:$operand);
   let results = (outs FloatLike);
 }
 
@@ -161,8 +161,8 @@ def fxpmath_RoundingDivideByPotISOp :
     Also known as a rounding arithmetic right shift. See
     gemmlowp::RoundingDivideByPOT for a reference implementation.
   }];
-  let arguments = (ins IntegerLike:$x, APIntAttr:$exponent);
-  let results = (outs IntegerLike:$y);
+  let arguments = (ins IntegerLike:$operand, APIntAttr:$exponent);
+  let results = (outs IntegerLike:$res);
   let verifier = [{
     auto verifyExponent = exponent().getSExtValue();
     if (verifyExponent < 0 || verifyExponent > 31) {
@@ -215,25 +215,17 @@ class fxpmath_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
 // Element wise binary real math ops.
 //===----------------------------------------------------------------------===//
 
-// The broadcasting dimensions correspond to a tuple that describes how a
-// smaller rank shape is broadcast into a larger rank shape. For example,
-// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
-// matching the matrix to dimensions 1 and 2 of the cuboid.
-def fxpmath_BroadcastDimAttr : OptionalAttr<ElementsAttr>;
-
 class fxpmath_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
     fxpmath_RealMathOp<mnemonic, traits,
-                     (ins quant_RealValueType:$x,
-                      quant_RealValueType:$y,
-                      fxpmath_BroadcastDimAttr:$broadcast_dimensions
-                     )>,
-    Results<(outs quant_RealValueType:$r)>;
+                     (ins quant_RealValueType:$lhs,
+                      quant_RealValueType:$rhs)>,
+    Results<(outs quant_RealValueType:$res)>;
 
 class fxpmath_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
     fxpmath_RealMathOp<mnemonic, traits,
-                     (ins quant_RealValueType:$x, quant_RealValueType:$y,
+                     (ins quant_RealValueType:$lhs, quant_RealValueType:$rhs,
                           quant_RealValueType:$bias)>,
-    Results<(outs quant_RealValueType:$r)>;
+    Results<(outs quant_RealValueType:$res)>;
 
 def fxpmath_RealAddEwOp :
     fxpmath_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
@@ -253,16 +245,46 @@ def fxpmath_RealDivEwOp :
 
 def fxpmath_RealUnaryEwOp :
     fxpmath_RealMathOp<"real_unary_ew", [NoSideEffect],
-        (ins quant_RealValueType:$x, fxpmath_EwUnaryFnAttr:$fn)>,
-    Results<(outs quant_RealValueType:$r)>;
+        (ins quant_RealValueType:$operand, fxpmath_EwUnaryFnAttr:$fn)>,
+    Results<(outs quant_RealValueType:$res)>;
 
 def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>,
-    Arguments<(ins quant_RealValueType:$x, fxpmath_CompareFnAttr:$fn)>,
-    Results<(outs I1Tensor:$r)> {
+    Arguments<(ins quant_RealValueType:$operand, fxpmath_CompareFnAttr:$fn)>,
+    Results<(outs I1Tensor:$res)> {
   let description = [{
     Compares a real value to zero, returning an I1 (boolean) tensor with the
     result of applying the comparison function.
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Dot op with fused bias addition.
+//===----------------------------------------------------------------------===//
+
+def fxpmath_RealMatMulOp :
+    fxpmath_RealBinaryOp<"real_matmul", [NoSideEffect]> {
+  let summary = "Matmul";
+  let description = [{
+    A matrix multiply of [m, k] and [k, n] -> [m, n] where the bias vector is
+    of shape [n]. Also accepts rank 3 or more input tensors, in which case
+    the leading dimensions are batch dims.
+
+    Many real systems have specific library calls optimized for this precise
+    operation, which is why it is handled explicitly versus purely as a
+    generalized tensor contraction.
+  }];
+}
+
+def fxpmath_RealMatMulBiasOp :
+    fxpmath_RealBinaryBiasOp<"real_matmul_bias", [NoSideEffect]> {
+  let summary = "Matmul with bias";
+  let description = [{
+    A specialization of a RealMatMulOp that also accepts an [n] dimension
+    bias vector.
+
+    In addition, there is often special support for a fused bias and clamp,
+    which is why they are included.
+  }];
+}
+
 #endif  // FXPMATH_OPS
index fd84701..81eeb13 100644 (file)
@@ -735,6 +735,9 @@ class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
         ]>,
     "whose " # index # "-th element must be at least " # min>;
 
+def IsNullAttr : AttrConstraint<
+    CPred<"!$_self">, "empty attribute (for optional attributes)">;
+
 //===----------------------------------------------------------------------===//
 // OpTrait definitions
 //===----------------------------------------------------------------------===//
index 2ee39c9..3e8a47a 100644 (file)
@@ -331,8 +331,8 @@ struct UniformRealAddEwPattern : public RewritePattern {
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const {
     auto addOp = op->cast<RealAddEwOp>();
-    const UniformBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(),
-                                   addOp.clamp_max());
+    const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
+                                   addOp.clamp_min(), addOp.clamp_max());
     if (!info.isValid()) {
       return matchFailure();
     }
@@ -353,8 +353,8 @@ struct UniformRealMulEwPattern : public RewritePattern {
   PatternMatchResult matchAndRewrite(Operation *op,
                                      PatternRewriter &rewriter) const {
     auto mulOp = op->cast<RealMulEwOp>();
-    const UniformBinaryOpInfo info(op, mulOp.x(), mulOp.y(), mulOp.clamp_min(),
-                                   mulOp.clamp_max());
+    const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
+                                   mulOp.clamp_min(), mulOp.clamp_max());
     if (!info.isValid()) {
       return matchFailure();
     }