[ODS] Generate builders taking unwrapped value and defaults for attributes
authorLei Zhang <antiagainst@google.com>
Mon, 2 Dec 2019 17:33:24 +0000 (09:33 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 2 Dec 2019 17:33:57 +0000 (09:33 -0800)
Existing builders generated by ODS require attributes to be passed
in as mlir::Attribute or its subclasses. This is okay foraggregate-
parameter builders, which is primarily to be used by programmatic
C++ code generation; it is inconvenient for separate-parameter
builders meant to be called in manually written C++ code because
it requires developers to wrap raw values into mlir::Attribute by
themselves.

This CL extends to generate additional builder methods that
take raw values for attributes and handles the wrapping in the
builder implementation. Additionally, if an attribute appears
late in the arguments list and has a default value, the default
value is supplied in the declaration if possible.

PiperOrigin-RevId: 283355919

mlir/g3doc/OpDefinitions.md
mlir/include/mlir/TableGen/Attribute.h
mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Attribute.cpp
mlir/test/mlir-tblgen/op-attribute.td
mlir/test/mlir-tblgen/op-decl.td
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp

index ea79496..2586559 100644 (file)
@@ -382,27 +382,86 @@ def OpWithInferTypeInterfaceOp : Op<...
     [DeclareOpInterfaceMethods<MyInterface>]> { ... }
 ```
 
-### Custom builder methods
+### Builder methods
 
-For each operation, there are two builders automatically generated based on the
-arguments and returns types:
+For each operation, there are a few builders automatically generated based on
+the arguments and returns types. For example, given the following op definition:
 
-```c++
-static void build(Builder *, OperationState &tblgen_state,
-                  Type <result0-name>, Type <result1-name>, ...,
-                  Value <arg0-name>, Value <arg1-name>, ...,
-                  Attribute <attr0-name>, Attribute <attr1-name>, ...);
+```tablegen
+def MyOp : ... {
+  let arguments = (ins
+    I32:$i32_operand,
+    F32:$f32_operand,
+    ...,
 
-static void build(Builder *, OperationState &tblgen_state,
+    I32Attr:$i32_attr,
+    F32Attr:$f32_attr,
+    ...
+  );
+
+  let results = (outs
+    I32:$i32_result,
+    F32:$f32_result,
+    ...
+  );
+}
+```
+
+The following builders are generated:
+
+```c++
+// All result-types/operands/attributes have one aggregate parameter.
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
                   ArrayRef<Type> resultTypes,
                   ArrayRef<Value> operands,
                   ArrayRef<NamedAttribute> attributes);
+
+// Each result-type/operand/attribute has a separate parameter. The parameters
+// for attributes are of mlir::Attribute types.
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+                  Type i32_result, Type f32_result, ...,
+                  Value *i32_operand, Value *f32_operand, ...,
+                  IntegerAttr i32_attr, FloatAttr f32_attr, ...);
+
+// Each result-type/operand/attribute has a separate parameter. The parameters
+// for attributes are raw values unwrapped with mlir::Attribute instances.
+// (Note that this builder will not always be generated. See the following
+// explanation for more details.)
+static void build(Builder *tblgen_builder, OperationState &tblgen_state,
+                  Type i32_result, Type f32_result, ...,
+                  Value *i32_operand, Value *f32_operand, ...,
+                  APInt i32_attr, StringRef f32_attr, ...);
+
+// (And potentially others depending on the specific op.)
 ```
 
-The above cases make sure basic uniformity so that we can create ops using the
+The first form provides basic uniformity so that we can create ops using the
 same form regardless of the exact op. This is particularly useful for
 implementing declarative pattern rewrites.
 
+The second and third forms are good for use in manually written code given that
+they provide better guarantee via signatures.
+
+The third form will be generated if any of the op's attribute has different
+`Attr.returnType` from `Attr.storageType` and we know how to build an attribute
+from an unwrapped value (i.e., `Attr.constBuilderCall` is defined.)
+Additionally, for the third form, if an attribute appearing later in the
+`arguments` list has a default value, the default value will be supplied in the
+declaration. This works for `BoolAttr`, `StrAttr`, `EnumAttr` for now and the
+list can grow in the future. So if possible, default valued attribute should be
+placed at the end of the `arguments` list to leverage this feature. (This
+behavior is essentially due to C++ function parameter default value placement
+restrictions.) Otherwise, the builder of the third form will still be generated
+but default values for the attributes not at the end of the `arguments` list
+will not be supplied in the builder's signature.
+
+And there may potentially exist other builders depending on the specific op;
+please refer to the
+[generated C++ file](#run-mlir-tblgen-to-see-the-generated-content) for the
+complete list.
+
+#### Custom builder methods
+
 However, if the above cases cannot satisfy all needs, you can define additional
 convenience build methods with `OpBuilder`.
 
index 60f9515..242376e 100644 (file)
@@ -81,10 +81,10 @@ public:
   // built upon.
   Attribute getBaseAttr() const;
 
-  // Returns whether this attribute has a default value's initializer.
-  bool hasDefaultValueInitializer() const;
-  // Returns the default value's initializer for this attribute.
-  StringRef getDefaultValueInitializer() const;
+  // Returns whether this attribute has a default value.
+  bool hasDefaultValue() const;
+  // Returns the default value for this attribute.
+  StringRef getDefaultValue() const;
 
   // Returns whether this attribute is optional.
   bool isOptional() const;
index 7b636dd..89fd4ed 100644 (file)
@@ -103,6 +103,7 @@ public:
   llvm::iterator_range<attribute_iterator> getAttributes() const;
 
   int getNumAttributes() const { return attributes.size(); }
+  int getNumNativeAttributes() const { return numNativeAttributes; }
 
   // Op attribute accessors.
   NamedAttribute &getAttribute(int index) { return attributes[index]; }
index c2b673a..ec946a8 100644 (file)
@@ -107,12 +107,12 @@ tblgen::Attribute tblgen::Attribute::getBaseAttr() const {
   return *this;
 }
 
-bool tblgen::Attribute::hasDefaultValueInitializer() const {
+bool tblgen::Attribute::hasDefaultValue() const {
   const auto *init = def->getValueInit("defaultValue");
   return !getValueAsString(init).empty();
 }
 
-StringRef tblgen::Attribute::getDefaultValueInitializer() const {
+StringRef tblgen::Attribute::getDefaultValue() const {
   const auto *init = def->getValueInit("defaultValue");
   return getValueAsString(init);
 }
index 61fe70f..d5c6a4a 100644 (file)
@@ -1,4 +1,5 @@
-// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
 
 include "mlir/IR/OpBase.td"
 
@@ -26,51 +27,55 @@ def AOp : NS_Op<"a_op", []> {
   );
 }
 
-// CHECK-LABEL: AOp definitions
+// DEF-LABEL: AOp definitions
 
 // Test getter methods
 // ---
 
-// CHECK:      some-return-type AOp::aAttr() {
-// CHECK-NEXT:   auto attr = this->getAttr("aAttr").cast<some-attr-kind>();
-// CHECK-NEXT:   return attr.some-convert-from-storage();
+// DEF:      some-return-type AOp::aAttr() {
+// DEF-NEXT:   auto attr = this->getAttr("aAttr").cast<some-attr-kind>();
+// DEF-NEXT:   return attr.some-convert-from-storage();
 
-// CHECK:      some-return-type AOp::bAttr() {
-// CHECK-NEXT:   auto attr = this->getAttr("bAttr").dyn_cast_or_null<some-attr-kind>();
-// CHECK-NEXT:   if (!attr)
-// CHECK-NEXT:       return some-const-builder-call(mlir::Builder(this->getContext()), 4.2).some-convert-from-storage();
-// CHECK-NEXT:   return attr.some-convert-from-storage();
+// DEF:      some-return-type AOp::bAttr() {
+// DEF-NEXT:   auto attr = this->getAttr("bAttr").dyn_cast_or_null<some-attr-kind>();
+// DEF-NEXT:   if (!attr)
+// DEF-NEXT:       return some-const-builder-call(mlir::Builder(this->getContext()), 4.2).some-convert-from-storage();
+// DEF-NEXT:   return attr.some-convert-from-storage();
 
-// CHECK:      Optional<some-return-type> AOp::cAttr() {
-// CHECK-NEXT:   auto attr = this->getAttr("cAttr").dyn_cast_or_null<some-attr-kind>();
-// CHECK-NEXT:   return attr ? Optional<some-return-type>(attr.some-convert-from-storage()) : (llvm::None);
+// DEF:      Optional<some-return-type> AOp::cAttr() {
+// DEF-NEXT:   auto attr = this->getAttr("cAttr").dyn_cast_or_null<some-attr-kind>();
+// DEF-NEXT:   return attr ? Optional<some-return-type>(attr.some-convert-from-storage()) : (llvm::None);
 
 // Test build methods
 // ---
 
-// CHECK:      void AOp::build(
-// CHECK:        tblgen_state.addAttribute("aAttr", aAttr);
-// CHECK:        tblgen_state.addAttribute("bAttr", bAttr);
-// CHECK:        if (cAttr) {
-// CHECK-NEXT:     tblgen_state.addAttribute("cAttr", cAttr);
+// DEF:      void AOp::build(
+// DEF:        tblgen_state.addAttribute("aAttr", aAttr);
+// DEF:        tblgen_state.addAttribute("bAttr", bAttr);
+// DEF:        if (cAttr) {
+// DEF-NEXT:     tblgen_state.addAttribute("cAttr", cAttr);
 
-// CHECK:      void AOp::build(
-// CHECK-SAME:   ArrayRef<NamedAttribute> attributes
-// CHECK:      tblgen_state.addAttributes(attributes);
+// DEF:      void AOp::build(
+// DEF-SAME:   some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr
+// DEF:        tblgen_state.addAttribute("aAttr", some-const-builder-call((*tblgen_builder), aAttr));
+
+// DEF:      void AOp::build(
+// DEF-SAME:   ArrayRef<NamedAttribute> attributes
+// DEF:      tblgen_state.addAttributes(attributes);
 
 // Test verify method
 // ---
 
-// CHECK:      AOp::verify()
-// CHECK:      auto tblgen_aAttr = this->getAttr("aAttr");
-// CHECK-NEXT: if (!tblgen_aAttr) return emitOpError("requires attribute 'aAttr'");
-// CHECK:        if (!((some-condition))) return emitOpError("attribute 'aAttr' failed to satisfy constraint: some attribute kind");
-// CHECK:      auto tblgen_bAttr = this->getAttr("bAttr");
-// CHECK-NEXT: if (tblgen_bAttr) {
-// CHECK-NEXT:   if (!((some-condition))) return emitOpError("attribute 'bAttr' failed to satisfy constraint: some attribute kind");
-// CHECK:      auto tblgen_cAttr = this->getAttr("cAttr");
-// CHECK-NEXT: if (tblgen_cAttr) {
-// CHECK-NEXT:   if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind");
+// DEF:      AOp::verify()
+// DEF:      auto tblgen_aAttr = this->getAttr("aAttr");
+// DEF-NEXT: if (!tblgen_aAttr) return emitOpError("requires attribute 'aAttr'");
+// DEF:        if (!((some-condition))) return emitOpError("attribute 'aAttr' failed to satisfy constraint: some attribute kind");
+// DEF:      auto tblgen_bAttr = this->getAttr("bAttr");
+// DEF-NEXT: if (tblgen_bAttr) {
+// DEF-NEXT:   if (!((some-condition))) return emitOpError("attribute 'bAttr' failed to satisfy constraint: some attribute kind");
+// DEF:      auto tblgen_cAttr = this->getAttr("cAttr");
+// DEF-NEXT: if (tblgen_cAttr) {
+// DEF-NEXT:   if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind");
 
 def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
 
@@ -95,37 +100,37 @@ def BOp : NS_Op<"b_op", []> {
 // Test common attribute kind getters' return types
 // ---
 
-// CHECK: Attribute BOp::any_attr()
-// CHECK: bool BOp::bool_attr()
-// CHECK: APInt BOp::i32_attr()
-// CHECK: APInt BOp::i64_attr()
-// CHECK: APFloat BOp::f32_attr()
-// CHECK: APFloat BOp::f64_attr()
-// CHECK: StringRef BOp::str_attr()
-// CHECK: ElementsAttr BOp::elements_attr()
-// CHECK: StringRef BOp::function_attr()
-// CHECK: SomeType BOp::type_attr()
-// CHECK: ArrayAttr BOp::array_attr()
-// CHECK: ArrayAttr BOp::some_attr_array()
-// CHECK: Type BOp::type_attr()
+// DEF: Attribute BOp::any_attr()
+// DEF: bool BOp::bool_attr()
+// DEF: APInt BOp::i32_attr()
+// DEF: APInt BOp::i64_attr()
+// DEF: APFloat BOp::f32_attr()
+// DEF: APFloat BOp::f64_attr()
+// DEF: StringRef BOp::str_attr()
+// DEF: ElementsAttr BOp::elements_attr()
+// DEF: StringRef BOp::function_attr()
+// DEF: SomeType BOp::type_attr()
+// DEF: ArrayAttr BOp::array_attr()
+// DEF: ArrayAttr BOp::some_attr_array()
+// DEF: Type BOp::type_attr()
 
 // Test common attribute kinds' constraints
 // ---
 
-// CHECK-LABEL: BOp::verify
-// CHECK: if (!((true)))
-// CHECK: if (!((tblgen_bool_attr.isa<BoolAttr>())))
-// CHECK: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isInteger(32)))))
-// CHECK: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isInteger(64)))))
-// CHECK: if (!(((tblgen_f32_attr.isa<FloatAttr>())) && ((tblgen_f32_attr.cast<FloatAttr>().getType().isF32()))))
-// CHECK: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
-// CHECK: if (!((tblgen_str_attr.isa<StringAttr>())))
-// CHECK: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
-// CHECK: if (!((tblgen_function_attr.isa<FlatSymbolRefAttr>())))
-// CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
-// CHECK: if (!((tblgen_array_attr.isa<ArrayAttr>())))
-// CHECK: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
-// CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<Type>()))))
+// DEF-LABEL: BOp::verify
+// DEF: if (!((true)))
+// DEF: if (!((tblgen_bool_attr.isa<BoolAttr>())))
+// DEF: if (!(((tblgen_i32_attr.isa<IntegerAttr>())) && ((tblgen_i32_attr.cast<IntegerAttr>().getType().isInteger(32)))))
+// DEF: if (!(((tblgen_i64_attr.isa<IntegerAttr>())) && ((tblgen_i64_attr.cast<IntegerAttr>().getType().isInteger(64)))))
+// DEF: if (!(((tblgen_f32_attr.isa<FloatAttr>())) && ((tblgen_f32_attr.cast<FloatAttr>().getType().isF32()))))
+// DEF: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
+// DEF: if (!((tblgen_str_attr.isa<StringAttr>())))
+// DEF: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
+// DEF: if (!((tblgen_function_attr.isa<FlatSymbolRefAttr>())))
+// DEF: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
+// DEF: if (!((tblgen_array_attr.isa<ArrayAttr>())))
+// DEF: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
+// DEF: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<Type>()))))
 
 // Test building constant values for array attribute kinds
 // ---
@@ -140,12 +145,70 @@ def COp : NS_Op<"c_op", []> {
   );
 }
 
-// CHECK-LABEL: COp definitions
-// CHECK: mlir::Builder(this->getContext()).getI32ArrayAttr({1, 2})
-// CHECK: mlir::Builder(this->getContext()).getI64ArrayAttr({3, 4})
-// CHECK: mlir::Builder(this->getContext()).getF32ArrayAttr({5.f, 6.f})
-// CHECK: mlir::Builder(this->getContext()).getF64ArrayAttr({7., 8.})
-// CHECK: mlir::Builder(this->getContext()).getStrArrayAttr({"a", "b"})
+// DEF-LABEL: COp definitions
+// DEF: mlir::Builder(this->getContext()).getI32ArrayAttr({1, 2})
+// DEF: mlir::Builder(this->getContext()).getI64ArrayAttr({3, 4})
+// DEF: mlir::Builder(this->getContext()).getF32ArrayAttr({5.f, 6.f})
+// DEF: mlir::Builder(this->getContext()).getF64ArrayAttr({7., 8.})
+// DEF: mlir::Builder(this->getContext()).getStrArrayAttr({"a", "b"})
+
+
+// Test builder method which takes unwrapped values for attributes
+// ---
+
+def I32Case5:  I32EnumAttrCase<"case5", 5>;
+def I32Case10: I32EnumAttrCase<"case10", 10>;
+
+def SomeI32Enum: I32EnumAttr<
+  "SomeI32Enum", "", [I32Case5, I32Case10]>;
+
+def DOp : NS_Op<"d_op", []> {
+  let arguments = (ins
+    I32Attr:$i32_attr,
+    F64Attr:$f64_attr,
+    StrAttr:$str_attr,
+    BoolAttr:$bool_attr,
+    SomeI32Enum:$enum_attr,
+    DefaultValuedAttr<I32Attr, "42">:$dv_i32_attr,
+    DefaultValuedAttr<F64Attr, "8.">:$dv_f64_attr,
+    DefaultValuedAttr<StrAttr, "abc">:$dv_str_attr,
+    DefaultValuedAttr<BoolAttr, "true">:$dv_bool_attr,
+    DefaultValuedAttr<SomeI32Enum, "::SomeI32Enum::case5">:$dv_enum_attr
+  );
+}
+
+// DECL-LABEL: DOp declarations
+// DECL: static void build({{.*}}, APInt i32_attr, APFloat f64_attr,
+// DECL-SAME: StringRef str_attr, bool bool_attr, ::SomeI32Enum enum_attr,
+// DECL-SAME: APInt dv_i32_attr, APFloat dv_f64_attr,
+// DECL-SAME: StringRef dv_str_attr = "abc", bool dv_bool_attr = true,
+// DECL-SAME: ::SomeI32Enum dv_enum_attr = ::SomeI32Enum::case5)
+
+// Test that only default valued attributes at the end of the arguments
+// list get default values in the builder signature
+// ---
+
+def EOp : NS_Op<"e_op", []> {
+  let arguments = (ins
+    I32Attr:$i32_attr,
+    DefaultValuedAttr<I32Attr, "42">:$dv_i32_attr,
+    F64Attr:$f64_attr,
+    DefaultValuedAttr<F64Attr, "8.">:$dv_f64_attr,
+    StrAttr:$str_attr,
+    DefaultValuedAttr<StrAttr, "abc">:$dv_str_attr,
+    BoolAttr:$bool_attr,
+    DefaultValuedAttr<BoolAttr, "true">:$dv_bool_attr,
+    SomeI32Enum:$enum_attr,
+    DefaultValuedAttr<SomeI32Enum, "::SomeI32Enum::case5">:$dv_enum_attr
+  );
+}
+
+// DECL-LABEL: EOp declarations
+// DECL: static void build({{.*}}, APInt i32_attr, APInt dv_i32_attr,
+// DECL-SAME: APFloat f64_attr, APFloat dv_f64_attr,
+// DECL-SAME: StringRef str_attr, StringRef dv_str_attr,
+// DECL-SAME: bool bool_attr, bool dv_bool_attr,
+// DECL-SAME: ::SomeI32Enum enum_attr, ::SomeI32Enum dv_enum_attr = ::SomeI32Enum::case5)
 
 // Test mixing operands and attributes in arbitrary order
 // ---
@@ -154,12 +217,12 @@ def MixOperandsAndAttrs : NS_Op<"mix_operands_and_attrs", []> {
   let arguments = (ins F32Attr:$attr, F32:$operand, F32Attr:$otherAttr, F32:$otherArg);
 }
 
-// CHECK-LABEL: MixOperandsAndAttrs definitions
-// CHECK-DAG: Value *MixOperandsAndAttrs::operand()
-// CHECK-DAG: Value *MixOperandsAndAttrs::otherArg()
-// CHECK-DAG: void MixOperandsAndAttrs::build(Builder *, OperationState &tblgen_state, FloatAttr attr, Value *operand, FloatAttr otherAttr, Value *otherArg)
-// CHECK-DAG: APFloat MixOperandsAndAttrs::attr()
-// CHECK-DAG: APFloat MixOperandsAndAttrs::otherAttr()
+// DEF-LABEL: MixOperandsAndAttrs definitions
+// DEF-DAG: Value *MixOperandsAndAttrs::operand()
+// DEF-DAG: Value *MixOperandsAndAttrs::otherArg()
+// DEF-DAG: void MixOperandsAndAttrs::build(Builder *tblgen_builder, OperationState &tblgen_state, FloatAttr attr, Value *operand, FloatAttr otherAttr, Value *otherArg)
+// DEF-DAG: APFloat MixOperandsAndAttrs::attr()
+// DEF-DAG: APFloat MixOperandsAndAttrs::otherAttr()
 
 // Test unit attributes.
 // ---
@@ -168,8 +231,8 @@ def UnitAttrOp : NS_Op<"unit_attr_op", []> {
   let arguments = (ins UnitAttr:$attr);
 }
 
-// CHECK-LABEL: UnitAttrOp definitions
-// CHECK: bool UnitAttrOp::attr() {
-// CHECK:   return {{.*}} != nullptr
+// DEF-LABEL: UnitAttrOp definitions
+// DEF: bool UnitAttrOp::attr() {
+// DEF:   return {{.*}} != nullptr
 
-// CHECK: build(Builder *, OperationState &tblgen_state, /*optional*/UnitAttr attr)
+// DEF: build(Builder *tblgen_builder, OperationState &tblgen_state, /*optional*/UnitAttr attr)
index 672cfef..e66ea43 100644 (file)
@@ -66,7 +66,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> {
 // CHECK:   APInt attr1();
 // CHECK:   Optional< APFloat > attr2();
 // CHECK:   static void build(Value *val);
-// CHECK:   static void build(Builder *, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ArrayRef<Value *> b, IntegerAttr attr1, /*optional*/FloatAttr attr2);
+// CHECK:   static void build(Builder *tblgen_builder, OperationState &tblgen_state, Type r, ArrayRef<Type> s, Value *a, ArrayRef<Value *> b, IntegerAttr attr1, /*optional*/FloatAttr attr2);
 // CHECK:   static void build(Builder *, OperationState &tblgen_state, ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes);
 // CHECK:   static ParseResult parse(OpAsmParser &parser, OperationState &result);
 // CHECK:   void print(OpAsmPrinter &p);
index 9979480..007d747 100644 (file)
@@ -23,9 +23,9 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
 }
 
 // CHECK-LABEL: OpB definitions
-// CHECK: void OpB::build(Builder *, OperationState &tblgen_state, Type y, Value *x)
+// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Type y, Value *x)
 // CHECK:   tblgen_state.addTypes(y);
-// CHECK: void OpB::build(Builder *, OperationState &tblgen_state, Value *x)
+// CHECK: void OpB::build(Builder *tblgen_builder, OperationState &tblgen_state, Value *x)
 // CHECK:   tblgen_state.addTypes({x->getType()});
 
 def OpC : NS_Op<"three_normal_result_op", []> {
@@ -33,7 +33,7 @@ def OpC : NS_Op<"three_normal_result_op", []> {
 }
 
 // CHECK-LABEL: OpC definitions
-// CHECK:       void OpC::build(Builder *, OperationState &tblgen_state, Type x, Type resultType1, Type z)
+// CHECK:       void OpC::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, Type resultType1, Type z)
 // CHECK-NEXT:   tblgen_state.addTypes(x)
 // CHECK-NEXT:   tblgen_state.addTypes(resultType1)
 // CHECK-NEXT:   tblgen_state.addTypes(z)
@@ -73,7 +73,7 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
 
 // CHECK-LABEL: OpG definitions
 
-// CHECK:      void OpG::build(Builder *, OperationState &tblgen_state, Type x, ArrayRef<Type> y)
+// CHECK:      void OpG::build(Builder *tblgen_builder, OperationState &tblgen_state, Type x, ArrayRef<Type> y)
 // CHECK-NEXT:   tblgen_state.addTypes(x);
 // CHECK-NEXT:   tblgen_state.addTypes(y);
 
@@ -105,5 +105,5 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
   let results = (outs AnyTensor:$result);
 }
 
-// CHECK-LABEL: OpK::build(Builder *, OperationState &tblgen_state, ArrayRef<Value *> input)
+// CHECK-LABEL: OpK::build(Builder *tblgen_builder, OperationState &tblgen_state, ArrayRef<Value *> input)
 // CHECK: tblgen_state.addTypes({input.front()->getType()});
index 864f773..dcecd1c 100644 (file)
@@ -32,6 +32,8 @@
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
+#define DEBUG_TYPE "mlir-tblgen-opdefgen"
+
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
@@ -113,6 +115,14 @@ static std::string getArgumentName(const Operator &op, int index) {
     return formatv("{0}_{1}", generatedArgName, index);
 }
 
+// Returns true if we can use unwrapped value for the given `attr` in builders.
+static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
+  return attr.getReturnType() != attr.getStorageType() &&
+         // We need to wrap the raw value into an attribute in the builder impl
+         // so we need to make sure that the attribute specifies how to do that.
+         !attr.getConstBuilderTemplate().empty();
+}
+
 namespace {
 // Simple RAII helper for defining ifdef-undef-endif scopes.
 class IfDefScope {
@@ -506,46 +516,66 @@ private:
   void genBuilder();
 
   // Generates the build() method that takes each result-type/operand/attribute
-  // as a stand-alone parameter. This build() method also requires specifying
-  // result types for all results.
-  void genSeparateParamBuilder();
+  // as a stand-alone parameter. Attributes will take wrapped mlir::Attribute
+  // values. The generated build() method also requires specifying result types
+  // for all results.
+  void genSeparateParamWrappedAttrBuilder();
+
+  // Generates the build() method that takes each result-type/operand/attribute
+  // as a stand-alone parameter. Attributes will take raw values without
+  // mlir::Attribute wrapper. The generated build() method also requires
+  // specifying result types for all results.
+  void genSeparateParamUnwrappedAttrBuilder();
 
   // Generates the build() method that takes a single parameter for all the
   // result types and a separate parameter for each operand/attribute.
   void genCollectiveTypeParamBuilder();
 
   // Generates the build() method that takes each operand/attribute as a
-  // stand-alone parameter. This build() method uses first operand's type
-  // as all results' types.
+  // stand-alone parameter. The generated build() method uses first operand's
+  // type as all results' types.
   void genUseOperandAsResultTypeSeparateParamBuilder();
 
   // Generates the build() method that takes all operands/attributes
-  // collectively as one parameter. This build() method uses first operand's
-  // type as all results' types.
+  // collectively as one parameter. The generated build() method uses first
+  // operand's type as all results' types.
   void genUseOperandAsResultTypeCollectiveParamBuilder();
 
   // Generates the build() method that takes each operand/attribute as a
-  // stand-alone parameter. This build() method uses first attribute's type
-  // as all result's types.
+  // stand-alone parameter. The generated build() method uses first attribute's
+  // type as all result's types.
   void genUseAttrAsResultTypeBuilder();
 
   // Generates the build() method that takes all result types collectively as
   // one parameter. Similarly for operands and attributes.
   void genCollectiveParamBuilder();
 
-  enum class TypeParamKind { None, Separate, Collective };
+  // The kind of parameter to generate for result types in builders.
+  enum class TypeParamKind {
+    None,       // No result type in parameter list.
+    Separate,   // A separate parameter for each result type.
+    Collective, // An ArrayRef<Type> for all result types.
+  };
+
+  // The kind of parameter to generate for attributes in builders.
+  enum class AttrParamKind {
+    WrappedAttr,    // A wrapped MLIR Attribute instance.
+    UnwrappedValue, // A raw value without MLIR Attribute wrapper.
+  };
 
   // Builds the parameter list for build() method of this op. This method writes
-  // to `paramList` the comma-separated parameter list. If `includeResultTypes`
-  // is true then `paramList` will also contain the parameters for all results
-  // and `resultTypeNames` will be populated with the parameter name for each
-  // result type.
+  // to `paramList` the comma-separated parameter list and updates
+  // `resultTypeNames` with the names for parameters for specifying result
+  // types. The given `typeParamKind` and `attrParamKind` controls how result
+  // types and attributes are placed in the parameter list.
   void buildParamList(std::string &paramList,
                       SmallVectorImpl<std::string> &resultTypeNames,
-                      TypeParamKind kind);
+                      TypeParamKind typeParamKind,
+                      AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
 
   // Adds op arguments and regions into operation state for build() methods.
-  void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body);
+  void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
+                                              bool isRawValueAttr = false);
 
   // Generates canonicalizer declaration for the operation.
   void genCanonicalizerDecls();
@@ -650,18 +680,18 @@ void OpEmitter::genAttrGetters() {
 
     // Return the queried attribute with the correct return type.
     auto attrVal =
-        (attr.hasDefaultValueInitializer() || attr.isOptional())
+        (attr.hasDefaultValue() || attr.isOptional())
             ? formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()", name,
                       attr.getStorageType())
             : formatv("this->getAttr(\"{0}\").cast<{1}>()", name,
                       attr.getStorageType());
     body << "  auto attr = " << attrVal << ";\n";
-    if (attr.hasDefaultValueInitializer()) {
+    if (attr.hasDefaultValue()) {
       // Returns the default value if not set.
       // TODO: this is inefficient, we are recreating the attribute for every
       // call. This should be set instead.
-      std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx,
-                                       attr.getDefaultValueInitializer());
+      std::string defaultValue =
+          tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue());
       body << "    if (!attr)\n      return "
            << tgfmt(attr.getConvertFromStorageCall(),
                     &fctx.withSelf(defaultValue))
@@ -847,7 +877,7 @@ void OpEmitter::genNamedRegionGetters() {
   }
 }
 
-void OpEmitter::genSeparateParamBuilder() {
+void OpEmitter::genSeparateParamWrappedAttrBuilder() {
   std::string paramList;
   llvm::SmallVector<std::string, 4> resultNames;
   buildParamList(paramList, resultNames, TypeParamKind::Separate);
@@ -862,6 +892,42 @@ void OpEmitter::genSeparateParamBuilder() {
   }
 }
 
+void OpEmitter::genSeparateParamUnwrappedAttrBuilder() {
+  // If this op does not have native attributes at all, return directly to avoid
+  // redefining builders.
+  if (op.getNumNativeAttributes() == 0)
+    return;
+
+  bool canGenerate = false;
+  // We are generating builders that take raw values for attributes. We need to
+  // make sure the native attributes have a meaningful "unwrapped" value type
+  // different from the wrapped mlir::Attribute type to avoid redefining
+  // builders. This checks for the op has at least one such native attribute.
+  for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
+    NamedAttribute &namedAttr = op.getAttribute(i);
+    if (canUseUnwrappedRawValue(namedAttr.attr)) {
+      canGenerate = true;
+      break;
+    }
+  }
+  if (!canGenerate)
+    return;
+
+  std::string paramList;
+  llvm::SmallVector<std::string, 4> resultNames;
+  buildParamList(paramList, resultNames, TypeParamKind::Separate,
+                 AttrParamKind::UnwrappedValue);
+
+  auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
+  genCodeForAddingArgAndRegionForBuilder(m.body(), /*isRawValueAttr=*/true);
+
+  // Push all result types to the operation state.
+  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
+    m.body() << "  " << builderOpState << ".addTypes(" << resultNames[i]
+             << ");\n";
+  }
+}
+
 void OpEmitter::genCollectiveTypeParamBuilder() {
   auto numResults = op.getNumResults();
 
@@ -1006,7 +1072,8 @@ void OpEmitter::genBuilder() {
   // We generate three builders here:
   // 1. one having a stand-alone parameter for each result type / operand /
   //    attribute, and
-  genSeparateParamBuilder();
+  genSeparateParamWrappedAttrBuilder();
+  genSeparateParamUnwrappedAttrBuilder();
   // 2. one having a stand-alone parameter for each operand / attribute and
   //    an aggregated parameter for all result types, and
   genCollectiveTypeParamBuilder();
@@ -1069,15 +1136,16 @@ void OpEmitter::genCollectiveParamBuilder() {
 
 void OpEmitter::buildParamList(std::string &paramList,
                                SmallVectorImpl<std::string> &resultTypeNames,
-                               TypeParamKind kind) {
+                               TypeParamKind typeParamKind,
+                               AttrParamKind attrParamKind) {
   resultTypeNames.clear();
   auto numResults = op.getNumResults();
   resultTypeNames.reserve(numResults);
 
-  paramList = "Builder *, OperationState &";
+  paramList = "Builder *tblgen_builder, OperationState &";
   paramList.append(builderOpState);
 
-  switch (kind) {
+  switch (typeParamKind) {
   case TypeParamKind::None:
     break;
   case TypeParamKind::Separate: {
@@ -1100,10 +1168,36 @@ void OpEmitter::buildParamList(std::string &paramList,
   } break;
   }
 
+  // Add parameters for all arguments (operands and attributes).
+
   int numOperands = 0;
   int numAttrs = 0;
 
-  // Add parameters for all arguments (operands and attributes).
+  int defaultValuedAttrStartIndex = op.getNumArgs();
+  if (attrParamKind == AttrParamKind::UnwrappedValue) {
+    // Calculate the start index from which we can attach default values in the
+    // builder declaration.
+    for (int i = op.getNumArgs() - 1; i >= 0; --i) {
+      auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
+      if (!namedAttr || !namedAttr->attr.hasDefaultValue())
+        break;
+
+      if (!canUseUnwrappedRawValue(namedAttr->attr))
+        break;
+
+      // Creating an APInt requires us to provide bitwidth, value, and
+      // signedness, which is complicated compared to others. Similarly
+      // for APFloat.
+      // TODO(b/144412160) Adjust the 'returnType' field of such attributes
+      // to support them.
+      StringRef retType = namedAttr->attr.getReturnType();
+      if (retType == "APInt" || retType == "APFloat")
+        break;
+
+      defaultValuedAttrStartIndex = i;
+    }
+  }
+
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
     if (argument.is<tblgen::NamedTypeConstraint *>()) {
@@ -1113,24 +1207,46 @@ void OpEmitter::buildParamList(std::string &paramList,
       paramList.append(getArgumentName(op, numOperands));
       ++numOperands;
     } else {
-      // TODO(antiagainst): Support default initializer for attributes
       const auto &namedAttr = op.getAttribute(numAttrs);
       const auto &attr = namedAttr.attr;
       paramList.append(", ");
+
       if (attr.isOptional())
         paramList.append("/*optional*/");
-      paramList.append(attr.getStorageType());
+
+      switch (attrParamKind) {
+      case AttrParamKind::WrappedAttr:
+        paramList.append(attr.getStorageType());
+        break;
+      case AttrParamKind::UnwrappedValue:
+        if (canUseUnwrappedRawValue(attr)) {
+          paramList.append(attr.getReturnType());
+        } else {
+          paramList.append(attr.getStorageType());
+        }
+        break;
+      }
       paramList.append(" ");
       paramList.append(namedAttr.name);
+
+      // Attach default value if requested and possible.
+      if (attrParamKind == AttrParamKind::UnwrappedValue &&
+          i >= defaultValuedAttrStartIndex) {
+        bool isString = attr.getReturnType() == "StringRef";
+        paramList.append(" = ");
+        if (isString)
+          paramList.append("\"");
+        paramList.append(attr.getDefaultValue());
+        if (isString)
+          paramList.append("\"");
+      }
       ++numAttrs;
     }
   }
-
-  if (numOperands + numAttrs != op.getNumArgs())
-    PrintFatalError("op arguments must be either operands or attributes");
 }
 
-void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) {
+void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
+                                                       bool isRawValueAttr) {
   // Push all operands to the result
   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
     body << "  " << builderOpState << ".addOperands(" << getArgumentName(op, i)
@@ -1139,13 +1255,25 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) {
 
   // Push all attributes to the result
   for (const auto &namedAttr : op.getAttributes()) {
-    if (!namedAttr.attr.isDerivedAttr()) {
-      bool emitNotNullCheck = namedAttr.attr.isOptional();
+    auto &attr = namedAttr.attr;
+    if (!attr.isDerivedAttr()) {
+      bool emitNotNullCheck = attr.isOptional();
       if (emitNotNullCheck) {
         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
       }
-      body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
-                      namedAttr.name);
+      if (isRawValueAttr and canUseUnwrappedRawValue(attr)) {
+        // If this is a raw value, then we need to wrap it in an Attribute
+        // instance.
+        FmtContext fctx;
+        fctx.withBuilder("(*tblgen_builder)");
+        std::string value =
+            tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name);
+        body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
+                        namedAttr.name, value);
+      } else {
+        body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
+                        namedAttr.name);
+      }
       if (emitNotNullCheck) {
         body << "  }\n";
       }
@@ -1282,8 +1410,7 @@ void OpEmitter::genVerifier() {
     body << formatv("  auto {0} = this->getAttr(\"{1}\");\n", varName,
                     attrName);
 
-    bool allowMissingAttr =
-        attr.hasDefaultValueInitializer() || attr.isOptional();
+    bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
     if (allowMissingAttr) {
       // If the attribute has a default value, then only verify the predicate if
       // set. This does effectively assume that the default value is valid.
index a2cace7..d321b20 100644 (file)
@@ -342,10 +342,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
       attr.getStorageType(), namedAttr->name);
 
   // TODO(antiagainst): This should use getter method to avoid duplication.
-  if (attr.hasDefaultValueInitializer()) {
+  if (attr.hasDefaultValue()) {
     os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
                       << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
-                               attr.getDefaultValueInitializer())
+                               attr.getDefaultValue())
                       << ";\n";
   } else if (attr.isOptional()) {
     // For a missing attribute that is optional according to definition, we