[mlir][Linalg] NFC - Add an OpFoldResult-based builder for InitTensorOp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 12 Feb 2021 15:46:55 +0000 (15:46 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 12 Feb 2021 16:03:51 +0000 (16:03 +0000)
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

index 6916fa7..7212700 100644 (file)
@@ -116,7 +116,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
     OpBuilderDAG<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType),
     [{
       build($_builder, $_state, ValueRange{}, staticShape, elementType);
-    }]>
+    }]>,
+    OpBuilderDAG<(ins "ArrayRef<OpFoldResult>":$sizes, "Type":$elementType,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
 
   let hasCanonicalizer = 1;
index 989a164..42a7900 100644 (file)
@@ -87,6 +87,24 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
 
+/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
+/// it is a Value or into `staticVec` if it is an IntegerAttr.
+/// In the case of a Value, a copy of the `sentinel` value is also pushed to
+/// `staticVec`. This is useful to extract mixed static and dynamic entries that
+/// come from an AttrSizedOperandSegments trait.
+static void dispatchIndexOpFoldResult(OpFoldResult ofr,
+                                      SmallVectorImpl<Value> &dynamicVec,
+                                      SmallVectorImpl<int64_t> &staticVec,
+                                      int64_t sentinel) {
+  if (auto v = ofr.dyn_cast<Value>()) {
+    dynamicVec.push_back(v);
+    staticVec.push_back(sentinel);
+    return;
+  }
+  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
+  staticVec.push_back(apInt.getSExtValue());
+}
+
 /// This is a common class used for patterns of the form
 /// ```
 ///    someop(memrefcast) -> someop
@@ -539,6 +557,24 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
 //===----------------------------------------------------------------------===//
 // InitTensorOp
 //===----------------------------------------------------------------------===//
+void InitTensorOp::build(OpBuilder &b, OperationState &result,
+                         ArrayRef<OpFoldResult> sizes, Type elementType,
+                         ArrayRef<NamedAttribute> attrs) {
+  unsigned rank = sizes.size();
+  SmallVector<Value, 4> dynamicSizes;
+  SmallVector<int64_t, 4> staticSizes;
+  for (unsigned i = 0; i < rank; ++i) {
+    // staticLow and staticHigh have full information of the padding config.
+    // This will grow staticLow and staticHigh with 1 value. If the config is
+    // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
+    // value as well.
+    dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes,
+                              ShapedType::kDynamicSize);
+  }
+  auto resultType = RankedTensorType ::get(staticSizes, elementType);
+  build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
+  result.addAttributes(attrs);
+}
 
 static LogicalResult verify(InitTensorOp op) {
   RankedTensorType resultType = op.getType();
@@ -857,24 +893,6 @@ RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
   return RankedTensorType::get(resultShape, sourceType.getElementType());
 }
 
-/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
-/// it is a Value or into `staticVec` if it is an IntegerAttr.
-/// In the case of a Value, a copy of the `sentinel` value is also pushed to
-/// `staticVec`. This is useful to extract mixed static and dynamic entries that
-/// come from an AttrSizedOperandSegments trait.
-static void dispatchIndexOpFoldResult(OpFoldResult ofr,
-                                      SmallVectorImpl<Value> &dynamicVec,
-                                      SmallVectorImpl<int64_t> &staticVec,
-                                      int64_t sentinel) {
-  if (auto v = ofr.dyn_cast<Value>()) {
-    dynamicVec.push_back(v);
-    staticVec.push_back(sentinel);
-    return;
-  }
-  APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
-  staticVec.push_back(apInt.getSExtValue());
-}
-
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
                         ArrayRef<int64_t> staticLow,
                         ArrayRef<int64_t> staticHigh, ValueRange low,