[mlir][Linalg] Allow all build methods of Structured ops to specify additional attrib...
authorMaheshRavishankar <ravishankarm@google.com>
Mon, 23 Aug 2021 17:15:35 +0000 (10:15 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Mon, 23 Aug 2021 20:06:34 +0000 (13:06 -0700)
Differential Revision: https://reviews.llvm.org/D108338

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index 33f4992..1d4e6d5 100644 (file)
@@ -620,18 +620,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
       "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
       "StringRef":$libraryCall,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
       "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
       "StringRef":$doc, "StringRef":$libraryCall,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
       "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
       "ArrayRef<StringRef>":$iteratorTypes,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
     OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
       "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
-      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>
+      CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
index 6b65a9e..f4750ca 100644 (file)
@@ -502,13 +502,15 @@ void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs,
         builder.getAffineMapArrayAttr(indexingMaps),
         builder.getStrArrayAttr(iteratorTypes),
         doc.empty() ? StringAttr() : builder.getStringAttr(doc),
         libraryCall.empty() ? StringAttr()
                             : builder.getStringAttr(libraryCall));
+  result.addAttributes(attributes);
   if (!bodyBuild)
     return;
 
@@ -527,30 +529,33 @@ void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
-        iteratorTypes, doc, libraryCall, bodyBuild);
+        iteratorTypes, doc, libraryCall, bodyBuild, attributes);
 }
 
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, ValueRange inputs,
     ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
         /*doc=*/"",
-        /*libraryCall=*/"", bodyBuild);
+        /*libraryCall=*/"", bodyBuild, attributes);
 }
 
 void GenericOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
     ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
     ArrayRef<StringRef> iteratorTypes,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
+    ArrayRef<NamedAttribute> attributes) {
   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
         iteratorTypes,
         /*doc=*/"",
-        /*libraryCall=*/"", bodyBuild);
+        /*libraryCall=*/"", bodyBuild, attributes);
 }
 
 static void print(OpAsmPrinter &p, GenericOp op) {
index 471961f..743bdbd 100644 (file)
@@ -169,7 +169,8 @@ It has one output.
 // ODS-LABEL: def Test7Op
 // ODS:         OpBuilder<
 // ODS:           (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-// ODS:            "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b)
+// ODS:            "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b,
+// ODS:           CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)
 // ODS:           $_state.addAttribute("attr_a", attr_a);
 // ODS:           $_state.addAttribute("attr_b", attr_b);
 //
index 1bdb5b8..590f17f 100644 (file)
@@ -1910,7 +1910,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
       let skipDefaultBuilders = 1;
       let builders = [
         OpBuilder<
-        (ins "ValueRange":$inputs, "ValueRange":$outputs),
+        (ins "ValueRange":$inputs, "ValueRange":$outputs,
+             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
         [{{
           $_state.addOperands(inputs);
           $_state.addOperands(outputs);
@@ -1919,6 +1920,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
+          $_state.addAttributes(attributes);
           createAndFillStructuredOpRegion<{0}>(
             $_builder,
             $_state,
@@ -1927,7 +1929,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         }]>,
         OpBuilder<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-             "ValueRange":$outputs),
+             "ValueRange":$outputs,
+             CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
         [{{
           $_state.addOperands(inputs);
           $_state.addOperands(outputs);
@@ -1937,6 +1940,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
+          $_state.addAttributes(attributes);
           createAndFillStructuredOpRegion<{0}>(
             $_builder,
             $_state,
@@ -2020,7 +2024,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
     const char *builderFmt = R"FMT(
       , OpBuilder<
       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-           "ValueRange":$outputs, {1}),
+           "ValueRange":$outputs, {1},
+           CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
       [{{
         $_state.addOperands(inputs);
         $_state.addOperands(outputs);
@@ -2030,6 +2035,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           $_builder.getI32VectorAttr({{
             static_cast<int32_t>(inputs.size()),
             static_cast<int32_t>(outputs.size())}));
+        $_state.addAttributes(attributes);
         createAndFillStructuredOpRegion<{0}>(
           $_builder,
           $_state,
index a0eb1de..98e90b6 100644 (file)
@@ -457,7 +457,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
     let skipDefaultBuilders = 1;
     let builders = [
       OpBuilder<
-      (ins "ValueRange":$inputs, "ValueRange":$outputs),
+      (ins "ValueRange":$inputs, "ValueRange":$outputs,
+            CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
       [{{
         $_state.addOperands(inputs);
         $_state.addOperands(outputs);
@@ -471,6 +472,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
           $_builder.getI32VectorAttr({{
             static_cast<int32_t>(inputs.size()),
             static_cast<int32_t>(outputs.size())}));
+        $_state.addAttributes(attributes);
         createAndFillStructuredOpRegion<{0}>(
           $_builder,
           $_state,
@@ -539,7 +541,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
 static const char structuredOpBuilderFormat[] = R"FMT(
   , OpBuilder<
   (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
-       "ValueRange":$outputs, {1}),
+       "ValueRange":$outputs, {1},
+       CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
   [{{
     $_state.addOperands(inputs);
     $_state.addOperands(outputs);
@@ -555,6 +558,7 @@ static const char structuredOpBuilderFormat[] = R"FMT(
       TypeRange(inputs),
       TypeRange(outputs));
     {2}
+    $_state.addAttributes(attributes);
   }]>
 )FMT";