[FIX] Fixes #6096 (#6131)
authorTristan Konolige <tristan.konolige@gmail.com>
Fri, 31 Jul 2020 19:50:16 +0000 (12:50 -0700)
committerGitHub <noreply@github.com>
Fri, 31 Jul 2020 19:50:16 +0000 (12:50 -0700)
Clear the compile cache between module builds so that schedule changes
will have an effect. Also, clear the warning cache so that schedule
changes properly list untuned ops.

python/tvm/autotvm/task/relay_integration.py
src/relay/backend/build_module.cc
src/relay/backend/compile_engine.cc
src/relay/backend/compile_engine.h

index 9a43f2f..67ebda4 100644 (file)
@@ -24,6 +24,7 @@ import threading
 import logging
 
 import tvm
+from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
 from .task import create
 from .topi_integration import TaskExtractEnv
 
@@ -140,6 +141,10 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
             build_thread.start()
             build_thread.join()
             relay.backend.compile_engine.get().clear()
+            # Clear the warning message cache in FallbackContext
+            if isinstance(DispatchContext.current, FallbackContext):
+                DispatchContext.current.memory = {}
+                DispatchContext.warning_messages = set()
 
         logger.disabled = old_state
 
index 1392798..bfcc2a6 100644 (file)
@@ -31,6 +31,7 @@
 #include <memory>
 
 #include "../../target/source/codegen_source_base.h"
+#include "compile_engine.h"
 #include "utils.h"
 
 namespace tvm {
@@ -224,6 +225,8 @@ class RelayBuildModule : public runtime::ModuleNode {
     targets_ = targets;
     target_host_ = target_host;
     BuildRelay(mod, params_);
+    // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096.
+    CompileEngine::Global()->Clear();
   }
 
  protected:
index 2aae854..3c4faf7 100644 (file)
@@ -770,7 +770,7 @@ class CompileEngineImpl : public CompileEngineNode {
 };
 
 /*! \brief The global compile engine */
-const CompileEngine& CompileEngine::Global() {
+CompileEngine& CompileEngine::Global() {
   // intentionally allocate raw pointer to avoid
   // free during destructuion.
   static CompileEngine* inst = new CompileEngine(make_object<CompileEngineImpl>());
index a5f3f63..e392c79 100644 (file)
@@ -238,7 +238,7 @@ class CompileEngine : public ObjectRef {
   CompileEngineNode* operator->() { return static_cast<CompileEngineNode*>(get_mutable()); }
   using ContainerType = CompileEngineNode;
   /*! \brief The global compile engine. */
-  TVM_DLL static const CompileEngine& Global();
+  TVM_DLL static CompileEngine& Global();
 };
 
 /*!