[mlir][tosa] Add remaining tosa comparison folders
authorRob Suderman <suderman@google.com>
Thu, 1 Sep 2022 21:10:19 +0000 (14:10 -0700)
committerRob Suderman <suderman@google.com>
Thu, 1 Sep 2022 21:48:46 +0000 (14:48 -0700)
Added numerical splat folders for comparison operations and
equal of two identical int values.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D133138

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir

index c19c9e4..c50fee2 100644 (file)
@@ -1143,6 +1143,8 @@ def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
     /// InferTypeOpInterface.
     static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1191,6 +1193,8 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
   let results = (outs
     I1Tensor:$output
   );
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
index 385e406..abe33f8 100644 (file)
@@ -675,6 +675,11 @@ struct APIntFoldGreater {
   APIntFoldGreater() {}
   APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); }
 };
+
+struct APIntFoldGreaterEqual {
+  APIntFoldGreaterEqual() {}
+  APInt operator()(APInt l, APInt r) { return APInt(1, l.sge(r)); }
+};
 } // namespace
 
 OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
@@ -689,6 +694,42 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
       lhsAttr, rhsAttr, resultTy);
 }
 
+OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+  if (!lhsAttr || !rhsAttr)
+    return {};
+
+  return BinaryFolder<APIntFoldGreaterEqual,
+                      ComparisonFold<std::greater_equal<APFloat>>>(
+      lhsAttr, rhsAttr, resultTy);
+}
+
+OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  Value lhs = getInput1();
+  Value rhs = getInput2();
+  auto lhsTy = lhs.getType().cast<ShapedType>();
+
+  // If we are comparing an integer value to itself it is always true. We can
+  // not do this with float due to float values.
+  if (lhsTy.getElementType().isa<IntegerType>() && resultTy.hasStaticShape() &&
+      lhs == rhs) {
+    return DenseElementsAttr::get(resultTy, true);
+  }
+
+  if (!lhsAttr || !rhsAttr)
+    return {};
+
+  return BinaryFolder<ComparisonFold<std::equal_to<APInt>>,
+                      ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
+                                                              resultTy);
+}
+
 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   if (getInput().getType() == getType())
     return getInput();
index 43f2196..28be1ff 100644 (file)
@@ -350,50 +350,108 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
 
 // -----
 
-// CHECK-LABEL: @fold_greater_splat_f32_true
-func.func @fold_greater_splat_f32_true() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_f32_false
-func.func @fold_greater_splat_f32_false() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_i32_false
-func.func @fold_greater_splat_i32_false() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_i32_true
-func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
+// CHECK-LABEL: @fold_greater_splat_f32
+func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = "tosa.greater"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = "tosa.greater"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_i32
+func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
+  %false = "tosa.greater"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %true = "tosa.greater"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK: return %[[FALSE]], %[[TRUE]]
+  return %false, %true : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_eq_splat_f32
+func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = "tosa.greater_equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = "tosa.greater_equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_eq_splat_i32
+func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %true = "tosa.greater_equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %false = "tosa.greater_equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_splat_f32
+func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = "tosa.equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = "tosa.equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_splat_i32
+func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %true = "tosa.equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %false = "tosa.equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_i32
+func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) {
+  // CHECK: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  %0 = "tosa.equal"(%arg0, %arg0) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK: return %[[TRUE]]
+  return %0 : tensor<10xi1>
 }
 
 // -----