From 8544523dcb6249bf3055c3a6ab0cb48586999a30 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 20 Apr 2022 23:42:33 +0900 Subject: [PATCH] [mlir][tensor] Promote extract(from_elements(...)) to folding pattern Differential Revision: https://reviews.llvm.org/D123617 --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 86 ++++++++-------------- .../resolve-shaped-type-result-dims.mlir | 6 +- 2 files changed, 33 insertions(+), 59 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f1181e6..1f9f977 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -361,16 +361,13 @@ LogicalResult ExtractOp::verify() { } OpFoldResult ExtractOp::fold(ArrayRef operands) { - // The tensor operand must be a known constant. - Attribute tensor = operands.front(); - if (!tensor) - return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. - if (auto splatTensor = tensor.dyn_cast()) - return splatTensor.getSplatValue(); + if (Attribute tensor = operands.front()) + if (auto splatTensor = tensor.dyn_cast()) + return splatTensor.getSplatValue(); - // Otherwise, collect the constant indices into the tensor. + // Collect the constant indices into the tensor. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa()) @@ -378,10 +375,34 @@ OpFoldResult ExtractOp::fold(ArrayRef operands) { indices.push_back(indice.cast().getInt()); } + // Fold extract(from_elements(...)). + if (auto fromElementsOp = tensor().getDefiningOp()) { + auto tensorType = fromElementsOp.getType().cast(); + auto rank = tensorType.getRank(); + assert(static_cast(indices.size()) == tensorType.getRank() && + "rank mismatch"); + int flatIndex = 0; + int stride = 1; + for (int i = rank - 1; i >= 0; --i) { + if (i < rank - 1) + stride *= tensorType.getDimSize(i); + flatIndex += indices[i] * stride; + } + // Prevent out of bounds accesses. This can happen in invalid code that will + // never execute. + if (static_cast(fromElementsOp.elements().size()) <= flatIndex || + flatIndex < 0) + return {}; + return fromElementsOp.elements()[flatIndex]; + } + // If this is an elements attribute, query the value at the given indices. - auto elementsAttr = tensor.dyn_cast(); - if (elementsAttr && elementsAttr.isValidIndex(indices)) - return elementsAttr.getValues()[indices]; + if (Attribute tensor = operands.front()) { + auto elementsAttr = tensor.dyn_cast(); + if (elementsAttr && elementsAttr.isValidIndex(indices)) + return elementsAttr.getValues()[indices]; + } + return {}; } @@ -411,47 +432,6 @@ OpFoldResult FromElementsOp::fold(ArrayRef operands) { namespace { -// Canonicalizes the pattern of the form -// -// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> -// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32> -// -// to just %element. -struct ExtractElementFromTensorFromElements - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter &rewriter) const final { - auto tensorFromElements = extract.tensor().getDefiningOp(); - if (!tensorFromElements) - return failure(); - auto tensorType = tensorFromElements.getType().cast(); - auto rank = tensorType.getRank(); - if (rank == 0) { - rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); - return success(); - } - SmallVector indices(rank); - int64_t flatIndex = 0; - int64_t stride = 1; - for (int i = rank - 1; i >= 0; --i) { - APInt index; - if (!matchPattern(extract.indices()[i], m_ConstantInt(&index))) - return failure(); - if (i < rank - 1) - stride *= tensorType.getDimSize(i); - flatIndex += index.getSExtValue() * stride; - } - // Prevent out of bounds accesses. This can happen in invalid code that will - // never execute. - if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0) - return failure(); - rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex)); - return success(); - } -}; - // Pushes the index_casts that occur before extractions to after the extract. // This minimizes type conversion in some cases and enables the extract // canonicalizer. This changes: @@ -494,9 +474,7 @@ struct ExtractElementFromIndexCast void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir index ced1ca5..a3bd3d0 100644 --- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir @@ -22,12 +22,8 @@ func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor) // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]] -// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]] -// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]] // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]] -// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]] -// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]] -// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]] +// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] // ----- -- 2.7.4