"""
return _to_shape(node.checked_type.shape)
+def _operator_idx_inc(expr, count_meta, operator_current_idx):
+ """Increase operator index
+ """
+ if isinstance(expr, relay.expr.Constant):
+ operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx
+ else:
+ operator_current_idx = operator_current_idx + 1
+ return operator_current_idx
+
class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST.
"""
class BT(Exception):
pass
-def get_subgraph(expr, start_name, stop_name):
+def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta):
""" We assume stop_name only appears once for simplicity.
This constraint will be lifted in the future.
bitpack_start and bitpack_end are both inclusive.
bitpack_start = op.op.get('annotation.bitpack_start')
bitpack_end = op.op.get('annotation.bitpack_end')
anf = run_opt_pass(expr, transform.ToANormalForm())
- def _recursion(anf, start_found, stop_found):
+ operator_current_idx = 0
+ def _recursion(anf, start_found, stop_found, operator_current_idx):
""" Helper to obtain the subgraph.
"""
if isinstance(anf, relay.expr.Function):
return relay.expr.Function(anf.params,
- _recursion(anf.body, start_found, stop_found),
+ _recursion(anf.body, start_found, stop_found,
+ operator_current_idx),
anf.ret_type, anf.type_params, anf.attrs)
elif isinstance(anf, relay.expr.Let):
value = anf.value
if isinstance(value, relay.expr.Call):
if isinstance(value.op, relay.op.Op):
if value.op.name == start_name and not start_found:
- value = relay.expr.Call(bitpack_start, [value])
- start_found = True
+ if operator_current_idx == start_name_idx or start_name_idx is None:
+ value = relay.expr.Call(bitpack_start, [value])
+ start_found = True
elif value.op.name == stop_name:
- raise BT()
+ if operator_current_idx == stop_name_idx or stop_name_idx is None:
+ raise BT()
+
+ operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx)
+
try:
- return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found))
+ return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found,
+ operator_current_idx))
except BT:
assert start_found
assert not stop_found
assert start_found
assert stop_found
return anf
- annotated = _recursion(anf, False, False)
+ annotated = _recursion(anf, False, False, operator_current_idx)
return run_opt_pass(annotated, transform.ToGraphNormalForm())
def graph_pack(expr,
cfactor,
weight_bits,
start_name="nn.max_pool2d",
- stop_name="nn.global_avg_pool2d"):
+ stop_name="nn.global_avg_pool2d",
+ start_name_idx=None,
+ stop_name_idx=None,
+ count_meta=False):
"""Pack the graph into batch&channel packed format.
Parameters
The bit-width of the weights.
start_name: str, optional
- Start packing from certain known node.
+ Start packing from certain known node when start_name_idx is None.
stop_name: str, optional
- Stop packing from certain known node.
+ Stop packing from certain known node when stop_name_idx is None.
+
+ start_name_idx: int, optional
+ When start_name_idx not None, start packing only when node name equal start_name
+ and node idx equals start_name_idx.
+
+ stop_name_idx: int, optional
+ When stop_name_idx not None, stop packing only when node name equal stop_name
+ and node index equals stop_name_idx.
+
+ count_meta:boolean, optional
+ When count_meta is False, the operator increase logic would not count the meta that have
+ the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from
+ 'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
+ logic would count the meta.
Returns
-------
The transformed expression.
"""
assert isinstance(expr, relay.Function)
- expr = get_subgraph(expr, start_name, stop_name)
+ assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
+ expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(
bfactor, cfactor,