[Linalg] Change attribute n_loop_types to iterator
authorJose Ignacio Gomez <jigomez@ucm.es>
Thu, 28 Nov 2019 09:59:22 +0000 (01:59 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 28 Nov 2019 09:59:55 +0000 (01:59 -0800)
This addresses issue tensorflow/mlir#270. Linalg is updated to take the same form
of iterator_types than vector contraction.

Closes tensorflow/mlir#280

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/280 from tetuante:PRissue270 d26d88d090d3765d3b9884bfabdd023143f27287
PiperOrigin-RevId: 282905396

mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
mlir/test/Dialect/Linalg/fusion.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/llvm.mlir
mlir/test/Dialect/Linalg/loops.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile.mlir
mlir/test/Dialect/Linalg/transform-patterns.mlir

index 92b325b5943187c534a1428c0c69d25addaf787e..e0070a8da35bbaf03c5439bb67429b585fba8428 100644 (file)
@@ -368,7 +368,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
 class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
   let arguments = (ins Variadic<AnyStridedMemRef>:$views,
                    AffineMapArrayAttr:$indexing_maps,
-                   I64ArrayAttr:$n_loop_types,
+                   ArrayAttr:$iterator_types,
                    I64ArrayAttr:$n_views,
                    OptionalAttr<StrAttr>:$doc,
                    OptionalAttr<FlatSymbolRefAttr>:$fun,
@@ -377,7 +377,7 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
   let extraClassDeclaration = [{
     SmallVector<StringRef, 8> linalgTraitAttrNames() {
       return SmallVector<StringRef, 8>{
-        "doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views"
+        "doc", "fun", "indexing_maps", "library_call", "iterator_types", "n_views"
       };
     }
     unsigned getNumInputs() {
@@ -395,26 +395,35 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
       return val.getZExtValue();
     }
     unsigned getNumParallelLoops() {
-      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+      if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
         return 0;
-      auto val = n_loop_types().getValue()[0].cast<IntegerAttr>().getValue();
-      assert(val.getSExtValue() >= 0);
-      return val.getZExtValue();
+      unsigned nPar = 0;
+      for (auto ty : iterator_types()) {
+        if (ty.cast<StringAttr>().getValue() == "parallel")
+          nPar++;
+      }
+      return nPar;
     }
     unsigned getNumReductionLoops() {
-      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+      if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
         return 0;
-      auto val = n_loop_types().getValue()[1].cast<IntegerAttr>().getValue();
-      assert(val.getSExtValue() >= 0);
-      return val.getZExtValue();
-    }
-    unsigned getNumWindowLoops() {
-      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+      unsigned nRed = 0;
+      for (auto ty : iterator_types()) {
+        if (ty.cast<StringAttr>().getValue() == "reduction")
+          nRed++;
+      }
+      return nRed;
+   }
+   unsigned getNumWindowLoops() {
+      if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
         return 0;
-      auto val = n_loop_types().getValue()[2].cast<IntegerAttr>().getValue();
-      assert(val.getSExtValue() >= 0);
-      return val.getZExtValue();
-    }
+      unsigned nWin = 0;
+      for (auto ty : iterator_types()) {
+        if (ty.cast<StringAttr>().getValue() == "window")
+          nWin++;
+      }
+      return nWin;
+   }
     unsigned getNumLoops() {
       return getNumParallelLoops() + getNumReductionLoops() +
         getNumWindowLoops();
@@ -474,8 +483,9 @@ def GenericOp : GenericOpBase<"generic"> {
         The external library is assumed to be dynamically linked and no strong
         compile-time guarantees are provided. In the absence of such a library
         call, linalg.generic will always lower to loops.
-      - n_loops: a triple of I64Attr representing the number of enclosing
-        [parallel, reduction, window] loops respectively.
+      - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
+        the list represents and iterator of one of the following types:
+        parallel, reduction, window
       - n_views: a pair of I64Attr representing the number of input (readonly)
         and output (readwrite) views.
 
@@ -498,7 +508,7 @@ def GenericOp : GenericOpBase<"generic"> {
           indexing_maps = #matmul_accesses,
           library_call = "linalg_matmul",
           n_views = [2, 1],
-          n_loop_types = [2, 1, 0]
+          iterator_types = ["parallel", "parallel", "reduction"]
         }
       ```
 
@@ -568,8 +578,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
         maps to.  The external library is assumed to be dynamically linked and
         no strong compile-time guarantees are provided. In the absence of such
         a library call, linalg.indexed_generic will always lower to loops.
-      - n_loops: a triple of I64Attr representing the number of enclosing
-        [parallel, reduction, window] loops respectively.
+      - iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
+        the list represents and iterator of one of the following types:
+        parallel, reduction, window
       - n_views: a pair of I64Attr representing the number of input (readonly)
         and output (readwrite) views.
 
@@ -592,7 +603,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
           indexing_maps = #matmul_accesses,
           library_call = "linalg_matmul",
           n_views = [2, 1],
-          n_loop_types = [2, 1, 0]
+          iterator_types = ["parallel", "parallel", "reduction"]
         }
       ```
 
index bd67cc5788c8e31a726029bdecfafbac470b1077..616b30d3d9d7ac1892e901d7dd76c6567a068806 100644 (file)
@@ -306,7 +306,7 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
 #id_2d = (i, j) -> (i, j)
 #pointwise_2d_trait = {
   indexing_maps = [#id_2d, #id_2d, #id_2d],
-  n_loop_types = [2, 0, 0],
+  iterator_types = ["parallel", "parallel"],
   n_views = [2, 1]
 }
 func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>) {
index 603c67ef0d7a63d433e492466e349af7dc5436f9..9ec345455dced59271b143c58986890618a16fc0 100644 (file)
@@ -60,7 +60,7 @@ func @generic_at_least_2_operands(%arg0: memref<f32>) {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [1, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0: memref<f32>
 }
 
@@ -72,7 +72,7 @@ func @generic_exactly_2_views(%arg0: memref<f32>) {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [1, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0, %arg0, %arg0: memref<f32>, memref<f32>, memref<f32>
 }
 
@@ -84,7 +84,7 @@ func @generic_undefined_fun(%arg0: memref<f32>) {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [1, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0, %arg0: memref<f32>, memref<f32>
 }
 
@@ -98,7 +98,7 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0: memref<f32>
 }
 
@@ -112,7 +112,7 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0: memref<f32>
 }
 
@@ -126,7 +126,7 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {
     fun = @foo,
     indexing_maps =  [ ()[N] -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0]
+    iterator_types = ["parallel"]
   } %arg0: memref<i32>
 }
 
@@ -140,7 +140,7 @@ func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0]
+    iterator_types = ["parallel"]
   } %arg0: memref<i32>
 }
 
@@ -154,7 +154,7 @@ func @generic_zero_d_view(%arg0: memref<i32>) {
     fun = @foo,
     indexing_maps =  [ () -> (1) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0: memref<i32>
 }
 
@@ -168,7 +168,7 @@ func @generic_one_d_view(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
     fun = @foo,
     indexing_maps =  [ () -> (0, 0) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0: memref<?xf32, (i)[off]->(off + i)>
 }
 
@@ -185,7 +185,7 @@ func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>)
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0: memref<?xf32, (i)[off]->(off + i)>
 }
 
@@ -202,7 +202,7 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0: memref<?xf32, (i)[off]->(off + i)>
 }
 
@@ -219,7 +219,7 @@ func @generic_singular_maps(%arg0: memref<?xf32, (i)[off]->(off + i)>, %arg1: me
       (i, j) -> (i + j)
     ],
     n_views = [1, 1],
-    n_loop_types = [2, 0, 0]
+    iterator_types = ["parallel","parallel"]
   } %arg0, %arg1: memref<?xf32, (i)[off]->(off + i)>, memref<?xf32, (i)[off]->(off + i)>
 }
 
@@ -234,7 +234,7 @@ func @generic_empty_region(%arg0: memref<f32>) {
   linalg.generic {
     indexing_maps =  [ () -> (0) ],
     n_views = [1, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0, %arg0 {
     ^bb1:
     ^bb2:
@@ -248,7 +248,7 @@ func @generic_mismatched_num_arguments(%arg0: memref<f32>) {
   linalg.generic {
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0 {
     ^bb:
   }: memref<f32>
@@ -261,7 +261,7 @@ func @generic_block_arg_type(%arg0: memref<f32>) {
   linalg.generic {
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
-    n_loop_types = [0, 0, 0]
+    iterator_types = []
   } %arg0 {
     ^bb(%i: i1):
   }: memref<f32>
@@ -274,7 +274,7 @@ func @indexed_generic_block_arg_count(%arg0: memref<f32>) {
   linalg.indexed_generic {
     indexing_maps =  [ (d0) -> (d0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0]
+    iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%f: f32):
   }: memref<f32>
@@ -287,7 +287,7 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref<f32>) {
   linalg.indexed_generic {
     indexing_maps =  [ (d0) -> (d0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0]
+    iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%i: f64, %f: f32):
   }: memref<f32>
@@ -300,7 +300,7 @@ func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
   linalg.indexed_generic {
     indexing_maps =  [ (d0) -> (d0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0]
+    iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%i: index, %f: i1):
   }: memref<f32>
@@ -316,7 +316,7 @@ func @indexed_generic_fun_arg_count(%arg0: memref<f32>) {
   linalg.indexed_generic {
     indexing_maps =  [ (d0) -> (d0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0],
+    iterator_types = ["parallel"],
     fun = @foo
   } %arg0:  memref<f32>
 }
@@ -330,7 +330,7 @@ func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) {
   // expected-error @+1 {{op expected fun argument 0 to be of IndexType}}
   linalg.indexed_generic {
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0],
+    iterator_types = ["parallel"],
     indexing_maps = [ (i) -> (i) ],
     fun = @foo
   } %arg0 : memref<f32>
@@ -346,7 +346,7 @@ func @indexed_generic_fun_arg_type(%arg0: memref<f32>) {
   linalg.indexed_generic {
     indexing_maps =  [ (d0) -> (d0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0],
+    iterator_types = ["parallel"],
     fun = @foo
   } %arg0: memref<f32>
 }
@@ -361,7 +361,7 @@ func @indexed_generic_fun_result_count(%arg0: memref<f32>) {
   linalg.indexed_generic {
     indexing_maps =  [ (d0) -> (d0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0],
+    iterator_types = ["parallel"],
     fun = @foo
   } %arg0: memref<f32>
 }
@@ -377,7 +377,7 @@ func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
   linalg.indexed_generic {
     indexing_maps =  [ (d0) -> (d0) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0],
+    iterator_types = ["parallel"],
     fun = @foo
   } %arg0: memref<i32>
 }
@@ -389,7 +389,7 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
   linalg.generic {
     indexing_maps =  [ (i) -> (i) ],
     n_views = [0, 1],
-    n_loop_types = [1, 0, 0]
+    iterator_types = ["parallel"]
   } %arg0 {
     ^bb(%i: f32):
       %0 = constant 0: i1
index dd19d5d82cd35ca029e66b4238b4635c3098eb82..24ce8e36e8a6b77908b50a6ffed55ccd2e43288d 100644 (file)
@@ -139,7 +139,7 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
 ]
 #matmul_trait = {
   n_views = [2, 1],
-  n_loop_types = [2, 1, 0],
+  iterator_types = ["parallel", "parallel", "reduction"],
   indexing_maps = #matmul_accesses,
   library_call = "some_external_function_name_for_vector_outerproduct_matmul"
 }
index 9a1c91d09e0d93bc9a88458019c7bf3707f10a9a..1e1a6270d7e85a736a9547d86d9ec6b9c055c848 100644 (file)
@@ -223,7 +223,7 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
 ]
 #trait = {
   n_views = [1, 2],
-  n_loop_types = [3, 0, 0],
+  iterator_types = ["parallel", "parallel", "parallel"],
   indexing_maps = #accesses,
   fun = @foo,
   library_call = "some_external_function_name_1",
@@ -248,7 +248,7 @@ func @generic_function(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1
 
 #trait2 = {
   n_views = [1, 2],
-  n_loop_types = [3, 0, 0],
+  iterator_types = ["parallel", "parallel", "parallel"],
   indexing_maps = #accesses,
   library_call = "some_external_function_name_2",
   doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
@@ -281,7 +281,7 @@ func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) ->
 }
 #trait3 = {
   n_views = [1, 2],
-  n_loop_types = [3, 0, 0],
+  iterator_types = ["parallel", "parallel", "parallel"],
   indexing_maps = #accesses,
   fun = @indexed_foo,
   library_call = "some_external_function_name_1",
@@ -311,7 +311,7 @@ func @indexed_generic_function(
 
 #trait4 = {
   n_views = [1, 2],
-  n_loop_types = [3, 0, 0],
+  iterator_types = ["parallel", "parallel", "parallel"],
   indexing_maps = #accesses,
   library_call = "some_external_function_name_2",
   doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"
index 30704b50f44262f946bbd0d778040dc4caf1a604..b53e674368fa3f33997f4ca8184e489c0f34fdf7 100644 (file)
@@ -122,7 +122,7 @@ func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?,
 #trait = {
   indexing_maps = #accesses,
   n_views = [1, 1],
-  n_loop_types = [3, 0, 0],
+  iterator_types = ["parallel", "parallel", "parallel"],
   fun = @foo,
   library_call = "some_external_function_name_1"
 }
@@ -136,12 +136,12 @@ func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %ar
 }
 // CHECK-LABEL: func @foo
 // CHECK-LABEL: func @generic
-//       CHECK:   linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], library_call = "some_external_function_name_1", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
+//       CHECK:   linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1", n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
 
 #trait2 = {
   indexing_maps = #accesses,
   n_views = [1, 1],
-  n_loop_types = [3, 0, 0],
+  iterator_types = ["parallel", "parallel", "parallel"],
   library_call = "some_external_function_name_2"
 }
 func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
@@ -152,7 +152,7 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
   return
 }
 // CHECK-LABEL: func @generic_region
-//       CHECK:   linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], library_call = "some_external_function_name_2", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {
+//       CHECK:   linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
 //       CHECK:    ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):    // no predecessors
 //       CHECK:      linalg.yield %{{.*}} : f32
 //       CHECK:    } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
@@ -166,7 +166,7 @@ func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?,
   return
 }
 // CHECK-LABEL: func @indexed_generic
-//       CHECK:   linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}], library_call = "some_external_function_name_2", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {
+//       CHECK:   linalg.indexed_generic {indexing_maps = [#{{.*}}, #{{.*}}],  iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2", n_views = [1, 1]} %{{.*}}, %{{.*}} {
 //       CHECK:    ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
 //       CHECK:      linalg.yield %{{.*}} : f32
 //       CHECK:    } {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
index 47a35529b1fa2c77e835fab75404207f2025e938..4040063d77aaf2c32366220386f31bb448013b8d 100644 (file)
@@ -214,7 +214,7 @@ func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
 #id_2d = (i, j) -> (i, j)
 #pointwise_2d_trait = {
   indexing_maps = [#id_2d, #id_2d, #id_2d],
-  n_loop_types = [2, 0, 0],
+  iterator_types = ["parallel", "parallel"],
   n_views = [2, 1]
 }
 
index d94342f52ca81e4864e25fc37dbb72891c2c8901..d3686714ee4fefe07411745d5cd0b78e2c53f669 100644 (file)
@@ -83,7 +83,7 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
     (i, j) -> (i, j)
   ],
   n_views = [1, 1],
-  n_loop_types = [2, 0, 0]
+  iterator_types = ["parallel", "parallel"]
 }
 func @fusion_test(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
                   %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,