}
};
+// Pushes the index_casts that occur before extractions to after the extract.
+// This minimizes type conversion in some cases and enables the extract
+// canonicalizer. This changes:
+//
+// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
+// %extract = tensor.extract %cast[%index] : tensor<1xindex>
+//
+// to the following:
+//
+// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
+// %cast = arith.index_cast %extract : i32 to index
+//
+// to just %element.
+//
+// Consider expanding this to a template and handle all tensor cast operations.
+struct ExtractElementFromIndexCast
+ : public OpRewritePattern<tensor::ExtractOp> {
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+ PatternRewriter &rewriter) const final {
+ Location loc = extract.getLoc();
+ auto indexCast = extract.tensor().getDefiningOp<arith::IndexCastOp>();
+ if (!indexCast)
+ return failure();
+
+ Type elementTy = getElementTypeOrSelf(indexCast.getIn());
+
+ auto newExtract = rewriter.create<tensor::ExtractOp>(
+ loc, elementTy, indexCast.getIn(), extract.indices());
+
+ rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
+ newExtract);
+
+ return success();
+ }
+};
+
} // namespace
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractElementFromTensorFromElements>(context);
+ results
+ .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
+ context);
}
//===----------------------------------------------------------------------===//
%1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
return %1 : tensor<1xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @propogate_index_cast
+func @propogate_index_cast(%arg0: tensor<1xi32>) -> index {
+ // CHECK: %[[IDX:.+]] = arith.constant 0
+ // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
+ // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
+ // CHECK: return %[[CAST]] : index
+ %c0 = arith.constant 0 : index
+ %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
+ %1 = tensor.extract %0[%c0] : tensor<1xindex>
+ return %1 : index
+}