From 7b52aeadfa38c8a1fc0e97066f50900f1efafd42 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 10 May 2021 23:19:59 +0200 Subject: [PATCH] [mlir][Tensor] Add folding for tensor.from_elements This trivially folds into a constant when all operands are constant. Differential Revision: https://reviews.llvm.org/D102199 --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 1 + mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 ++++++ mlir/test/Dialect/Linalg/detensorize_trivial.mlir | 4 +--- mlir/test/Dialect/Tensor/canonicalize.mlir | 12 ++++++++++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index a0e4738..17141da 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -137,6 +137,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [ ]; let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 1beb458..2c9680a 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -238,6 +238,12 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result, build(builder, result, elements.front().getType(), elements); } +OpFoldResult FromElementsOp::fold(ArrayRef operands) { + if (!llvm::is_contained(operands, nullptr)) + return DenseElementsAttr::get(getType(), operands); + return {}; +} + namespace { // Canonicalizes the pattern of the form diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir index 6fcd056..4e0b8fd 100644 --- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir @@ -35,9 +35,7 @@ func @main(%farg0 : tensor) -> (tensor) attributes {} { // DET-ALL-NEXT: } // DET-CF-LABEL: func @main(%{{.*}}: tensor) -// DET-CF-NEXT: constant 10 : i32 -// DET-CF-NEXT: tensor.from_elements %{{.*}} -// DET-CF-NEXT: linalg.tensor_reshape %{{.*}} +// DET-CF-NEXT: constant dense<10> : tensor // DET-CF-NEXT: linalg.init_tensor [] : tensor // DET-CF-NEXT: linalg.generic // DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1) diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index be22f323..478117b 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -238,3 +238,15 @@ func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xi // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> return %0 : tensor<3x?x?x7x?xindex> } + +// ----- + +// CHECK-LABEL: @from_elements.constant +func @from_elements.constant() -> tensor<3xindex> { + // CHECK: %[[CST:.*]] = constant dense<[1, 2, 1]> : tensor<3xindex> + // CHECK: return %[[CST]] + %c1 = constant 1 : index + %c2 = constant 2 : index + %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex> + return %tensor : tensor<3xindex> +} -- 2.7.4