From 3fb937fe019ed824de309d09281a99587df17335 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 4 Feb 2020 09:21:51 -0800 Subject: [PATCH] [DOCS] Fix vta tutorial (#4809) --- tests/scripts/task_python_docs.sh | 1 + vta/tutorials/autotvm/tune_relay_vta.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index 9859ccc..951d1a3 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -25,6 +25,7 @@ rm -rf docs/_build/html/javadoc # remove stale tutorials and always build from scratch. rm -rf docs/tutorials +rm -rf docs/vta/tutorials # C++ doc make doc diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index eab23ee..25360ce 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -165,7 +165,7 @@ def compile_network(env, target, model, start_pack, stop_pack): # ---------------------------------- # key total free pending # ---------------------------------- -# pynq 6 6 0 +# pynq 6 6 0 # rpi3b 11 11 0 # ---------------------------------- # @@ -223,7 +223,7 @@ tuning_option = { # .. note:: How to set tuning options # # In general, the default values provided here work well. -# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` +# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` # to larger values, makes the tuning run for longer. # If your device is under-powered or your conv2d operators are large, consider # setting a longer timeout. @@ -348,12 +348,13 @@ def tune_and_evaluate(tuning_opt): # Perform task extraction on Relay program print("Extract tasks...") relay_prog, params = compile_network(env, target, network, start_pack, stop_pack) - tasks = autotvm.task.extract_from_program(func=relay_prog, + mod = relay.Module.from_expr(relay_prog) + tasks = autotvm.task.extract_from_program(mod, params=params, ops=(tvm.relay.op.nn.conv2d,), target=target, target_host=env.target_host) - + # We should have extracted 10 convolution tasks assert len(tasks) == 10 print("Extracted {} conv2d tasks:".format(len(tasks))) -- 2.7.4