}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
- // The tensor operand must be a known constant.
- Attribute tensor = operands.front();
- if (!tensor)
- return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
- if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
- return splatTensor.getSplatValue<Attribute>();
+ if (Attribute tensor = operands.front())
+ if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
+ return splatTensor.getSplatValue<Attribute>();
- // Otherwise, collect the constant indices into the tensor.
+ // Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
+ // Fold extract(from_elements(...)).
+ if (auto fromElementsOp = tensor().getDefiningOp<FromElementsOp>()) {
+ auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+ auto rank = tensorType.getRank();
+ assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
+ "rank mismatch");
+ int flatIndex = 0;
+ int stride = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ if (i < rank - 1)
+ stride *= tensorType.getDimSize(i);
+ flatIndex += indices[i] * stride;
+ }
+ // Prevent out of bounds accesses. This can happen in invalid code that will
+ // never execute.
+ if (static_cast<int>(fromElementsOp.elements().size()) <= flatIndex ||
+ flatIndex < 0)
+ return {};
+ return fromElementsOp.elements()[flatIndex];
+ }
+
// If this is an elements attribute, query the value at the given indices.
- auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
- if (elementsAttr && elementsAttr.isValidIndex(indices))
- return elementsAttr.getValues<Attribute>()[indices];
+ if (Attribute tensor = operands.front()) {
+ auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
+ if (elementsAttr && elementsAttr.isValidIndex(indices))
+ return elementsAttr.getValues<Attribute>()[indices];
+ }
+
return {};
}
namespace {
-// Canonicalizes the pattern of the form
-//
-// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
-// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
-//
-// to just %element.
-struct ExtractElementFromTensorFromElements
- : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
- PatternRewriter &rewriter) const final {
- auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
- if (!tensorFromElements)
- return failure();
- auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
- auto rank = tensorType.getRank();
- if (rank == 0) {
- rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
- return success();
- }
- SmallVector<APInt, 3> indices(rank);
- int64_t flatIndex = 0;
- int64_t stride = 1;
- for (int i = rank - 1; i >= 0; --i) {
- APInt index;
- if (!matchPattern(extract.indices()[i], m_ConstantInt(&index)))
- return failure();
- if (i < rank - 1)
- stride *= tensorType.getDimSize(i);
- flatIndex += index.getSExtValue() * stride;
- }
- // Prevent out of bounds accesses. This can happen in invalid code that will
- // never execute.
- if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0)
- return failure();
- rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex));
- return success();
- }
-};
-
// Pushes the index_casts that occur before extractions to after the extract.
// This minimizes type conversion in some cases and enables the extract
// canonicalizer. This changes:
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
- context);
+ results.add<ExtractElementFromIndexCast>(context);
}
//===----------------------------------------------------------------------===//
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
-// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
-// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
-// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
-// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
-// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
// -----