[VTA] Support network which have no unique operator as start/stop name for graph...
authorHua Jiang <huaj@xilinx.com>
Thu, 23 Jan 2020 22:05:07 +0000 (14:05 -0800)
committerThierry Moreau <tmoreau@octoml.ai>
Thu, 23 Jan 2020 22:05:07 +0000 (14:05 -0800)
* [VTA] Support network which have no unique operator as start/stop name
for graph pack.

[Issue]
  Current vta use 'start' and 'stop' name to define the pack start point
  and end point, but this method not work for these network which have
  no 2 unique operator as  start point and stop point.

[Solution]
  In this solution we give 2 addtional parameters start_name_indx and
  stop_name_indx to make vta pack logic work with the said network,
  for exampl for following networks which have no unique operator,

  %0 = nn.add
  %1 = nn.conv2d
  %2 = nn.batch_norm
  %3 = nn.leaky_relu
  %4 = nn.add
  %5 = nn.conv2d
  %6 = nn.batch_norm
  %7 = nn.leaky_relu
  %8 = nn.add

  with this solution we can use following parameter format to make
  vta work on it.

  relay_prog = graph_pack(
                //....
                start_name="nn.add",
                stop_name="nn.add",
                start_name_idx=0,
                stop_name_idx=4)

  to apply on new network, by printing the network we can get index information like following.

  print(mod.astext(show_meta_data=False))
  relay_prog = graph_pack(mod
                          ...
                          start_name="nn.add",
                          stop_name="nn.add",
                          start_name_idx=0,
                          stop_name_idx=4)

* address review comments and fix index count bug

issue:
when do print(mod), the output not only the Call is also have other type
like Var, need add logic to count all except meta.

solution:
add related logic

* address review comments.

* address review comments

* add more detail comments.

vta/python/vta/top/graphpack.py

index a4c0548..ba139a8 100644 (file)
@@ -110,6 +110,15 @@ def _get_shape(node):
     """
     return _to_shape(node.checked_type.shape)
 
+def _operator_idx_inc(expr, count_meta, operator_current_idx):
+    """Increase operator index
+    """
+    if isinstance(expr, relay.expr.Constant):
+        operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx
+    else:
+        operator_current_idx = operator_current_idx + 1
+    return operator_current_idx
+
 class ExprPack(ExprMutator):
     """Visitor to perform graph packing on an AST.
     """
@@ -246,7 +255,7 @@ class ExprPack(ExprMutator):
 
 class BT(Exception):
     pass
-def get_subgraph(expr, start_name, stop_name):
+def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta):
     """ We assume stop_name only appears once for simplicity.
         This constraint will be lifted in the future.
         bitpack_start and bitpack_end are both inclusive.
@@ -254,24 +263,32 @@ def get_subgraph(expr, start_name, stop_name):
     bitpack_start = op.op.get('annotation.bitpack_start')
     bitpack_end = op.op.get('annotation.bitpack_end')
     anf = run_opt_pass(expr, transform.ToANormalForm())
-    def _recursion(anf, start_found, stop_found):
+    operator_current_idx = 0
+    def _recursion(anf, start_found, stop_found, operator_current_idx):
         """ Helper to obtain the subgraph.
         """
         if isinstance(anf, relay.expr.Function):
             return relay.expr.Function(anf.params,
-                                       _recursion(anf.body, start_found, stop_found),
+                                       _recursion(anf.body, start_found, stop_found,
+                                                  operator_current_idx),
                                        anf.ret_type, anf.type_params, anf.attrs)
         elif isinstance(anf, relay.expr.Let):
             value = anf.value
             if isinstance(value, relay.expr.Call):
                 if isinstance(value.op, relay.op.Op):
                     if value.op.name == start_name and not start_found:
-                        value = relay.expr.Call(bitpack_start, [value])
-                        start_found = True
+                        if operator_current_idx == start_name_idx or start_name_idx is None:
+                            value = relay.expr.Call(bitpack_start, [value])
+                            start_found = True
                     elif value.op.name == stop_name:
-                        raise BT()
+                        if operator_current_idx == stop_name_idx or stop_name_idx is None:
+                            raise BT()
+
+            operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx)
+
             try:
-                return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found))
+                return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found,
+                                                                 operator_current_idx))
             except BT:
                 assert start_found
                 assert not stop_found
@@ -283,7 +300,7 @@ def get_subgraph(expr, start_name, stop_name):
             assert start_found
             assert stop_found
             return anf
-    annotated = _recursion(anf, False, False)
+    annotated = _recursion(anf, False, False, operator_current_idx)
     return run_opt_pass(annotated, transform.ToGraphNormalForm())
 
 def graph_pack(expr,
@@ -291,7 +308,10 @@ def graph_pack(expr,
                cfactor,
                weight_bits,
                start_name="nn.max_pool2d",
-               stop_name="nn.global_avg_pool2d"):
+               stop_name="nn.global_avg_pool2d",
+               start_name_idx=None,
+               stop_name_idx=None,
+               count_meta=False):
     """Pack the graph into batch&channel packed format.
 
     Parameters
@@ -309,10 +329,24 @@ def graph_pack(expr,
         The bit-width of the weights.
 
     start_name: str, optional
-       Start packing from certain known node.
+       Start packing from certain known node when start_name_idx is None.
 
     stop_name: str, optional
-       Stop packing from certain known node.
+       Stop packing from certain known node when stop_name_idx is None.
+
+    start_name_idx: int, optional
+        When start_name_idx not None, start packing only when node name equal start_name
+        and node idx equals start_name_idx.
+
+    stop_name_idx: int, optional
+        When stop_name_idx not None, stop packing only when node name equal stop_name
+        and node index equals stop_name_idx.
+
+    count_meta:boolean, optional
+        When count_meta is False, the operator increase logic would not count the meta that have
+        the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from
+        'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
+        logic would count the meta.
 
     Returns
     -------
@@ -320,7 +354,8 @@ def graph_pack(expr,
         The transformed expression.
     """
     assert isinstance(expr, relay.Function)
-    expr = get_subgraph(expr, start_name, stop_name)
+    assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
+    expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
     expr = run_opt_pass(expr, transform.InferType())
     packer = ExprPack(
         bfactor, cfactor,