[IR][TRANSFORM] Enable CopyOnWrite for passes. (#5309)
authorTianqi Chen <tqchen@users.noreply.github.com>
Sun, 12 Apr 2020 00:42:42 +0000 (17:42 -0700)
committerGitHub <noreply@github.com>
Sun, 12 Apr 2020 00:42:42 +0000 (17:42 -0700)
This PR enables the copy on write optimizations passes:
- Enable COW for IRModule both TIR and relay passes.
- Enabled COW for PrimFunc in TIR passes.

Need more thoughts into whether/how to enable COW
for relay::Function, due to some function passes depend
on the presence of IRModule for context information,
and the std::move of the related function to nullptr
might affect the related behavior.

23 files changed:
include/tvm/ir/expr.h
include/tvm/ir/transform.h
include/tvm/runtime/data_type.h
include/tvm/runtime/object.h
include/tvm/runtime/packed_func.h
include/tvm/tir/transform.h
python/tvm/error.py
python/tvm/tir/transform/transform.py
src/ir/expr.cc
src/ir/module.cc
src/ir/transform.cc
src/node/container.cc
src/relay/ir/transform.cc
src/support/str_escape.h [new file with mode: 0644]
src/tir/ir/expr.cc
src/tir/ir/transform.cc
src/tir/transforms/narrow_datatype.cc
src/tir/transforms/split_host_device.cc
tests/cpp/packed_func_test.cc
tests/python/unittest/test_tir_transform_narrow_datatype.py
tests/python/unittest/test_tir_transform_prim_func_pass.py
topi/include/topi/util.h
topi/src/broadcast.cc

index 13a699a..4e0a301 100644 (file)
@@ -124,7 +124,7 @@ class PrimExpr : public BaseExpr {
  private:
   // Internal function for conversion.
   friend struct runtime::PackedFuncValueConverter<PrimExpr>;
-  TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
+  TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
 };
 
 /*!
@@ -464,9 +464,8 @@ struct PackedFuncValueConverter<PrimExpr> {
     if (val.type_code() == kDLFloat) {
       return PrimExpr(static_cast<float>(val.operator double()));
     }
-    TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
-    Object* ptr = val.ptr<Object>();
-    return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));
+
+    return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
   }
 };
 }  // namespace runtime
index 3a9913f..8361902 100644 (file)
@@ -62,6 +62,7 @@
 #include <tvm/ir/error.h>
 #include <tvm/ir/module.h>
 #include <string>
+#include <utility>
 
 namespace tvm {
 namespace transform {
@@ -251,8 +252,8 @@ class PassNode : public Object {
    *
    * \return The transformed module.
    */
-  IRModule operator()(const IRModule& mod) const {
-    return this->operator()(mod, PassContext::Current());
+  IRModule operator()(IRModule mod) const {
+    return this->operator()(std::move(mod), PassContext::Current());
   }
 
   /*!
@@ -263,7 +264,7 @@ class PassNode : public Object {
    *
    * \return The transformed module.
    */
-  virtual IRModule operator()(const IRModule& mod,
+  virtual IRModule operator()(IRModule mod,
                               const PassContext& pass_ctx) const = 0;
 
   void VisitAttrs(AttrVisitor* v) {}
@@ -277,14 +278,22 @@ class Pass : public ObjectRef {
   /*!
    * \brief Transform mod using the default PassContext in the current scope.
    *
+   * \code
+   *
+   * // If you do no longer need the input module
+   * // it is recommended to use std::move to move your input module.
+   * mod = pass(std::move(mod));
+   *
+   * \endcode
+   *
    * \param mod The module that an optimization pass runs on.
    *
    * \return The transformed module.
    */
-  IRModule operator()(const IRModule& mod) const {
+  IRModule operator()(IRModule mod) const {
     const PassNode* node = operator->();
     CHECK(node != nullptr);
-    return node->operator()(mod);
+    return node->operator()(std::move(mod));
   }
   /*!
    * \brief Transform mod using a functor under a given pass context.
@@ -294,11 +303,11 @@ class Pass : public ObjectRef {
    *
    * \return The transformed module.
    */
-  IRModule operator()(const IRModule& mod,
+  IRModule operator()(IRModule mod,
                       const PassContext& pass_ctx) const {
     const PassNode* node = operator->();
     CHECK(node != nullptr);
-    return node->operator()(mod, pass_ctx);
+    return node->operator()(std::move(mod), pass_ctx);
   }
 
   TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
index 9e92db9..44385d6 100644 (file)
@@ -325,6 +325,7 @@ inline const char* TypeCode2Str(int type_code) {
     case kTVMModuleHandle: return "ModuleHandle";
     case kTVMNDArrayHandle: return "NDArrayContainer";
     case kTVMObjectHandle: return "Object";
+    case kTVMObjectRValueRefArg: return "ObjectRValueRefArg";
     default: LOG(FATAL) << "unknown type_code="
                         << static_cast<int>(type_code); return "";
   }
index 8005cf6..acbb939 100644 (file)
@@ -885,7 +885,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
 
 template <typename SubRef, typename BaseRef>
 inline SubRef Downcast(BaseRef ref) {
-  CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
+  CHECK(!ref.defined() || ref->template IsInstance<typename SubRef::ContainerType>())
       << "Downcast from " << ref->GetTypeKey() << " to "
       << SubRef::ContainerType::_type_key << " failed.";
   return SubRef(std::move(ref.data_));
index 2dcb4ff..c5f0df5 100644 (file)
@@ -1357,7 +1357,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
                 ptr->IsInstance<Module::ContainerType>())) {
       values_[i].v_handle = ptr;
       type_codes_[i] = kTVMModuleHandle;
-    } else if (std::is_rvalue_reference<T>::value) {
+    } else if (std::is_rvalue_reference<decltype(value)>::value) {
       values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
       type_codes_[i] = kTVMObjectRValueRefArg;
     } else {
index 5ad40a3..0c5b39b 100644 (file)
@@ -197,11 +197,12 @@ TVM_DLL Pass CombineContextCall();
 /*!
  * \brief Narrow down PrimExpr datatype in stmt to target_bits.
  *
- * \note Run this pass after StorageFlatten.
+ * \param target_bits The target bits
  *
+ * \note Run this pass after storage flatten.
  * \return The pass.
  */
-TVM_DLL Pass NarrowDataType();
+TVM_DLL Pass NarrowDataType(int target_bits);
 
 }  // namespace transform
 }  // namespace tir
index 02bc90b..4c3e606 100644 (file)
@@ -54,6 +54,7 @@ class InternalError(TVMError):
 register_error("ValueError", ValueError)
 register_error("TypeError", TypeError)
 register_error("AttributeError", AttributeError)
+register_error("KeyError", KeyError)
 
 
 @register_error
index 64c31a5..91321fb 100644 (file)
@@ -38,7 +38,7 @@ def Apply(ftransform):
     # pylint: disable=unused-argument
     def _transform(func, mod, ctx):
         return ftransform(func)
-    return _fpass.prim_func_pass(_transform, opt_level=0)
+    return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply")
 
 
 def Filter(fcond):
@@ -57,7 +57,7 @@ def Filter(fcond):
     # pylint: disable=unused-argument
     def _transform(func, mod, ctx):
         return func if fcond(func) else None
-    return _fpass.prim_func_pass(_transform, opt_level=0)
+    return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
 
 
 def LowerCustomDatatypes():
@@ -221,9 +221,14 @@ def CombineContextCall():
     return _ffi_api.CombineContextCall()
 
 
-def NarrowDataType():
+def NarrowDataType(target_bits):
     """Narrow down PrimExpr datatype in stmt to target_bits.
 
+    Parameters
+    ----------
+    target_bits : int
+        The target bit configuration.
+
     Returns
     -------
     fpass : tvm.ir.transform.Pass
@@ -233,4 +238,4 @@ def NarrowDataType():
     ----
     Run this pass after StorageFlatten.
     """
-    return _ffi_api.NarrowDataType()
+    return _ffi_api.NarrowDataType(target_bits)
index 1f0337e..e08d832 100644 (file)
@@ -43,21 +43,21 @@ PrimExpr::PrimExpr(float value)
 PrimExpr::PrimExpr(runtime::String value)
     : PrimExpr(tir::StringImmNode::make(value)) {}
 
-PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
+PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
   using runtime::ObjectTypeChecker;
-  if (ptr->IsInstance<tir::IterVarNode>()) {
-    return tir::IterVar(ptr)->var;
+  if (auto* ptr = ref.as<tir::IterVarNode>()) {
+    return GetRef<tir::IterVar>(ptr)->var;
   }
-  if (ptr->IsInstance<te::TensorNode>()) {
-    return te::Tensor(ptr)();
+  if (auto* ptr = ref.as<te::TensorNode>()) {
+    return GetRef<te::Tensor>(ptr)();
   }
-  if (ptr->IsInstance<runtime::StringObj>()) {
-    return tir::StringImmNode::make(runtime::String(ptr));
+  if (auto* ptr = ref.as<runtime::StringObj>()) {
+    return tir::StringImmNode::make(GetRef<runtime::String>(ptr));
   }
-  CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
+  CHECK(ObjectTypeChecker<PrimExpr>::Check(ref.get()))
       << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
-      << " but get " << ptr->GetTypeKey();
-  return PrimExpr(ptr);
+      << " but get " << ref->GetTypeKey();
+  return Downcast<PrimExpr>(ref);
 }
 
 
index ea74f4c..bcf56aa 100644 (file)
@@ -121,8 +121,20 @@ bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
 
 GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
   auto it = global_var_map_.find(name);
-  CHECK(it != global_var_map_.end())
-    << "Cannot find global var " << name << " in the Module";
+  if (it == global_var_map_.end()) {
+    std::ostringstream msg;
+    msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
+        << "candidates are: [";
+    int counter = 0;
+    for (auto kv : global_var_map_) {
+      if (counter++ != 0) {
+        msg << ", ";
+      }
+      msg << "\"" << kv.first << "\"";
+    }
+    msg << "]";
+    LOG(FATAL) << msg.str();
+  }
   return (*it).second;
 }
 
index 49c0ef4..ef524c3 100644 (file)
@@ -126,7 +126,7 @@ class ModulePassNode : public PassNode {
    *
    * \return Return the updated module.
    */
-  IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
+  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
 
   /*!
    * \brief Get the pass information/meta data.
@@ -205,7 +205,7 @@ class SequentialNode : public PassNode {
    *
    * \return Return the updated module.
    */
-  IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
+  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
 
   static constexpr const char* _type_key = "transform.Sequential";
   TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
@@ -231,19 +231,20 @@ ModulePass::ModulePass(
 }
 
 // Module -> Module optimizations.
-IRModule ModulePassNode::operator()(const IRModule& mod,
+IRModule ModulePassNode::operator()(IRModule mod,
                                     const PassContext& pass_ctx) const {
   const PassInfo& pass_info = Info();
   DLOG(INFO) << "Executing module pass : "
              << pass_info->name
              << " with opt level: "
              << pass_info->opt_level;
+
   CHECK(mod.defined());
   pass_ctx.Trace(mod, pass_info, true);
-  IRModule updated_mod = pass_func(mod, pass_ctx);
-  CHECK(updated_mod.defined());
-  pass_ctx.Trace(updated_mod, pass_info, false);
-  return updated_mod;
+  mod = pass_func(std::move(mod), pass_ctx);
+  CHECK(mod.defined());
+  pass_ctx.Trace(mod, pass_info, false);
+  return mod;
 }
 
 Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
@@ -314,18 +315,17 @@ Pass GetPass(const std::string& pass_name) {
 // TODO(zhiics): we currenlty only sequentially execute each pass in
 // a Sequential without the consideration of their orders. The phase
 // ordering problem needs to be handled in the future.
-IRModule SequentialNode::operator()(const IRModule& module,
+IRModule SequentialNode::operator()(IRModule mod,
                                     const PassContext& pass_ctx) const {
-  IRModule mod = module;
   for (const Pass& pass : passes) {
     CHECK(pass.defined()) << "Found undefined pass for optimization.";
     const PassInfo& pass_info = pass->Info();
     if (!PassEnabled(pass_info))  continue;
     // resolve dependencies
     for (const auto& it : pass_info->required) {
-      mod = GetPass(it)(mod, pass_ctx);
+      mod = GetPass(it)(std::move(mod), pass_ctx);
     }
-    mod = pass(mod, pass_ctx);
+    mod = pass(std::move(mod), pass_ctx);
   }
   return mod;
 }
@@ -375,11 +375,8 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass")
 });
 
 TVM_REGISTER_GLOBAL("transform.RunPass")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
-  Pass pass = args[0];
-  IRModule mod = args[1];
-  ObjectRef ref = args[1];
-  *ret = pass(mod);
+.set_body_typed([](Pass pass, IRModule mod) {
+  return pass(std::move(mod));
 });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
index 0949da0..52e4bf1 100644 (file)
@@ -24,7 +24,7 @@
 #include <tvm/runtime/container.h>
 #include <tvm/node/container.h>
 #include <tvm/tir/expr.h>
-#include <cstring>
+#include "../support/str_escape.h"
 
 namespace tvm {
 
@@ -63,6 +63,13 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
       static_cast<const runtime::StringObj*>(n)).operator std::string();
 });
 
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
+  auto* op = static_cast<const runtime::StringObj*>(node.get());
+  p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
+});
+
+
 struct ADTObjTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
 
index fa709eb..a06eb5a 100644 (file)
@@ -68,7 +68,7 @@ class FunctionPassNode : public PassNode {
    *
    * \return Return the updated module.
    */
-  IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
+  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
 
   /*!
    * \brief Get the pass information/meta data.
@@ -113,7 +113,7 @@ FunctionPass::FunctionPass(
 }
 
 // Perform Module -> Module optimizations at the Function level.
-IRModule FunctionPassNode::operator()(const IRModule& mod,
+IRModule FunctionPassNode::operator()(IRModule mod,
                                       const PassContext& pass_ctx) const {
   const PassInfo& pass_info = Info();
   CHECK(mod.defined());
@@ -122,6 +122,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
              << " with opt level: "
              << pass_info->opt_level;
   pass_ctx.Trace(mod, pass_info, true);
+
   // Execute the pass function and return a new module.
   IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
   std::vector<std::pair<GlobalVar, Function> > updates;
diff --git a/src/support/str_escape.h b/src/support/str_escape.h
new file mode 100644 (file)
index 0000000..fd25c01
--- /dev/null
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file support/str_escape.h
+ * \brief Print escape sequence of a string.
+ */
+#ifndef TVM_SUPPORT_STR_ESCAPE_H_
+#define TVM_SUPPORT_STR_ESCAPE_H_
+
+#include <string>
+#include <sstream>
+
+namespace tvm {
+namespace support {
+
+/*!
+ * \brief Create a stream with escape.
+ * \param data The data
+ * \param size The size of the string.
+ * \return the Result string.
+ */
+inline std::string StrEscape(const char* data, size_t size) {
+  std::ostringstream stream;
+  for (size_t i = 0; i < size; ++i) {
+    unsigned char c = data[i];
+    if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
+      stream << c;
+    } else {
+      stream << '\\';
+      switch (c) {
+        case '"':
+          stream << '"';
+          break;
+        case '\\':
+          stream << '\\';
+          break;
+        case '\t':
+          stream << 't';
+          break;
+        case '\r':
+          stream << 'r';
+          break;
+        case '\n':
+          stream << 'n';
+          break;
+        default:
+          const char* hex_digits = "0123456789ABCDEF";
+          stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
+      }
+    }
+  }
+  return stream.str();
+}
+
+/*!
+ * \brief Create a stream with escape.
+ * \param data The data
+ * \param size The size of the string.
+ * \return the Result string.
+ */
+inline std::string StrEscape(const std::string& val) {
+  return StrEscape(val.data(), val.length());
+}
+
+}  // namespace support
+}  // namespace tvm
+#endif  // TVM_SUPPORT_STR_ESCAPE_H_
index 0efa33a..65d424e 100644 (file)
@@ -28,6 +28,7 @@
 #include <memory>
 #include <limits>
 #include "../pass/ir_util.h"
+#include "../../support/str_escape.h"
 
 namespace tvm {
 namespace tir {
@@ -425,38 +426,8 @@ TVM_REGISTER_NODE_TYPE(BufferLoadNode);
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
     auto* op = static_cast<const StringImmNode*>(node.get());
-    auto& stream = p->stream;
-    stream << '"';
-    for (size_t i = 0; i < op->value.size(); ++i) {
-      unsigned char c = op->value[i];
-      if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
-        stream << c;
-      } else {
-        stream << '\\';
-        switch (c) {
-          case '"':
-            stream << '"';
-            break;
-          case '\\':
-            stream << '\\';
-            break;
-          case '\t':
-            stream << 't';
-            break;
-          case '\r':
-            stream << 'r';
-            break;
-          case '\n':
-            stream << 'n';
-            break;
-          default:
-            const char* hex_digits = "0123456789ABCDEF";
-            stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
-        }
-      }
-    }
-    stream << '"';
-  });
+    p->stream << '\"' << support::StrEscape(op->value) << '\"';
+});
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 .set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
index 001c7cf..dda9ff4 100644 (file)
@@ -55,7 +55,7 @@ class PrimFuncPassNode : public PassNode {
    *
    * \return Return the updated module.
    */
-  IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
+  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
 
   /*!
    * \brief Get the pass information/meta data.
@@ -90,34 +90,35 @@ PrimFuncPass::PrimFuncPass(
 }
 
 // Perform Module -> Module optimizations at the PrimFunc level.
-IRModule PrimFuncPassNode::operator()(const IRModule& mod,
+IRModule PrimFuncPassNode::operator()(IRModule mod,
                                       const PassContext& pass_ctx) const {
   const PassInfo& pass_info = Info();
   CHECK(mod.defined());
   pass_ctx.Trace(mod, pass_info, true);
-  // Execute the pass function and return a new module.
-  IRModule updated_mod = IRModule(
-      mod->functions, mod->type_definitions, mod->Imports());
-  std::vector<std::pair<GlobalVar, PrimFunc> > updates;
-  for (const auto& it : updated_mod->functions) {
-    // only picks up relay::PrimFunc
-    if (auto* n = it.second.as<PrimFuncNode>()) {
-      PrimFunc func = GetRef<PrimFunc>(n);
-      auto updated_func =
-          pass_func(func, updated_mod, pass_ctx);
-      updates.push_back({it.first, updated_func});
+  std::vector<ObjectRef> deleted_list;
+  IRModuleNode* mod_ptr = mod.CopyOnWrite();
+  auto* func_dict = mod_ptr->functions.CopyOnWrite();
+  // directly loop over the underlying dict
+  for (auto& kv : func_dict->data) {
+    // only picks up tir::PrimFunc
+    if (kv.second->IsInstance<PrimFuncNode>()) {
+      // move out the function so that it is the only copy.
+      PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
+      func = pass_func(std::move(func), mod, pass_ctx);
+      kv.second = std::move(func);
+
+      if (!kv.second.defined()) {
+        deleted_list.push_back(kv.first);
+      }
     }
   }
+
   // automatic removal of None
-  for (const auto& pair : updates) {
-    if (pair.second.defined()) {
-      updated_mod->Add(pair.first, pair.second, true);
-    } else {
-      updated_mod->Remove(pair.first);
-    }
+  for (const auto& gv : deleted_list) {
+    func_dict->data.erase(gv);
   }
-  pass_ctx.Trace(updated_mod, pass_info, false);
-  return updated_mod;
+  pass_ctx.Trace(mod, pass_info, false);
+  return mod;
 }
 
 Pass CreatePrimFuncPass(
index 00bc45a..1f9d976 100644 (file)
@@ -397,17 +397,14 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) {
 
 namespace transform {
 
-Pass NarrowDataType() {
-  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+Pass NarrowDataType(int target_bits) {
+  auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    IntImm target_bits = f->GetAttr<IntImm>("target_bits");
-    CHECK(target_bits.defined())
-      << "NarrowDataType: Require the target_bits";
-    n->body = DataTypeRewriter(target_bits->value)(std::move(n->body));
+    n->body = DataTypeRewriter(target_bits)(std::move(n->body));
     return f;
   };
   return CreatePrimFuncPass(
-      pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
+      pass_func, 0, "tir.NarrowDataType", {});
 }
 
 TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
index 5149d28..792a061 100644 (file)
@@ -173,7 +173,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
 
 class HostDeviceSplitter : public StmtMutator {
  public:
-  explicit HostDeviceSplitter(IRModuleNode* device_mod,
+  explicit HostDeviceSplitter(IRModule* device_mod,
                               Target device_target,
                               std::string name_prefix)
       : device_mod_(device_mod),
@@ -240,7 +240,7 @@ class HostDeviceSplitter : public StmtMutator {
                            runtime::String(kernel_symbol));
     device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
     device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
-    device_mod_->Add(GlobalVar(kernel_symbol), device_func);
+    (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
 
     // generate calls to the device function
     Array<PrimExpr> call_args;
@@ -257,7 +257,7 @@ class HostDeviceSplitter : public StmtMutator {
   }
 
   // target ir module
-  IRModuleNode* device_mod_;
+  IRModule* device_mod_;
   // Device target
   Target device_target_;
   // function name hint
@@ -268,7 +268,7 @@ class HostDeviceSplitter : public StmtMutator {
 };
 
 
-PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
+PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
   auto target = func->GetAttr<Target>(tvm::attr::kTarget);
   CHECK(target.defined())
       << "SplitHostDevice: Require the target attribute";
@@ -287,26 +287,22 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
 }
 
 
-
 namespace transform {
 
 Pass SplitHostDevice() {
-  auto pass_func = [](IRModule m, PassContext ctx) {
-    IRModuleNode* mptr = m.CopyOnWrite();
-    std::vector<std::pair<GlobalVar, PrimFunc> > updates;
-
-    for (const auto& kv : mptr->functions) {
-      if (auto* n = kv.second.as<PrimFuncNode>()) {
-        PrimFunc func = GetRef<PrimFunc>(n);
-        auto updated_func = SplitHostDevice(std::move(func), mptr);
-        updates.push_back({kv.first, updated_func});
+  auto pass_func = [](IRModule mod, PassContext ctx) {
+    IRModuleNode* mod_ptr = mod.CopyOnWrite();
+    auto* func_dict = mod_ptr->functions.CopyOnWrite();
+    IRModule device_mod = IRModule::Empty();
+
+    for (auto& kv : func_dict->data) {
+      if (kv.second->IsInstance<PrimFuncNode>()) {
+        PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
+        kv.second = SplitHostDevice(std::move(func), &device_mod);
       }
     }
-
-    for (const auto& pair : updates) {
-      mptr->Add(pair.first, pair.second, true);
-    }
-    return m;
+    mod->Update(device_mod);
+    return mod;
   };
 
   return tvm::transform::CreateModulePass(
index 99b6ca2..787e0c4 100644 (file)
@@ -22,6 +22,7 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/container.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
 #include <tvm/tir/expr.h>
 
 TEST(PackedFunc, Basic) {
@@ -274,6 +275,16 @@ TEST(TypedPackedFunc, RValue) {
   using namespace tvm;
   using namespace tvm::runtime;
   {
+
+    auto inspect = [](TVMArgs args, TVMRetValue* rv) {
+      for (int i = 0; i < args.size(); ++i) {
+        CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg);
+      }
+    };
+    PackedFunc  finspect(inspect);
+    finspect(tir::Var("x"));
+  }
+  {
     auto f = [](tir::Var x, bool move) {
       if (move) {
         CHECK(x.unique());
@@ -287,9 +298,9 @@ TEST(TypedPackedFunc, RValue) {
 
     tir::Var var("x");
     CHECK(var.unique());
-    f(var, false);
+    tf(var, false);
     // move the result to the function.
-    tir::Var ret = f(std::move(var), true);
+    tir::Var ret = tf(std::move(var), true);
     CHECK(!var.defined());
   }
 
@@ -307,10 +318,10 @@ TEST(TypedPackedFunc, RValue) {
 
     tir::Var var("x");
     CHECK(var.unique());
-    f(var, false);
-    f(std::move(var), true);
+    tf(var, false);
+    tf(std::move(var), true);
     // auto conversion.
-    f(1, true);
+    tf(1, true);
   }
 }
 
index 49df1c2..a5583d5 100644 (file)
@@ -20,9 +20,9 @@ from tvm.tir import const
 
 
 def lower_stmt(params, stmt, target_bits):
-    func = tvm.tir.PrimFunc(params, stmt).with_attr(
-        "target_bits", target_bits)
-    func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"]
+    func = tvm.tir.PrimFunc(params, stmt)
+    func = tvm.tir.transform.NarrowDataType(target_bits)(
+        tvm.IRModule.from_expr(func))["main"]
     stmt = func.body
     return stmt
 
index 1695cbc..f286bf0 100644 (file)
@@ -46,5 +46,25 @@ def test_prim_func_pass():
     assert tvm.ir.structural_equal(mod["main"].body, new_func.body)
 
 
+def test_cow_pass():
+    def fapply(f):
+        assert tvm.testing.object_use_count(f) == 1
+        return f
+
+    pidentity = tvm.tir.transform.Apply(fapply)
+    x = te.var('x')
+    func = tvm.tir.PrimFunc(
+        [x], tvm.tir.Evaluate(x)).with_attr("target_bits", 32)
+    func_hash = func.__hash__()
+    mod = tvm.IRModule({"main": func})
+    del func
+    # copy on write
+    mod_hash = mod.__hash__()
+    mod = tvm.ir.transform.Sequential(
+        [pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move())
+    assert mod_hash == mod.__hash__()
+    assert func_hash == mod["main"].__hash__()
+
 if __name__ == "__main__":
+    test_cow_pass()
     test_prim_func_pass()
index 95c5c55..133bc85 100644 (file)
@@ -42,12 +42,5 @@ inline Array<Integer> ArrayOrInt(TVMArgValue arg) {
     return arg;
   }
 }
-
-inline bool IsTensorType(TVMArgValue arg) {
-  return (arg.type_code() == kTVMObjectHandle &&
-          static_cast<Object*>(
-              arg.value().v_handle)->IsInstance<tvm::te::TensorNode>());
-}
-
 }  // namespace topi
 #endif  // TOPI_UTIL_H_
index 3ae3dae..0f0241f 100644 (file)
@@ -35,8 +35,8 @@ using namespace tvm::runtime;
 #define TOPI_REGISTER_BCAST_OP(OpName, Op)                              \
   TVM_REGISTER_GLOBAL(OpName)                                           \
   .set_body([](TVMArgs args, TVMRetValue *rv) {                         \
-      bool lhs_is_tensor = IsTensorType(args[0]);                       \
-      bool rhs_is_tensor = IsTensorType(args[1]);                       \
+      bool lhs_is_tensor = args[0].IsObjectRef<tvm::te::Tensor>();      \
+      bool rhs_is_tensor = args[1].IsObjectRef<tvm::te::Tensor>();      \
       if (lhs_is_tensor && rhs_is_tensor) {                             \
         *rv = Op(args[0].operator tvm::te::Tensor(),                    \
                  args[1].operator tvm::te::Tensor());                   \