Update the FxpMathOps to better reflect what is needed to legalize from XLA.
authorStella Laurenzo <laurenzo@google.com>
Wed, 10 Apr 2019 19:37:45 +0000 (12:37 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 11 Apr 2019 17:52:51 +0000 (10:52 -0700)
--

PiperOrigin-RevId: 242919924

mlir/include/mlir/FxpMathOps/FxpMathOps.td

index 24d5e6f..083f3d9 100644 (file)
@@ -15,7 +15,7 @@
 // limitations under the License.
 // =============================================================================
 //
-// This is the operation definition file for fixed point ops (and real 
+// This is the operation definition file for fixed point ops (and real
 // equivalents).
 //
 //===----------------------------------------------------------------------===//
 #ifdef OP_BASE
 #else
 include "mlir/IR/OpBase.td"
-include "mlir/Quantization/QuantPredicates.td"
 #endif // OP_BASE
 
+include "mlir/Quantization/QuantPredicates.td"
+
 //===----------------------------------------------------------------------===//
 // Attributes
 //===----------------------------------------------------------------------===//
@@ -45,20 +46,40 @@ def fxpmath_EwUnaryFnAttr :
 }
 
 class fxpmath_ConstEwUnaryFn<string val> : ConstantAttr<fxpmath_EwUnaryFnAttr, val>;
-def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">;
-def fxpmath_EwUnaryFn_Tanh    : fxpmath_ConstEwUnaryFn<"TANH">;
-def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">;
+def fxpmath_EwUnaryFn_Abs     : fxpmath_ConstEwUnaryFn<"ABS">;
 def fxpmath_EwUnaryFn_Exp     : fxpmath_ConstEwUnaryFn<"EXP">;
+def fxpmath_EwUnaryFn_Identity: fxpmath_ConstEwUnaryFn<"IDENTITY">;
 def fxpmath_EwUnaryFn_Log     : fxpmath_ConstEwUnaryFn<"LOG">;
 def fxpmath_EwUnaryFn_Neg     : fxpmath_ConstEwUnaryFn<"NEG">;
 def fxpmath_EwUnaryFn_Rsqrt   : fxpmath_ConstEwUnaryFn<"RSQRT">;
+def fxpmath_EwUnaryFn_Sigmoid : fxpmath_ConstEwUnaryFn<"SIGMOID">;
+def fxpmath_EwUnaryFn_Sign    : fxpmath_ConstEwUnaryFn<"SIGN">;
 def fxpmath_EwUnaryFn_Sin     : fxpmath_ConstEwUnaryFn<"SIN">;
-def fxpmath_EwUnaryFn_Square  : fxpmath_ConstEwUnaryFn<"SQUARE">;
 def fxpmath_EwUnaryFn_Sqrt    : fxpmath_ConstEwUnaryFn<"SQRT">;
-def fxpmath_EwUnaryFn_CmpZ    : fxpmath_ConstEwUnaryFn<"CMPZ">;
-def fxpmath_EwUnaryFn_CmpNZ   : fxpmath_ConstEwUnaryFn<"CMPNZ">;
-def fxpmath_EwUnaryFn_CmpLZ   : fxpmath_ConstEwUnaryFn<"CMPLZ">;
-def fxpmath_EwUnaryFn_CmpGZ   : fxpmath_ConstEwUnaryFn<"CMPGZ">;
+def fxpmath_EwUnaryFn_Square  : fxpmath_ConstEwUnaryFn<"SQUARE">;
+def fxpmath_EwUnaryFn_Tanh    : fxpmath_ConstEwUnaryFn<"TANH">;
+
+//===----------------------------------------------------------------------===//
+// Comparison functions (compares relative to zero on a subtraction result).
+//===----------------------------------------------------------------------===//
+
+def fxpmath_CompareZ    : EnumAttrCase<"CMPZ">;
+def fxpmath_CompareNZ   : EnumAttrCase<"CMPNZ">;
+def fxpmath_CompareLZ   : EnumAttrCase<"CMPLZ">;
+def fxpmath_CompareLZE  : EnumAttrCase<"CMPLZE">;
+def fxpmath_CompareGZ   : EnumAttrCase<"CMPGZ">;
+def fxpmath_CompareGZE  : EnumAttrCase<"CMPGZE">;
+
+def fxpmath_CompareFnAttr : EnumAttr<"ComparisonFn",
+    "Type of subtraction-result comparison to perform.",
+    [
+      fxpmath_CompareZ,
+      fxpmath_CompareNZ,
+      fxpmath_CompareLZ,
+      fxpmath_CompareLZE,
+      fxpmath_CompareGZ,
+      fxpmath_CompareGZE
+    ]>;
 
 //===----------------------------------------------------------------------===//
 // Base classes
@@ -148,9 +169,18 @@ class fxpmath_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
 // Element wise binary real math ops.
 //===----------------------------------------------------------------------===//
 
+// The broadcasting dimensions correspond to a tuple that describes how a
+// smaller rank shape is broadcast into a larger rank shape. For example,
+// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
+// matching the matrix to dimensions 1 and 2 of the cuboid.
+def fxpmath_BroadcastDimAttr : OptionalAttr<ElementsAttr>;
+
 class fxpmath_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
     fxpmath_RealMathOp<mnemonic, traits,
-                     (ins quant_RealValueType:$x, quant_RealValueType:$y)>,
+                     (ins quant_RealValueType:$x,
+                      quant_RealValueType:$y,
+                      fxpmath_BroadcastDimAttr:$broadcast_dimensions
+                     )>,
     Results<(outs quant_RealValueType:$r)>;
 
 class fxpmath_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
@@ -180,4 +210,13 @@ def fxpmath_RealUnaryEwOp :
         (ins quant_RealValueType:$x, fxpmath_EwUnaryFnAttr:$fn)>,
     Results<(outs quant_RealValueType:$r)>;
 
+def fxpmath_RealCompareZeroEwOp : fxpmath_Op<"compare", [NoSideEffect]>,
+    Arguments<(ins quant_RealValueType:$x, fxpmath_CompareFnAttr:$fn)>,
+    Results<(outs I1Tensor:$r)> {
+  let description = [{
+    Compares a real value to zero, returning an I1 (boolean) tensor with the
+    result of applying the comparison function.
+  }];
+}
+
 #endif  // FXPMATH_OPS