Relay C++ Build Module (#3082)
authorBing Xu <antinucleon@gmail.com>
Wed, 8 May 2019 07:16:15 +0000 (00:16 -0700)
committerJared Roesch <roeschinc@gmail.com>
Wed, 8 May 2019 07:16:15 +0000 (03:16 -0400)
* [Relay] C++ Build module

* asdf

include/tvm/build_module.h
src/codegen/build_module.cc
src/relay/backend/build_module.cc [new file with mode: 0644]
src/relay/backend/compile_engine.cc
src/relay/backend/graph_runtime_codegen.cc
tests/cpp/relay_build_module_test.cc [new file with mode: 0644]
tests/python/relay/test_cpp_build_module.py [new file with mode: 0644]

index 3c13644..334fe16 100644 (file)
@@ -344,6 +344,19 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch,
                                  const std::string& name,
                                  const std::unordered_map<Tensor, Buffer>& binds,
                                  const BuildConfig& config);
+/*!
+* \brief Split host/device function and running necessary pass before build
+* \param funcs The functions to be built.
+* \param target The target device to build for.
+* \param target_host The target for building host code. To use the default, pass Target()
+* \param config The build configuration.
+* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
+          second is device function array
+*/
+TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
+                                                        const Target& target,
+                                                        const Target& target_host,
+                                                        const BuildConfig& config);
 
 /*!
 * \brief Build a device and host module for a specific target from an array of lowered functions.
index 92a12a0..01ebcac 100644 (file)
@@ -423,10 +423,10 @@ Array<LoweredFunc> lower(Schedule sch,
   return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
 }
 
-runtime::Module build(const Array<LoweredFunc>& funcs,
-                      const Target& target,
-                      const Target& target_host,
-                      const BuildConfig& config) {
+Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
+                                                const Target& target,
+                                                const Target& target_host,
+                                                const BuildConfig& config) {
   std::unordered_set<std::string> all_names;
   for (const auto &x : funcs) {
     CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
@@ -493,6 +493,17 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
     func = ir::CombineContextCall(func);
     fhost.Set(i, func);
   }
+  return {fhost, fdevice};
+}
+
+runtime::Module build(const Array<LoweredFunc>& funcs,
+                      const Target& target,
+                      const Target& target_host,
+                      const BuildConfig& config) {
+  auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
+  auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
+  auto& fhost = host_dev_funcs[0];
+  auto& fdevice = host_dev_funcs[1];
 
   auto mhost = codegen::Build(fhost, target_host_val->str());
 
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
new file mode 100644 (file)
index 0000000..b60a048
--- /dev/null
@@ -0,0 +1,713 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file relay/backend/build_module.cc
+ * \brief Code generation for TVM's graph runtime.
+ */
+
+#include <tvm/build_module.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <vector>
+#include <string>
+#include <memory>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Context name / index 
+ *        See: python/tvm/_ffi/runtime_ctypes.py
+ */
+struct ContextMap {
+  static const std::unordered_map<int, std::string> mask2str;
+  static const std::unordered_map<std::string, int> str2mask;
+  static std::string Mask2Str(int mask) {
+    CHECK_GT(mask2str.count(mask), 0) << "Unknown mask.";
+    return mask2str.at(mask);
+  }
+  static int Str2Mask(const std::string& str) {
+    CHECK_GT(str2mask.count(str), 0) << "Unknown context.";
+    return str2mask.at(str);
+  }
+};
+
+const std::unordered_map<int, std::string> ContextMap::mask2str = {
+  {1, "cpu"},
+  {2, "gpu"},
+  {4, "opencl"},
+  {5, "aocl"},
+  {6, "sdaccel"},
+  {7, "vulkan"},
+  {8, "metal"},
+  {9, "vpi"},
+  {10, "rocm"},
+  {11, "opengl"},
+  {12, "ext_dev"}
+};
+
+const std::unordered_map<std::string, int> ContextMap::str2mask = {
+  {"llvm", 1},
+  {"cpu", 1},
+  {"c", 1},
+  {"gpu", 2},
+  {"cuda", 2},
+  {"nvptx", 2},
+  {"cl", 4},
+  {"opencl", 4},
+  {"aocl", 5},
+  {"aocl_sw_emu", 5},
+  {"vulkan", 7},
+  {"metal", 8},
+  {"vpi", 9},
+  {"rocm", 10},
+  {"opengl", 11},
+  {"ext_dev", 12}
+};
+
+/*!
+ * \brief A data structure to map the names of specific optimizations to
+ *        numeric optimization levels
+ * 
+ */
+struct OptPassLevel {
+  static const std::unordered_map<std::string, int> _data;
+  /*!
+   * \brief Get level for an optimization pass
+   * 
+   * \param key pass name
+   * \return int level
+   */
+  int operator[](const std::string& key) const {
+    auto it = _data.find(key);
+    if (it == _data.end()) {
+      return -1;
+    }
+    return it->second;
+  }
+};
+
+const std::unordered_map<std::string, int> OptPassLevel::_data = {
+  {"SimplifyInference", 0},
+  {"OpFusion", 1},
+  {"FoldConstant", 2},
+  {"CombineParallelConv2D", 3},
+  {"FoldScaleAxis", 3},
+  {"AlterOpLayout", 3},
+  {"CanonicalizeOps", 3},
+  {"EliminateCommonSubexpr", 3}
+};
+
+/*!
+ * \brief Output of building module
+ * 
+ */
+struct BuildOutput {
+  std::string graph_json;
+  runtime::Module mod;
+  std::unordered_map<std::string, tvm::runtime::NDArray> params;
+};
+
+/*!
+ * \brief Relay building config
+ * 
+ */
+struct RelayBuildConfig {
+  int opt_level{2};
+  std::string fallback_device{"llvm"};
+  std::unordered_set<std::string> enabled_pass;
+  std::unordered_set<std::string> disabled_pass;
+  OptPassLevel OPT_PASS_LEVEL;
+  inline bool pass_enabled(const std::string& pass_name) const {
+    if (disabled_pass.count(pass_name)) {
+      return false;
+    }
+    if (enabled_pass.count(pass_name)) {
+      return true;
+    }
+    return opt_level >= OPT_PASS_LEVEL[pass_name];
+  }
+};
+
+/*!
+ * \brief GraphCodegen module wrapper 
+ * 
+ */
+struct GraphCodegen {
+ public:
+  GraphCodegen() {
+    auto pf = GetPackedFunc("relay.build_module._GraphRuntimeCodegen");
+    mod = (*pf)();
+  }
+  ~GraphCodegen() {}
+
+  void Init(runtime::Module* m,
+            Map<HalideIR::Expr, HalideIR::Expr> targets) {
+    Array<HalideIR::Expr> tgts;
+    for (auto kv : targets) {
+      tgts.push_back(kv.first);
+      tgts.push_back(kv.second);
+    }
+    CallFunc("init", m, tgts);
+  }
+
+  void Codegen(const Function& func) {
+    CallFunc("codegen", func);
+  }
+
+  std::string GetJSON() {
+    return CallFunc<std::string>("get_graph_json", nullptr);
+  }
+
+  Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
+    return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
+  }
+
+  std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
+    std::unordered_map<std::string, tvm::runtime::NDArray> ret;
+    auto names = CallFunc<Array<HalideIR::Expr> >("list_params_name", nullptr);
+    for (auto expr : names) {
+      auto key = expr.as<ir::StringImm>()->value;
+      ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
+    }
+    return ret;
+  }
+
+ protected:
+  tvm::runtime::Module mod;
+  template<typename R, typename ...Args>
+  R CallFunc(const std::string &name, Args... args) {
+    auto pf = mod.GetFunction(name, false);
+    return pf(std::forward<Args>(args)...);
+  }
+  template<typename ...Args>
+  void CallFunc(const std::string &name, Args... args) {
+    auto pf = mod.GetFunction(name, false);
+    pf(std::forward<Args>(args)...);
+    return;
+  }
+};
+
+template<typename R, typename ...Args>
+R CallPackedFunc(const std::string &name, Args... args) {
+  auto pf = GetPackedFunc(name);
+  return (*pf)(std::forward<Args>(args)...);
+}
+
+template<typename ...Args>
+Function CallPackedFunc(const std::string &name, Args... args) {
+  auto pf = GetPackedFunc(name);
+  return (*pf)(std::forward<Args>(args)...);
+}
+
+/*!
+ * \brief Relay build module
+ * 
+ */
+class RelayBuildModule : public runtime::ModuleNode {
+ public:
+  /*!
+   * \brief Get member function to front-end
+   * \param name The name of the function.
+   * \param sptr_to_self The pointer to the module node.
+   * \return The corresponding member function.
+   */
+  PackedFunc GetFunction(const std::string& name,
+                         const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+    if (name == "get_graph_json") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        *rv = this->GetGraphJSON();
+      });
+    } else if (name == "get_module") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        *rv = this->GetModule();
+      });
+    } else if (name == "build") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        CHECK_EQ(args.num_args, 3);
+        Array<HalideIR::Expr> tmp = args[1];
+        std::unordered_map<std::string, std::string> targets;
+        for (size_t i = 0; i < tmp.size(); i += 2) {
+          auto k = tmp[i].as<ir::StringImm>()->value;
+          auto v = tmp[i + 1].as<ir::StringImm>()->value;
+          targets[k] = v;
+        }
+        this->Build(args[0], targets, args[2]);
+      });
+    } else if (name == "list_params") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        *rv = this->ListParamNames();
+      });
+    } else if (name == "get_params") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        *rv = this->GetParams();
+      });
+    } else if (name == "set_opt_level") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        CHECK_EQ(args.num_args, 1);
+        int level = args[0];
+        this->SetOptLevel(level);
+      });
+    } else if (name == "set_fallback_device") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        std::string dev = args[0];
+        this->SetFallBackDev(dev);
+      });
+    } else if (name == "add_pass") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        std::string pass_name = args[0];
+        this->AddPass(pass_name);
+      });
+    } else if (name == "disable_pass") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        std::string pass_name = args[0];
+        this->DisablePass(pass_name);
+      });
+    } else if (name == "set_params") {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        Map<std::string, Constant> params = args[0];
+        for (const auto& kv : params) {
+          this->SetParam(kv.first, kv.second->data);
+        }
+      });
+    } else {
+      LOG(FATAL) << "Unknown packed function: " << name;
+      return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
+    }
+  }
+
+  /*!
+   * \brief Get the GraphJSON for runtime
+   *
+   * \return const std::string graph_json
+   */
+  const std::string& GetGraphJSON() {
+    return ret_.graph_json;
+  }
+  /*!
+   * \brief Add extra pass into build cfg
+   * 
+   * \param pass_name name of pass 
+   */
+  void AddPass(const std::string& pass_name) {
+    cfg_.enabled_pass.insert(pass_name);
+  }
+  /*!
+   * \brief Disable a specific pass in cfg
+   * 
+   * \param pass_name name of pass
+   */
+  void DisablePass(const std::string& pass_name) {
+    cfg_.disabled_pass.insert(pass_name);
+  }
+  /*!
+   * \brief Set the Fallback device 
+   * 
+   * \param device name
+   */
+  void SetFallBackDev(const std::string& dev) {
+    cfg_.fallback_device = dev;
+  }
+  /*!
+   * \brief Get the Module object
+   *
+   * \return runtime::Module
+   */
+  runtime::Module GetModule() {
+    return ret_.mod;
+  }
+
+  /*!
+   * \brief List all paramter names
+   * 
+   * \return Array<StringImm> names of params
+   */
+  Array<HalideIR::Expr> ListParamNames() {
+    Array<HalideIR::Expr> ret;
+    for (const auto& kv : params_) {
+      ret.push_back(ir::StringImm::make(kv.first));
+    }
+    return ret;
+  }
+
+  /*!
+   * \brief Get params dictionary
+   * 
+   * \return Map<std::string, Constant> params dictionary
+   */
+  Map<std::string, Constant> GetParams() {
+    Map<std::string, Constant> ret;
+    for (const auto& kv : ret_.params) {
+      ret.Set(kv.first, ConstantNode::make(kv.second));
+    }
+    return ret;
+  }
+
+  /*!
+   * \brief Set the parameters
+   *
+   * \param name name of parameter
+   * \param data_in input DLTensor
+   */
+  void SetParam(const std::string& name, runtime::NDArray data_in) {
+    params_[name] = data_in;
+  }
+
+  /*!
+   * \brief Set the optimization level
+   *
+   * \param level
+   */
+  void SetOptLevel(char level) {
+    cfg_.opt_level = level;
+  }
+
+  /*!
+   * \brief type key
+   *
+   * \return const char*
+   */
+  const char* type_key() const final {
+    return "RelayBuildModule";
+  }
+
+  /*!
+   * \brief Build relay function for graph runtime
+   *
+   * \param func Relay Function
+   * \param target Target device
+   * \param target_host Host target device
+   */
+  void Build(Function func,
+             const std::unordered_map<std::string, std::string>& targets,
+             const std::string& target_host) {
+    targets_ = targets;
+    target_host_ = target_host;
+    BuildRelay(func, cfg_, params_);
+  }
+
+ protected:
+  /*!
+   * \brief Bind params to function by using name
+   * \param func Relay function
+   * \param params params dict
+   * \return relay::Function
+   */
+  relay::Function BindParamsByName(relay::Function func,
+                              const std::unordered_map<std::string, runtime::NDArray>& params) {
+    std::unordered_map<std::string, relay::Var> name_dict;
+    std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
+    for (auto arg : func->params) {
+      const auto &name = arg->name_hint();
+      if (name_dict.count(name)) {
+        repeat_var.insert(arg);
+      } else {
+        name_dict[name] = arg;
+      }
+    }
+
+    std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
+    for (auto &kv : params) {
+      if (name_dict.count(kv.first) == 0) {
+        continue;
+      }
+      auto arg = name_dict.at(kv.first);
+      if (repeat_var.count(arg)) {
+        LOG(FATAL) << "Multiple args in the function have name " << kv.first;
+      }
+      auto e = CallPackedFunc<Expr>("relay._make.Constant", kv.second);
+      bind_dict[arg] = e;
+    }
+    return CallPackedFunc("relay._expr.Bind", func, tvm::Map<relay::Var, Expr>(bind_dict));
+  }
+
+  /*!
+   * \brief Optimize Relay function
+   *
+   * \param func Input function
+   * \param target target device
+   * \param cfg Relay build config
+   * \param params params dict
+   * \return relay::Function
+   */
+  relay::Function Optimize(relay::Function func,
+                           const std::unordered_map<std::string, std::string>& targets,
+                           const RelayBuildConfig& cfg,
+                           const std::unordered_map<std::string, runtime::NDArray>& params) {
+    if (params.size()) {
+      func = BindParamsByName(func, params);
+    }
+    if (cfg.pass_enabled("SimplifyInference")) {
+      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+      func = CallPackedFunc("relay._ir_pass.simplify_inference", func);
+    }
+    if (cfg.pass_enabled("EliminateCommonSubexpr")) {
+      auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
+        Expr expr = args[0];
+        if (expr.as<CallNode>()) {
+          auto call_node = expr.as<CallNode>();
+          auto op_node = call_node->op.as<OpNode>();
+          if (op_node->name == "cast") {
+            auto attrs = call_node->attrs.as<CastAttrs>();
+            if (attrs->dtype == HalideIR::Int(32)) {
+              *rv = true;
+            }
+          }
+        }
+        *rv =  false;
+      });
+      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+      func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip);
+    }
+    if (cfg.pass_enabled("CombineParallelConv2D")) {
+      const int min_num_branches = 3;
+      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+      func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches);
+    }
+    if (cfg.pass_enabled("FoldConstant")) {
+      func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
+    }
+    if (cfg.pass_enabled("FoldScaleAxis")) {
+      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+      func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func);
+      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+      func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func);
+      func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
+    }
+    if (cfg.pass_enabled("CanonicalizeOps")) {
+      func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+      func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func);
+    }
+    if (cfg.pass_enabled("AlterOpLayout")) {
+      if (targets.size() == 1) {
+        func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+        func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func);
+      } else {
+        LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous"
+                  << " execution yet.";
+      }
+    }
+    if (cfg.pass_enabled("FoldConstant")) {
+      func = CallPackedFunc("relay._ir_pass.FoldConstant", func);
+    }
+    return func;
+  }
+  /*!
+   * \brief Update the target and fallback device required for heterogeneous
+   * compilation. CPU is used as the fallback device if it wasn't provided.
+   * Meanwhile, a CPU device type and "llvm" pair will be added to the target
+   * dictionary in this case.
+   * 
+   * \param targets dictionary
+   * \param cfg 
+   * \return Map<HalideIR::Expr, HalideIR::Expr> 
+   */
+  Map<HalideIR::Expr, HalideIR::Expr> UpdateHeterogeneousInputs(
+    const std::unordered_map<std::string, std::string>& targets,
+    const RelayBuildConfig& cfg) {
+    Map<HalideIR::Expr, HalideIR::Expr> device_target;
+    std::unordered_map<int64_t, std::string> tmp_map;
+    auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
+
+    for (const auto& kv : targets) {
+      tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second;
+    }
+    if (tmp_map.count(fallback_idx) == 0) {
+      tmp_map[fallback_idx] = cfg.fallback_device;
+    }
+    for (const auto& kv : tmp_map) {
+      device_target.Set(
+        ir::IntImm::make(HalideIR::Int(64), kv.first),
+        ir::StringImm::make(kv.second));
+    }
+    return device_target;
+  }
+  /*!
+   * \brief Execute the device annotation passes to update the input program and
+   *        target information.
+   * 
+   * \param func 
+   * \param cfg 
+   * \param targets_map_ptr 
+   * \return Function 
+   */
+  Function RunDeviceAnnotationPass(
+      Function func,
+      const RelayBuildConfig& cfg,
+      Map<HalideIR::Expr, HalideIR::Expr>* targets_map_ptr) {
+    auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device);
+    func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+    func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx);
+    auto device_map = CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceInfo",
+                                                       func,
+                                                       nullptr);
+    if (device_map.size() == 0) {
+      auto annotation_map =
+        CallPackedFunc<Map<Expr, Integer> >("relay._ir_pass.CollectDeviceAnnotationOps",
+                                            func,
+                                            nullptr);
+      if (annotation_map.size() == 0) {
+        targets_map_ptr->Set(
+          ir::IntImm::make(HalideIR::Int(64), 0),
+          ir::StringImm::make(cfg.fallback_device));
+      } else {
+        int64_t dev_type = -1;
+        for (auto kv : annotation_map) {
+          dev_type = kv.second->value;
+          break;
+        }
+        for (auto kv : annotation_map) {
+          CHECK_EQ(kv.second->value, dev_type)
+            << "Expressions in the function are "
+            << "annotated with various device types,"
+            << "but not device copy operators "
+            << "found. Please check the "
+            << "RewriteAnnotation pass.";
+        }
+        targets_map_ptr->Set(
+          ir::IntImm::make(HalideIR::Int(64), 0),
+          ir::StringImm::make(ContextMap::Mask2Str(dev_type)));
+      }
+    }
+    return func;
+  }
+  /*!
+   * \brief Build module given lowered functions for each target
+   * 
+   * \param lowered_funcs target_str -> Array<LoweredFunc> map
+   * \param targets Targets map
+   * \param cfg Building configuration
+   */
+  void BuildModule(const Map<std::string, Array<LoweredFunc> >& lowered_funcs,
+                   const Map<HalideIR::Expr, HalideIR::Expr>& targets,
+                   const BuildConfig& cfg) {
+    auto target_host = Target::create(cfg_.fallback_device);
+    for (const auto& kv : lowered_funcs) {
+      std::unordered_set<std::string> fname_set;
+      for (auto f : kv.second) {
+        if (fname_set.count(f->name)) {
+          LOG(FATAL) << "Duplicate function name "
+                     << f->name;
+        }
+        fname_set.insert(f->name);
+      }
+    }
+    std::unordered_map<std::string, Target> target_map;
+    for (const auto& kv : lowered_funcs) {
+      target_map[kv.first] = Target::create(kv.first);
+    }
+    Array<LoweredFunc> fhost_all;
+    std::vector<runtime::Module> device_module;
+    for (const auto& kv : lowered_funcs) {
+      auto target = target_map[kv.first];
+      auto host_dev_funcs = split_dev_host_funcs(kv.second, target, target_host, cfg);
+      for (auto f : host_dev_funcs[0]) {
+        fhost_all.push_back(f);
+      }
+      if (host_dev_funcs[1].size()) {
+        auto mdev = codegen::Build(host_dev_funcs[1], target->str());
+        device_module.push_back(mdev);
+      }
+    }
+
+    auto mhost = codegen::Build(fhost_all, target_host->str());
+
+    for (auto mdev : device_module) {
+      mhost.Import(mdev);
+    }
+    ret_.mod = mhost;
+  }
+
+  /*!
+   * \brief Build relay function to runtime module
+   *
+   * \param func Relay Function
+   * \param cfg Relay build config
+   * \param params parameters
+   */
+  void BuildRelay(Function func,
+                  const RelayBuildConfig& cfg,
+                  const std::unordered_map<std::string, tvm::runtime::NDArray> &params) {
+    // convert
+    tvm_cfg_ = build_config();
+    Map<HalideIR::Expr, HalideIR::Expr> device_target;
+    if (targets_.size() > 1) {
+      device_target = UpdateHeterogeneousInputs(targets_, cfg);
+    } else {
+      for (auto &kv : targets_) {
+        device_target.Set(
+          ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)),
+          ir::StringImm::make(kv.second));
+      }
+    }
+    func = Optimize(func, targets_, cfg, params);
+    if (device_target.size() > 1) {
+      func = RunDeviceAnnotationPass(func, cfg, &device_target);
+    }
+    func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+    func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level);
+    func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr);
+
+    graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
+    graph_codegen_->Init(nullptr, device_target);
+    graph_codegen_->Codegen(func);
+
+    ret_.graph_json = graph_codegen_->GetJSON();
+    ret_.params = graph_codegen_->GetParams();
+
+    BuildModule(graph_codegen_->GetLoweredFunc(),
+                device_target,
+                tvm_cfg_);
+  }
+
+ protected:
+  std::unique_ptr<GraphCodegen> graph_codegen_;
+  /*! \brief target device */
+  std::unordered_map<std::string, std::string> targets_;
+  /*! \brief target host device */
+  std::string target_host_;
+  /*! \brief frontend optimization configure */
+  RelayBuildConfig cfg_;
+  /*! \brief parameters */
+  std::unordered_map<std::string, runtime::NDArray> params_;
+  /*! \brief building output */
+  BuildOutput ret_;
+  /*! \brief tvm building cfg */
+  BuildConfig tvm_cfg_;
+};
+
+runtime::Module RelayBuildCreate() {
+  std::shared_ptr<RelayBuildModule> exec = std::make_shared<RelayBuildModule>();
+  return runtime::Module(exec);
+}
+
+TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = RelayBuildCreate();
+});
+
+}  // namespace backend
+}  // namespace relay
+}  // namespace tvm
index 4b5842c..a824c45 100644 (file)
@@ -371,7 +371,9 @@ class CompileEngineImpl : public CompileEngineNode {
       cache_node->funcs = (*f)(
           spair.first, all_args, cache_node->func_name, key->source_func);
     } else {
-      LOG(FATAL) << "relay.backend.lower is not registred";
+      tvm::BuildConfig bcfg = tvm::build_config();
+      std::unordered_map<Tensor, Buffer> binds;
+      cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
     }
     value->cached_func = CachedFunc(cache_node);
     return value;
index 7f16891..415e0ec 100644 (file)
@@ -416,7 +416,12 @@ class GraphRuntimeCodegen
     } else {
       // heterogeneous execution.
       const auto call_dev_key = std::to_string(call_dev_type);
-      const auto call_dev_name = runtime::DeviceName(call_dev_type);
+      std::string call_dev_name;
+      if (call_dev_type == 0) {
+        call_dev_name = "llvm";
+      } else {
+        call_dev_name = runtime::DeviceName(call_dev_type);
+      }
       if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
         LOG(FATAL) << "No target is provided for device "
                    << call_dev_name;
diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc
new file mode 100644 (file)
index 0000000..38481bf
--- /dev/null
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/tvm.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <tvm/relay/pass.h>
+#include <topi/generic/injective.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/packed_func_ext.h>
+
+
+TVM_REGISTER_GLOBAL("test.sch")
+.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
+  *rv = topi::generic::schedule_injective(args[0], args[1]);
+  });
+
+TEST(Relay, BuildModule) {
+  using namespace tvm;
+  auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
+  auto a = relay::VarNode::make("a", tensor_type);
+  auto b = relay::VarNode::make("b", tensor_type);
+  auto add_op = relay::Op::Get("add");
+  auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
+  auto c = relay::VarNode::make("c", tensor_type);
+  auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
+  auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {});
+  auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+  auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+  auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
+
+  auto pA = (float*)A.ToDLPack()->dl_tensor.data;
+  auto pB = (float*)B.ToDLPack()->dl_tensor.data;
+  auto pC = (float*)C.ToDLPack()->dl_tensor.data;
+
+  for (int i = 0; i < 6; ++i) {
+    pA[i] = i;
+    pB[i] = i + 1;
+    pC[i] = i + 2;
+  }
+  // get schedule
+  auto reg = tvm::runtime::Registry::Get("relay.op._Register");
+  auto s_i = tvm::runtime::Registry::Get("test.sch");
+  if (!reg) {
+    LOG(FATAL) << "no _Register";
+  }
+  if (!s_i) {
+    LOG(FATAL) << "no _Register";
+  }
+  (*reg)("add", "FTVMSchedule", *s_i, 10);
+  // build
+  auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
+  tvm::runtime::Module build_mod = (*pfb)();
+  auto build_f = build_mod.GetFunction("build", false);
+  auto json_f = build_mod.GetFunction("get_graph_json", false);
+  auto mod_f = build_mod.GetFunction("get_module", false);
+  Array<HalideIR::Expr> target_pair;
+  target_pair.push_back(ir::StringImm::make("cpu"));
+  target_pair.push_back(ir::StringImm::make("llvm"));
+  build_f(func, target_pair, "llvm");
+  std::string json = json_f();
+  tvm::runtime::Module mod = mod_f();
+  // run
+  auto ctx = A->ctx;
+  auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
+  tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id);
+  auto set_input_f = run_mod.GetFunction("set_input", false);
+  auto run_f = run_mod.GetFunction("run", false);
+  auto get_output_f = run_mod.GetFunction("get_output", false);
+  set_input_f("a", A);
+  set_input_f("b", B);
+  set_input_f("c", C);
+  run_f();
+  tvm::runtime::NDArray Y = get_output_f(0);
+  auto pY = (float*)Y.ToDLPack()->dl_tensor.data;
+  for (int i = 0; i < 6; ++i) {
+    CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
+  }
+}
+
+int main(int argc, char ** argv) {
+  testing::InitGoogleTest(&argc, argv);
+  testing::FLAGS_gtest_death_test_style = "threadsafe";
+  return RUN_ALL_TESTS();
+}
diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py
new file mode 100644 (file)
index 0000000..c69d877
--- /dev/null
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+
+import tvm
+from tvm import relay
+
+from tvm._ffi.function import _init_api
+_init_api("tvm.relay.build_module")
+
+class BuildModule(object):
+    def __init__(self):
+        self.mod = relay.build_module._BuildModule()
+        self._get_graph_json = self.mod["get_graph_json"]
+        self._get_module = self.mod["get_module"]
+        self._build = self.mod["build"]
+        self._set_opt_level = self.mod["set_opt_level"]
+        self._set_params_func = self.mod["set_params"]
+        self._get_params_func = self.mod["get_params"]
+
+  
+    def build(self, func, target, target_host, params):
+        tgts = []
+        for kv in target.items():
+            tgts.append(kv[0])
+            tgts.append(kv[1])
+        self._set_params(params)
+        self._build(func, tgts, target_host)
+
+    def get_json(self):
+        return self._get_graph_json()
+
+    def get_module(self):
+        return self._get_module()
+
+    def set_opt_level(self, level):
+        self._set_opt_level(level)
+
+    def _set_params(self, params):
+        inputs = {}
+        for name, param in params.items():
+            inputs[name] = relay.Constant(param)
+        self._set_params_func(inputs)
+
+    def get_params(self):
+        params = self._get_params_func()
+        ret = {}
+        for key, value in params.items():
+            ret[key] = value.data
+        return ret
+
+
+def test_build():
+    m_bld = BuildModule()
+    tgt_name = "llvm"
+    tgt = "llvm"
+    ctx = tvm.cpu()
+    # func
+    a = relay.var("a", dtype="float32", shape=(16, 8))
+    b = relay.var("b", dtype="float32", shape=(8, 8))
+    c = relay.var("c", dtype="float32", shape=(16, 8))
+    x = relay.nn.dense(a, b)
+    y = relay.nn.relu(x)
+    z = y + c
+    func = relay.Function([a, b, c], z)
+    A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
+    B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx)
+    C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
+    params = {
+        "b" : B,
+        "c" : C
+    }
+    # build
+    targets = {
+        tgt: tgt
+    }
+    m_bld.set_opt_level(3)
+    m_bld.build(func, targets, "llvm -mcpu=sse3", params=params)
+    g_json = m_bld.get_json()
+    mmod = m_bld.get_module()
+    params = m_bld.get_params()
+   
+    # test
+    rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
+    rt.set_input("a", A)
+    rt.load_params(relay.save_param_dict(params))
+    rt.run()
+    out = rt.get_output(0)
+   
+    np.testing.assert_allclose(out.asnumpy(),
+        np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)
+