def __truediv__(self, other):
return is_op("divide")(self, other)
- def has_attr(self, attr_name: str, attr_value):
+ def has_attr(self, attrs):
"""
Add an attribute constraint to this pattern
Parameters
----------
- attr_name: str
- The name of the attribute to match
- attr_value: Any
- The value of the attribute to match
+ attrs: Dict[str, Object]
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting AttrPattern
"""
- attrs = make_node("DictAttrs", **{attr_name: attr_value})
+ attrs = make_node("DictAttrs", **attrs)
return AttrPattern(self, attrs)
def has_type(self, ttype):
return TypePattern(pattern, ttype)
-def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern:
+def has_attr(attrs, pattern=None) -> DFPattern:
"""
Syntatic sugar for creating an AttrPattern
Parameters
----------
- pattern: tvm.relay.dataflow_pattern.DFPattern
- The input pattern.
-
- attrs: tvm.Attrs
+ attrs: Dict[str, Object]
The attributes to match
+ pattern: Optional[tvm.relay.dataflow_pattern.DFPattern]
+ The input pattern.
+
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
"""
if pattern is None:
pattern = wildcard()
- return pattern.has_attr(attr_name, attr_value)
+ return pattern.has_attr(attrs)
def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern:
return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
}
+bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
+ switch (rhs.type_code()) {
+ case kDLInt:
+ if (auto* val = lhs.as<IntImmNode>()) {
+ return val->value == rhs.operator int64_t();
+ }
+ break;
+ case kDLFloat:
+ if (auto* val = lhs.as<FloatImmNode>()) {
+ return val->value == rhs.operator double();
+ }
+ break;
+ case kTVMStr:
+ std::cout << lhs << std::endl;
+ if (auto* val = lhs.as<tir::StringImmNode>()) {
+ return val->value == rhs.operator std::string();
+ } else if (auto* val = lhs.as<StringObj>()) {
+ return val->data == rhs.operator std::string();
+ }
+ break;
+ default:
+ CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
+ }
+ return false;
+}
+
bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
bool matches = false;
+ auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
if (const auto* op_node = expr.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
- auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
for (auto kv : attributes) {
auto attr_name = kv.first;
auto attr_value = kv.second;
auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
- switch (op_map[op].type_code()) {
- case kDLInt:
- if (auto* val = kv.second.as<IntImmNode>()) {
- matches = val->value == op_map[op].operator int64_t();
- }
- break;
- case kDLFloat:
- if (auto* val = kv.second.as<FloatImmNode>()) {
- matches = val->value == op_map[op].operator double();
- }
- break;
- case kTVMStr:
- if (auto* val = kv.second.as<tir::StringImmNode>()) {
- matches = val->value == op_map[op].operator std::string();
- }
- break;
- default:
- CHECK(false) << "Unsupported type in Type Pattern Node";
- }
+ matches = MatchRetValue(attr_value, op_map[op]);
+ }
+ }
+ } else if (auto* op = expr.as<CallNode>()) {
+ matches = true;
+ // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this
+ // and replace the whole thing with a Visitor-based approach
+ ReflectionVTable* reflection = ReflectionVTable::Global();
+ auto attrs_node = const_cast<Object*>(op->attrs.get());
+ auto attr_names = reflection->ListAttrNames(attrs_node);
+ for (auto kv : attributes) {
+ if (matches &&
+ std::find(attr_names.begin(), attr_names.end(), kv.first) != attr_names.end()) {
+ matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, kv.first));
+ } else {
+ matches = false;
+ break;
+ }
+ }
+ } else if (auto* op = expr.as<FunctionNode>()) {
+ matches = true;
+ for (auto kv : attributes) {
+ if (matches && op->attrs->dict.count(kv.first)) {
+ matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]);
+ } else {
+ matches = false;
+ break;
}
}
}
- return matches;
+ return matches && VisitDFPattern(attr_pattern->pattern, expr);
}
Array<DFPattern> reverse(const Array<DFPattern>& args) {
assert ty_pat.type == ttype
def test_AttrPattern():
- op = is_op('add').has_attr("TOpPattern", K_ELEMWISE)
+ op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
assert isinstance(op, AttrPattern)
assert op.attrs["TOpPattern"] == K_ELEMWISE
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert not ty_pat.match(x)
-def test_match_attr():
- op = is_op('add').has_attr("TOpPattern", K_BROADCAST)
+def test_match_op_attr():
+ op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert op_pat.match(x + y)
-def test_no_match_attr():
- op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
+def test_no_match_op_attr():
+ op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert not op_pat.match(relay.op.nn.dense(x, y))
+ op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
+ op_pat = op(wildcard(), wildcard())
+ x = relay.var('x')
+ y = relay.var('y')
+ assert not op_pat.match(x - y)
+
+def test_match_func_attr():
+ pattern = wildcard().has_attr({"Composite": "add"})
+ x = relay.var('x')
+ y = relay.var('y')
+ f = relay.Function([x, y], x + y).with_attr("Composite", "add")
+ assert pattern.match(f)
+
+def test_no_match_func_attr():
+ pattern = wildcard().has_attr({"Composite": "add"})
+ x = relay.var('x')
+ y = relay.var('y')
+
+ f = relay.Function([x, y], x + y).with_attr("RandomTest", "add")
+ assert not pattern.match(f)
+ f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias")
+ assert not pattern.match(f)
+
+def test_match_call_attr():
+ is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
+ x = relay.var('x')
+ y = relay.var('y')
+ assert is_conv2d.match(relay.op.nn.conv2d(x, y))
+
+def test_no_match_call_attr():
+ x = relay.var('x')
+ y = relay.var('y')
+
+ is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
+ assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
+
+ is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
+ assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
def test_match_diamond():
# Pattern
def test_match_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
- is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
# Fuzzy path/nested Diamond
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
- is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
def test_not_match_dominator():
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
- is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
self.weight = wildcard()
is_conv2d = is_op('nn.conv2d')(self.inp, self.weight)
- is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)
def test_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
- is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)
def test_quadruple_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
- is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
+ is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)