[REFACTOR][IR] kExternalSymbol -> kGlobalSymbol (#5211)
authorZhi <5145158+zhiics@users.noreply.github.com>
Thu, 2 Apr 2020 15:30:46 +0000 (08:30 -0700)
committerGitHub <noreply@github.com>
Thu, 2 Apr 2020 15:30:46 +0000 (08:30 -0700)
* expose runtime::String to Python

* kExternalSymbol -> kGlobalSymbol

include/tvm/runtime/container.h
python/tvm/runtime/container.py
src/relay/backend/compile_engine.cc
src/relay/backend/contrib/codegen_c/codegen_c.h
src/relay/transforms/partition_graph.cc
src/runtime/container.cc
tests/python/relay/test_external_codegen.py
tests/python/relay/test_pass_partition_graph.py

index 4164451..50b406b 100644 (file)
@@ -512,12 +512,12 @@ class String : public ObjectRef {
 #endif
   }
 
-  TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
-
- private:
   /*! \return the internal StringObj pointer */
   const StringObj* get() const { return operator->(); }
 
+  TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
+
+ private:
   /*!
    * \brief Compare two char sequence
    *
index 02a082a..dd59011 100644 (file)
@@ -109,4 +109,22 @@ def tuple_object(fields=None):
     return _Tuple(*fields)
 
 
+@tvm._ffi.register_object("runtime.String")
+class String(Object):
+    """The string object.
+
+    Parameters
+    ----------
+    string : Str
+        The string used to construct a runtime String object
+
+    Returns
+    -------
+    ret : String
+        The created object.
+    """
+    def __init__(self, string):
+        self.__init_handle_by_constructor__(_String, string)
+
+
 tvm._ffi._init_api("tvm.runtime.container")
index 410a6df..f75da07 100644 (file)
@@ -26,6 +26,7 @@
 #include <tvm/te/operation.h>
 #include <tvm/te/schedule_pass.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr.h>
@@ -622,10 +623,10 @@ class CompileEngineImpl : public CompileEngineNode {
         if (ext_mods.find(code_gen->value) == ext_mods.end()) {
           ext_mods[code_gen->value] = IRModule({}, {});
         }
-        auto symbol_name = src_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
+        auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
         CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
                                      << AsText(src_func, false);
-        auto gv = GlobalVar(symbol_name->value);
+        auto gv = GlobalVar(std::string(symbol_name));
         ext_mods[code_gen->value]->Add(gv, src_func);
         cached_ext_funcs.push_back(it.first);
       }
@@ -693,10 +694,10 @@ class CompileEngineImpl : public CompileEngineNode {
     if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
       auto cache_node = make_object<CachedFuncNode>();
       const auto name_node =
-          key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
+          key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
       CHECK(name_node.defined())
           << "External function has not been attached a name yet.";
-      cache_node->func_name = name_node->value;
+      cache_node->func_name = std::string(name_node);
       cache_node->target = tvm::target::ext_dev();
       value->cached_func = CachedFunc(cache_node);
       return value;
index 60cecef..79d4d3f 100644 (file)
@@ -27,6 +27,7 @@
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/function.h>
+#include <tvm/runtime/container.h>
 #include <sstream>
 #include <string>
 #include <utility>
@@ -69,10 +70,9 @@ class CSourceModuleCodegenBase {
    */
   std::string GetExtSymbol(const Function& func) const {
     const auto name_node =
-        func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
+        func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
     CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
-    std::string ext_symbol = name_node->value;
-    return ext_symbol;
+    return std::string(name_node);
   }
 };
 
index d8e93ed..a4e3863 100644 (file)
@@ -35,6 +35,7 @@
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
+#include <tvm/runtime/container.h>
 
 #include <unordered_map>
 #include <unordered_set>
@@ -239,8 +240,8 @@ class Partitioner : public ExprMutator {
         std::string target = call->attrs.as<CompilerAttrs>()->compiler;
         std::string name = target + "_" + std::to_string(region->GetID());
 
-        global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
-                                      tir::StringImmNode::make(name));
+        global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
+                                      runtime::String(name));
         global_region_func =
             WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
         global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
index f54ae6c..400f646 100644 (file)
@@ -76,7 +76,13 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT")
   *rv = ADT(tag, fields);
 });
 
+TVM_REGISTER_GLOBAL("runtime.container._String")
+.set_body_typed([](std::string str) {
+  return String(std::move(str));
+});
+
 TVM_REGISTER_OBJECT_TYPE(ADTObj);
+TVM_REGISTER_OBJECT_TYPE(StringObj);
 TVM_REGISTER_OBJECT_TYPE(ClosureObj);
 
 }  // namespace runtime
index bda590f..724e81d 100644 (file)
@@ -80,7 +80,8 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
 def set_external_func_attr(func, compiler, ext_symbol):
     func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
     func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
-    func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
+    func = func.with_attr("global_symbol",
+                          runtime.container.String(ext_symbol))
     return func
 
 
index 9d4d711..ab9f47e 100644 (file)
@@ -23,6 +23,7 @@ import tvm
 import tvm.relay.testing
 from tvm import relay
 from tvm import runtime
+from tvm.runtime import container
 from tvm.relay import transform
 from tvm.contrib import util
 from tvm.relay.op.annotation import compiler_begin, compiler_end
@@ -305,10 +306,8 @@ def test_extern_ccompiler_default_ops():
         func = relay.Function([x0, y0], add)
         func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
-        func = func.with_attr("Compiler",
-                                  tvm.tir.StringImm("ccompiler"))
-        func = func.with_attr("ExternalSymbol",
-                                  tvm.tir.StringImm("ccompiler_0"))
+        func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
+        func = func.with_attr("global_symbol", container.String("ccompiler_0"))
         glb_0 = relay.GlobalVar("ccompiler_0")
         mod[glb_0] = func
         add_call = relay.Call(glb_0, [x, y])
@@ -319,7 +318,7 @@ def test_extern_ccompiler_default_ops():
         concat = relay.concatenate([log, exp], axis=0)
         fused_func = relay.Function([p0], concat)
         fused_func = fused_func.with_attr("Primitive",
-                                              tvm.tir.IntImm("int32", 1))
+                                          tvm.tir.IntImm("int32", 1))
         fused_call = relay.Call(fused_func, [add_call])
         main = relay.Function([x, y], fused_call)
         mod["main"] = main
@@ -393,8 +392,7 @@ def test_extern_dnnl():
         func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
-        func = func.with_attr("ExternalSymbol",
-                                  tvm.tir.StringImm("dnnl_0"))
+        func = func.with_attr("global_symbol", container.String("dnnl_0"))
         glb_var = relay.GlobalVar("dnnl_0")
         mod = tvm.IRModule()
         mod[glb_var] = func
@@ -520,8 +518,8 @@ def test_function_lifting():
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Compiler",
                                     tvm.tir.StringImm("test_compiler"))
-        func0 = func0.with_attr("ExternalSymbol",
-                                    tvm.tir.StringImm("test_compiler_0"))
+        func0 = func0.with_attr("global_symbol",
+                                container.String("test_compiler_0"))
         gv0 = relay.GlobalVar("test_compiler_0")
         mod[gv0] = func0
 
@@ -539,8 +537,8 @@ def test_function_lifting():
         func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func1 = func1.with_attr("Compiler",
                                     tvm.tir.StringImm("test_compiler"))
-        func1 = func1.with_attr("ExternalSymbol",
-                                    tvm.tir.StringImm("test_compiler_1"))
+        func1 = func1.with_attr("global_symbol",
+                                container.String("test_compiler_1"))
         gv1 = relay.GlobalVar("test_compiler_1")
         mod[gv1] = func1
 
@@ -613,8 +611,8 @@ def test_function_lifting_inline():
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Compiler",
                                     tvm.tir.StringImm("test_compiler"))
-        func0 = func0.with_attr("ExternalSymbol",
-                                    tvm.tir.StringImm("test_compiler_0"))
+        func0 = func0.with_attr("global_symbol",
+                                container.String("test_compiler_0"))
 
         # main function
         data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
@@ -649,8 +647,7 @@ def test_constant_propagation():
         func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
-        func = func.with_attr("ExternalSymbol",
-                              tvm.tir.StringImm("ccompiler_0"))
+        func = func.with_attr("global_symbol", container.String("ccompiler_0"))
         glb_0 = relay.GlobalVar("ccompiler_0")
         mod[glb_0] = func
         add_call = relay.Call(glb_0, [y])
@@ -751,8 +748,8 @@ def test_multiple_outputs():
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Compiler",
                                 tvm.tir.StringImm("test_target"))
-        func0 = func0.with_attr("ExternalSymbol",
-                                tvm.tir.StringImm("test_target_2"))
+        func0 = func0.with_attr("global_symbol",
+                                container.String("test_target_2"))
         gv0 = relay.GlobalVar("test_target_2")
         mod[gv0] = func0
 
@@ -819,8 +816,8 @@ def test_mixed_single_multiple_outputs():
         func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func1 = func1.with_attr("Compiler",
                                 tvm.tir.StringImm("test_target"))
-        func1 = func1.with_attr("ExternalSymbol",
-                                tvm.tir.StringImm("test_target_1"))
+        func1 = func1.with_attr("global_symbol",
+                                container.String("test_target_1"))
         gv1 = relay.GlobalVar("test_target_1")
         mod[gv1] = func1
 
@@ -834,8 +831,8 @@ def test_mixed_single_multiple_outputs():
         func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
         func0 = func0.with_attr("Compiler",
                                 tvm.tir.StringImm("test_target"))
-        func0 = func0.with_attr("ExternalSymbol",
-                                tvm.tir.StringImm("test_target_0"))
+        func0 = func0.with_attr("global_symbol",
+                                container.String("test_target_0"))
         gv0 = relay.GlobalVar("test_target_0")
         mod[gv0] = func0