Add methods for building array attributes in Builder
authorLei Zhang <antiagainst@google.com>
Fri, 5 Apr 2019 19:19:22 +0000 (12:19 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 8 Apr 2019 01:19:56 +0000 (18:19 -0700)
    I32/I64/F32/F64/Str array attributes are commonly used in ops. It helps
    to have handy methods for them.

--

PiperOrigin-RevId: 242170569

mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/OpBase.td
mlir/lib/IR/Builders.cpp
mlir/lib/TableGen/Attribute.cpp
mlir/test/mlir-tblgen/op-attribute.td

index 4cbf4ee..a3f8ad5 100644 (file)
@@ -132,6 +132,12 @@ public:
   IntegerAttr getI32IntegerAttr(int32_t value);
   IntegerAttr getI64IntegerAttr(int64_t value);
 
+  ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
+  ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
+  ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
+  ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
+  ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
+
   // Affine expressions and affine maps.
   AffineExpr getAffineDimExpr(unsigned position);
   AffineExpr getAffineSymbolExpr(unsigned position);
index 791e99c..602d50c 100644 (file)
@@ -550,11 +550,22 @@ class TypedArrayAttrBase<Attr element, string description>: ArrayAttrBase<
 }
 
 def I32ArrayAttr : TypedArrayAttrBase<I32Attr,
-                                      "32-bit integer array attribute">;
+                                      "32-bit integer array attribute"> {
+  let constBuilderCall = "{0}.getI32ArrayAttr({1})";
+}
 def I64ArrayAttr : TypedArrayAttrBase<I64Attr,
-                                      "64-bit integer array attribute">;
-def F32ArrayAttr : TypedArrayAttrBase<F32Attr, "32-bit float array attribute">;
-def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute">;
+                                      "64-bit integer array attribute"> {
+  let constBuilderCall = "{0}.getI64ArrayAttr({1})";
+}
+def F32ArrayAttr : TypedArrayAttrBase<F32Attr, "32-bit float array attribute"> {
+  let constBuilderCall = "{0}.getF32ArrayAttr({1})";
+}
+def F64ArrayAttr : TypedArrayAttrBase<F64Attr, "64-bit float array attribute"> {
+  let constBuilderCall = "{0}.getF64ArrayAttr({1})";
+}
+def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute"> {
+  let constBuilderCall = "{0}.getStrArrayAttr({1})";
+}
 
 // Attributes containing functions.
 def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">,
index a0d9367..962fa34 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/Functional.h"
 using namespace mlir;
 
 Builder::Builder(Module *module) : context(module->getContext()) {}
@@ -202,6 +203,36 @@ ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect,
   return OpaqueElementsAttr::get(dialect, type, bytes);
 }
 
+ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
+  auto attrs = functional::map(
+      [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
+  auto attrs = functional::map(
+      [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
+  auto attrs = functional::map(
+      [this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
+  auto attrs = functional::map(
+      [this](double v) -> Attribute { return getF64FloatAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
+ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
+  auto attrs = functional::map(
+      [this](StringRef v) -> Attribute { return getStringAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
 Attribute Builder::getZeroAttr(Type type) {
   switch (type.getKind()) {
   case StandardTypes::F32:
index 26c2204..3791d2c 100644 (file)
@@ -103,9 +103,15 @@ bool tblgen::Attribute::isOptional() const {
 
 std::string tblgen::Attribute::getDefaultValueTemplate() const {
   assert(isConstBuildable() && "requiers constBuilderCall");
-  const auto *init = def->getValueInit("defaultValue");
+  StringRef defaultValue = getValueAsString(def->getValueInit("defaultValue"));
+  // TODO(antiagainst): This is a temporary hack to support array initializers
+  // because '{' is the special marker for placeholders for formatv. Remove this
+  // after switching to our own formatting utility and $-placeholders.
+  bool needsEscape =
+      defaultValue.startswith("{") && !defaultValue.startswith("{{");
+
   return llvm::formatv(getConstBuilderTemplate().str().c_str(), "{0}",
-                       getValueAsString(init));
+                       needsEscape ? "{" + defaultValue : defaultValue);
 }
 
 StringRef tblgen::Attribute::getTableGenDefName() const {
index 17c88dc..4b8a881 100644 (file)
@@ -9,6 +9,9 @@ def SomeAttr : Attr<CPred<"some-condition">, "some attribute kind"> {
   let constBuilderCall = "some-const-builder-call({0}, {1})";
 }
 
+// Test required, optional, default-valued attributes
+// ---
+
 def AOp : Op<"a_op", []> {
   let arguments = (ins
       SomeAttr:$aAttr,
@@ -100,6 +103,28 @@ def BOp : Op<"b_op", []> {
 // CHECK: if (!((tblgen_array_attr.isa<ArrayAttr>())))
 // CHECK: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
 
+// Test building constant values for array attribute kinds
+// ---
+
+def COp : Op<"c_op", []> {
+  let arguments = (ins
+    DefaultValuedAttr<I32ArrayAttr, "{1, 2}">:$i32_array_attr,
+    DefaultValuedAttr<I64ArrayAttr, "{3, 4}">:$i64_array_attr,
+    DefaultValuedAttr<F32ArrayAttr, "{5.f, 6.f}">:$f32_array_attr,
+    DefaultValuedAttr<F64ArrayAttr, "{7., 8.}">:$f64_array_attr,
+    DefaultValuedAttr<StrArrayAttr, "{\"a\", \"b\"}">:$str_array_attr
+  );
+}
+
+// CHECK-LABEL: COp definitions
+// CHECK: mlir::Builder(this->getContext()).getI32ArrayAttr({1, 2})
+// CHECK: mlir::Builder(this->getContext()).getI64ArrayAttr({3, 4})
+// CHECK: mlir::Builder(this->getContext()).getF32ArrayAttr({5.f, 6.f})
+// CHECK: mlir::Builder(this->getContext()).getF64ArrayAttr({7., 8.})
+// CHECK: mlir::Builder(this->getContext()).getStrArrayAttr({"a", "b"})
+
+// Test mixing operands and attributes in arbitrary order
+// ---
 
 def MixOperandsAndAttrs : Op<"mix_operands_and_attrs", []> {
   let arguments = (ins F32Attr:$attr, F32:$operand, F32Attr:$otherAttr, F32:$otherArg);