[RELAY] Hotfix build_module creation (#3198)
authorTianqi Chen <tqchen@users.noreply.github.com>
Thu, 16 May 2019 20:22:31 +0000 (13:22 -0700)
committerGitHub <noreply@github.com>
Thu, 16 May 2019 20:22:31 +0000 (13:22 -0700)
src/relay/backend/build_module.cc

index 63ee2d5..8a0c32f 100644 (file)
  */
 
 /*!
- *  Copyright (c) 2019 by Contributors
  * \file relay/backend/build_module.cc
  * \brief Code generation for TVM's graph runtime.
  */
-
 #include <tvm/build_module.h>
+#include <tvm/runtime/device_api.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/attrs/nn.h>
@@ -41,31 +40,6 @@ namespace backend {
 using TargetsMap = Map<tvm::Integer, tvm::Target>;
 
 /*!
- * \brief Context index to Target
- */
-struct ContextTargetMap {
-  static const std::unordered_map<int, tvm::Target> mask2str;
-  static tvm::Target Mask2Str(int mask) {
-    CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
-    return mask2str.at(mask);
-  }
-};
-
-const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = {
-  {1, tvm::Target::create("llvm")},
-  {2, tvm::Target::create("cuda")},
-  {4, tvm::Target::create("opencl")},
-  {5, tvm::Target::create("aocl")},
-  {6, tvm::Target::create("sdaccel")},
-  {7, tvm::Target::create("vulkan")},
-  {8, tvm::Target::create("metal")},
-  {9, tvm::Target::create("vpi")},
-  {10, tvm::Target::create("rocm")},
-  {11, tvm::Target::create("opengl")},
-  {12, tvm::Target::create("ext_dev")}
-};
-
-/*!
  * \brief A data structure to map the names of specific optimizations to
  *        numeric optimization levels
  *
@@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode {
    *
    * \return Array<StringImm> names of params
    */
-  Array<HalideIR::Expr> ListParamNames() {
-    Array<HalideIR::Expr> ret;
+  Array<tvm::Expr> ListParamNames() {
+    Array<tvm::Expr> ret;
     for (const auto& kv : params_) {
       ret.push_back(ir::StringImm::make(kv.first));
     }
@@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode {
     if (cfg.pass_enabled("AlterOpLayout")) {
       if (targets.size() == 1) {
         func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
-        auto enter_pf = GetPackedFunc("_EnterTargetScope");
-        auto exit_pf = GetPackedFunc("_ExitTargetScope");
         for (const auto& kv : targets) {
-          (*enter_pf)(kv.second);
+          TargetContext tctx(kv.second);
           func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
-          (*exit_pf)();
         }
       } else {
         LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
@@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode {
     }
     return func;
   }
+
+  /*!
+   * \brief Create a default type.
+   * \param device_type The device type index.
+   * \return the default target for the device.
+   */
+  Target CreateDefaultTarget(int device_type) {
+    std::string name = runtime::DeviceName(device_type);
+    if (name == "cpu") return Target::create("llvm");
+    if (name == "gpu") return Target::create("cuda");
+    return Target::create(name);
+  }
   /*!
    * \brief Update the target and fallback device required for heterogeneous
    * compilation. CPU is used as the fallback device if it wasn't provided.
@@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     if (tmp_map.count(cfg.fallback_device) == 0) {
       device_target.Set(
           cfg.fallback_device,
-          ContextTargetMap::Mask2Str(cfg.fallback_device));
+          CreateDefaultTarget(cfg.fallback_device));
     }
     return device_target;
   }
@@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode {
    * \param targets_map_ptr
    * \return Function
    */
-  Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg,
+  Function RunDeviceAnnotationPass(Function func,
+                                   const RelayBuildConfig& cfg,
                                    TargetsMap* targets_map_ptr) {
     func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
     func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func,
@@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode {
           "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr);
       if (annotation_map.size() == 0) {
         targets_map_ptr->Set(
-            0, ContextTargetMap::Mask2Str(cfg.fallback_device));
+            0, CreateDefaultTarget(cfg.fallback_device));
       } else {
         int64_t dev_type = -1;
         for (auto kv : annotation_map) {
@@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode {
             << "found. Please check the "
             << "RewriteAnnotation pass.";
         }
-        targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type));
+        targets_map_ptr->Set(0, CreateDefaultTarget(dev_type));
       }
     }
     return func;
@@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() {
   return runtime::Module(exec);
 }
 
-TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
+TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
   *rv = RelayBuildCreate();
 });