Add ArithBuilder::sub, make add, mul work with IndexTypes.
authorJohannes Reifferscheid <jreiffers@google.com>
Mon, 5 Sep 2022 09:25:58 +0000 (11:25 +0200)
committerJohannes Reifferscheid <jreiffers@google.com>
Mon, 5 Sep 2022 10:44:19 +0000 (12:44 +0200)
sgt and slt already worked with IndexTypes, the others did not.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D133285

mlir/include/mlir/Dialect/Arithmetic/Utils/Utils.h
mlir/lib/Dialect/Arithmetic/Utils/Utils.cpp

index 924de08..fd11bd3 100644 (file)
@@ -99,6 +99,7 @@ struct ArithBuilder {
 
   Value _and(Value lhs, Value rhs);
   Value add(Value lhs, Value rhs);
+  Value sub(Value lhs, Value rhs);
   Value mul(Value lhs, Value rhs);
   Value select(Value cmp, Value lhs, Value rhs);
   Value sgt(Value lhs, Value rhs);
index b568891..ca61669 100644 (file)
@@ -93,24 +93,29 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
   return b.create<arith::AndIOp>(loc, lhs, rhs);
 }
 Value ArithBuilder::add(Value lhs, Value rhs) {
-  if (lhs.getType().isa<IntegerType>())
-    return b.create<arith::AddIOp>(loc, lhs, rhs);
-  return b.create<arith::AddFOp>(loc, lhs, rhs);
+  if (lhs.getType().isa<FloatType>())
+    return b.create<arith::AddFOp>(loc, lhs, rhs);
+  return b.create<arith::AddIOp>(loc, lhs, rhs);
+}
+Value ArithBuilder::sub(Value lhs, Value rhs) {
+  if (lhs.getType().isa<FloatType>())
+    return b.create<arith::SubFOp>(loc, lhs, rhs);
+  return b.create<arith::SubIOp>(loc, lhs, rhs);
 }
 Value ArithBuilder::mul(Value lhs, Value rhs) {
-  if (lhs.getType().isa<IntegerType>())
-    return b.create<arith::MulIOp>(loc, lhs, rhs);
-  return b.create<arith::MulFOp>(loc, lhs, rhs);
+  if (lhs.getType().isa<FloatType>())
+    return b.create<arith::MulFOp>(loc, lhs, rhs);
+  return b.create<arith::MulIOp>(loc, lhs, rhs);
 }
 Value ArithBuilder::sgt(Value lhs, Value rhs) {
-  if (lhs.getType().isa<IndexType, IntegerType>())
-    return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
-  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
+  if (lhs.getType().isa<FloatType>())
+    return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
+  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
 }
 Value ArithBuilder::slt(Value lhs, Value rhs) {
-  if (lhs.getType().isa<IndexType, IntegerType>())
-    return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
-  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
+  if (lhs.getType().isa<FloatType>())
+    return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
+  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
 }
 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);