From 5885c85fc6a78a65dd9219e1921a307f7c963dd5 Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Tue, 28 Feb 2023 15:31:18 +0100 Subject: [PATCH] [MLIR][Linalg] Fix propagation for rank-zero tensor `isScalar` only returns true if the operand is non-shaped. But we need to handle also rank zero tensors. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D144989 --- .../Linalg/Transforms/DataLayoutPropagation.cpp | 2 +- .../Dialect/Linalg/data-layout-propagation.mlir | 27 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index bc90980..2d1b544 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -157,7 +157,7 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); llvm::DenseMap domainDimToOperandDim; SmallVector exprs(origIndexingMap.getResults()); - if (genericOp.isScalar(opOperand)) + if (genericOp.isScalar(opOperand) || exprs.size() == 0) return std::make_tuple(opOperand->get(), AffineMap::get(numLoops, 0, exprs, b.getContext())); diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 7d54e28..f9d6275 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -652,3 +652,30 @@ func.func @would_break_dominance(%arg0: tensor<128x256xi32>) -> tensor<4x16x16x3 // CHECK-NEXT: %{{.+}} = tensor.pack %[[GEN]] // CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] // CHECK-SAME: into %[[ALLOC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> ()> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @scalar_tensor(%arg0 : tensor) -> tensor<1x32x7x7x32xf32> { + %empty_gen = tensor.empty() : tensor<1x7x7x1024xf32> + %gen = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor) outs(%empty_gen : tensor<1x7x7x1024xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x7x7x1024xf32> + %empty_pack = tensor.empty() : tensor<1x32x7x7x32xf32> + %pack = tensor.pack %gen outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %empty_pack : tensor<1x7x7x1024xf32> -> tensor<1x32x7x7x32xf32> + return %pack : tensor<1x32x7x7x32xf32> +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> ()> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func.func @scalar_tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x7x7x32xf32> +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] -- 2.7.4