[TableGen] Unify cOp and tAttr into NativeCodeCall
authorLei Zhang <antiagainst@google.com>
Mon, 22 Apr 2019 21:13:45 +0000 (14:13 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Wed, 24 Apr 2019 05:02:00 +0000 (22:02 -0700)
    Both cOp and tAttr were used to perform some native C++ code expression.
    Unifying them simplifies the concepts and reduces cognitive burden.

--

PiperOrigin-RevId: 244731946

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Pattern.cpp
mlir/test/mlir-tblgen/pattern-NativeCodeCall.td [new file with mode: 0644]
mlir/test/mlir-tblgen/pattern-bound-symbol.td
mlir/test/mlir-tblgen/pattern-tAttr.td [deleted file]
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 9c00505..8e9fb57 100644 (file)
@@ -984,42 +984,25 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
   dag benefitAdded = (addBenefit 0)> :
   Pattern<pattern, [result], preds, benefitAdded>;
 
-// Attribute transformation. This is the base class to specify a transformation
-// of matched attributes. Used on the output attribute of a rewrite rule.
+// Native code call wrapper. This allows invoking an arbitrary C++ expression
+// to create an op operand/attribute or replace an op result.
 //
 // ## Placeholders
 //
-// The following special placeholders are supported
+// If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`,
+// the wrapped expression can take special placeholders listed below:
 //
 // * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
 // * `$_self` will be replaced with the entity this transformer is attached to.
 //   E.g., with the definition `def transform : tAttr<$_self...>`, `$_self` in
 //   `transform:$attr` will be replaced by  the value for `$att`.
-
-// Besides, if this is used as a DAG node, i.e., `(tAttr <arg0>, ..., <argN>)`,
-// then positional placeholders are supported and placholder `$N` will be
-// replaced by `<argN>`.
-class tAttr<code transform> {
-  code attrTransform = transform;
-}
-
-// Native code op creation method. This allows performing an arbitrary op
-// creation/replacement by invoking a C++ function with the operands and
-// attributes. The function specified needs to have the signature:
 //
-//   void f(Operation *op, ArrayRef<Value *> operands,
-//          ArrayRef<Attribute> attrs, PatternRewriter &rewriter);
-//
-// The operands and attributes are passed to this function in the order of
-// the DAG specified. It is the responsibility of this function to replace the
-// matched op(s) using the rewriter. This is intended for the long tail op
-// creation and replacement.
-// TODO(antiagainst): Unify this and tAttr into a single creation mechanism.
-class cOp<string f> {
-  // Function to invoke with the given arguments to construct a new op. The
-  // operands will be passed to the function first followed by the attributes
-  // (as in the function signature above and required by Op arguments).
-  string function = f;
+// If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`,
+// then positional placeholders are also supported; placeholder `$N` in the
+// wrapped C++ expression will be replaced by `<argN>`.
+
+class NativeCodeCall<string expr> {
+  string expression = expr;
 }
 
 //===----------------------------------------------------------------------===//
index 22bc7b3..e833e49 100644 (file)
@@ -77,8 +77,8 @@ public:
   // specifies an attribute constraint.
   bool isAttrMatcher() const;
 
-  // Returns true if this DAG leaf is transforming an attribute.
-  bool isAttrTransformer() const;
+  // Returns true if this DAG leaf is wrapping native code call.
+  bool isNativeCodeCall() const;
 
   // Returns true if this DAG leaf is specifying a constant attribute.
   bool isConstantAttr() const;
@@ -100,9 +100,9 @@ public:
   // leaf is an operand/attribute matcher and asserts otherwise.
   std::string getConditionTemplate() const;
 
-  // Returns the transformation template inside this DAG leaf. Assumes the
-  // leaf is an attribute transformation and asserts otherwise.
-  std::string getTransformationTemplate() const;
+  // Returns the native code call template inside this DAG leaf.
+  // Precondition: isNativeCodeCall()
+  llvm::StringRef getNativeCodeTemplate() const;
 
 private:
   // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
@@ -162,10 +162,6 @@ public:
   // Returns the specified name of the `index`-th argument.
   llvm::StringRef getArgName(unsigned index) const;
 
-  // Returns the native builder for the pattern.
-  // Precondition: isNativeCodeBuilder.
-  llvm::StringRef getNativeCodeBuilder() const;
-
   // Returns true if this DAG construct means to replace with an existing SSA
   // value.
   bool isReplaceWithValue() const;
@@ -173,16 +169,12 @@ public:
   // Returns true if this DAG node is the `verifyUnusedValue` directive.
   bool isVerifyUnusedValue() const;
 
-  // Returns true if this DAG construct is meant to invoke a native code
-  // constructor.
-  bool isNativeCodeBuilder() const;
-
-  // Returns true if this DAG construct is transforming attributes.
-  bool isAttrTransformer() const;
+  // Returns true if this DAG node is wrapping native code call.
+  bool isNativeCodeCall() const;
 
-  // Returns the transformation template inside this DAG construct.
-  // Precondition: isAttrTransformer.
-  std::string getTransformationTemplate() const;
+  // Returns the native code call template inside this DAG node.
+  // Precondition: isNativeCodeCall()
+  llvm::StringRef getNativeCodeTemplate() const;
 
 private:
   const llvm::DagInit *node; // nullptr means null DagNode
index 92267d1..420f9d2 100644 (file)
@@ -44,8 +44,8 @@ bool tblgen::DagLeaf::isAttrMatcher() const {
   return isSubClassOf("AttrConstraint");
 }
 
-bool tblgen::DagLeaf::isAttrTransformer() const {
-  return isSubClassOf("tAttr");
+bool tblgen::DagLeaf::isNativeCodeCall() const {
+  return isSubClassOf("NativeCodeCall");
 }
 
 bool tblgen::DagLeaf::isConstantAttr() const {
@@ -76,12 +76,9 @@ std::string tblgen::DagLeaf::getConditionTemplate() const {
   return getAsConstraint().getConditionTemplate();
 }
 
-std::string tblgen::DagLeaf::getTransformationTemplate() const {
-  assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
-  return cast<llvm::DefInit>(def)
-      ->getDef()
-      ->getValueAsString("attrTransform")
-      .str();
+llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
+  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
+  return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
 }
 
 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
@@ -90,19 +87,17 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
   return false;
 }
 
-bool tblgen::DagNode::isAttrTransformer() const {
-  auto op = node->getOperator();
-  if (!op || !isa<llvm::DefInit>(op))
-    return false;
-  return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("tAttr");
+bool tblgen::DagNode::isNativeCodeCall() const {
+  if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
+    return defInit->getDef()->isSubClassOf("NativeCodeCall");
+  return false;
 }
 
-std::string tblgen::DagNode::getTransformationTemplate() const {
-  assert(isAttrTransformer() && "the DAG leaf must be attribute transformer");
+llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
+  assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
   return cast<llvm::DefInit>(node->getOperator())
       ->getDef()
-      ->getValueAsString("attrTransform")
-      .str();
+      ->getValueAsString("expression");
 }
 
 llvm::StringRef tblgen::DagNode::getOpName() const {
@@ -156,17 +151,6 @@ bool tblgen::DagNode::isVerifyUnusedValue() const {
   return dagOpDef->getName() == "verifyUnusedValue";
 }
 
-bool tblgen::DagNode::isNativeCodeBuilder() const {
-  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
-  return dagOpDef->isSubClassOf("cOp");
-}
-
-llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const {
-  assert(isNativeCodeBuilder());
-  auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
-  return dagOpDef->getValueAsString("function");
-}
-
 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
     : def(*def), recordOpMap(mapper) {
   collectBoundArguments(getSourcePattern());
diff --git a/mlir/test/mlir-tblgen/pattern-NativeCodeCall.td b/mlir/test/mlir-tblgen/pattern-NativeCodeCall.td
new file mode 100644 (file)
index 0000000..317284d
--- /dev/null
@@ -0,0 +1,35 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def CreateOperand : NativeCodeCall<"buildOperand($0, $1)">;
+def CreateArrayAttr : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
+def CreateOpResult : NativeCodeCall<"buildOp($0, $1)">;
+
+def NS_AOp : Op<"a_op", []> {
+  let arguments = (ins I32:$input1, I32:$input2, I32Attr:$attr);
+  let results = (outs I32:$output);
+}
+
+def NS_BOp : Op<"b_op", []> {
+  let arguments = (ins I32:$input, I32Attr:$attr);
+  let results = (outs I32:$output);
+}
+
+def TestCreateOpResult : Pat<
+  (NS_BOp $input, $attr),
+  (CreateOpResult $input, $attr)>;
+
+// CHECK-LABEL: TestCreateOpResult
+
+// CHECK: rewriter.replaceOp(op, {buildOp(s.input, s.attr)});
+
+def TestCreateOperandAndAttr : Pat<
+  (NS_AOp $input1, $input2, $attr),
+  (NS_BOp (CreateOperand $input1, $input2), (CreateArrayAttr $attr, $attr))>;
+
+// CHECK-LABEL: TestCreateOperandAndAttr
+
+// CHECK:      rewriter.create<NS::BOp>
+// CHECK-NEXT: buildOperand(s.input1, s.input2),
+// CHECK-NEXT: rewriter.getArrayAttr({s.attr, s.attr})
index 0d1a53e..805a032 100644 (file)
@@ -23,10 +23,11 @@ def OpD : Op<"op_d", []> {
 }
 
 def hasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
+def getResult0 : NativeCodeCall<"$_self->getResult(0)">;
 
 def : Pattern<(OpA:$res_a $operand, $attr),
               [(OpC:$res_c (OpB:$res_b $operand)),
-               (OpD $res_b, $res_c, $res_a, $attr)],
+               (OpD $res_b, $res_c, getResult0:$res_a, $attr)],
               [(hasOneUse $res_a)]>;
 
 // CHECK-LABEL: GeneratedConvert0
@@ -59,5 +60,5 @@ def : Pattern<(OpA:$res_a $operand, $attr),
 // CHECK:   auto vOpD0 = rewriter.create<OpD>(
 // CHECK:     /*input1=*/res_b,
 // CHECK:     /*input2=*/res_c,
-// CHECK:     /*input3=*/s.res_a,
+// CHECK:     /*input3=*/s.res_a->getResult(0),
 // CHECK:     /*attr=*/s.attr
diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td
deleted file mode 100644 (file)
index 02a1256..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
-
-include "mlir/IR/OpBase.td"
-
-// Create a Type and Attribute.
-def T : BuildableType<"buildT()">;
-def T_Attr : TypedAttrBase<T, "Attribute",CPred<"true">, "attribute of T type">;
-def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
-def T_Compose_Attr : tAttr<"$_builder.getArrayAttr({$0, $1})">;
-
-// Define ops to rewrite.
-def Y_Op : Op<"y.op"> {
-  let arguments = (ins T_Attr:$attrName);
-  let results = (outs I32:$result);
-}
-def Z_Op : Op<"z.op"> {
-  let arguments = (ins T_Attr:$attrName1, T_Attr:$attrName2);
-  let results = (outs I32:$result);
-}
-
-// Define rewrite pattern.
-def : Pat<(Y_Op $attr1), (Y_Op (T_Compose_Attr $attr1, T_Const_Attr))>;
-// CHECK-LABEL: struct GeneratedConvert0
-// CHECK: void rewrite(
-// CHECK:   /*attrName=*/rewriter.getArrayAttr({s.attr1, rewriter.getAttribute(rewriter.buildT(), attrValue)})
-
-def : Pat<(Z_Op $attr1, $attr2), (Y_Op (T_Compose_Attr $attr1, $attr2))>;
-// CHECK-LABEL: struct GeneratedConvert1
-// CHECK: void rewrite(
-// CHECK:   /*attrName=*/rewriter.getArrayAttr({s.attr1, s.attr2})
index 797d336..bbc961f 100644 (file)
@@ -166,9 +166,9 @@ private:
   std::string handleRewritePattern(DagNode resultTree, int resultIndex,
                                    int depth);
 
-  // Emits the C++ statement to replace the matched DAG with a native C++ built
-  // value.
-  std::string emitReplaceWithNativeBuilder(DagNode resultTree);
+  // Emits the C++ statement to replace the matched DAG with a value built via
+  // calling native C++ code.
+  std::string emitReplaceWithNativeCodeCall(DagNode resultTree);
 
   // Returns the C++ expression referencing the old value serving as the
   // replacement.
@@ -193,9 +193,6 @@ private:
   // `patArgName` is used to bound the argument to the source pattern.
   std::string handleOpArgument(DagLeaf leaf, llvm::StringRef patArgName);
 
-  // Returns the C++ expression to build an argument from the given DAG `tree`.
-  std::string handleOpArgument(DagNode tree);
-
   // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
   // is already bound.
   void addSymbol(DagNode node);
@@ -515,8 +512,8 @@ std::string PatternEmitter::getUniqueValueName(const Operator *op) {
 
 std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
                                                  int resultIndex, int depth) {
-  if (resultTree.isNativeCodeBuilder())
-    return emitReplaceWithNativeBuilder(resultTree);
+  if (resultTree.isNativeCodeCall())
+    return emitReplaceWithNativeCodeCall(resultTree);
 
   if (resultTree.isVerifyUnusedValue()) {
     if (depth > 0) {
@@ -584,22 +581,18 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
     return result;
   }
-  if (leaf.isAttrTransformer()) {
-    return tgfmt(leaf.getTransformationTemplate(),
-                 &rewriteCtx.withSelf(result));
+  if (leaf.isNativeCodeCall()) {
+    return tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(result));
   }
   PrintFatalError(loc, "unhandled case when rewriting op");
 }
 
-std::string PatternEmitter::handleOpArgument(DagNode tree) {
-  if (!tree.isAttrTransformer()) {
-    PrintFatalError(loc, "only tAttr is supported in nested dag attribute");
-  }
-  auto fmt = tree.getTransformationTemplate();
+std::string PatternEmitter::emitReplaceWithNativeCodeCall(DagNode tree) {
+  auto fmt = tree.getNativeCodeTemplate();
   // TODO(fengliuai): replace formatv arguments with the exact specified args.
   SmallVector<std::string, 8> attrs(8);
   if (tree.getNumArgs() > 8) {
-    PrintFatalError(loc, "unsupported tAttr argument numbers: " +
+    PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
                              Twine(tree.getNumArgs()));
   }
   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) {
@@ -692,7 +685,9 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
   // Create the builder call for the result.
   // Add operands.
   int i = 0;
-  for (auto operand : resultOp.getOperands()) {
+  for (int e = resultOp.getNumOperands(); i < e; ++i) {
+    const auto &operand = resultOp.getOperand(i);
+
     // Start each operand on its own line.
     (os << ",\n").indent(6);
 
@@ -702,11 +697,15 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
     if (tree.isNestedDagArg(i)) {
       os << childNodeNames[i];
     } else {
-      os << resolveSymbol(tree.getArgName(i));
+      DagLeaf leaf = tree.getArgAsLeaf(i);
+      auto symbol = resolveSymbol(tree.getArgName(i));
+      if (leaf.isNativeCodeCall()) {
+        os << tgfmt(leaf.getNativeCodeTemplate(), &rewriteCtx.withSelf(symbol));
+      } else {
+        os << symbol;
+      }
     }
-
     // TODO(jpienaar): verify types
-    ++i;
   }
 
   // Add attributes.
@@ -716,7 +715,11 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
     // The argument in the op definition.
     auto opArgName = resultOp.getArgName(i);
     if (auto subTree = tree.getArgAsNestedDag(i)) {
-      os << formatv("/*{0}=*/{1}", opArgName, handleOpArgument(subTree));
+      if (!subTree.isNativeCodeCall())
+        PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
+                             "for creating attribute");
+      os << formatv("/*{0}=*/{1}", opArgName,
+                    emitReplaceWithNativeCodeCall(subTree));
     } else {
       auto leaf = tree.getArgAsLeaf(i);
       // The argument in the result DAG pattern.
@@ -739,36 +742,6 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
   return resultValue;
 }
 
-std::string PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) {
-  // The variable's name for holding the result of this native builder call
-  std::string value = formatv("v{0}", nextValueId++).str();
-
-  os.indent(4) << "auto " << value << " = " << resultTree.getNativeCodeBuilder()
-               << "(op, {";
-  const auto &boundedValues = pattern.getSourcePatternBoundArgs();
-  bool first = true;
-  bool printingAttr = false;
-  for (int i = 0, e = resultTree.getNumArgs(); i != e; ++i) {
-    auto name = resultTree.getArgName(i);
-    pattern.ensureBoundInSourcePattern(name);
-    const auto &val = boundedValues.find(name);
-    if (val->second.dyn_cast<NamedAttribute *>() && !printingAttr) {
-      os << "}, {";
-      first = true;
-      printingAttr = true;
-    }
-    if (!first)
-      os << ",";
-    os << getBoundSymbol(name);
-    first = false;
-  }
-  if (!printingAttr)
-    os << "},{";
-  os << "}, rewriter);\n";
-
-  return value;
-}
-
 void PatternEmitter::emit(StringRef rewriteName, Record *p,
                           RecordOperatorMap *mapper, raw_ostream &os) {
   PatternEmitter(p, mapper, os).emit(rewriteName);