[MLIR] Make the implementations for getMixedOffsets/Sizes/Strides independent of...
authorFrederik Gossen <frgossen@google.com>
Wed, 3 Aug 2022 22:23:55 +0000 (18:23 -0400)
committerFrederik Gossen <frgossen@google.com>
Thu, 4 Aug 2022 15:58:15 +0000 (11:58 -0400)
The functions are effectively independent of the interface already, however, they take it as an argument for no reason.
The current state complicates reuse outside of MLIR.

Differential Revision: https://reviews.llvm.org/D131120

mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Interfaces/ViewLikeInterface.cpp

index 2b735f1..e590be2 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
+
 /// Auxiliary range data structure to unpack the offset, size and stride
 /// operands into a list of triples. Such a list can be more convenient to
 /// manipulate.
@@ -29,31 +30,19 @@ struct Range {
   OpFoldResult stride;
 };
 
-class OffsetSizeAndStrideOpInterface;
-
 /// Return a vector of OpFoldResults given the special value
 /// that indicates whether of the value is dynamic or not.
 SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
                                             ValueRange dynamicValues,
                                             int64_t dynamicValueIndicator);
 
-/// Return a vector of all the static or dynamic offsets of the op from provided
-/// external static and dynamic offsets.
-SmallVector<OpFoldResult, 4> getMixedOffsets(OffsetSizeAndStrideOpInterface op,
-                                             ArrayAttr staticOffsets,
-                                             ValueRange offsets);
+/// Return a vector of all the static and dynamic offsets/strides.
+SmallVector<OpFoldResult, 4> getMixedStridesOrOffsets(ArrayAttr staticValues,
+                                                      ValueRange dynamicValues);
 
-/// Return a vector of all the static or dynamic sizes of the op from provided
-/// external static and dynamic sizes.
-SmallVector<OpFoldResult, 4> getMixedSizes(OffsetSizeAndStrideOpInterface op,
-                                           ArrayAttr staticSizes,
-                                           ValueRange sizes);
-
-/// Return a vector of all the static or dynamic strides of the op from provided
-/// external static and dynamic strides.
-SmallVector<OpFoldResult, 4> getMixedStrides(OffsetSizeAndStrideOpInterface op,
-                                             ArrayAttr staticStrides,
-                                             ValueRange strides);
+/// Return a vector of all the static and dynamic sizes.
+SmallVector<OpFoldResult, 4> getMixedSizes(ArrayAttr staticValues,
+                                           ValueRange dynamicValues);
 
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
@@ -62,9 +51,9 @@ decomposeMixedValues(Builder &b,
                      const SmallVectorImpl<OpFoldResult> &mixedValues,
                      const int64_t dynamicValueIndicator);
 
-/// Decompose a vector of mixed static or dynamic strides/offsets into the
+/// Decompose a vector of mixed static and dynamic strides/offsets into the
 /// corresponding pair of arrays. This is the inverse function of
-/// `getMixedStrides` and `getMixedOffsets`.
+/// `getMixedStridesOrOffsets`.
 std::pair<ArrayAttr, SmallVector<Value>> decomposeMixedStridesOrOffsets(
     OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues);
 
@@ -75,12 +64,16 @@ std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedSizes(OpBuilder &b,
                     const SmallVectorImpl<OpFoldResult> &mixedValues);
 
+class OffsetSizeAndStrideOpInterface;
+
 namespace detail {
+
 LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
 
 bool sameOffsetsSizesAndStrides(
     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp);
+
 } // namespace detail
 } // namespace mlir
 
index ed4ba55..0c6bc86 100644 (file)
@@ -165,8 +165,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::getMixedOffsets($_op, $_op.static_offsets(),
-                                       $_op.offsets());
+        return ::mlir::getMixedStridesOrOffsets($_op.static_offsets(),
+                                                $_op.offsets());
       }]
     >,
     InterfaceMethod<
@@ -178,7 +178,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::getMixedSizes($_op, $_op.static_sizes(), $_op.sizes());
+        return ::mlir::getMixedSizes($_op.static_sizes(), $_op.sizes());
       }]
     >,
     InterfaceMethod<
@@ -190,8 +190,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::getMixedStrides($_op, $_op.static_strides(),
-                                       $_op.strides());
+        return ::mlir::getMixedStridesOrOffsets($_op.static_strides(),
+                                                $_op.strides());
       }]
     >,
 
@@ -237,30 +237,6 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
         return ::mlir::ShapedType::isDynamicStrideOrOffset(v.getSExtValue());
       }]
     >,
-    StaticInterfaceMethod<
-      /*desc=*/"Return constant that indicates the offset is dynamic",
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getDynamicOffsetIndicator",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
-    >,
-    StaticInterfaceMethod<
-      /*desc=*/"Return constant that indicates the size is dynamic",
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getDynamicSizeIndicator",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicSize; }]
-    >,
-    StaticInterfaceMethod<
-      /*desc=*/"Return constant that indicates the stride is dynamic",
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getDynamicStrideIndicator",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImpl=*/[{ return ::mlir::ShapedType::kDynamicStrideOrOffset; }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Assert the offset `idx` is a static constant and return its value.
index be35698..2d9cce6 100644 (file)
@@ -181,7 +181,7 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
 
 SmallVector<OpFoldResult, 4>
 mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues,
-                     int64_t dynamicValueIndicator) {
+                     const int64_t dynamicValueIndicator) {
   SmallVector<OpFoldResult, 4> res;
   res.reserve(staticValues.size());
   unsigned numDynamic = 0;
@@ -196,21 +196,15 @@ mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues,
 }
 
 SmallVector<OpFoldResult, 4>
-mlir::getMixedOffsets(OffsetSizeAndStrideOpInterface op,
-                      ArrayAttr staticOffsets, ValueRange offsets) {
-  return getMixedValues(staticOffsets, offsets, op.getDynamicOffsetIndicator());
+mlir::getMixedStridesOrOffsets(ArrayAttr staticValues,
+                               ValueRange dynamicValues) {
+  return getMixedValues(staticValues, dynamicValues,
+                        ShapedType::kDynamicStrideOrOffset);
 }
 
-SmallVector<OpFoldResult, 4>
-mlir::getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes,
-                    ValueRange sizes) {
-  return getMixedValues(staticSizes, sizes, op.getDynamicSizeIndicator());
-}
-
-SmallVector<OpFoldResult, 4>
-mlir::getMixedStrides(OffsetSizeAndStrideOpInterface op,
-                      ArrayAttr staticStrides, ValueRange strides) {
-  return getMixedValues(staticStrides, strides, op.getDynamicStrideIndicator());
+SmallVector<OpFoldResult, 4> mlir::getMixedSizes(ArrayAttr staticValues,
+                                                 ValueRange dynamicValues) {
+  return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamicSize);
 }
 
 std::pair<ArrayAttr, SmallVector<Value>>