"""
return match(self, expr)
- def partition(self, expr: Expr, attrs=None) -> Expr:
+ def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
"""
Parition the expression into functions defined by this pattern
The expression to match.
attrs : Optional[Dict[str, Object]]
A dictionary of Attribute name/values to add to the paritioned function
+ check : Function
+ A function to perform more complicated checks on the matched expression.
+ Returns true if partitioning should proceed, false otherwise.
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
- return partition(self, expr, attrs)
+ return partition(self, expr, attrs, check)
def dominates(self, parent, path=None):
"""
return ffi.rewrite(tmp, expr)
-def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
+def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
"""
Parition the expression into a series of functions that match the pattern
The pattern to match
expr : tvm.relay.Expr
The expression to split into functions
- expr : Optional[Dict[str, Object]]
+ attrs : Optional[Dict[str, Object]]
A dict of attributes to apply to the partitioned function
+ check : Function
+ A function to perform more complicated checks on the matched expression.
+ Returns true if partitioning should proceed, false otherwise.
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
- return ffi.partition(pattern, expr, attrs)
+ return ffi.partition(pattern, expr, attrs, check)
class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre,
- const Map<std::string, ObjectRef>& attrs) {
+ const Map<std::string, ObjectRef>& attrs, PackedFunc check) {
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
attrs_ = attrs;
+ check_ = check;
return this->VisitExpr(pre);
}
Expr DispatchVisitExpr(const Expr& pre) override {
auto post = MixedModeMutator::DispatchVisitExpr(pre);
- if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) {
+ if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node &&
+ static_cast<bool>(check_(pre))) {
post = RewritePartition(groups_[gid_assignments_[pre]]);
}
return post;
Map<std::string, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
+ PackedFunc check_;
};
-Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
- return PatternPartitioner().Partition(pattern, expr, attrs);
+Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
+ PackedFunc check) {
+ return PatternPartitioner().Partition(pattern, expr, attrs, check);
}
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
- .set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
- return PartitionPattern(pattern, expr, attrs);
- });
+ .set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
+ PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); });
} // namespace relay
} // namespace tvm
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
+from tvm.relay.testing import run_opt_pass
import numpy as np
# NB: 1 corresponds to the C++ enum that specicfies this
def get_BN(x, var, mean, beta, gamma, eps = 1e-5):
return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta
-def test_parition_batchnorm():
+def test_partition_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))
-def test_parition_double_batchnorm():
+def test_partition_double_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
- # The paritioner doesn't replace duplicates, so we use two copies of the function
+ # The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)
+def test_partition_check():
+ pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+ def check(pre):
+ return pre.args[0].attrs.data_layout == "NCHW"
+
+ x = relay.var('input')
+ w = relay.var('weight')
+ conv2d = relay.op.nn.conv2d(x, w)
+ relu = relay.op.nn.relu(conv2d)
+
+ xf = relay.var('input')
+ wf = relay.var('weight')
+ conv2df = relay.op.nn.conv2d(xf, wf)
+ reluf = relay.op.nn.relu(conv2df)
+ func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")
+
+ reference = func(x, w)
+ partitioned = pattern.partition(relu, check=check)
+ assert tvm.ir.structural_equal(partitioned, reference)
+
+ conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
+ relu = relay.op.nn.relu(conv2d)
+ assert relu == pattern.partition(relu, check=check)
+
+def test_partition_check_types():
+ pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+ def check(pre):
+ conv = pre.args[0]
+ return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)
+
+ x = relay.var('input', shape=(1, 10, 10, 10))
+ w = relay.var('weight', shape=(10, 10, 3, 3))
+ conv2d = relay.op.nn.conv2d(x, w)
+ relu = relay.op.nn.relu(conv2d)
+ relu = run_opt_pass(relu, relay.transform.InferType())
+
+ partitioned = pattern.partition(relu, check=check)
+ assert partitioned.op.attrs["PartitionedFromPattern"] == "nn.conv2d_nn.relu_"
+
+ conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
+ relu = relay.op.nn.relu(conv2d)
+ relu = run_opt_pass(relu, relay.transform.InferType())
+ assert relu == pattern.partition(relu, check=check)
+
+ x = relay.var('input', shape=(2, 10, 10, 10))
+ w = relay.var('weight', shape=(10, 10, 3, 3))
+ conv2d = relay.op.nn.conv2d(x, w)
+ relu = relay.op.nn.relu(conv2d)
+ relu = run_opt_pass(relu, relay.transform.InferType())
+ assert relu == pattern.partition(relu, check=check)
+
+
if __name__ == "__main__":
test_match_op()
test_no_match_op()
test_algebraic_simplify()
test_partition_dominator()
test_quadruple_partition_dominator()
- test_parition_batchnorm()
- test_parition_double_batchnorm()
+ test_partition_batchnorm()
+ test_partition_double_batchnorm()
+ test_partition_check()
+ test_partition_check_types()