[mlir] Clean-up ViewLikeOpInterface w.r.t. kDynamic change.
authorAlexander Belyaev <pifon@google.com>
Tue, 22 Nov 2022 07:55:59 +0000 (08:55 +0100)
committerAlexander Belyaev <pifon@google.com>
Tue, 22 Nov 2022 09:51:53 +0000 (10:51 +0100)
Differential Revision: https://reviews.llvm.org/D138478

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/IR/BuiltinTypeInterfaces.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp

index 9ee5d3d..1723f58 100644 (file)
@@ -949,11 +949,9 @@ def TileToForeachThreadOp :
   let assemblyFormat = [{
     $target oilist(
         `num_threads` custom<DynamicIndexList>($num_threads,
-                                               $static_num_threads,
-                                               "ShapedType::kDynamic") |
+                                               $static_num_threads) |
          `tile_sizes` custom<DynamicIndexList>($tile_sizes,
-                                               $static_tile_sizes,
-                                               "ShapedType::kDynamic"))
+                                               $static_tile_sizes))
     (`(` `mapping` `=` $mapping^ `)`)? attr-dict
   }];
   let hasVerifier = 1;
index 319c089..ccf01be 100644 (file)
@@ -1267,14 +1267,11 @@ def MemRef_ReinterpretCastOp
 
   let assemblyFormat = [{
     $source `to` `offset` `` `:`
-    custom<DynamicIndexList>($offsets, $static_offsets,
-                               "ShapedType::kDynamic")
+    custom<DynamicIndexList>($offsets, $static_offsets)
     `` `,` `sizes` `` `:`
-    custom<DynamicIndexList>($sizes, $static_sizes,
-                               "ShapedType::kDynamic")
+    custom<DynamicIndexList>($sizes, $static_sizes)
     `` `,` `strides` `` `:`
-    custom<DynamicIndexList>($strides, $static_strides,
-                               "ShapedType::kDynamic")
+    custom<DynamicIndexList>($strides, $static_strides)
     attr-dict `:` type($source) `to` type($result)
   }];
 
@@ -1865,12 +1862,9 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
   let assemblyFormat = [{
     $source ``
-    custom<DynamicIndexList>($offsets, $static_offsets,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($sizes, $static_sizes,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($strides, $static_strides,
-                               "ShapedType::kDynamic")
+    custom<DynamicIndexList>($offsets, $static_offsets)
+    custom<DynamicIndexList>($sizes, $static_sizes)
+    custom<DynamicIndexList>($strides, $static_strides)
     attr-dict `:` type($source) `to` type($result)
   }];
 
index 661a8f8..24acd4b 100644 (file)
@@ -334,12 +334,9 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
 
   let assemblyFormat = [{
     $source ``
-    custom<DynamicIndexList>($offsets, $static_offsets,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($sizes, $static_sizes,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($strides, $static_strides,
-                               "ShapedType::kDynamic")
+    custom<DynamicIndexList>($offsets, $static_offsets)
+    custom<DynamicIndexList>($sizes, $static_sizes)
+    custom<DynamicIndexList>($strides, $static_strides)
     attr-dict `:` type($source) `to` type($result)
   }];
 
@@ -818,12 +815,9 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
 
   let assemblyFormat = [{
     $source `into` $dest ``
-    custom<DynamicIndexList>($offsets, $static_offsets,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($sizes, $static_sizes,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($strides, $static_strides,
-                               "ShapedType::kDynamic")
+    custom<DynamicIndexList>($offsets, $static_offsets)
+    custom<DynamicIndexList>($sizes, $static_sizes)
+    custom<DynamicIndexList>($strides, $static_strides)
     attr-dict `:` type($source) `into` type($dest)
   }];
 
@@ -1221,10 +1215,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [
   let assemblyFormat = [{
     $source
     (`nofold` $nofold^)?
-    `low` `` custom<DynamicIndexList>($low, $static_low,
-                                        "ShapedType::kDynamic")
-    `high` `` custom<DynamicIndexList>($high, $static_high,
-                                         "ShapedType::kDynamic")
+    `low` `` custom<DynamicIndexList>($low, $static_low)
+    `high` `` custom<DynamicIndexList>($high, $static_high)
     $region attr-dict `:` type($source) `to` type($result)
   }];
 
@@ -1411,12 +1403,9 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
   );
   let assemblyFormat = [{
     $source `into` $dest ``
-    custom<DynamicIndexList>($offsets, $static_offsets,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($sizes, $static_sizes,
-                               "ShapedType::kDynamic")
-    custom<DynamicIndexList>($strides, $static_strides,
-                               "ShapedType::kDynamic")
+    custom<DynamicIndexList>($offsets, $static_offsets)
+    custom<DynamicIndexList>($sizes, $static_sizes)
+    custom<DynamicIndexList>($strides, $static_strides)
     attr-dict `:` type($source) `into` type($dest)
   }];
 
index 05b34a7..700546d 100644 (file)
 
 namespace mlir {
 
-/// Return a vector of OpFoldResults given the special value
-/// that indicates whether of the value is dynamic or not.
+/// Return a vector of OpFoldResults with the same size a staticValues, but all
+/// elements for which ShapedType::isDynamic is true, will be replaced by
+/// dynamicValues.
 SmallVector<OpFoldResult, 4> getMixedValues(ArrayAttr staticValues,
-                                            ValueRange dynamicValues,
-                                            int64_t dynamicValueIndicator);
-
-/// Return a vector of all the static and dynamic offsets/strides.
-SmallVector<OpFoldResult, 4> getMixedStridesOrOffsets(ArrayAttr staticValues,
-                                                      ValueRange dynamicValues);
-
-/// Return a vector of all the static and dynamic sizes.
-SmallVector<OpFoldResult, 4> getMixedSizes(ArrayAttr staticValues,
-                                           ValueRange dynamicValues);
+                                            ValueRange dynamicValues);
 
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
 std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedValues(Builder &b,
-                     const SmallVectorImpl<OpFoldResult> &mixedValues,
-                     const int64_t dynamicValueIndicator);
-
-/// Decompose a vector of mixed static and dynamic strides/offsets into the
-/// corresponding pair of arrays. This is the inverse function of
-/// `getMixedStridesOrOffsets`.
-std::pair<ArrayAttr, SmallVector<Value>> decomposeMixedStridesOrOffsets(
-    OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues);
-
-/// Decompose a vector of mixed static or dynamic strides/offsets into the
-/// corresponding pair of arrays. This is the inverse function of
-/// `getMixedSizes`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedSizes(OpBuilder &b,
-                    const SmallVectorImpl<OpFoldResult> &mixedValues);
+                     const SmallVectorImpl<OpFoldResult> &mixedValues);
 
 class OffsetSizeAndStrideOpInterface;
 
@@ -83,8 +61,7 @@ namespace mlir {
 /// idiomatic printing of mixed value and integer attributes in a list. E.g.
 /// `[%arg0, 7, 42, %arg42]`.
 void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
-                           OperandRange values, ArrayAttr integers,
-                           int64_t dynVal);
+                           OperandRange values, ArrayAttr integers);
 
 /// Pasrer hook for custom directive in assemblyFormat.
 ///
@@ -102,13 +79,13 @@ void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
 ParseResult
 parseDynamicIndexList(OpAsmParser &parser,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                      ArrayAttr &integers, int64_t dynVal);
+                      ArrayAttr &integers);
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.
-LogicalResult verifyListOfOperandsOrIntegers(
-    Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
-    ValueRange values, function_ref<bool(int64_t)> isDynamic);
+LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name,
+                                             unsigned expectedNumElements,
+                                             ArrayAttr attr, ValueRange values);
 
 } // namespace mlir
 
index ea94f65..aca0126 100644 (file)
@@ -165,8 +165,8 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::getMixedStridesOrOffsets($_op.getStaticOffsets(),
-                                                $_op.getOffsets());
+        return ::mlir::getMixedValues($_op.getStaticOffsets(),
+                                      $_op.getOffsets());
       }]
     >,
     InterfaceMethod<
@@ -178,7 +178,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::getMixedSizes($_op.getStaticSizes(), $_op.sizes());
+        return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes());
       }]
     >,
     InterfaceMethod<
@@ -190,15 +190,13 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return ::mlir::getMixedStridesOrOffsets($_op.getStaticStrides(),
-                                                $_op.getStrides());
+        return ::mlir::getMixedValues($_op.getStaticStrides(),
+                                      $_op.getStrides());
       }]
     >,
 
     InterfaceMethod<
-      /*desc=*/[{
-        Return true if the offset `idx` is dynamic.
-      }],
+      /*desc=*/"Return true if the offset `idx` is dynamic.",
       /*retTy=*/"bool",
       /*methodName=*/"isDynamicOffset",
       /*args=*/(ins "unsigned":$idx),
@@ -210,9 +208,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return true if the size `idx` is dynamic.
-      }],
+      /*desc=*/"Return true if the size `idx` is dynamic.",
       /*retTy=*/"bool",
       /*methodName=*/"isDynamicSize",
       /*args=*/(ins "unsigned":$idx),
@@ -224,9 +220,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-       Return true if the stride `idx` is dynamic.
-      }],
+      /*desc=*/"Return true if the stride `idx` is dynamic.",
       /*retTy=*/"bool",
       /*methodName=*/"isDynamicStride",
       /*args=*/(ins "unsigned":$idx),
index 26b63a9..f02ccfa 100644 (file)
@@ -1321,8 +1321,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
   if (parser.parseOperand(target) ||
       parser.resolveOperand(target, pdlOperationType, result.operands) ||
-      parseDynamicIndexList(parser, dynamicSizes, staticSizes,
-                            ShapedType::kDynamic) ||
+      parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
       parser.parseOptionalAttrDict(result.attributes))
     return ParseResult::failure();
@@ -1336,8 +1335,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
 
 void TileOp::print(OpAsmPrinter &p) {
   p << ' ' << getTarget();
-  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
-                        ShapedType::kDynamic);
+  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
 }
 
@@ -1549,11 +1547,11 @@ void transform::TileToForeachThreadOp::getEffects(
 }
 
 SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
-  return getMixedSizes(getStaticNumThreads(), getNumThreads());
+  return getMixedValues(getStaticNumThreads(), getNumThreads());
 }
 
 SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
-  return getMixedSizes(getStaticTileSizes(), getTileSizes());
+  return getMixedValues(getStaticTileSizes(), getTileSizes());
 }
 
 LogicalResult TileToForeachThreadOp::verify() {
@@ -1680,8 +1678,7 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
   if (parser.parseOperand(target) ||
       parser.resolveOperand(target, pdlOperationType, result.operands) ||
-      parseDynamicIndexList(parser, dynamicSizes, staticSizes,
-                            ShapedType::kDynamic) ||
+      parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
       parser.parseOptionalAttrDict(result.attributes))
     return ParseResult::failure();
@@ -1695,8 +1692,7 @@ ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
 
 void TileToScfForOp::print(OpAsmPrinter &p) {
   p << ' ' << getTarget();
-  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
-                        ShapedType::kDynamic);
+  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
 }
 
index bd7f617..88791fc 100644 (file)
@@ -24,7 +24,6 @@ using namespace mlir::detail;
 //===----------------------------------------------------------------------===//
 
 constexpr int64_t ShapedType::kDynamic;
-constexpr int64_t ShapedType::kDynamic;
 
 int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
   int64_t num = 1;
index f818df1..775d26a 100644 (file)
@@ -17,16 +17,18 @@ using namespace mlir;
 /// Include the definitions of the loop-like interfaces.
 #include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
 
-LogicalResult mlir::verifyListOfOperandsOrIntegers(
-    Operation *op, StringRef name, unsigned numElements, ArrayAttr attr,
-    ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic) {
+LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
+                                                   StringRef name,
+                                                   unsigned numElements,
+                                                   ArrayAttr attr,
+                                                   ValueRange values) {
   /// Check static and dynamic offsets/sizes/strides does not overflow type.
   if (attr.size() != numElements)
     return op->emitError("expected ")
            << numElements << " " << name << " values";
   unsigned expectedNumDynamicEntries =
       llvm::count_if(attr.getValue(), [&](Attribute attr) {
-        return isDynamic(attr.cast<IntegerAttr>().getInt());
+        return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
       });
   if (values.size() != expectedNumDynamicEntries)
     return op->emitError("expected ")
@@ -56,23 +58,19 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
            << ") so the rank of the result type is well-formed.";
 
   if (failed(verifyListOfOperandsOrIntegers(op, "offset", maxRanks[0],
-                                            op.static_offsets(), op.offsets(),
-                                            ShapedType::isDynamic)))
+                                            op.static_offsets(), op.offsets())))
     return failure();
   if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1],
-                                            op.static_sizes(), op.sizes(),
-                                            ShapedType::isDynamic)))
+                                            op.static_sizes(), op.sizes())))
     return failure();
   if (failed(verifyListOfOperandsOrIntegers(op, "stride", maxRanks[2],
-                                            op.static_strides(), op.strides(),
-                                            ShapedType::isDynamic)))
+                                            op.static_strides(), op.strides())))
     return failure();
   return success();
 }
 
 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
-                                 OperandRange values, ArrayAttr integers,
-                                 int64_t dynVal) {
+                                 OperandRange values, ArrayAttr integers) {
   printer << '[';
   if (integers.empty()) {
     printer << "]";
@@ -81,7 +79,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
   unsigned idx = 0;
   llvm::interleaveComma(integers, printer, [&](Attribute a) {
     int64_t val = a.cast<IntegerAttr>().getInt();
-    if (val == dynVal)
+    if (ShapedType::isDynamic(val))
       printer << values[idx++];
     else
       printer << val;
@@ -92,7 +90,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
 ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    ArrayAttr &integers, int64_t dynVal) {
+    ArrayAttr &integers) {
   if (failed(parser.parseLSquare()))
     return failure();
   // 0-D.
@@ -107,7 +105,7 @@ ParseResult mlir::parseDynamicIndexList(
     auto res = parser.parseOptionalOperand(operand);
     if (res.has_value() && succeeded(res.value())) {
       values.push_back(operand);
-      attrVals.push_back(dynVal);
+      attrVals.push_back(ShapedType::kDynamic);
     } else {
       IntegerAttr attr;
       if (failed(parser.parseAttribute<IntegerAttr>(attr)))
@@ -147,57 +145,33 @@ bool mlir::detail::sameOffsetsSizesAndStrides(
   return true;
 }
 
-SmallVector<OpFoldResult, 4>
-mlir::getMixedValues(ArrayAttr staticValues, ValueRange dynamicValues,
-                     const int64_t dynamicValueIndicator) {
+SmallVector<OpFoldResult, 4> mlir::getMixedValues(ArrayAttr staticValues,
+                                                  ValueRange dynamicValues) {
   SmallVector<OpFoldResult, 4> res;
   res.reserve(staticValues.size());
   unsigned numDynamic = 0;
   unsigned count = static_cast<unsigned>(staticValues.size());
   for (unsigned idx = 0; idx < count; ++idx) {
     APInt value = staticValues[idx].cast<IntegerAttr>().getValue();
-    res.push_back(value.getSExtValue() == dynamicValueIndicator
+    res.push_back(ShapedType::isDynamic(value.getSExtValue())
                       ? OpFoldResult{dynamicValues[numDynamic++]}
                       : OpFoldResult{staticValues[idx]});
   }
   return res;
 }
 
-SmallVector<OpFoldResult, 4>
-mlir::getMixedStridesOrOffsets(ArrayAttr staticValues,
-                               ValueRange dynamicValues) {
-  return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic);
-}
-
-SmallVector<OpFoldResult, 4> mlir::getMixedSizes(ArrayAttr staticValues,
-                                                 ValueRange dynamicValues) {
-  return getMixedValues(staticValues, dynamicValues, ShapedType::kDynamic);
-}
-
 std::pair<ArrayAttr, SmallVector<Value>>
 mlir::decomposeMixedValues(Builder &b,
-                           const SmallVectorImpl<OpFoldResult> &mixedValues,
-                           const int64_t dynamicValueIndicator) {
+                           const SmallVectorImpl<OpFoldResult> &mixedValues) {
   SmallVector<int64_t> staticValues;
   SmallVector<Value> dynamicValues;
   for (const auto &it : mixedValues) {
     if (it.is<Attribute>()) {
       staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
     } else {
-      staticValues.push_back(dynamicValueIndicator);
+      staticValues.push_back(ShapedType::kDynamic);
       dynamicValues.push_back(it.get<Value>());
     }
   }
   return {b.getI64ArrayAttr(staticValues), dynamicValues};
 }
-
-std::pair<ArrayAttr, SmallVector<Value>> mlir::decomposeMixedStridesOrOffsets(
-    OpBuilder &b, const SmallVectorImpl<OpFoldResult> &mixedValues) {
-  return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic);
-}
-
-std::pair<ArrayAttr, SmallVector<Value>>
-mlir::decomposeMixedSizes(OpBuilder &b,
-                          const SmallVectorImpl<OpFoldResult> &mixedValues) {
-  return decomposeMixedValues(b, mixedValues, ShapedType::kDynamic);
-}