From c37ed7762e9f473e9497c52c7669a025965651f7 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 23 Aug 2022 11:38:54 +0200 Subject: [PATCH] [tensor][bufferize] Use affine.apply instead of arith.addi in PadOp lowering Affine exprs compose better than arith ops. Differential Revision: https://reviews.llvm.org/D132456 --- .../Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp | 10 +++++++--- mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt | 1 + mlir/test/Dialect/Tensor/bufferize.mlir | 7 +++---- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 881237c..3600524 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -790,9 +791,12 @@ struct PadOpInterface Value srcDim = rewriter.create(loc, padOp.getSource(), i); Value lowPad = toValue(mixedLowPad[i]); Value highPad = toValue(mixedHighPad[i]); - Value s1 = rewriter.create(loc, lowPad, highPad); - Value s2 = rewriter.create(loc, s1, srcDim); - dynamicSizes.push_back(s2); + AffineExpr s0, s1, s2; + bindSymbols(op->getContext(), s0, s1, s2); + AffineExpr sumExpr = s0 + s1 + s2; + Value sum = rewriter.create( + loc, sumExpr, ValueRange{srcDim, lowPad, highPad}); + dynamicSizes.push_back(sum); } // Create tensor::GenerateOp. diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt index 8479c43..66e4cc9 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms MLIRTensorTransformsIncGen LINK_LIBS PUBLIC + MLIRAffineDialect MLIRArithmeticDialect MLIRBufferizationDialect MLIRBufferizationTransforms diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 8d53585c..7cde99d 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -547,6 +547,7 @@ func.func @tensor.reshape(%t1: tensor) -> tensor<2x2x5xf32> { // ----- +// CHECK: #[[$sum_map:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)> // CHECK-LABEL: func @tensor.pad( // CHECK-SAME: %[[t1:.*]]: tensor, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index func.func @tensor.pad(%t1: tensor, %l2: index, %h1: index, @@ -557,10 +558,8 @@ func.func @tensor.pad(%t1: tensor, %l2: index, %h1: index, // CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]] // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]] - // CHECK-DAG: %[[pad0:.*]] = arith.addi %[[c5]], %[[h1]] - // CHECK-DAG: %[[size0:.*]] = arith.addi %[[pad0]], %[[dim0]] - // CHECK-DAG: %[[pad1:.*]] = arith.addi %[[l2]], %[[h2]] - // CHECK-DAG: %[[size1:.*]] = arith.addi %[[pad1]], %[[dim1]] + // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map]]()[%[[dim0]], %[[c5]], %[[h1]]] + // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map]]()[%[[dim1]], %[[l2]], %[[h2]]] // CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref // CHECK: scf.parallel ({{.*}}) = (%[[c0]], %[[c0]]) to (%[[size0]], %[[size1]]) step (%[[c1]], %[[c1]]) { // CHECK: memref.store diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index f6f9539..7d45b2e 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5048,6 +5048,7 @@ cc_library( ], includes = ["include"], deps = [ + ":AffineDialect", ":ArithmeticDialect", ":BufferizationDialect", ":BufferizationTransforms", -- 2.7.4