[MLIR][Shape] Add custom assembly format for `shape.any`
authorFrederik Gossen <frgossen@google.com>
Fri, 14 Aug 2020 08:33:58 +0000 (08:33 +0000)
committerFrederik Gossen <frgossen@google.com>
Fri, 14 Aug 2020 09:15:15 +0000 (09:15 +0000)
Add custom assembly format for `shape.any` with variadic operands.

Differential Revision: https://reviews.llvm.org/D85306

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir

index ac07743..2e8f032 100644 (file)
@@ -581,6 +581,8 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative,
   let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
   let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
+  let assemblyFormat = "$inputs attr-dict `:` type($inputs) `->` type($result)";
+
   let hasFolder = 1;
 }
 
index 8a38b42..670d207 100644 (file)
@@ -428,7 +428,7 @@ func @f(%arg : !shape.shape) -> !shape.shape {
   // CHECK-NEXT: %[[CS:.*]] = shape.const_shape
   // CHECK-NEXT: return %[[CS]]
   %0 = shape.const_shape [2, 3, 4] : !shape.shape
-  %1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape
+  %1 = shape.any %0, %arg : !shape.shape, !shape.shape -> !shape.shape
   return %1 : !shape.shape
 }
 
@@ -440,7 +440,7 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
   // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
   // CHECK-NEXT: return %[[CS]] : tensor<?xindex>
   %0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
-  %1 = "shape.any"(%0, %arg) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>
+  %1 = shape.any %0, %arg : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
   return %1 : tensor<?xindex>
 }
 
@@ -449,9 +449,9 @@ func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
 // Folding of any with partially constant operands is not yet implemented.
 // CHECK-LABEL: func @f
 func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
-  // CHECK-NEXT: %[[CS:.*]] = "shape.any"
+  // CHECK-NEXT: %[[CS:.*]] = shape.any
   // CHECK-NEXT: return %[[CS]]
-  %1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape
+  %1 = shape.any %arg0, %arg1 : !shape.shape, !shape.shape -> !shape.shape
   return %1 : !shape.shape
 }
 
index 172835a..c13f89d 100644 (file)
@@ -235,3 +235,26 @@ func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !sha
   %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
   return %2 : !shape.shape
 }
+
+func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
+    -> !shape.shape {
+  %result = shape.any %a, %b, %c
+      : !shape.shape, !shape.shape, !shape.shape -> !shape.shape
+  return %result : !shape.shape
+}
+
+func @any_on_mixed(%a : tensor<?xindex>,
+                   %b : tensor<?xindex>,
+                   %c : !shape.shape) -> !shape.shape {
+  %result = shape.any %a, %b, %c
+      : tensor<?xindex>, tensor<?xindex>, !shape.shape -> !shape.shape
+  return %result : !shape.shape
+}
+
+func @any_on_extent_tensors(%a : tensor<?xindex>,
+                            %b : tensor<?xindex>,
+                            %c : tensor<?xindex>) -> tensor<?xindex> {
+  %result = shape.any %a, %b, %c
+      : tensor<?xindex>, tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+  return %result : tensor<?xindex>
+}