Add DivOp to the Shape dialect
authorJing Pu <jingpu@google.com>
Fri, 19 Feb 2021 00:58:47 +0000 (16:58 -0800)
committerJacques Pienaar <jpienaar@google.com>
Fri, 19 Feb 2021 00:58:47 +0000 (16:58 -0800)
Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D96907

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir

index 30e8ca1..ecd7f3a 100644 (file)
@@ -138,6 +138,36 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
   let hasFolder = 1;
 }
 
+def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
+  let summary = "Division of sizes and indices";
+  let description = [{
+    Divides two sizes or indices. If either operand is an error it will be
+    propagated to the result. The operands can be of type `size` or `index`.
+    If at least one of the operands can hold an error, i.e. if it is of type
+    `size`, the result must be of type `size`. If error propagation is not
+    possible because both operands are of type `index` then the result may be
+    of type  `size` or `index`. If both operands and result are of type `index`,
+    their runtime values could be negative. The result is rounded toward
+    negative infinity, i.e. floor(lhs / rhs), such that
+
+        div(lhs, rhs) * rhs + mod(lhs, rhs) = lhs
+
+    always holds. If any of the values is of type `size`, the behavior for
+    negative value is undefined.
+  }];
+
+  let arguments = (ins Shape_SizeOrIndexType:$lhs,
+                       Shape_SizeOrIndexType:$rhs);
+  let results = (outs Shape_SizeOrIndexType:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+  }];
+
+  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
+  let hasFolder = 1;
+}
+
 def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
   let summary = "Returns whether the input shapes or extent tensors are equal";
   let description = [{
index b1199fb..6146239 100644 (file)
@@ -601,6 +601,30 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// DivOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
+  auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
+  if (!lhs)
+    return nullptr;
+  auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
+  if (!rhs)
+    return nullptr;
+
+  // Division in APInt does not follow floor(lhs, rhs) when the result is
+  // negative. Rather, APInt rounds toward zero.
+  APInt quotient, remainder;
+  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
+  if (quotient.isNegative() && !remainder.isNullValue()) {
+    quotient -= 1;
+  }
+
+  Type indexTy = IndexType::get(getContext());
+  return IntegerAttr::get(indexTy, quotient);
+}
+
+//===----------------------------------------------------------------------===//
 // ShapeEqOp
 //===----------------------------------------------------------------------===//
 
index 3a04c16..d5bf3f7 100644 (file)
@@ -950,6 +950,71 @@ func @fold_mul_mixed() -> !shape.size {
 
 // -----
 
+// Fold `div` for constant sizes.
+// CHECK-LABEL: @fold_div_size
+func @fold_div_size() -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 3
+  // CHECK: return %[[RESULT]] : !shape.size
+  %c2 = shape.const_size 10
+  %c3 = shape.const_size 3
+  %result = shape.div %c2, %c3 : !shape.size, !shape.size -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
+// Fold `div` for constant indices.
+// CHECK-LABEL: @fold_div_index
+func @fold_div_index() -> index {
+  // CHECK: %[[RESULT:.*]] = constant 2 : index
+  // CHECK: return %[[RESULT]] : index
+  %c2 = constant 10 : index
+  %c3 = constant 4 : index
+  %result = shape.div %c2, %c3 : index, index -> index
+  return %result : index
+}
+
+// -----
+
+// Fold `div` for constant indices and lhs is negative.
+// CHECK-LABEL: @fold_div_index_neg_lhs
+func @fold_div_index_neg_lhs() -> index {
+  // CHECK: %[[RESULT:.*]] = constant -3 : index
+  // CHECK: return %[[RESULT]] : index
+  %c2 = constant -10 : index
+  %c3 = constant 4 : index
+  %result = shape.div %c2, %c3 : index, index -> index
+  return %result : index
+}
+
+// -----
+
+// Fold `div` for constant indices and rhs is negative.
+// CHECK-LABEL: @fold_div_index_neg_rhs
+func @fold_div_index_neg_rhs() -> index {
+  // CHECK: %[[RESULT:.*]] = constant -3 : index
+  // CHECK: return %[[RESULT]] : index
+  %c2 = constant 10 : index
+  %c3 = constant -4 : index
+  %result = shape.div %c2, %c3 : index, index -> index
+  return %result : index
+}
+
+// -----
+
+// Fold `div` for mixed constants.
+// CHECK-LABEL: @fold_div_mixed
+func @fold_div_mixed() -> !shape.size {
+  // CHECK: %[[RESULT:.*]] = shape.const_size 4
+  // CHECK: return %[[RESULT]] : !shape.size
+  %c2 = shape.const_size 12
+  %c3 = constant 3 : index
+  %result = shape.div %c2, %c3 : !shape.size, index -> !shape.size
+  return %result : !shape.size
+}
+
+// -----
+
 // Fold index_cast when already on index.
 // CHECK-LABEL: @fold_index_cast_on_index
 func @fold_index_cast_on_index(%arg: index) -> index {
index 57195ba..1a5735c 100644 (file)
@@ -129,6 +129,15 @@ func @mul(%size_arg : !shape.size, %index_arg : index) {
   return
 }
 
+func @div(%size_arg : !shape.size, %index_arg : index) {
+  %size_div = shape.div %size_arg, %size_arg
+      : !shape.size, !shape.size -> !shape.size
+  %index_div = shape.div %index_arg, %index_arg : index, index -> index
+  %mixed_div = shape.div %size_arg, %index_arg
+      : !shape.size, index -> !shape.size
+  return
+}
+
 func @add(%size_arg : !shape.size, %index_arg : index) {
   %size_sum = shape.add %size_arg, %size_arg
       : !shape.size, !shape.size -> !shape.size