[mlir] NFC - Extend inferResultType API for SubViewOp and SubTensorOp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 10 Feb 2021 22:53:33 +0000 (22:53 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 10 Feb 2021 22:55:28 +0000 (22:55 +0000)
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp

index b9545e2..3d6eee4 100644 (file)
@@ -3001,6 +3001,10 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
                                 ArrayRef<int64_t> staticOffsets,
                                 ArrayRef<int64_t> staticSizes,
                                 ArrayRef<int64_t> staticStrides);
+    static Type inferResultType(MemRefType sourceMemRefType,
+                                ArrayRef<OpFoldResult> staticOffsets,
+                                ArrayRef<OpFoldResult> staticSizes,
+                                ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
@@ -3123,6 +3127,10 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
                                 ArrayRef<int64_t> staticOffsets,
                                 ArrayRef<int64_t> staticSizes,
                                 ArrayRef<int64_t> staticStrides);
+    static Type inferResultType(RankedTensorType sourceRankedTensorType,
+                                ArrayRef<OpFoldResult> staticOffsets,
+                                ArrayRef<OpFoldResult> staticSizes,
+                                ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
index ca2e273..9af00be 100644 (file)
@@ -2831,6 +2831,23 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
       sourceMemRefType.getMemorySpace());
 }
 
+Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
+                                ArrayRef<OpFoldResult> leadingStaticOffsets,
+                                ArrayRef<OpFoldResult> leadingStaticSizes,
+                                ArrayRef<OpFoldResult> leadingStaticStrides) {
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
+                                    staticSizes, staticStrides)
+      .cast<MemRefType>();
+}
+
 // Build a SubViewOp with mixed static and dynamic entries and custom result
 // type. If the type passed is nullptr, it is inferred.
 void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
@@ -3386,6 +3403,23 @@ Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
                                sourceRankedTensorType.getElementType());
 }
 
+Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType,
+                                  ArrayRef<OpFoldResult> leadingStaticOffsets,
+                                  ArrayRef<OpFoldResult> leadingStaticSizes,
+                                  ArrayRef<OpFoldResult> leadingStaticStrides) {
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
+                             staticOffsets, ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
+                             staticStrides, ShapedType::kDynamicStrideOrOffset);
+  return SubTensorOp::inferResultType(sourceRankedTensorType, staticOffsets,
+                                      staticSizes, staticStrides)
+      .cast<RankedTensorType>();
+}
+
 // Build a SubTensorOp with mixed static and dynamic entries and custom result
 // type. If the type passed is nullptr, it is inferred.
 void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result,