[mlir][std] Canonicalize extract_element(tensor_cast).
authorStephan Herhut <herhut@google.com>
Tue, 17 Nov 2020 12:59:26 +0000 (13:59 +0100)
committerStephan Herhut <herhut@google.com>
Tue, 17 Nov 2020 13:41:39 +0000 (14:41 +0100)
Canonicalize extract_element(tensor_cast(v)) to just extract_element(v).

Differential Revision: https://reviews.llvm.org/D91621

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Transforms/canonicalize.mlir

index 4698893..d2a2ca1 100644 (file)
@@ -1945,12 +1945,37 @@ struct ExtractElementFromDynamicTensorFromElements
   }
 };
 
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %extracted_element = extract_element %val[%c0] : tensor<2xi32>
+///
+/// to
+///
+/// %extracted_element = extract_element %source[%c0] : tensor<?xi32>
+struct ExtractElementFromTensorCast
+    : public OpRewritePattern<ExtractElementOp> {
+  using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractElementOp extract,
+                                PatternRewriter &rewriter) const final {
+    auto tensorCast = extract.aggregate().getDefiningOp<TensorCastOp>();
+    if (!tensorCast)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ExtractElementOp>(extract, tensorCast.source(),
+                                                  extract.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<ExtractElementFromDynamicTensorFromElements,
-                 StaticDynamicTensorFromElements>(context);
+                 ExtractElementFromTensorCast, StaticDynamicTensorFromElements>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
index 08f3ac7..4a74f54 100644 (file)
@@ -1202,3 +1202,17 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
 
   return %2 : tensor<?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_tensor_cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_element_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
+  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+  %c0 = constant 0 : index
+  // CHECK-NOT: tensor_cast
+  %casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
+  // CHECK-NEXT: extract_element %[[TENSOR]][%[[C0]]]
+  %result = extract_element %casted[%c0] : tensor<?xf32>
+  return %result : f32
+}