Extend AttrPattern to support CallNode and FunctionNode attributes (#5637)
authorMatthew Brookhart <mbrookhart@octoml.ai>
Thu, 21 May 2020 03:16:40 +0000 (20:16 -0700)
committerGitHub <noreply@github.com>
Thu, 21 May 2020 03:16:40 +0000 (12:16 +0900)
* Extend AttrPattern to support CallNode and FunctionNode attributes

* Update tutorial and add breaks

* add func attr test

docs/langref/relay_pattern.rst
python/tvm/relay/dataflow_pattern/__init__.py
src/relay/ir/dataflow_matcher.cc
tests/python/relay/test_dataflow_pattern.py

index 7f81b9b..f56d496 100644 (file)
@@ -41,7 +41,7 @@ There are quite a few properties that are worth matching of operators below we e
 The next example is a dense operation with any operator that is marked element-wise::
 
     def test_no_match_attr():
-        op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
+        op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
         op_pat = op(wildcard(), wildcard())
         x = relay.var('x')
         y = relay.var('y')
@@ -97,7 +97,7 @@ The high level design is to introduce a language of patterns for now we propose
             | *
             | pattern(pattern1, ... patternN)
             | has_type(pattern, type)
-            | has_attr(pattern, attr, attr_value)
+            | has_attr(pattern, attrs)
             | is_input(name)
             | pattern1 `|` pattern2
             | dominates(parent_pattern, path_pattern, child_pattern)
index 54fe80a..2582894 100644 (file)
@@ -61,23 +61,20 @@ class DFPattern(Node):
     def __truediv__(self, other):
         return is_op("divide")(self, other)
 
-    def has_attr(self, attr_name: str, attr_value):
+    def has_attr(self, attrs):
         """
         Add an attribute constraint to this pattern
 
         Parameters
         ----------
-        attr_name: str
-            The name of the attribute to match
-        attr_value: Any
-            The value of the attribute to match
+        attrs: Dict[str, Object]
 
         Returns
         -------
         result: tvm.relay.dataflow_pattern.DFPattern
             The resulting AttrPattern
         """
-        attrs = make_node("DictAttrs", **{attr_name: attr_value})
+        attrs = make_node("DictAttrs", **attrs)
         return AttrPattern(self, attrs)
 
     def has_type(self, ttype):
@@ -237,18 +234,18 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
     return TypePattern(pattern, ttype)
 
 
-def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern:
+def has_attr(attrs, pattern=None) -> DFPattern:
     """
     Syntatic sugar for creating an AttrPattern
 
     Parameters
     ----------
-    pattern: tvm.relay.dataflow_pattern.DFPattern
-        The input pattern.
-
-    attrs: tvm.Attrs
+    attrs: Dict[str, Object]
         The attributes to match
 
+    pattern: Optional[tvm.relay.dataflow_pattern.DFPattern]
+        The input pattern.
+
     Returns
     -------
     result: tvm.relay.dataflow_pattern.DFPattern
@@ -256,7 +253,7 @@ def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern:
     """
     if pattern is None:
         pattern = wildcard()
-    return pattern.has_attr(attr_name, attr_value)
+    return pattern.has_attr(attrs)
 
 
 def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern:
index 0cd3bf7..3f21581 100644 (file)
@@ -101,39 +101,73 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp
   return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
 }
 
+bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
+  switch (rhs.type_code()) {
+    case kDLInt:
+      if (auto* val = lhs.as<IntImmNode>()) {
+        return val->value == rhs.operator int64_t();
+      }
+      break;
+    case kDLFloat:
+      if (auto* val = lhs.as<FloatImmNode>()) {
+        return val->value == rhs.operator double();
+      }
+      break;
+    case kTVMStr:
+      std::cout << lhs << std::endl;
+      if (auto* val = lhs.as<tir::StringImmNode>()) {
+        return val->value == rhs.operator std::string();
+      } else if (auto* val = lhs.as<StringObj>()) {
+        return val->data == rhs.operator std::string();
+      }
+      break;
+    default:
+      CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
+  }
+  return false;
+}
+
 bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
   bool matches = false;
+  auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
   if (const auto* op_node = expr.as<OpNode>()) {
     Op op = GetRef<Op>(op_node);
-    auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
     for (auto kv : attributes) {
       auto attr_name = kv.first;
       auto attr_value = kv.second;
       auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
       if (op_map.count(op)) {
-        switch (op_map[op].type_code()) {
-          case kDLInt:
-            if (auto* val = kv.second.as<IntImmNode>()) {
-              matches = val->value == op_map[op].operator int64_t();
-            }
-            break;
-          case kDLFloat:
-            if (auto* val = kv.second.as<FloatImmNode>()) {
-              matches = val->value == op_map[op].operator double();
-            }
-            break;
-          case kTVMStr:
-            if (auto* val = kv.second.as<tir::StringImmNode>()) {
-              matches = val->value == op_map[op].operator std::string();
-            }
-            break;
-          default:
-            CHECK(false) << "Unsupported type in Type Pattern Node";
-        }
+        matches = MatchRetValue(attr_value, op_map[op]);
+      }
+    }
+  } else if (auto* op = expr.as<CallNode>()) {
+    matches = true;
+    // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this
+    // and replace the whole thing with a Visitor-based approach
+    ReflectionVTable* reflection = ReflectionVTable::Global();
+    auto attrs_node = const_cast<Object*>(op->attrs.get());
+    auto attr_names = reflection->ListAttrNames(attrs_node);
+    for (auto kv : attributes) {
+      if (matches &&
+          std::find(attr_names.begin(), attr_names.end(), kv.first) != attr_names.end()) {
+        matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, kv.first));
+      } else {
+        matches = false;
+        break;
+      }
+    }
+  } else if (auto* op = expr.as<FunctionNode>()) {
+    matches = true;
+    for (auto kv : attributes) {
+      if (matches && op->attrs->dict.count(kv.first)) {
+        matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]);
+      } else {
+        matches = false;
+        break;
       }
     }
   }
-  return matches;
+  return matches && VisitDFPattern(attr_pattern->pattern, expr);
 }
 
 Array<DFPattern> reverse(const Array<DFPattern>& args) {
index 4f3560c..411ef0f 100644 (file)
@@ -77,7 +77,7 @@ def test_TypePattern():
     assert ty_pat.type == ttype
 
 def test_AttrPattern():
-    op = is_op('add').has_attr("TOpPattern", K_ELEMWISE)
+    op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
     assert isinstance(op, AttrPattern)
     assert op.attrs["TOpPattern"] == K_ELEMWISE
 
@@ -225,19 +225,57 @@ def test_no_match_type():
     ty_pat = has_type(relay.TensorType((10, 10), "float32"))
     assert not ty_pat.match(x)
 
-def test_match_attr():
-    op = is_op('add').has_attr("TOpPattern", K_BROADCAST)
+def test_match_op_attr():
+    op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
     op_pat = op(wildcard(), wildcard())
     x = relay.var('x')
     y = relay.var('y')
     assert op_pat.match(x + y)
 
-def test_no_match_attr():
-    op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
+def test_no_match_op_attr():
+    op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
     op_pat = op(wildcard(), wildcard())
     x = relay.var('x')
     y = relay.var('y')
     assert not op_pat.match(relay.op.nn.dense(x, y))
+    op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
+    op_pat = op(wildcard(), wildcard())
+    x = relay.var('x')
+    y = relay.var('y')
+    assert not op_pat.match(x - y)
+
+def test_match_func_attr():
+    pattern = wildcard().has_attr({"Composite": "add"})
+    x = relay.var('x')
+    y = relay.var('y')
+    f = relay.Function([x, y], x + y).with_attr("Composite", "add")
+    assert pattern.match(f)
+
+def test_no_match_func_attr():
+    pattern = wildcard().has_attr({"Composite": "add"})
+    x = relay.var('x')
+    y = relay.var('y')
+
+    f = relay.Function([x, y], x + y).with_attr("RandomTest", "add")
+    assert not pattern.match(f)
+    f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias")
+    assert not pattern.match(f)
+
+def test_match_call_attr():
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
+    x = relay.var('x')
+    y = relay.var('y')
+    assert is_conv2d.match(relay.op.nn.conv2d(x, y))
+
+def test_no_match_call_attr():
+    x = relay.var('x')
+    y = relay.var('y')
+
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
+    assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
+
+    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
+    assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
 
 def test_match_diamond():
     # Pattern
@@ -301,7 +339,7 @@ def test_match_fake_diamond():
 def test_match_dominator():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
     reduction = is_op('add')(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
@@ -344,7 +382,7 @@ def test_match_dominator():
     
     # Fuzzy path/nested Diamond
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
     reduction = is_op('add')(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
@@ -361,7 +399,7 @@ def test_match_dominator():
 
 def test_not_match_dominator():
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
     reduction = is_op('add')(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
@@ -578,7 +616,7 @@ def test_quadruple_rewrite_dominator():
             self.weight = wildcard()
             
             is_conv2d = is_op('nn.conv2d')(self.inp, self.weight)
-            is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
+            is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
             reduction = is_op('add')(wildcard(), wildcard())
             self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)
 
@@ -740,7 +778,7 @@ def test_double_partition():
 def test_partition_dominator():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
     reduction = is_op('add')(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
 
@@ -765,7 +803,7 @@ def test_partition_dominator():
 def test_quadruple_partition_dominator():
     # Pattern
     is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
-    is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
+    is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
     reduction = is_op('add')(wildcard(), wildcard())
     diamond = dominates(is_conv2d, is_unary_elemwise, reduction)