[mlir][Linalg] NFC : Fix check for scalar case handling in LinalgToLoops
authorMaheshRavishankar <ravishankarm@google.com>
Sun, 12 Apr 2020 06:01:40 +0000 (23:01 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Mon, 13 Apr 2020 20:23:01 +0000 (13:23 -0700)
The invertPermutation method does not return a nullptr anymore, but
rather returns an empty map for the scalar case. Update the check in
LinalgToLoops to reflect this.
Also add test case for generating scalar code.

mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
mlir/test/Dialect/Linalg/loops.mlir

index 9717bb8..6be0bd8 100644 (file)
@@ -652,8 +652,8 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
       linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
   auto maps =
       functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
-  auto invertedMap = inversePermutation(concatAffineMaps(maps));
-  if (!invertedMap) {
+  AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
+  if (invertedMap.isEmpty()) {
     LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
         {}, linalgOp);
     return LinalgLoops();
index a4d3acd..48e4b6e 100644 (file)
@@ -913,3 +913,46 @@ func @generic_const_init(%arg0: memref<?xf32>) {
 //       CHECKPARALLEL: %[[CONST:.*]] = constant 1.000000e+00 : f32
 //       CHECKPARALLEL: loop.parallel (%[[i:.*]])
 //       CHECKPARALLEL:   store %[[CONST]], %[[ARG0]]
+
+#scalar_access = [
+  affine_map<() -> ()>,
+  affine_map<() -> ()>,
+  affine_map<() -> ()>
+]
+#scalar_trait = {
+  args_in = 2,
+  args_out = 1,
+  iterator_types = [],
+  indexing_maps = #scalar_access,
+  library_call = "some_external_fn"
+}
+func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
+{
+  linalg.generic #scalar_trait %arg0, %arg1, %arg2 {
+  ^bb(%a : f32, %b : f32, %c : f32) :
+    %0 = addf %a, %b : f32
+    linalg.yield %0 : f32
+  } : memref<f32>, memref<f32>, memref<f32>
+  return
+}
+// CHECKLOOP-LABEL: @scalar_code
+//  CHECKLOOP-SAME: %[[ARG0]]: memref<f32>
+//  CHECKLOOP-SAME: %[[ARG1]]: memref<f32>
+//  CHECKLOOP-SAME: %[[ARG2]]: memref<f32>
+//   CHECKLOOP-NOT: loop.for
+//   CHECKLOOP-DAG: load %[[ARG0]][]
+//   CHECKLOOP-DAG: load %[[ARG1]][]
+//   CHECKLOOP-DAG: load %[[ARG2]][]
+//       CHECKLOOP: addf
+//       CHECKLOOP: store %{{.*}}, %[[ARG2]][]
+
+// CHECKPARALLEL-LABEL: @scalar_code
+//  CHECKPARALLEL-SAME: %[[ARG0]]: memref<f32>
+//  CHECKPARALLEL-SAME: %[[ARG1]]: memref<f32>
+//  CHECKPARALLEL-SAME: %[[ARG2]]: memref<f32>
+//   CHECKPARALLEL-NOT: loop.for
+//   CHECKPARALLEL-DAG: load %[[ARG0]][]
+//   CHECKPARALLEL-DAG: load %[[ARG1]][]
+//   CHECKPARALLEL-DAG: load %[[ARG2]][]
+//       CHECKPARALLEL: addf
+//       CHECKPARALLEL: store %{{.*}}, %[[ARG2]][]