[mlir][shape] Further operand and result type generalization
authorJacques Pienaar <jpienaar@google.com>
Sun, 26 Jul 2020 04:37:15 +0000 (21:37 -0700)
committerJacques Pienaar <jpienaar@google.com>
Sun, 26 Jul 2020 04:41:31 +0000 (21:41 -0700)
Previous changes generalized some of the operands and results. Complete
a larger group of those to simplify progressive lowering. Also update
some of the declarative asm form due to generalization. Tried to keep it
mostly mechanical.

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir

index 3c50a4f..7b676a2 100644 (file)
@@ -86,11 +86,12 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
     broadcastable output shape possible for the given inputs.
   }];
 
-  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs,
-                   OptionalAttr<StrAttr>:$error);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
+                       Shape_ShapeOrExtentTensorType:$rhs,
+                       OptionalAttr<StrAttr>:$error);
   let results = (outs Shape_ShapeType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict";
+  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
 
   let hasFolder = 1;
 }
@@ -220,10 +221,10 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
     If the shape represents an error, this op's behavior is undefined.
   }];
 
-  let arguments = (ins Shape_ShapeType:$input);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$input);
   let results = (outs IndexTensor:$result);
 
-  let assemblyFormat = "attr-dict $input `:` type($result)";
+  let assemblyFormat = "attr-dict $input `:` type($input) `->` type($result)";
 
   let hasFolder = 1;
 }
@@ -342,6 +343,10 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
   let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
   let results = (outs Shape_SizeOrIndexType:$result);
 
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &result, Value shape">,
+  ];
+
   let assemblyFormat = "$shape `:` type($shape) `->` type($result) attr-dict";
 
   let hasFolder = 1;
@@ -412,23 +417,28 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
 
   let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
 
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &result, Value arg">
+  ];
+
   let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
 
 def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
   let summary = "Casts between index types of the shape and standard dialect";
   let description = [{
-    Converts a `shape.size` to a standard index.
-    This operation and its inverse, `index_to_size`, facilitate index conversion
-    between the standard and the shape dialect.
-    The behavior is undefined for unknown and invalid arguments.
+    Converts a `shape.size` to a standard index. This operation and its
+    inverse, `index_to_size`, facilitate index conversion between the standard
+    and the shape dialect. The behavior is undefined for unknown and invalid
+    arguments.
   }];
 
-  let arguments = (ins Shape_SizeType:$arg);
+  let arguments = (outs Shape_SizeOrIndexType:$arg);
   let results = (outs Index:$result);
 
-  let assemblyFormat = "$arg attr-dict";
+  let assemblyFormat = "$arg attr-dict `:` type($arg)";
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -490,7 +500,7 @@ def Shape_SplitAtOp : Shape_Op<"split_at", []> {
     - `index` is in the range [-rank(operand),rank(operand)]
   }];
 
-  let arguments = (ins Shape_ShapeType:$operand, I32:$index);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, I32:$index);
   let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
   let hasFolder = 1;
 }
@@ -520,8 +530,7 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
 
 // TODO: Move the code below and witnesses to a different file.
 def Shape_AnyOp : Shape_Op<"any", [Commutative,
-                                   NoSideEffect,
-                                   SameOperandsAndResultType]> {
+                                   NoSideEffect]> {
   let summary = "Return any combination of the input shapes";
   let description = [{
     This operation takes multiple input shapes or extent tensors and returns
@@ -541,7 +550,6 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative,
   let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
   let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
-  let assemblyFormat = "$inputs `:` type($result)  attr-dict";
   let hasFolder = 1;
 }
 
index 104ab46..4887c87 100644 (file)
@@ -674,6 +674,16 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
   return builder.getIndexAttr(product.getLimitedValue());
 }
 
+void NumElementsOp::build(OpBuilder &builder, OperationState &result,
+                          Value shape) {
+  if (shape.getType().isa<ShapedType>()) {
+    auto type = builder.getIndexType();
+    return build(builder, result, type, shape);
+  }
+  auto type = SizeType::get(builder.getContext());
+  return build(builder, result, type, shape);
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//
@@ -702,6 +712,38 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
   return builder.getIndexTensorAttr(type.getShape());
 }
 
+void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
+  if (arg.getType().isa<ShapedType>()) {
+    auto type = RankedTensorType::get({ShapedType::kDynamicSize},
+                                      builder.getIndexType());
+    return ShapeOfOp::build(builder, result, type, arg);
+  }
+  auto type = ShapeType::get(builder.getContext());
+  return ShapeOfOp::build(builder, result, type, arg);
+}
+
+namespace {
+struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+  using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!op.arg().getType().isa<ShapedType>())
+      return failure();
+    if (op.getType().isa<ShapedType>())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(), op.arg());
+    return success();
+  }
+};
+} // namespace
+
+void ShapeOfOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+                                            MLIRContext *context) {
+  patterns.insert<ShapeOfWithTensor>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // SizeToIndexOp
 //===----------------------------------------------------------------------===//
index 8236c6f..9336402 100644 (file)
@@ -50,7 +50,6 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
   // CHECK-DAG: %[[C2:.*]] = constant 2 : index
   // CHECK-DAG: %[[C3:.*]] = constant 3 : index
   // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
-  // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
   %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
   return
 }
@@ -66,7 +65,6 @@ func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
   // CHECK-DAG: %[[C2:.*]] = constant 2 : index
   // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
   // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
-  // CHECK-DAG: %[[SHAPE:.*]] = tensor_cast %[[SHAPE_UNCASTED]] : tensor<3xindex> to tensor<?xindex>
   %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
   return
 }
@@ -120,7 +118,7 @@ func @any_of_three(%a : tensor<?xindex>,
                    %b : tensor<?xindex>,
                    %c : tensor<?xindex>) -> tensor<?xindex> {
   // CHECK: return %[[A]] : tensor<?xindex>
-  %result = shape.any %a, %b, %c : tensor<?xindex>
+  %result = "shape.any"(%a, %b, %c) : (tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
   return %result : tensor<?xindex>
 }
 
@@ -131,7 +129,7 @@ func @any_of_three(%a : tensor<?xindex>,
 // CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
 func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
   // CHECK: return %[[A]] : tensor<?xindex>
-  %result = shape.any %a : tensor<?xindex>
+  %result = "shape.any"(%a) : (tensor<?xindex>) -> tensor<?xindex>
   return %result : tensor<?xindex>
 }
 
index e147fbe..5fe2ac1 100644 (file)
@@ -54,7 +54,7 @@ func @f() -> !shape.shape {
   // CHECK: shape.const_shape [7, 2] : !shape.shape
   %0 = shape.const_shape [1, 2] : !shape.shape
   %1 = shape.const_shape [7, 1] : !shape.shape
-  %2 = shape.broadcast %0, %1
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
   return %2 : !shape.shape
 }
 
@@ -65,7 +65,7 @@ func @f() -> !shape.shape {
 func @f(%arg0 : !shape.shape) -> !shape.shape {
   // CHECK: return %arg0
   %0 = shape.const_shape [] : !shape.shape
-  %1 = shape.broadcast %arg0, %0
+  %1 = shape.broadcast %arg0, %0 : !shape.shape, !shape.shape
   return %1 : !shape.shape
 }
 
@@ -76,7 +76,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
 func @f(%arg0 : !shape.shape) -> !shape.shape {
   // CHECK: return %arg0
   %0 = shape.const_shape [] : !shape.shape
-  %1 = shape.broadcast %0, %arg0
+  %1 = shape.broadcast %0, %arg0 : !shape.shape, !shape.shape
   return %1 : !shape.shape
 }
 
@@ -89,7 +89,7 @@ func @f() -> !shape.shape {
   // CHECK: return %[[CST]]
   %0 = shape.const_shape [] : !shape.shape
   %1 = shape.const_shape [1, 2, 3] : !shape.shape
-  %2 = shape.broadcast %0, %1
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
   return %2 : !shape.shape
 }
 
@@ -101,7 +101,7 @@ func @f() -> !shape.shape {
   // CHECK: shape.broadcast
   %0 = shape.const_shape [2] : !shape.shape
   %1 = shape.const_shape [7] : !shape.shape
-  %2 = shape.broadcast %0, %1
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
   return %2 : !shape.shape
 }
 
@@ -124,7 +124,7 @@ func @f() -> !shape.shape {
 func @f() -> tensor<2xindex> {
   // CHECK: constant dense<[0, 1]> : tensor<2xindex>
   %cs = shape.const_shape [0, 1] : !shape.shape
-  %0 = shape.to_extent_tensor %cs : tensor<2xindex>
+  %0 = shape.to_extent_tensor %cs : !shape.shape -> tensor<2xindex>
   return %0 : tensor<2xindex>
 }
 
@@ -159,7 +159,7 @@ func @const_size_to_index() -> index {
   // CHECK-NOT: shape.index_cast
   %cs = shape.const_size 123
   // CHECK: constant 123 : index
-  %ci = shape.size_to_index %cs
+  %ci = shape.size_to_index %cs : !shape.size
   return %ci : index
 }
 
@@ -185,7 +185,7 @@ func @const_index_to_size_to_index() -> index {
   %cs0 = shape.index_to_size %ci0
   // CHECK: %[[CI:.*]] = constant 123 : index
   // CHECK-NEXT: return %[[CI]] : index
-  %ci1 = shape.size_to_index %cs0
+  %ci1 = shape.size_to_index %cs0 : !shape.size
   return %ci1 : index
 }
 
@@ -195,7 +195,7 @@ func @const_index_to_size_to_index() -> index {
 // CHECK-LABEL: func @nonfoldable_size_to_index
 func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
   // CHECK: shape.size_to_index
-  %ci = shape.size_to_index %cs
+  %ci = shape.size_to_index %cs : !shape.size
   return %ci : index
 }
 
@@ -403,7 +403,7 @@ func @f(%arg : !shape.shape) -> !shape.shape {
   // CHECK-NEXT: %[[CS:.*]] = shape.const_shape
   // CHECK-NEXT: return %[[CS]]
   %0 = shape.const_shape [2, 3, 4] : !shape.shape
-  %1 = shape.any %0, %arg : !shape.shape
+  %1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape
   return %1 : !shape.shape
 }
 
@@ -415,7 +415,7 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
   // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
   // CHECK-NEXT: return %[[CS]] : tensor<?xindex>
   %0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
-  %1 = shape.any %0, %arg : tensor<?xindex>
+  %1 = "shape.any"(%0, %arg) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
   return %1 : tensor<?xindex>
 }
 
@@ -424,9 +424,9 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
 // Folding of any with partially constant operands is not yet implemented.
 // CHECK-LABEL: func @f
 func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
-  // CHECK-NEXT: %[[CS:.*]] = shape.any
+  // CHECK-NEXT: %[[CS:.*]] = "shape.any"
   // CHECK-NEXT: return %[[CS]]
-  %1 = shape.any %arg0, %arg1 : !shape.shape
+  %1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape
   return %1 : !shape.shape
 }
 
@@ -619,7 +619,7 @@ func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> index {
 func @index_to_size_to_index(%index : index) -> index {
   // CHECK: return %[[IDX]] : index
   %size = shape.index_to_size %index
-  %result = shape.size_to_index %size
+  %result = shape.size_to_index %size : !shape.size
   return %result : index
 }
 
@@ -630,7 +630,7 @@ func @index_to_size_to_index(%index : index) -> index {
 // CHECK-SAME: (%[[SIZE:.*]]: !shape.size) -> !shape.size
 func @size_to_index_to_size(%size : !shape.size) -> !shape.size {
   // CHECK: return %[[SIZE]] : !shape.size
-  %idx = shape.size_to_index %size
+  %idx = shape.size_to_index %size : !shape.size
   %result = shape.index_to_size %idx
   return %result : !shape.size
 }
index f578260..87af623 100644 (file)
@@ -49,7 +49,7 @@ func @test_shape_num_elements_fixed() {
 func @test_broadcast_fixed() {
   %0 = shape.const_shape [10, 1, 57, 92] : !shape.shape
   %1 = shape.const_shape [4, 57, 92] : !shape.shape
-  %2 = shape.broadcast %0, %1
+  %2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape
   %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
   return
 }
@@ -99,7 +99,7 @@ func @test_constraints() {
   %w3 = shape.const_witness false
   %w4 = shape.assuming_all %w0, %w1, %w2, %w3
   shape.assuming %w4 -> !shape.shape {
-    %2 = shape.any %0, %1 : !shape.shape
+    %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
     shape.assuming_yield %2 : !shape.shape
   }
   return
@@ -131,7 +131,7 @@ func @const_size() {
 }
 
 func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
-  %0 = shape.to_extent_tensor %arg : tensor<3xindex>
+  %0 = shape.to_extent_tensor %arg : !shape.shape -> tensor<3xindex>
   return %0 : tensor<3xindex>
 }
 
@@ -188,10 +188,10 @@ func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
 func @any() {
   %0 = shape.const_shape [1, 2, 3] : !shape.shape
   %1 = shape.const_shape [4, 5, 6] : !shape.shape
-  %2 = shape.any %0, %1 : !shape.shape
+  %2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
   %3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
   %4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
-  %5 = shape.any %3, %4 : tensor<?xindex>
+  %5 = "shape.any"(%3, %4) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
   return
 }