}
};
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %extracted_element = extract_element %val[%c0] : tensor<2xi32>
+///
+/// to
+///
+/// %extracted_element = extract_element %source[%c0] : tensor<?xi32>
+struct ExtractElementFromTensorCast
+ : public OpRewritePattern<ExtractElementOp> {
+ using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractElementOp extract,
+ PatternRewriter &rewriter) const final {
+ auto tensorCast = extract.aggregate().getDefiningOp<TensorCastOp>();
+ if (!tensorCast)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ExtractElementOp>(extract, tensorCast.source(),
+ extract.getIndices());
+ return success();
+ }
+};
+
} // namespace
void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ExtractElementFromDynamicTensorFromElements,
- StaticDynamicTensorFromElements>(context);
+ ExtractElementFromTensorCast, StaticDynamicTensorFromElements>(
+ context);
}
//===----------------------------------------------------------------------===//
return %2 : tensor<?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_tensor_cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_element_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
+ // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+ %c0 = constant 0 : index
+ // CHECK-NOT: tensor_cast
+ %casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
+ // CHECK-NEXT: extract_element %[[TENSOR]][%[[C0]]]
+ %result = extract_element %casted[%c0] : tensor<?xf32>
+ return %result : f32
+}