[Target][Codegen] Use target class in all codegens (#6347)
authorJunru Shao <junrushao1994@gmail.com>
Sat, 29 Aug 2020 05:59:35 +0000 (22:59 -0700)
committerGitHub <noreply@github.com>
Sat, 29 Aug 2020 05:59:35 +0000 (22:59 -0700)
* [Target][Codegen] Make all code generator use Target class instead of target string

* Remove dep to TargetNode::str() in LLVM module

* Allow  for llvm nvptx codegen

* ...

* Address comments from Cody

* Rename UpdateTargetConfig => UpdateTargetConfigKeyValueEntry

23 files changed:
include/tvm/target/codegen.h
include/tvm/target/target.h
python/tvm/target/target.py
src/target/build_common.h
src/target/codegen.cc
src/target/llvm/codegen_amdgpu.cc
src/target/llvm/codegen_blob.cc
src/target/llvm/codegen_hexagon.cc
src/target/llvm/codegen_nvptx.cc
src/target/llvm/llvm_common.cc
src/target/llvm/llvm_common.h
src/target/llvm/llvm_module.cc
src/target/opt/build_cuda_on.cc
src/target/source/codegen_aocl.cc
src/target/source/codegen_c_host.cc
src/target/source/codegen_metal.cc
src/target/source/codegen_opencl.cc
src/target/source/codegen_vhls.cc
src/target/spirv/build_vulkan.cc
src/target/stackvm/codegen_stackvm.cc
src/target/target.cc
src/target/target_kind.cc
tests/cpp/build_module_test.cc

index e89d44d..b2cab0e 100644 (file)
@@ -45,7 +45,7 @@ using runtime::TVMRetValue;
  * \param target The target to be built.
  * \return The result runtime::Module.
  */
-runtime::Module Build(IRModule mod, const Target& target);
+runtime::Module Build(IRModule mod, Target target);
 
 /*!
  * \brief Pack imported device library to a C file.
index 258b2d8..4d5fba3 100644 (file)
@@ -52,9 +52,10 @@ class TargetNode : public Object {
   Array<String> keys;
   /*! \brief Collection of attributes */
   Map<String, ObjectRef> attrs;
-
   /*! \return the full device string to pass to codegen::Build */
   TVM_DLL const std::string& str() const;
+  /*! \return Export target to JSON-like configuration */
+  TVM_DLL Map<String, ObjectRef> Export() const;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("kind", &kind);
index 9dcc164..986caa1 100644 (file)
@@ -54,6 +54,9 @@ class Target(Object):
     def __exit__(self, ptype, value, trace):
         _ffi_api.ExitTargetScope(self)
 
+    def export(self):
+        return _ffi_api.TargetExport(self)
+
     @staticmethod
     def current(allow_none=True):
         """Returns the current target.
index ec5b522..531bd62 100644 (file)
@@ -62,6 +62,23 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
   }
   return fmap;
 }
+
+inline void UpdateTargetConfigKeyValueEntry(const String& key, const String& value,
+                                            Map<String, ObjectRef>* target_config,
+                                            bool error_if_inconsistent) {
+  if (target_config->count(key)) {
+    const ObjectRef& obj = (*target_config)[key];
+    CHECK(obj->IsInstance<StringObj>()) << "TypeError: Expect target key \"" << key
+                                        << "\" to be String, but gets type: " << obj->GetTypeKey();
+    if (error_if_inconsistent) {
+      String old_value = Downcast<String>(obj);
+      CHECK_EQ(old_value, value) << "ValueError: Target key \"" << key << "\" has been set to \""
+                                 << old_value << "\", and cannot be reset to \"" << value << "\"";
+    }
+  }
+  target_config->Set(key, value);
+}
+
 }  // namespace codegen
 }  // namespace tvm
 #endif  // TVM_TARGET_BUILD_COMMON_H_
index 0ac4993..47603e4 100644 (file)
@@ -41,7 +41,7 @@
 namespace tvm {
 namespace codegen {
 
-runtime::Module Build(IRModule mod, const Target& target) {
+runtime::Module Build(IRModule mod, Target target) {
   if (transform::PassContext::Current()
           ->GetConfig<Bool>("tir.disable_assert", Bool(false))
           .value()) {
@@ -55,8 +55,8 @@ runtime::Module Build(IRModule mod, const Target& target) {
   }
   // the build function.
   const PackedFunc* bf = runtime::Registry::Get(build_f_name);
-  CHECK(bf != nullptr) << "target.build." << target << " is not enabled";
-  return (*bf)(mod, target->str());
+  CHECK(bf != nullptr) << build_f_name << " is not enabled";
+  return (*bf)(mod, target);
 }
 
 /*! \brief Helper class to serialize module */
index 758a4f6..c19c01b 100644 (file)
@@ -191,12 +191,17 @@ class CodeGenAMDGPU : public CodeGenLLVM {
   }
 };
 
-inline int DetectROCMComputeVersion(const std::string& target) {
-  size_t pos = target.find("=gfx");
-  if (pos != std::string::npos) {
-    int value;
-    std::stringstream is(target.substr(pos + 4));
-    if (is >> value) return value;
+inline int DetectROCMComputeVersion(const Target& target) {
+  if (const Optional<String> mcpu = target->GetAttr<String>("mcpu")) {
+    std::string gfx = mcpu.value();
+    if (gfx.length() >= 3 && gfx.substr(0, 3) == "gfx") {
+      int version;
+      std::stringstream is(gfx.substr(3));
+      if (is >> version) {
+        return version;
+      }
+    }
+    LOG(FATAL) << "ValueError: Unrecognized -mcpu value: " << mcpu;
   }
   TVMContext tvm_ctx;
   tvm_ctx.device_type = kDLROCM;
@@ -228,23 +233,34 @@ inline int DetectROCMApiVersion() {
   return 305;
 }
 
-runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
+Target UpdateTarget(const Target& original_target) {
+  Map<String, ObjectRef> target_config = original_target->Export();
+  UpdateTargetConfigKeyValueEntry("mtriple", "amdgcn-amd-amdhsa-hcc", &target_config, true);
+  UpdateTargetConfigKeyValueEntry("mcpu",
+                                  "gfx" + std::to_string(DetectROCMComputeVersion(original_target)),
+                                  &target_config, false);
+  if (DetectROCMApiVersion() < 305) {
+    // before ROCm 3.5 we needed code object v2, starting
+    // with 3.5 we need v3 (this argument disables v3)
+    Array<String> mattr;
+    if (target_config.count("mattr")) {
+      mattr = Downcast<Array<String>>(target_config["mattr"]);
+    }
+    mattr.push_back("-code-object-v3");
+    target_config.Set("mattr", mattr);
+  }
+  return Target::FromConfig(target_config);
+}
+
+runtime::Module BuildAMDGPU(IRModule mod, Target original_target) {
 #if TVM_LLVM_VERSION < 90
   LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
   // Lower versions will crash when loading the bitcode, see
   // issue #4087 for a discussion
 #endif
   InitializeLLVM();
-  CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm");
-  std::ostringstream config;
-  config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target);
-  if (DetectROCMApiVersion() < 305) {
-    // before ROCm 3.5 we needed code object v2, starting
-    // with 3.5 we need v3 (this argument disables v3)
-    config << " -mattr=-code-object-v3 ";
-  }
-  config << target.substr(4, target.length() - 4);
-  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
+  Target target = UpdateTarget(original_target);
+  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
   // careful: cg will hold a naked pointer reference to ctx, so it should
   // have a shorter lifetime than the ctx.
index 6df4817..5d8a769 100644 (file)
@@ -24,6 +24,7 @@
 #include "codegen_blob.h"
 
 #include <tvm/runtime/module.h>
+#include <tvm/target/target.h>
 
 #include <cstring>
 
@@ -33,8 +34,8 @@ namespace codegen {
 std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
     const std::string& data, bool system_lib, const std::string& target_triple) {
   InitializeLLVM();
-  std::string full_target_triple = std::string("-mtriple ") + target_triple;
-  auto tm = GetLLVMTargetMachine(full_target_triple);
+  Target target = Target::Create("llvm -mtriple " + target_triple);
+  auto tm = GetLLVMTargetMachine(target);
   auto triple = tm->getTargetTriple();
   auto ctx = std::make_shared<llvm::LLVMContext>();
   std::string module_name = "devc";
@@ -43,7 +44,7 @@ std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> Cod
   // Store full target string in metadata, because flags such as -mfloat-abi must be preserved for
   // ModulePackImportsToLLVM.
   module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target",
-                        llvm::MDString::get(*ctx, full_target_triple));
+                        llvm::MDString::get(*ctx, LLVMTargetToString(target)));
   module->setDataLayout(tm->createDataLayout());
   auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
   auto* tvm_dev_mblob = new llvm::GlobalVariable(
index c77215d..c52f9b0 100644 (file)
@@ -658,11 +658,7 @@ void ProcessLLVMOptions(const std::vector<std::string>& llvm_vec) {
 
 }  // namespace
 
-runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
-  if (target_str.empty()) {
-    LOG(FATAL) << "Unknown or invalid target.";
-  }
-
+runtime::Module BuildHexagon(IRModule mod, Target target) {
   // Make sure all targets are registered. InitializeLLVM can be called
   // multiple times, after the first call all subsequent calls are no-ops.
   InitializeLLVM();
@@ -675,21 +671,12 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
     }
     return vec;
   };
-  auto starts_with = [](const std::string& s, const std::string& p) {
-    return !s.compare(0, p.size(), p);
-  };
-
-  std::vector<std::string> flags = split(target_str);
-  std::string llvm_target_str, llvm_options_str = "llvm";
-
-  for (const auto& s : flags) {
-    if (starts_with(s, "-mattr=") || starts_with(s, "-mtriple=") || starts_with(s, "-mcpu=")) {
-      llvm_target_str += " " + s;
-    } else if (starts_with(s, "-llvm-options=")) {
-      llvm_options_str += "," + s.substr(14 /*length of -llvm-options=*/);
-    }
+  std::string llvm_options_str;
+  if (const Optional<String> llvm_options = target->GetAttr<String>("llvm-options")) {
+    llvm_options_str = "llvm," + llvm_options.value();
+  } else {
+    llvm_options_str = "llvm";
   }
-
   // Postprocess the LLVM options string: replace '@' with '=', and ',' with ' '.
   for (int i = 0, e = llvm_options_str.size(); i != e; ++i) {
     switch (llvm_options_str[i]) {
@@ -716,7 +703,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
   static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true);
   (void)CallOnce;
 
-  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target_str);
+  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
   std::unique_ptr<CodeGenHexagon> cg(new CodeGenHexagon());
   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
   cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false);
@@ -802,9 +789,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
                              export_abi);
 }
 
-TVM_REGISTER_GLOBAL("target.build.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = BuildHexagon(args[0], args[1]);
-});
+TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon);
 
 }  // namespace codegen
 }  // namespace tvm
index e2690b9..fe409ba 100644 (file)
@@ -254,14 +254,19 @@ inline int DetectCUDAComputeVersion() {
   }
 }
 
-runtime::Module BuildNVPTX(IRModule mod, std::string target) {
+Target UpdateTarget(const Target& original_target, int compute_ver) {
+  Map<String, ObjectRef> target_config = original_target->Export();
+  UpdateTargetConfigKeyValueEntry("mtriple", "nvptx64-nvidia-cuda", &target_config, true);
+  UpdateTargetConfigKeyValueEntry("mcpu", "sm_" + std::to_string(compute_ver), &target_config,
+                                  false);
+  return Target::FromConfig(target_config);
+}
+
+runtime::Module BuildNVPTX(IRModule mod, Target original_target) {
   InitializeLLVM();
-  CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx");
   int compute_ver = DetectCUDAComputeVersion();
-  std::ostringstream config;
-  config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
-         << target.substr(5, target.length() - 5);
-  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
+  Target target = UpdateTarget(original_target, compute_ver);
+  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
   // careful: cg will hold a naked pointer reference to ctx, so it should
   // have a shorter lifetime than the ctx.
index 3a1036b..e8225ab 100644 (file)
@@ -25,6 +25,7 @@
 #include "llvm_common.h"
 
 #include <dmlc/logging.h>
+#include <tvm/target/target.h>
 
 #include <atomic>
 #include <memory>
@@ -58,53 +59,44 @@ void InitializeLLVM() {
   }
 }
 
-void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu,
+void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu,
                             std::string* mattr, llvm::TargetOptions* options) {
-  // setup target triple
-  size_t start = 0;
-  if (target_str.length() >= 4 && target_str.substr(0, 4) == "llvm") {
-    start = 4;
-  }
   // simple parser
   triple->resize(0);
   mcpu->resize(0);
   mattr->resize(0);
-
   bool soft_float_abi = false;
-  std::string key, value;
-  std::istringstream is(target_str.substr(start, target_str.length() - start));
-  while (is >> key) {
-    if (key == "-system-lib" || key == "-system-lib=0" || key == "-system-lib=1") {
-      continue;
-    }
-    size_t pos = key.find('=');
-    if (pos != std::string::npos) {
-      CHECK_GE(key.length(), pos + 1) << "invalid argument " << key;
-      value = key.substr(pos + 1, key.length() - 1);
-      key = key.substr(0, pos);
-    } else {
-      CHECK(is >> value) << "Unspecified value for option " << key;
+  if (const Optional<String>& v = target->GetAttr<String>("mtriple")) {
+    *triple = v.value();
+  }
+  if (const Optional<String>& v = target->GetAttr<String>("mcpu")) {
+    *mcpu = v.value();
+  }
+  if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("mattr")) {
+    std::ostringstream os;
+    bool is_first = true;
+    for (const String& s : v.value()) {
+      if (!is_first) {
+        os << ',';
+      }
+      is_first = false;
+      os << s;
     }
-    if (key == "-mtriple") {
-      *triple = value;
-    } else if (key == "-mcpu") {
-      *mcpu = value;
-    } else if (key == "-mattr") {
-      *mattr = value;
-    } else if (key == "-mfloat-abi") {
-      if (value == "hard") {
+    *mattr = os.str();
+  }
+  if (const Optional<String>& v = target->GetAttr<String>("mfloat-abi")) {
+    String value = v.value();
+    if (value == "hard") {
 #if TVM_LLVM_VERSION < 60
-        LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0";
+      LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0";
 #endif
-        soft_float_abi = false;
-      } else if (value == "soft") {
-        soft_float_abi = true;
-      } else {
-        LOG(FATAL) << "invalid -mfloat-abi option " << value;
-      }
+      soft_float_abi = false;
+    } else if (value == "soft") {
+      soft_float_abi = true;
+    } else {
+      LOG(FATAL) << "invalid -mfloat-abi option " << value;
     }
   }
-
   if (triple->length() == 0 || *triple == "default") {
     *triple = llvm::sys::getDefaultTargetTriple();
   }
@@ -125,12 +117,11 @@ void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple,
   }
 }
 
-std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const std::string& target_str,
-                                                          bool allow_null) {
+std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const Target& target, bool allow_null) {
   std::string target_triple, mcpu, mattr;
   llvm::TargetOptions opt;
 
-  ParseLLVMTargetOptions(target_str, &target_triple, &mcpu, &mattr, &opt);
+  ParseLLVMTargetOptions(target, &target_triple, &mcpu, &mattr, &opt);
 
   if (target_triple.length() == 0 || target_triple == "default") {
     target_triple = llvm::sys::getDefaultTargetTriple();
@@ -140,16 +131,42 @@ std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const std::string& tar
   }
 
   std::string err;
-  const llvm::Target* target = llvm::TargetRegistry::lookupTarget(target_triple, err);
-  if (target == nullptr) {
+  const llvm::Target* llvm_target = llvm::TargetRegistry::lookupTarget(target_triple, err);
+  if (llvm_target == nullptr) {
     CHECK(allow_null) << err << " target_triple=" << target_triple;
     return nullptr;
   }
   llvm::TargetMachine* tm =
-      target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
+      llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);
   return std::unique_ptr<llvm::TargetMachine>(tm);
 }
 
+std::string LLVMTargetToString(const Target& target) {
+  std::ostringstream os;
+  os << "llvm";
+  if (Optional<String> mtriple = target->GetAttr<String>("mtriple")) {
+    os << " -mtriple=" << mtriple.value();
+  }
+  if (Optional<String> mcpu = target->GetAttr<String>("mcpu")) {
+    os << " -mcpu=" << mcpu.value();
+  }
+  if (Optional<Array<String>> mattr = target->GetAttr<Array<String>>("mattr")) {
+    bool is_first = true;
+    os << " -mattr=";
+    for (const String& attr : mattr.value()) {
+      if (!is_first) {
+        os << ",";
+      }
+      is_first = false;
+      os << attr;
+    }
+  }
+  if (Optional<String> mfloat_abo = target->GetAttr<String>("mfloat-abi")) {
+    os << " -mfloat-abi=" << mfloat_abo.value();
+  }
+  return os.str();
+}
+
 }  // namespace codegen
 }  // namespace tvm
 #endif  // TVM_LLVM_VERSION
index 738e055..42cb9db 100644 (file)
 #include <utility>
 
 namespace tvm {
+
+// The TVM target
+class Target;
+
 namespace codegen {
 
 /*!
@@ -89,24 +93,31 @@ void InitializeLLVM();
 
 /*!
  * \brief Parse target options
- * \param target_str Target string, in format "llvm -mtriple=xxx -mcpu=xxx"
+ * \param target The TVM target
  * \param triple Target triple
  * \param mcpu cpu info
  * \param options the options
  * \param mattr The attributes
  */
-void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu,
+void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu,
                             std::string* mattr, llvm::TargetOptions* options);
 
 /*!
- * \brief Get target machine from target_str string.
- * \param target_str Target string, in format "llvm -mtriple=xxx -mcpu=xxx"
+ * \brief Get target machine from TVM target.
+ * \param target The TVM target
  * \param allow_null Whether allow null to be returned.
  * \return target machine
  */
-std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const std::string& target_str,
+std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const Target& target,
                                                           bool allow_null = false);
 
+/*!
+ * \brief Convert the TVM's LLVM target to string by extracting only relevant fields
+ * \param target The TVM target to be extracted
+ * \return The raw string format for the TVM LLVM target
+ */
+std::string LLVMTargetToString(const Target& target);
+
 }  // namespace codegen
 }  // namespace tvm
 
index de2dadf..b3d448a 100644 (file)
@@ -189,10 +189,9 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     return "";
   }
 
-  void Init(const IRModule& mod, std::string target_str) {
+  void Init(const IRModule& mod, const Target& target) {
     InitializeLLVM();
-    tm_ = GetLLVMTargetMachine(target_str);
-    auto target = Target::Create(target_str);
+    tm_ = GetLLVMTargetMachine(target);
     bool system_lib = target->GetAttr<Bool>("system-lib").value_or(Bool(false));
     bool target_c_runtime = (target->GetAttr<String>("runtime").value_or("") == kTvmRuntimeCrt);
     ctx_ = std::make_shared<llvm::LLVMContext>();
@@ -225,7 +224,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
 
     module_ = cg->Finish();
     module_->addModuleFlag(llvm::Module::Warning, "tvm_target",
-                           llvm::MDString::get(*ctx_, target_str));
+                           llvm::MDString::get(*ctx_, LLVMTargetToString(target)));
     module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
                            llvm::DEBUG_METADATA_VERSION);
 
@@ -238,7 +237,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))
         << "LLVM module verification failed with the following errors: \n"
         << verify_errors.str();
-    target_ = target_str;
+    target_ = target;
     mptr_ = module_.get();
   }
 
@@ -251,19 +250,22 @@ class LLVMModuleNode final : public runtime::ModuleNode {
       std::string msg = std::string(err.getMessage());
       LOG(FATAL) << "Fail to load module: " << msg;
     }
-    std::string target_;
-    llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target");
-    if (mtarget != nullptr) {
-      llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(mtarget);
+    std::string target_metadata;
+    llvm::Metadata* tvm_target = module_->getModuleFlag("tvm_target");
+    if (tvm_target != nullptr) {
+      llvm::MDString* pstr = llvm::dyn_cast<llvm::MDString>(tvm_target);
       CHECK(pstr != nullptr);
-      target_ = pstr->getString().str();
+      target_metadata = pstr->getString().str();
+      if (!(target_metadata.length() >= 4 && target_metadata.substr(0, 4) == "llvm")) {
+        target_metadata = "llvm " + target_metadata;
+      }
     } else {
       std::ostringstream os;
       os << "llvm -mtriple " << module_->getTargetTriple();
-      target_ = os.str();
+      target_metadata = os.str();
     }
     mptr_ = module_.get();
-    tm_ = GetLLVMTargetMachine(target_);
+    tm_ = GetLLVMTargetMachine(Target::Create(target_metadata));
   }
 
   void LoadIR(const std::string& file_name) {
@@ -284,6 +286,9 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     if (ee_) {
       return;
     }
+    if (!target_.defined()) {
+      target_ = Target::Create("llvm");
+    }
     llvm::EngineBuilder builder(std::move(module_));
     std::string triple, mcpu, mattr;
     llvm::TargetOptions opt;
@@ -299,7 +304,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     }
     builder.setTargetOptions(opt);
     auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
-    std::unique_ptr<llvm::TargetMachine> tm_sys = GetLLVMTargetMachine("llvm");
+    std::unique_ptr<llvm::TargetMachine> tm_sys = GetLLVMTargetMachine(Target::Create("llvm"));
     if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) {
       LOG(FATAL) << "Cannot run module, architecture mismatch "
                  << " module=" << tm->getTargetTriple().str()
@@ -340,7 +345,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
   }
 
   // The target configuration string
-  std::string target_;
+  Target target_;
   // JIT lock
   std::mutex mutex_;
   // execution engine
@@ -355,64 +360,62 @@ class LLVMModuleNode final : public runtime::ModuleNode {
   std::shared_ptr<llvm::LLVMContext> ctx_;
 };
 
-unsigned LookupLLVMIntrinsic(const std::string& name) {
-  return llvm::Function::lookupIntrinsicID(name);
-}
-
-TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed([](IRModule mod, std::string target) {
-  auto n = make_object<LLVMModuleNode>();
-  n->Init(mod, target);
-  return runtime::Module(n);
+TVM_REGISTER_GLOBAL("target.build.llvm")
+    .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
+      auto n = make_object<LLVMModuleNode>();
+      n->Init(mod, target);
+      return runtime::Module(n);
+    });
+
+TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
+    .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module {
+      Target target = Target::Create(target_str);
+      auto n = make_object<LLVMModuleNode>();
+      // Generate a LLVM module from an input target string
+      InitializeLLVM();
+      auto tm = GetLLVMTargetMachine(target);
+      auto ctx = std::make_shared<llvm::LLVMContext>();
+      std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
+      // Use a default data layout and target triple
+      auto triple = tm->getTargetTriple();
+      module->setTargetTriple(triple.str());
+      module->setDataLayout(tm->createDataLayout());
+      n->Init(std::move(module), ctx);
+      return runtime::Module(n);
+    });
+
+TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
+    .set_body_typed([](std::string name) -> int64_t {
+      return static_cast<int64_t>(llvm::Function::lookupIntrinsicID(name));
+    });
+
+TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
+  return TVM_LLVM_VERSION / 10;
 });
 
-TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
-  auto n = make_object<LLVMModuleNode>();
-  auto target = args[0].operator std::string();
-  auto module_name = args[1].operator std::string();
-
-  // Generate a LLVM module from an input target string
-  InitializeLLVM();
-  auto tm = GetLLVMTargetMachine(target);
-  auto ctx = std::make_shared<llvm::LLVMContext>();
-  std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
-
-  // Use a default data layout and target triple
-  auto triple = tm->getTargetTriple();
-  module->setTargetTriple(triple.str());
-  module->setDataLayout(tm->createDataLayout());
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
+    .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module {
+      auto n = make_object<LLVMModuleNode>();
+      n->LoadIR(filename);
+      return runtime::Module(n);
+    });
+
+TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled")
+    .set_body_typed([](std::string target_str) -> bool {
+      InitializeLLVM();
+      Target target = Target::Create(target_str);
+      return (GetLLVMTargetMachine(target, true) != nullptr);
+    });
+
+TVM_REGISTER_GLOBAL("codegen.codegen_blob")
+    .set_body_typed([](std::string data, bool system_lib,
+                       std::string target_triple) -> runtime::Module {
+      auto n = make_object<LLVMModuleNode>();
+      auto p = CodeGenBlob(data, system_lib, target_triple);
+      n->Init(std::move(p.first), p.second);
+      return runtime::Module(n);
+    });
 
-  n->Init(std::move(module), ctx);
-
-  *rv = runtime::Module(n);
-});
-
-TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
-});
-
-TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body([](TVMArgs args, TVMRetValue* rv) {
-  int major = TVM_LLVM_VERSION / 10;
-  *rv = major;
-});
-
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll").set_body([](TVMArgs args, TVMRetValue* rv) {
-  auto n = make_object<LLVMModuleNode>();
-  n->LoadIR(args[0]);
-  *rv = runtime::Module(n);
-});
-
-TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled").set_body([](TVMArgs args, TVMRetValue* rv) {
-  InitializeLLVM();
-  *rv = (GetLLVMTargetMachine(args[0], true) != nullptr);
-});
-
-TVM_REGISTER_GLOBAL("codegen.codegen_blob").set_body([](TVMArgs args, TVMRetValue* rv) {
-  auto n = make_object<LLVMModuleNode>();
-  auto p = CodeGenBlob(args[0].operator std::string(), args[1].operator bool(),
-                       args[2].operator std::string());
-  n->Init(std::move(p.first), p.second);
-  *rv = runtime::Module(n);
-});
 }  // namespace codegen
 }  // namespace tvm
 #endif  // TVM_LLVM_VERSION
index c9471d1..780829c 100644 (file)
@@ -121,7 +121,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
   return ptx;
 }
 
-runtime::Module BuildCUDA(IRModule mod, std::string target) {
+runtime::Module BuildCUDA(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
   CodeGenCUDA cg;
index 597fd37..e90b7d4 100644 (file)
@@ -33,7 +33,7 @@
 namespace tvm {
 namespace codegen {
 
-runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) {
+runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) {
   // Get code.
   using tvm::runtime::Registry;
   bool output_ssa = false;
@@ -61,7 +61,6 @@ runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation)
   std::string cmd = "aoc aocl.cl";
   // AOCL supports fp64.
   cmd += " -Dcl_khr_fp64";
-  Target target = Target::Create(target_str);
   Optional<String> device = target->GetAttr<String>("device");
   if (device.defined()) {
     cmd += " -board=" + device.value();
@@ -80,13 +79,15 @@ runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation)
   return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code);
 }
 
-TVM_REGISTER_GLOBAL("target.build.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = BuildAOCL(args[0], args[1], false);
-});
+TVM_REGISTER_GLOBAL("target.build.aocl")
+    .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
+      return BuildAOCL(mod, target, false);
+    });
 
-TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = BuildAOCL(args[0], args[1], true);
-});
+TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu")
+    .set_body_typed([](IRModule mod, Target target) -> runtime::Module {
+      return BuildAOCL(mod, target, true);
+    });
 
 }  // namespace codegen
 }  // namespace tvm
index f4aa728..5bd7b2e 100644 (file)
@@ -298,12 +298,11 @@ void CodeGenCHost::GenerateCrtSystemLib() {
          << "}\n";
 }
 
-runtime::Module BuildCHost(IRModule mod, const std::string& target_str) {
+runtime::Module BuildCHost(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
   bool emit_asserts = false;
   CodeGenCHost cg;
-  auto target = Target::Create(target_str);
   cg.Init(output_ssa, emit_asserts);
 
   for (auto kv : mod->functions) {
@@ -323,8 +322,6 @@ runtime::Module BuildCHost(IRModule mod, const std::string& target_str) {
   return CSourceModuleCreate(code, "c");
 }
 
-TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = BuildCHost(args[0], args[1]);
-});
+TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost);
 }  // namespace codegen
 }  // namespace tvm
index 1c4256c..fb235d2 100644 (file)
@@ -282,7 +282,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT
   }
 }
 
-runtime::Module BuildMetal(IRModule mod) {
+runtime::Module BuildMetal(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
   CodeGenMetal cg;
@@ -308,8 +308,6 @@ runtime::Module BuildMetal(IRModule mod) {
   return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source);
 }
 
-TVM_REGISTER_GLOBAL("target.build.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
-  *rv = BuildMetal(args[0]);
-});
+TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);
 }  // namespace codegen
 }  // namespace tvm
index 21e5ed6..10cc007 100644 (file)
@@ -280,7 +280,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {  // N
   }
 }
 
-runtime::Module BuildOpenCL(IRModule mod, std::string target) {
+runtime::Module BuildOpenCL(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
   CodeGenOpenCL cg;
index 3d77dda..9401f06 100644 (file)
@@ -137,7 +137,7 @@ void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) {  // NOL
   PrintBinaryExpr(op, opstr, os, this);
 }
 
-runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
+runtime::Module BuildSDAccel(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
   CodeGenVivadoHLS cg;
@@ -178,7 +178,6 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
 
   std::string xclbin;
   if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
-    Target target = Target::Create(target_str);
     String device = target->GetAttr<String>("device", "").value();
     xclbin = (*f)(kernel_info, device).operator std::string();
   } else {
index 86d1614..1eef2f8 100644 (file)
@@ -63,7 +63,7 @@ class SPIRVTools {
   spv_context ctx_;
 };
 
-runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restriction) {
+runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) {
   using tvm::runtime::Registry;
   using tvm::runtime::VulkanShader;
 
@@ -116,11 +116,11 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restric
   return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str());
 }
 
-TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, std::string target) {
+TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) {
   return BuildSPIRV(mod, target, false);
 });
 
-TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, std::string target) {
+TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) {
   return BuildSPIRV(mod, target, true);
 });
 
index 9cad92d..ac3ba78 100644 (file)
@@ -510,7 +510,7 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) {
   this->Push(op->body);
 }
 
-runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) {
+runtime::Module BuildStackVM(IRModule mod, Target target) {
   std::unordered_map<std::string, StackVM> fmap;
   std::string entry_func;
 
index 47b4054..ccc0023 100644 (file)
@@ -276,6 +276,18 @@ std::unordered_set<std::string> TargetNode::GetLibs() const {
   return result;
 }
 
+Map<String, ObjectRef> TargetNode::Export() const {
+  Map<String, ObjectRef> result = {
+      {"kind", this->kind->name},
+      {"tag", this->tag},
+      {"keys", this->keys},
+  };
+  for (const auto& kv : attrs) {
+    result.Set(kv.first, kv.second);
+  }
+  return result;
+}
+
 const std::string& TargetNode::str() const {
   if (str_repr_.empty()) {
     std::ostringstream os;
@@ -527,10 +539,14 @@ TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create);
 
 TVM_REGISTER_GLOBAL("target.TargetFromConfig").set_body_typed(Target::FromConfig);
 
+TVM_REGISTER_GLOBAL("target.TargetExport")
+    .set_body_typed([](Target target) -> Map<String, ObjectRef> { return target->Export(); });
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<TargetNode>([](const ObjectRef& node, ReprPrinter* p) {
-      auto* op = static_cast<const TargetNode*>(node.get());
-      p->stream << op->str();
+      const auto* target = node.as<TargetNode>();
+      CHECK(target);
+      p->stream << target->str();
     });
 
 namespace target {
index 40ade4d..29f1692 100644 (file)
@@ -106,6 +106,7 @@ TVM_REGISTER_TARGET_KIND("nvptx")
     .add_attr_option<Integer>("max_num_threads", Integer(1024))
     .add_attr_option<Integer>("thread_warp_size", Integer(32))
     .add_attr_option<String>("mcpu")
+    .add_attr_option<String>("mtriple")
     .set_default_keys({"cuda", "gpu"})
     .set_device_type(kDLGPU);
 
index 2462fd1..48edfcd 100644 (file)
@@ -56,9 +56,14 @@ TEST(BuildModule, Basic) {
   auto module = build(lowered, target, Target());
 
   auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali");
-  CHECK_EQ(
-      mali_target->str(),
-      "opencl -keys=mali,opencl,gpu -device=mali -max_num_threads=256 -model=Mali-T860MP4@800Mhz");
+  CHECK_EQ(mali_target->kind->name, "opencl");
+  CHECK_EQ(mali_target->keys.size(), 3);
+  CHECK_EQ(mali_target->keys[0], "mali");
+  CHECK_EQ(mali_target->keys[1], "opencl");
+  CHECK_EQ(mali_target->keys[2], "gpu");
+  CHECK_EQ(mali_target->GetAttr<String>("device").value(), "mali");
+  CHECK_EQ(mali_target->GetAttr<String>("model").value(), "Mali-T860MP4@800Mhz");
+  CHECK_EQ(mali_target->GetAttr<Integer>("max_num_threads").value(), 256);
 }
 
 TEST(BuildModule, Heterogeneous) {