[mlir][linalg] Add nicer builders for `map` and `reduce`.
authorOleg Shyshkov <shyshkov@google.com>
Thu, 27 Oct 2022 18:48:04 +0000 (20:48 +0200)
committerOleg Shyshkov <shyshkov@google.com>
Fri, 28 Oct 2022 06:51:09 +0000 (08:51 +0200)
The new builders get a list of additional attrs, a lambda to build the region
body and infer return types from `init`.

Differential Revision: https://reviews.llvm.org/D136838

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

index 1692a0f..510f883 100644 (file)
@@ -267,6 +267,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
   let results = (outs Variadic<AnyTensor>:$result);
   let regions = (region SizedRegion<1>:$mapper);
 
+  let builders = [
+    OpBuilder<(ins "ValueRange":$inputs, "Value":$init,
+      "function_ref<void(OpBuilder &, Location, ValueRange)>",
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+  ];
+
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Implement functions necessary for LinalgStructuredInterface.
     SmallVector<StringRef> getIteratorTypesArray();
@@ -341,6 +347,13 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
   let results = (outs Variadic<AnyTensor>);
   let regions = (region SizedRegion<1>:$combiner);
 
+  let builders = [
+    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
+      "ArrayRef<int64_t>":$dimensions,
+      "function_ref<void(OpBuilder &, Location, ValueRange)>",
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
+  ];
+
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
     SmallVector<StringRef> getIteratorTypesArray();
index 03a959a..443bc5c 100644 (file)
@@ -661,6 +661,26 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // GenericOp
 //===----------------------------------------------------------------------===//
 
+static void buildGenericRegion(
+    OpBuilder &builder, OperationState &result, ValueRange inputs,
+    ValueRange outputs,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+  SmallVector<Type, 4> blockArgTypes;
+  SmallVector<Location, 4> blockArgLocs;
+  for (ValueRange container : {inputs, outputs}) {
+    for (Value v : container) {
+      blockArgTypes.push_back(getElementTypeOrSelf(v));
+      blockArgLocs.push_back(v.getLoc());
+    }
+  }
+
+  OpBuilder::InsertionGuard guard(builder);
+  auto &region = *result.regions.front();
+  Block *bodyBlock =
+      builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
+  bodyBuild(builder, result.location, bodyBlock->getArguments());
+}
+
 void GenericOp::getAsmBlockArgumentNames(Region &region,
                                          OpAsmSetValueNameFn setNameFn) {
   for (Value v : getRegionInputArgs())
@@ -678,23 +698,8 @@ void GenericOp::build(
   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
         iteratorTypes, doc, libraryCall);
   result.addAttributes(attributes);
-  if (!bodyBuild)
-    return;
-
-  SmallVector<Type, 4> blockArgTypes;
-  SmallVector<Location, 4> blockArgLocs;
-  for (ValueRange container : {inputs, outputs}) {
-    for (Value v : container) {
-      blockArgTypes.push_back(getElementTypeOrSelf(v));
-      blockArgLocs.push_back(v.getLoc());
-    }
-  }
-
-  OpBuilder::InsertionGuard guard(builder);
-  auto &region = *result.regions.front();
-  Block *bodyBlock =
-      builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
-  bodyBuild(builder, result.location, bodyBlock->getArguments());
+  if (bodyBuild)
+    buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
 }
 
 void GenericOp::build(
@@ -1329,6 +1334,22 @@ void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
     setNameFn(getResults().front(), "mapped");
 }
 
+void MapOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, TypeRange{}, inputs, init);
+  result.addAttributes(attributes);
+
+  // Add output types for `RankedTensorType` output arguments.
+  Type initType = init.getType();
+  if (initType.isa<RankedTensorType>())
+    result.addTypes(initType);
+
+  if (bodyBuild)
+    buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
+}
+
 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseDstStyleOp(parser, result))
     return failure();
@@ -1436,6 +1457,25 @@ void ReduceOp::getAsmResultNames(
     setNameFn(getResults().front(), "reduced");
 }
 
+void ReduceOp::build(
+    OpBuilder &builder, OperationState &result, ValueRange inputs,
+    ValueRange inits, ArrayRef<int64_t> dimensions,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
+  build(builder, result, TypeRange{}, inputs, inits, dimensions);
+  result.addAttributes(attributes);
+
+  // Add output types for `RankedTensorType` output arguments.
+  for (Value init : inits) {
+    Type initType = init.getType();
+    if (initType.isa<RankedTensorType>())
+      result.addTypes(initType);
+  }
+
+  if (bodyBuild)
+    buildGenericRegion(builder, result, inputs, inits, bodyBuild);
+}
+
 SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
   int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
   SmallVector<StringRef> iteratorTypes(inputRank,
@@ -1618,45 +1658,32 @@ TransposeOp::getRegionBuilder() {
   };
 }
 
-void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
-                               ::mlir::OperationState &odsState) {
-  Region *region = odsState.addRegion();
-
-  SmallVector<Type> argTypes;
-  SmallVector<Location> argLocs;
-  for (auto t : odsState.operands) {
-    argTypes.push_back(getElementTypeOrSelf(t));
-    argLocs.push_back(opBuilder.getUnknownLoc());
-  }
-
-  // RAII.
-  OpBuilder::InsertionGuard guard(opBuilder);
-  Block *body =
-      opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
-
-  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
-  getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
-}
-
-void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
-                        ::mlir::OperationState &odsState, Value input,
-                        Value init, DenseI64ArrayAttr permutation,
+void TransposeOp::build(::mlir::OpBuilder &builder,
+                        ::mlir::OperationState &result, Value input, Value init,
+                        DenseI64ArrayAttr permutation,
                         ArrayRef<NamedAttribute> attributes) {
-  odsState.addOperands(input);
-  odsState.addOperands(init);
-  odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
-  odsState.addAttributes(attributes);
-  odsState.addTypes(init.getType());
+  result.addOperands(input);
+  result.addOperands(init);
+  result.addAttribute(getPermutationAttrName(result.name), permutation);
+  result.addAttributes(attributes);
+
+  // Add output types for `RankedTensorType` output arguments.
+  Type initType = init.getType();
+  if (initType.isa<RankedTensorType>())
+    result.addTypes(initType);
 
-  createRegion(odsBuilder, odsState);
+  buildGenericRegion(builder, result, input, init,
+                     [&](OpBuilder &b, Location loc, ValueRange args) {
+                       b.create<linalg::YieldOp>(loc, args[0]);
+                     });
 }
 
-void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
-                        ::mlir::OperationState &odsState, Value input,
-                        Value init, ArrayRef<int64_t> permutation,
+void TransposeOp::build(::mlir::OpBuilder &builder,
+                        ::mlir::OperationState &result, Value input, Value init,
+                        ArrayRef<int64_t> permutation,
                         ArrayRef<NamedAttribute> attributes) {
-  build(odsBuilder, odsState, input, init,
-        odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
+  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
+        attributes);
 }
 
 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1666,8 +1693,13 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
           })))
     return failure();
 
-  OpBuilder opBuilder(parser.getContext());
-  createRegion(opBuilder, result);
+  (void)result.addRegion();
+  OpBuilder builder(parser.getContext());
+  buildGenericRegion(builder, result, /*inputs=*/result.operands,
+                     /*outputs=*/{},
+                     [&](OpBuilder &b, Location loc, ValueRange args) {
+                       b.create<linalg::YieldOp>(loc, args[0]);
+                     });
   return success();
 }