From e4b80bda71bfad9874c6d418ac8127f675d5b1d0 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 11 Apr 2020 17:42:42 -0700 Subject: [PATCH] [IR][TRANSFORM] Enable CopyOnWrite for passes. (#5309) This PR enables the copy on write optimizations passes: - Enable COW for IRModule both TIR and relay passes. - Enabled COW for PrimFunc in TIR passes. Need more thoughts into whether/how to enable COW for relay::Function, due to some function passes depend on the presence of IRModule for context information, and the std::move of the related function to nullptr might affect the related behavior. --- include/tvm/ir/expr.h | 7 +- include/tvm/ir/transform.h | 23 ++++-- include/tvm/runtime/data_type.h | 1 + include/tvm/runtime/object.h | 2 +- include/tvm/runtime/packed_func.h | 2 +- include/tvm/tir/transform.h | 5 +- python/tvm/error.py | 1 + python/tvm/tir/transform/transform.py | 13 +++- src/ir/expr.cc | 20 ++--- src/ir/module.cc | 16 +++- src/ir/transform.cc | 29 ++++---- src/node/container.cc | 9 ++- src/relay/ir/transform.cc | 5 +- src/support/str_escape.h | 85 ++++++++++++++++++++++ src/tir/ir/expr.cc | 35 +-------- src/tir/ir/transform.cc | 43 +++++------ src/tir/transforms/narrow_datatype.cc | 11 +-- src/tir/transforms/split_host_device.cc | 34 ++++----- tests/cpp/packed_func_test.cc | 21 ++++-- .../unittest/test_tir_transform_narrow_datatype.py | 6 +- .../unittest/test_tir_transform_prim_func_pass.py | 20 +++++ topi/include/topi/util.h | 7 -- topi/src/broadcast.cc | 4 +- 23 files changed, 253 insertions(+), 146 deletions(-) create mode 100644 src/support/str_escape.h diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 13a699a..4e0a301 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -124,7 +124,7 @@ class PrimExpr : public BaseExpr { private: // Internal function for conversion. friend struct runtime::PackedFuncValueConverter; - TVM_DLL static PrimExpr FromObject_(ObjectPtr ptr); + TVM_DLL static PrimExpr FromObject_(ObjectRef ref); }; /*! @@ -464,9 +464,8 @@ struct PackedFuncValueConverter { if (val.type_code() == kDLFloat) { return PrimExpr(static_cast(val.operator double())); } - TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle); - Object* ptr = val.ptr(); - return PrimExpr::FromObject_(GetObjectPtr(ptr)); + + return PrimExpr::FromObject_(val.AsObjectRef()); } }; } // namespace runtime diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 3a9913f..8361902 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -62,6 +62,7 @@ #include #include #include +#include namespace tvm { namespace transform { @@ -251,8 +252,8 @@ class PassNode : public Object { * * \return The transformed module. */ - IRModule operator()(const IRModule& mod) const { - return this->operator()(mod, PassContext::Current()); + IRModule operator()(IRModule mod) const { + return this->operator()(std::move(mod), PassContext::Current()); } /*! @@ -263,7 +264,7 @@ class PassNode : public Object { * * \return The transformed module. */ - virtual IRModule operator()(const IRModule& mod, + virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; void VisitAttrs(AttrVisitor* v) {} @@ -277,14 +278,22 @@ class Pass : public ObjectRef { /*! * \brief Transform mod using the default PassContext in the current scope. * + * \code + * + * // If you do no longer need the input module + * // it is recommended to use std::move to move your input module. + * mod = pass(std::move(mod)); + * + * \endcode + * * \param mod The module that an optimization pass runs on. * * \return The transformed module. */ - IRModule operator()(const IRModule& mod) const { + IRModule operator()(IRModule mod) const { const PassNode* node = operator->(); CHECK(node != nullptr); - return node->operator()(mod); + return node->operator()(std::move(mod)); } /*! * \brief Transform mod using a functor under a given pass context. @@ -294,11 +303,11 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - IRModule operator()(const IRModule& mod, + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); CHECK(node != nullptr); - return node->operator()(mod, pass_ctx); + return node->operator()(std::move(mod), pass_ctx); } TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 9e92db9..44385d6 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -325,6 +325,7 @@ inline const char* TypeCode2Str(int type_code) { case kTVMModuleHandle: return "ModuleHandle"; case kTVMNDArrayHandle: return "NDArrayContainer"; case kTVMObjectHandle: return "Object"; + case kTVMObjectRValueRefArg: return "ObjectRValueRefArg"; default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); return ""; } diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 8005cf6..acbb939 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -885,7 +885,7 @@ inline ObjectPtr GetObjectPtr(ObjType* ptr) { template inline SubRef Downcast(BaseRef ref) { - CHECK(ref->template IsInstance()) + CHECK(!ref.defined() || ref->template IsInstance()) << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key << " failed."; return SubRef(std::move(ref.data_)); diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 2dcb4ff..c5f0df5 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -1357,7 +1357,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; - } else if (std::is_rvalue_reference::value) { + } else if (std::is_rvalue_reference::value) { values_[i].v_handle = const_cast(&(value.data_.data_)); type_codes_[i] = kTVMObjectRValueRefArg; } else { diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 5ad40a3..0c5b39b 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -197,11 +197,12 @@ TVM_DLL Pass CombineContextCall(); /*! * \brief Narrow down PrimExpr datatype in stmt to target_bits. * - * \note Run this pass after StorageFlatten. + * \param target_bits The target bits * + * \note Run this pass after storage flatten. * \return The pass. */ -TVM_DLL Pass NarrowDataType(); +TVM_DLL Pass NarrowDataType(int target_bits); } // namespace transform } // namespace tir diff --git a/python/tvm/error.py b/python/tvm/error.py index 02bc90b..4c3e606 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -54,6 +54,7 @@ class InternalError(TVMError): register_error("ValueError", ValueError) register_error("TypeError", TypeError) register_error("AttributeError", AttributeError) +register_error("KeyError", KeyError) @register_error diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 64c31a5..91321fb 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -38,7 +38,7 @@ def Apply(ftransform): # pylint: disable=unused-argument def _transform(func, mod, ctx): return ftransform(func) - return _fpass.prim_func_pass(_transform, opt_level=0) + return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") def Filter(fcond): @@ -57,7 +57,7 @@ def Filter(fcond): # pylint: disable=unused-argument def _transform(func, mod, ctx): return func if fcond(func) else None - return _fpass.prim_func_pass(_transform, opt_level=0) + return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") def LowerCustomDatatypes(): @@ -221,9 +221,14 @@ def CombineContextCall(): return _ffi_api.CombineContextCall() -def NarrowDataType(): +def NarrowDataType(target_bits): """Narrow down PrimExpr datatype in stmt to target_bits. + Parameters + ---------- + target_bits : int + The target bit configuration. + Returns ------- fpass : tvm.ir.transform.Pass @@ -233,4 +238,4 @@ def NarrowDataType(): ---- Run this pass after StorageFlatten. """ - return _ffi_api.NarrowDataType() + return _ffi_api.NarrowDataType(target_bits) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 1f0337e..e08d832 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -43,21 +43,21 @@ PrimExpr::PrimExpr(float value) PrimExpr::PrimExpr(runtime::String value) : PrimExpr(tir::StringImmNode::make(value)) {} -PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { +PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; - if (ptr->IsInstance()) { - return tir::IterVar(ptr)->var; + if (auto* ptr = ref.as()) { + return GetRef(ptr)->var; } - if (ptr->IsInstance()) { - return te::Tensor(ptr)(); + if (auto* ptr = ref.as()) { + return GetRef(ptr)(); } - if (ptr->IsInstance()) { - return tir::StringImmNode::make(runtime::String(ptr)); + if (auto* ptr = ref.as()) { + return tir::StringImmNode::make(GetRef(ptr)); } - CHECK(ObjectTypeChecker::Check(ptr.get())) + CHECK(ObjectTypeChecker::Check(ref.get())) << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); - return PrimExpr(ptr); + << " but get " << ref->GetTypeKey(); + return Downcast(ref); } diff --git a/src/ir/module.cc b/src/ir/module.cc index ea74f4c..bcf56aa 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -121,8 +121,20 @@ bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const { GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const { auto it = global_var_map_.find(name); - CHECK(it != global_var_map_.end()) - << "Cannot find global var " << name << " in the Module"; + if (it == global_var_map_.end()) { + std::ostringstream msg; + msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n" + << "candidates are: ["; + int counter = 0; + for (auto kv : global_var_map_) { + if (counter++ != 0) { + msg << ", "; + } + msg << "\"" << kv.first << "\""; + } + msg << "]"; + LOG(FATAL) << msg.str(); + } return (*it).second; } diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 49c0ef4..ef524c3 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -126,7 +126,7 @@ class ModulePassNode : public PassNode { * * \return Return the updated module. */ - IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. @@ -205,7 +205,7 @@ class SequentialNode : public PassNode { * * \return Return the updated module. */ - IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "transform.Sequential"; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); @@ -231,19 +231,20 @@ ModulePass::ModulePass( } // Module -> Module optimizations. -IRModule ModulePassNode::operator()(const IRModule& mod, +IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; + CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); - IRModule updated_mod = pass_func(mod, pass_ctx); - CHECK(updated_mod.defined()); - pass_ctx.Trace(updated_mod, pass_info, false); - return updated_mod; + mod = pass_func(std::move(mod), pass_ctx); + CHECK(mod.defined()); + pass_ctx.Trace(mod, pass_info, false); + return mod; } Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { @@ -314,18 +315,17 @@ Pass GetPass(const std::string& pass_name) { // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. -IRModule SequentialNode::operator()(const IRModule& module, +IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { - IRModule mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { - mod = GetPass(it)(mod, pass_ctx); + mod = GetPass(it)(std::move(mod), pass_ctx); } - mod = pass(mod, pass_ctx); + mod = pass(std::move(mod), pass_ctx); } return mod; } @@ -375,11 +375,8 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass") }); TVM_REGISTER_GLOBAL("transform.RunPass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Pass pass = args[0]; - IRModule mod = args[1]; - ObjectRef ref = args[1]; - *ret = pass(mod); +.set_body_typed([](Pass pass, IRModule mod) { + return pass(std::move(mod)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/node/container.cc b/src/node/container.cc index 0949da0..52e4bf1 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include "../support/str_escape.h" namespace tvm { @@ -63,6 +63,13 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) static_cast(n)).operator std::string(); }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; +}); + + struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index fa709eb..a06eb5a 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -68,7 +68,7 @@ class FunctionPassNode : public PassNode { * * \return Return the updated module. */ - IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. @@ -113,7 +113,7 @@ FunctionPass::FunctionPass( } // Perform Module -> Module optimizations at the Function level. -IRModule FunctionPassNode::operator()(const IRModule& mod, +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); @@ -122,6 +122,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, << " with opt level: " << pass_info->opt_level; pass_ctx.Trace(mod, pass_info, true); + // Execute the pass function and return a new module. IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); std::vector > updates; diff --git a/src/support/str_escape.h b/src/support/str_escape.h new file mode 100644 index 0000000..fd25c01 --- /dev/null +++ b/src/support/str_escape.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file support/str_escape.h + * \brief Print escape sequence of a string. + */ +#ifndef TVM_SUPPORT_STR_ESCAPE_H_ +#define TVM_SUPPORT_STR_ESCAPE_H_ + +#include +#include + +namespace tvm { +namespace support { + +/*! + * \brief Create a stream with escape. + * \param data The data + * \param size The size of the string. + * \return the Result string. + */ +inline std::string StrEscape(const char* data, size_t size) { + std::ostringstream stream; + for (size_t i = 0; i < size; ++i) { + unsigned char c = data[i]; + if (c >= ' ' && c <= '~' && c != '\\' && c != '"') { + stream << c; + } else { + stream << '\\'; + switch (c) { + case '"': + stream << '"'; + break; + case '\\': + stream << '\\'; + break; + case '\t': + stream << 't'; + break; + case '\r': + stream << 'r'; + break; + case '\n': + stream << 'n'; + break; + default: + const char* hex_digits = "0123456789ABCDEF"; + stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf]; + } + } + } + return stream.str(); +} + +/*! + * \brief Create a stream with escape. + * \param data The data + * \param size The size of the string. + * \return the Result string. + */ +inline std::string StrEscape(const std::string& val) { + return StrEscape(val.data(), val.length()); +} + +} // namespace support +} // namespace tvm +#endif // TVM_SUPPORT_STR_ESCAPE_H_ diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 0efa33a..65d424e 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -28,6 +28,7 @@ #include #include #include "../pass/ir_util.h" +#include "../../support/str_escape.h" namespace tvm { namespace tir { @@ -425,38 +426,8 @@ TVM_REGISTER_NODE_TYPE(BufferLoadNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - auto& stream = p->stream; - stream << '"'; - for (size_t i = 0; i < op->value.size(); ++i) { - unsigned char c = op->value[i]; - if (c >= ' ' && c <= '~' && c != '\\' && c != '"') { - stream << c; - } else { - stream << '\\'; - switch (c) { - case '"': - stream << '"'; - break; - case '\\': - stream << '\\'; - break; - case '\t': - stream << 't'; - break; - case '\r': - stream << 'r'; - break; - case '\n': - stream << 'n'; - break; - default: - const char* hex_digits = "0123456789ABCDEF"; - stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf]; - } - } - } - stream << '"'; - }); + p->stream << '\"' << support::StrEscape(op->value) << '\"'; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 001c7cf..dda9ff4 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -55,7 +55,7 @@ class PrimFuncPassNode : public PassNode { * * \return Return the updated module. */ - IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. @@ -90,34 +90,35 @@ PrimFuncPass::PrimFuncPass( } // Perform Module -> Module optimizations at the PrimFunc level. -IRModule PrimFuncPassNode::operator()(const IRModule& mod, +IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); - // Execute the pass function and return a new module. - IRModule updated_mod = IRModule( - mod->functions, mod->type_definitions, mod->Imports()); - std::vector > updates; - for (const auto& it : updated_mod->functions) { - // only picks up relay::PrimFunc - if (auto* n = it.second.as()) { - PrimFunc func = GetRef(n); - auto updated_func = - pass_func(func, updated_mod, pass_ctx); - updates.push_back({it.first, updated_func}); + std::vector deleted_list; + IRModuleNode* mod_ptr = mod.CopyOnWrite(); + auto* func_dict = mod_ptr->functions.CopyOnWrite(); + // directly loop over the underlying dict + for (auto& kv : func_dict->data) { + // only picks up tir::PrimFunc + if (kv.second->IsInstance()) { + // move out the function so that it is the only copy. + PrimFunc func = Downcast(std::move(kv.second)); + func = pass_func(std::move(func), mod, pass_ctx); + kv.second = std::move(func); + + if (!kv.second.defined()) { + deleted_list.push_back(kv.first); + } } } + // automatic removal of None - for (const auto& pair : updates) { - if (pair.second.defined()) { - updated_mod->Add(pair.first, pair.second, true); - } else { - updated_mod->Remove(pair.first); - } + for (const auto& gv : deleted_list) { + func_dict->data.erase(gv); } - pass_ctx.Trace(updated_mod, pass_info, false); - return updated_mod; + pass_ctx.Trace(mod, pass_info, false); + return mod; } Pass CreatePrimFuncPass( diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 00bc45a..1f9d976 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -397,17 +397,14 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) { namespace transform { -Pass NarrowDataType() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { +Pass NarrowDataType(int target_bits) { + auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - IntImm target_bits = f->GetAttr("target_bits"); - CHECK(target_bits.defined()) - << "NarrowDataType: Require the target_bits"; - n->body = DataTypeRewriter(target_bits->value)(std::move(n->body)); + n->body = DataTypeRewriter(target_bits)(std::move(n->body)); return f; }; return CreatePrimFuncPass( - pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); + pass_func, 0, "tir.NarrowDataType", {}); } TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType") diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 5149d28..792a061 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -173,7 +173,7 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { class HostDeviceSplitter : public StmtMutator { public: - explicit HostDeviceSplitter(IRModuleNode* device_mod, + explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) : device_mod_(device_mod), @@ -240,7 +240,7 @@ class HostDeviceSplitter : public StmtMutator { runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); - device_mod_->Add(GlobalVar(kernel_symbol), device_func); + (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); // generate calls to the device function Array call_args; @@ -257,7 +257,7 @@ class HostDeviceSplitter : public StmtMutator { } // target ir module - IRModuleNode* device_mod_; + IRModule* device_mod_; // Device target Target device_target_; // function name hint @@ -268,7 +268,7 @@ class HostDeviceSplitter : public StmtMutator { }; -PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) { +PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; @@ -287,26 +287,22 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) { } - namespace transform { Pass SplitHostDevice() { - auto pass_func = [](IRModule m, PassContext ctx) { - IRModuleNode* mptr = m.CopyOnWrite(); - std::vector > updates; - - for (const auto& kv : mptr->functions) { - if (auto* n = kv.second.as()) { - PrimFunc func = GetRef(n); - auto updated_func = SplitHostDevice(std::move(func), mptr); - updates.push_back({kv.first, updated_func}); + auto pass_func = [](IRModule mod, PassContext ctx) { + IRModuleNode* mod_ptr = mod.CopyOnWrite(); + auto* func_dict = mod_ptr->functions.CopyOnWrite(); + IRModule device_mod = IRModule::Empty(); + + for (auto& kv : func_dict->data) { + if (kv.second->IsInstance()) { + PrimFunc func = Downcast(std::move(kv.second)); + kv.second = SplitHostDevice(std::move(func), &device_mod); } } - - for (const auto& pair : updates) { - mptr->Add(pair.first, pair.second, true); - } - return m; + mod->Update(device_mod); + return mod; }; return tvm::transform::CreateModulePass( diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 99b6ca2..787e0c4 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include TEST(PackedFunc, Basic) { @@ -274,6 +275,16 @@ TEST(TypedPackedFunc, RValue) { using namespace tvm; using namespace tvm::runtime; { + + auto inspect = [](TVMArgs args, TVMRetValue* rv) { + for (int i = 0; i < args.size(); ++i) { + CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg); + } + }; + PackedFunc finspect(inspect); + finspect(tir::Var("x")); + } + { auto f = [](tir::Var x, bool move) { if (move) { CHECK(x.unique()); @@ -287,9 +298,9 @@ TEST(TypedPackedFunc, RValue) { tir::Var var("x"); CHECK(var.unique()); - f(var, false); + tf(var, false); // move the result to the function. - tir::Var ret = f(std::move(var), true); + tir::Var ret = tf(std::move(var), true); CHECK(!var.defined()); } @@ -307,10 +318,10 @@ TEST(TypedPackedFunc, RValue) { tir::Var var("x"); CHECK(var.unique()); - f(var, false); - f(std::move(var), true); + tf(var, false); + tf(std::move(var), true); // auto conversion. - f(1, true); + tf(1, true); } } diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index 49df1c2..a5583d5 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -20,9 +20,9 @@ from tvm.tir import const def lower_stmt(params, stmt, target_bits): - func = tvm.tir.PrimFunc(params, stmt).with_attr( - "target_bits", target_bits) - func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"] + func = tvm.tir.PrimFunc(params, stmt) + func = tvm.tir.transform.NarrowDataType(target_bits)( + tvm.IRModule.from_expr(func))["main"] stmt = func.body return stmt diff --git a/tests/python/unittest/test_tir_transform_prim_func_pass.py b/tests/python/unittest/test_tir_transform_prim_func_pass.py index 1695cbc..f286bf0 100644 --- a/tests/python/unittest/test_tir_transform_prim_func_pass.py +++ b/tests/python/unittest/test_tir_transform_prim_func_pass.py @@ -46,5 +46,25 @@ def test_prim_func_pass(): assert tvm.ir.structural_equal(mod["main"].body, new_func.body) +def test_cow_pass(): + def fapply(f): + assert tvm.testing.object_use_count(f) == 1 + return f + + pidentity = tvm.tir.transform.Apply(fapply) + x = te.var('x') + func = tvm.tir.PrimFunc( + [x], tvm.tir.Evaluate(x)).with_attr("target_bits", 32) + func_hash = func.__hash__() + mod = tvm.IRModule({"main": func}) + del func + # copy on write + mod_hash = mod.__hash__() + mod = tvm.ir.transform.Sequential( + [pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move()) + assert mod_hash == mod.__hash__() + assert func_hash == mod["main"].__hash__() + if __name__ == "__main__": + test_cow_pass() test_prim_func_pass() diff --git a/topi/include/topi/util.h b/topi/include/topi/util.h index 95c5c55..133bc85 100644 --- a/topi/include/topi/util.h +++ b/topi/include/topi/util.h @@ -42,12 +42,5 @@ inline Array ArrayOrInt(TVMArgValue arg) { return arg; } } - -inline bool IsTensorType(TVMArgValue arg) { - return (arg.type_code() == kTVMObjectHandle && - static_cast( - arg.value().v_handle)->IsInstance()); -} - } // namespace topi #endif // TOPI_UTIL_H_ diff --git a/topi/src/broadcast.cc b/topi/src/broadcast.cc index 3ae3dae..0f0241f 100644 --- a/topi/src/broadcast.cc +++ b/topi/src/broadcast.cc @@ -35,8 +35,8 @@ using namespace tvm::runtime; #define TOPI_REGISTER_BCAST_OP(OpName, Op) \ TVM_REGISTER_GLOBAL(OpName) \ .set_body([](TVMArgs args, TVMRetValue *rv) { \ - bool lhs_is_tensor = IsTensorType(args[0]); \ - bool rhs_is_tensor = IsTensorType(args[1]); \ + bool lhs_is_tensor = args[0].IsObjectRef(); \ + bool rhs_is_tensor = args[1].IsObjectRef(); \ if (lhs_is_tensor && rhs_is_tensor) { \ *rv = Op(args[0].operator tvm::te::Tensor(), \ args[1].operator tvm::te::Tensor()); \ -- 2.7.4