From 3b2f26ab05a80ffb3fcee62fd690da2e6d39c4a3 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Sat, 11 Apr 2020 23:01:40 -0700 Subject: [PATCH] [mlir][Linalg] NFC : Fix check for scalar case handling in LinalgToLoops The invertPermutation method does not return a nullptr anymore, but rather returns an empty map for the scalar case. Update the check in LinalgToLoops to reflect this. Also add test case for generating scalar code. --- .../Dialect/Linalg/Transforms/LinalgToLoops.cpp | 4 +- mlir/test/Dialect/Linalg/loops.mlir | 43 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index 9717bb8..6be0bd8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -652,8 +652,8 @@ LinalgOpToLoopsImpl::doit(Operation *op, linalgOp.indexing_maps().template getAsRange(); auto maps = functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); - auto invertedMap = inversePermutation(concatAffineMaps(maps)); - if (!invertedMap) { + AffineMap invertedMap = inversePermutation(concatAffineMaps(maps)); + if (invertedMap.isEmpty()) { LinalgScopedEmitter::emitScalarImplementation( {}, linalgOp); return LinalgLoops(); diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index a4d3acd..48e4b6e 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -913,3 +913,46 @@ func @generic_const_init(%arg0: memref) { // CHECKPARALLEL: %[[CONST:.*]] = constant 1.000000e+00 : f32 // CHECKPARALLEL: loop.parallel (%[[i:.*]]) // CHECKPARALLEL: store %[[CONST]], %[[ARG0]] + +#scalar_access = [ + affine_map<() -> ()>, + affine_map<() -> ()>, + affine_map<() -> ()> +] +#scalar_trait = { + args_in = 2, + args_out = 1, + iterator_types = [], + indexing_maps = #scalar_access, + library_call = "some_external_fn" +} +func @scalar_code(%arg0: memref, %arg1 : memref, %arg2 : memref) +{ + linalg.generic #scalar_trait %arg0, %arg1, %arg2 { + ^bb(%a : f32, %b : f32, %c : f32) : + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } : memref, memref, memref + return +} +// CHECKLOOP-LABEL: @scalar_code +// CHECKLOOP-SAME: %[[ARG0]]: memref +// CHECKLOOP-SAME: %[[ARG1]]: memref +// CHECKLOOP-SAME: %[[ARG2]]: memref +// CHECKLOOP-NOT: loop.for +// CHECKLOOP-DAG: load %[[ARG0]][] +// CHECKLOOP-DAG: load %[[ARG1]][] +// CHECKLOOP-DAG: load %[[ARG2]][] +// CHECKLOOP: addf +// CHECKLOOP: store %{{.*}}, %[[ARG2]][] + +// CHECKPARALLEL-LABEL: @scalar_code +// CHECKPARALLEL-SAME: %[[ARG0]]: memref +// CHECKPARALLEL-SAME: %[[ARG1]]: memref +// CHECKPARALLEL-SAME: %[[ARG2]]: memref +// CHECKPARALLEL-NOT: loop.for +// CHECKPARALLEL-DAG: load %[[ARG0]][] +// CHECKPARALLEL-DAG: load %[[ARG1]][] +// CHECKPARALLEL-DAG: load %[[ARG2]][] +// CHECKPARALLEL: addf +// CHECKPARALLEL: store %{{.*}}, %[[ARG2]][] -- 2.7.4