[DOCS] Fix vta tutorial (#4809)
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 4 Feb 2020 17:21:51 +0000 (09:21 -0800)
committerGitHub <noreply@github.com>
Tue, 4 Feb 2020 17:21:51 +0000 (09:21 -0800)
tests/scripts/task_python_docs.sh
vta/tutorials/autotvm/tune_relay_vta.py

index 9859ccc..951d1a3 100755 (executable)
@@ -25,6 +25,7 @@ rm -rf docs/_build/html/javadoc
 
 # remove stale tutorials and always build from scratch.
 rm -rf docs/tutorials
+rm -rf docs/vta/tutorials
 
 # C++ doc
 make doc
index eab23ee..25360ce 100644 (file)
@@ -165,7 +165,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
 #    ----------------------------------
 #    key          total  free  pending
 #    ----------------------------------
-#    pynq         6      6     0 
+#    pynq         6      6     0
 #    rpi3b        11     11    0
 #    ----------------------------------
 #
@@ -223,7 +223,7 @@ tuning_option = {
 # .. note:: How to set tuning options
 #
 #   In general, the default values provided here work well.
-#   If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` 
+#   If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping`
 #   to larger values, makes the tuning run for longer.
 #   If your device is under-powered or your conv2d operators are large, consider
 #   setting a longer timeout.
@@ -348,12 +348,13 @@ def tune_and_evaluate(tuning_opt):
     # Perform task extraction on Relay program
     print("Extract tasks...")
     relay_prog, params = compile_network(env, target, network, start_pack, stop_pack)
-    tasks = autotvm.task.extract_from_program(func=relay_prog,
+    mod = relay.Module.from_expr(relay_prog)
+    tasks = autotvm.task.extract_from_program(mod,
                                               params=params,
                                               ops=(tvm.relay.op.nn.conv2d,),
                                               target=target,
                                               target_host=env.target_host)
-    
+
     # We should have extracted 10 convolution tasks
     assert len(tasks) == 10
     print("Extracted {} conv2d tasks:".format(len(tasks)))