* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)
+
/*!
- * \brief construct from string.
- * \param str The value to be constructed.
+ * \brief construct from runtime String.
+ * \param value The value to be constructed.
*/
- TVM_DLL PrimExpr(std::string str); // NOLINT(*)
+ TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)
/*! \return the data type of this expression. */
DataType dtype() const {
#define TVM_IR_TRANSFORM_H_
#include <tvm/support/with.h>
+#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
- Array<PrimExpr> required_pass;
+ Array<runtime::String> required_pass;
/*! \brief The list of disabled passes. */
- Array<PrimExpr> disabled_pass;
+ Array<runtime::String> disabled_pass;
TraceFunc trace_func;
std::string name;
/*! \brief The passes that are required to perform the current pass. */
- Array<PrimExpr> required;
+ Array<runtime::String> required;
PassInfoNode() = default;
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
- Array<PrimExpr> required);
+ Array<runtime::String> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const Array<PrimExpr>& required);
+ const Array<runtime::String>& required);
} // namespace transform
} // namespace tvm
namespace tvm {
+using runtime::String;
+using runtime::StringObj;
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
#define TVM_NODE_NODE_H_
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
+using runtime::String;
} // namespace tvm
#endif // TVM_NODE_NODE_H_
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
+#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required);
+ const tvm::Array<runtime::String>& required);
/*! \brief Remove expressions which does not effect the program result.
*
*
* \return The pass.
*/
-TVM_DLL Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions);
+TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);
} // namespace transform
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
- explicit String(std::string other);
+ String(std::string other); // NOLINT(*)
+
+ /*!
+ * \brief Construct a new String object
+ *
+ * \param other a char array.
+ */
+ String(const char* other) // NOLINT(*)
+ : String(std::string(other)) {}
/*!
* \brief Change the value the reference object points to.
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
- Array<PrimExpr> keys_array;
+ Array<runtime::String> keys_array;
/*! \brief Options for this target */
- Array<PrimExpr> options_array;
+ Array<runtime::String> options_array;
/*! \brief Collection of imported libs */
- Array<PrimExpr> libs_array;
+ Array<runtime::String> libs_array;
/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
- * \param only_enable List of StringImm.
+ * \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
- const Array<PrimExpr>& only_enable = {});
+ const Array<runtime::String>& only_enable = {});
/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required);
+ const tvm::Array<runtime::String>& required);
/*!
* \brief Transform the high-level PrimFunc to a low-level version
*
* \return The pass.
*/
-TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
+TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
/*!
import numpy as np
from tvm import target as _target
+from tvm import runtime
from tvm.ir import container
from tvm.tir import expr
from tvm.te import tensor, placeholder
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
+ if isinstance(x, runtime.container.String):
+ return str(x)
if x is None:
return None
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
- for name in param_names:
- key = name.value
+ for key in param_names:
arr = self._get_param_by_name(key)
param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx)
arr.copyto(param)
# under the License.
"""Runtime container structures."""
import tvm._ffi
-
+from tvm._ffi.base import string_types
from tvm.runtime import Object, ObjectTypes
+from tvm.runtime import _ffi_api
def getitem_helper(obj, elem_getter, length, idx):
"""Helper function to implement a pythonic getitem function.
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or " \
"tvm NDArray type, but received : {0}".format(type(f))
- self.__init_handle_by_constructor__(_ADT, tag, *fields)
+ self.__init_handle_by_constructor__(_ffi_api.ADT, tag,
+ *fields)
@property
def tag(self):
- return _GetADTTag(self)
+ return _ffi_api.GetADTTag(self)
def __getitem__(self, idx):
return getitem_helper(
- self, _GetADTFields, len(self), idx)
+ self, _ffi_api.GetADTFields, len(self), idx)
def __len__(self):
- return _GetADTSize(self)
+ return _ffi_api.GetADTSize(self)
def tuple_object(fields=None):
for f in fields:
assert isinstance(f, ObjectTypes), "Expect object or tvm " \
"NDArray type, but received : {0}".format(type(f))
- return _Tuple(*fields)
+ return _ffi_api.Tuple(*fields)
@tvm._ffi.register_object("runtime.String")
Parameters
----------
- string : Str
+ string : str
The string used to construct a runtime String object
Returns
The created object.
"""
def __init__(self, string):
- self.__init_handle_by_constructor__(_String, string)
+ self.__init_handle_by_constructor__(_ffi_api.String, string)
+
+ def __str__(self):
+ return _ffi_api.GetStdString(self)
+
+ def __len__(self):
+ return _ffi_api.GetStringSize(self)
+
+ def __hash__(self):
+ return _ffi_api.StringHash(self)
+
+ def __eq__(self, other):
+ if isinstance(other, string_types):
+ return self.__str__() == other
+
+ if not isinstance(other, String):
+ return False
+
+ return _ffi_api.CompareString(self, other) == 0
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __gt__(self, other):
+ return _ffi_api.CompareString(self, other) > 0
+
+ def __lt__(self, other):
+ return _ffi_api.CompareString(self, other) < 0
+
+ def __getitem__(self, key):
+ return self.__str__()[key]
+
+ def startswith(self, string):
+ """Check if the runtime string starts with a given string
+ Parameters
+ ----------
+ string : str
+ The provided string
-tvm._ffi._init_api("tvm.runtime.container")
+ Returns
+ -------
+ ret : boolean
+ Return true if the runtime string starts with the given string,
+ otherwise, false.
+ """
+ return self.__str__().startswith(string)
from numbers import Number, Integral
from tvm._ffi.base import string_types
-from . import _ffi_node_api
+from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic
from .ndarray import NDArrayBase
from .packed_func import PackedFuncBase, convert_to_tvm_func
if isinstance(value, Number):
return const(value)
if isinstance(value, string_types):
- return _ffi_node_api.String(value)
+ return _ffi_api.String(value)
if isinstance(value, (list, tuple)):
value = [convert_to_object(x) for x in value]
return _ffi_node_api.Array(*value)
@property
def keys(self):
if not self._keys:
- self._keys = [k.value for k in self.keys_array]
+ self._keys = [str(k) for k in self.keys_array]
return self._keys
@property
def options(self):
if not self._options:
- self._options = [o.value for o in self.options_array]
+ self._options = [str(o) for o in self.options_array]
return self._options
@property
def libs(self):
if not self._libs:
- self._libs = [l.value for l in self.libs_array]
+ self._libs = [str(l) for l in self.libs_array]
return self._libs
@property
def model(self):
for opt in self.options_array:
- if opt.value.startswith('-model='):
- return opt.value[7:]
+ if opt.startswith('-model='):
+ return opt[7:]
return 'unknown'
@property
for (auto var : vars) {
Array<Array<PrimExpr> > feature_row;
ItervarFeature &fea = touch_analyzer.itervar_map[var];
- feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
+ feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_itervar_"), var});
- Array<PrimExpr> attr{std::string("_attr_"),
+ Array<PrimExpr> attr{tvm::tir::StringImmNode::make("_attr_"),
FloatImm(DataType::Float(32), trans(fea.length)),
IntImm(DataType::Int(32), fea.nest_level),
FloatImm(DataType::Float(32), trans(fea.topdown_product)),
feature_row.push_back(attr);
// arithmetic
- feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
+ feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImmNode::make("_arith_"),
FloatImm(DataType::Float(32), trans(fea.add_ct)),
FloatImm(DataType::Float(32), trans(fea.mul_ct)),
FloatImm(DataType::Float(32), trans(fea.div_ct)),
for (auto k : bufs) {
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(
- Array<PrimExpr>{k,
+ Array<PrimExpr>{tvm::tir::StringImmNode::make(k),
FloatImm(DataType::Float(32), trans(v.stride)),
FloatImm(DataType::Float(32), trans(v.mod)),
FloatImm(DataType::Float(32), trans(v.count)),
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kTVMStr) {
- dict.Set(key, PrimExpr(val.operator std::string()));
+ dict.Set(key, val.operator String());
} else {
dict.Set(key, val.operator PrimExpr());
}
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
-PrimExpr::PrimExpr(std::string str)
- : PrimExpr(tir::StringImmNode::make(str)) {}
+PrimExpr::PrimExpr(runtime::String value)
+ : PrimExpr(tir::StringImmNode::make(value)) {}
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
using runtime::ObjectTypeChecker;
if (ptr->IsInstance<te::TensorNode>()) {
return te::Tensor(ptr)();
}
+ if (ptr->IsInstance<runtime::StringObj>()) {
+ return tir::StringImmNode::make(runtime::String(ptr));
+ }
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
#include <tvm/runtime/module.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
// Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
.set_body_typed([]() {
- Array<tvm::PrimExpr> ret;
- for (const std::string& name :
- dmlc::Registry<OpRegistry>::ListAllNames()) {
- ret.push_back(tvm::PrimExpr(name));
+ Array<runtime::String> ret;
+ for (const std::string& name : dmlc::Registry<OpRegistry>::ListAllNames()) {
+ ret.push_back(name);
}
return ret;
});
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
#include <tvm/runtime/device_api.h>
#include <tvm/node/repr_printer.h>
#include <tvm/ir/transform.h>
PassInfo::PassInfo(int opt_level,
std::string name,
- tvm::Array<tvm::PrimExpr> required) {
+ tvm::Array<runtime::String> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
}
// linearly scan the pass array to match pass_name
-inline bool PassArrayContains(const Array<tvm::PrimExpr>& pass_array,
+inline bool PassArrayContains(const Array<runtime::String>& pass_array,
const std::string& pass_name) {
for (auto x : pass_array) {
- auto* str_name = x.as<tir::StringImmNode>();
- CHECK(str_name) << "pass name must be str";
- if (str_name->value == pass_name) return true;
+ if (x == pass_name) return true;
}
return false;
}
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
- const auto* name = it.as<tvm::tir::StringImmNode>();
- CHECK(name);
- mod = GetPass(name->value)(mod, pass_ctx);
+ mod = GetPass(it)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required) {
+ const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("transform.PassInfo")
-.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
+.set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
return PassInfo(opt_level, name, required);
});
p->stream << "opt_level: " << node->opt_level;
p->stream << "required passes: [" << "\n";
for (const auto& it : node->required) {
- const auto* str = it.as<tvm::tir::StringImmNode>();
- p->stream << str->value << ", ";
+ p->stream << it << ", ";
}
p->stream << "]\n";
});
tvm::Array<Pass> passes = args[0];
int opt_level = args[1];
std::string name = args[2];
- tvm::Array<tvm::PrimExpr> required = args[3];
+ tvm::Array<runtime::String> required = args[3];
PassInfo pass_info = PassInfo(opt_level, name, required);
*ret = Sequential(passes, pass_info);
});
auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
- tvm::Array<tvm::PrimExpr> required = args[2];
- tvm::Array<tvm::PrimExpr> disabled = args[3];
+ tvm::Array<runtime::String> required = args[2];
+ tvm::Array<runtime::String> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
static_cast<const runtime::StringObj*>(n)).operator std::string();
});
-
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
- auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
- for (auto expr : names) {
- auto key = expr.as<tir::StringImmNode>()->value;
+ auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
+ for (const auto& expr : names) {
+ // Implicit cast from runtime::String to std::string
+ std::string key = expr;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
}
return ret;
/*!
* \brief List all paramter names
*
- * \return Array<StringImm> names of params
+ * \return Array<runtime::String> names of params
*/
- Array<tvm::PrimExpr> ListParamNames() {
- Array<tvm::PrimExpr> ret;
+ Array<runtime::String> ListParamNames() {
+ Array<runtime::String> ret;
for (const auto& kv : params_) {
- ret.push_back(tir::StringImmNode::make(kv.first));
+ ret.push_back(kv.first);
}
return ret;
}
}
Array<Pass> pass_seqs;
- Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+ Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
- if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
- auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
+ if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+ auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
- if (ext_mods.find(code_gen->value) == ext_mods.end()) {
- ext_mods[code_gen->value] = IRModule({}, {});
+ std::string code_gen_name = code_gen;
+ if (ext_mods.find(code_gen_name) == ext_mods.end()) {
+ ext_mods[code_gen_name] = IRModule({}, {});
}
- auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(std::string(symbol_name));
- ext_mods[code_gen->value]->Add(gv, src_func);
+ ext_mods[code_gen_name]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
- if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
- key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = std::string(name_node);
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node =
- func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
return std::string(name_node);
}
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
- if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
- CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
+ CHECK(op->GetAttr<String>(attr::kCompiler).defined())
<< "Only functions supported by custom codegen";
return {};
}
});
} else if (name == "list_params_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- Array<tvm::PrimExpr> ret;
+ Array<runtime::String> ret;
for (const auto &kv : this->output_.params) {
- tvm::PrimExpr name = tir::StringImmNode::make(kv.first);
- ret.push_back(name);
+ ret.push_back(kv.first);
}
*rv = ret;
});
Target target;
- if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
auto cfunc = engine_->Lower(key);
auto op_index = -1;
- if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) {
Array<Pass> pass_seqs;
- Array<tvm::PrimExpr> entry_functions{tvm::PrimExpr{"main"}};
+ Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
- if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+ if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
- if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+ if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
func = Function(func->params,
VisitExpr(func->body),
* \return The module with dead functions removed.
*/
IRModule RemoveUnusedFunctions(const IRModule& module,
- Array<tvm::PrimExpr> entry_funcs) {
+ Array<runtime::String> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
- auto* str_name = entry.as<tir::StringImmNode>();
- auto funcs = CallTracer(module).Trace(str_name->value);
+ auto funcs = CallTracer(module).Trace(entry);
called_funcs.insert(funcs.cbegin(), funcs.cend());
}
auto existing_functions = module->functions;
namespace transform {
-Pass RemoveUnusedFunctions(Array<tvm::PrimExpr> entry_functions) {
+Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
return relay::vm::RemoveUnusedFunctions(m, entry_functions);
bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
- (func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
+ (func->GetAttr<String>(attr::kCompiler).defined());
}
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required) {
+ const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPass(pass_func, pass_info);
}
te::Tensor start = inputs[0];
te::Tensor stop = inputs[1];
te::Tensor step = inputs[2];
- Array<tvm::PrimExpr> empty = {0};
return { DynamicArange(start, stop, step, param->dtype) };
}
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
};
- return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout")
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
- auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
+ auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined()) {
- size_t i = comp_name->value.find('.');
+ std::string comp_name_str = comp_name;
+ size_t i = comp_name_str.find('.');
if (i != std::string::npos) {
- std::string target = comp_name->value.substr(0, i);
+ std::string target = comp_name_str.substr(0, i);
if (target == target_) return true;
}
}
Function func;
Expr new_body;
// don't step into composite functions
- if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
+ if (fn->GetAttr<String>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
- {tir::StringImmNode::make("InferType")});
+ {"InferType"});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
- return CreateFunctionPass(pass_func, 3, "CanonicalizeCast",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast")
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeOps(f));
};
- return CreateFunctionPass(pass_func, 3, "CanonicalizeOps",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps")
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelConv2D(f, min_num_branches));
};
- return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D")
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
};
- return CreateFunctionPass(pass_func, 4, "CombineParallelDense",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense")
batch_op_name,
min_num_branches));
};
- return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch")
return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout));
};
return CreateFunctionPass(
- pass_func, 3, "ConvertLayout",
- {tir::StringImmNode::make("InferType"),
- tir::StringImmNode::make("CanonicalizeOps")});
+ pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"});
}
TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout);
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
};
- return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
};
- return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
- return CreateFunctionPass(pass_func, 4, "FastMath",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.FastMath")
return Downcast<Function>(
relay::fold_scale_axis::ForwardFoldScaleAxis(f));
};
- return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis")
return Downcast<Function>(
relay::fold_scale_axis::BackwardFoldScaleAxis(f));
};
- return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis")
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
return Downcast<Function>(FuseOps(f, opt_level, m));
};
- return CreateFunctionPass(pass_func, 1, "FuseOps",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.FuseOps")
fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
- if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
+ if (!func->GetAttr<String>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
- return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
- const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite);
+ auto name_node = func->GetAttr<String>(attr::kComposite);
// don't step into existing composite functions
- if (name_node.defined() && name_node->value != "") {
+ if (name_node.defined() && name_node != "") {
tvm::Array<tvm::relay::Expr> new_args;
for (const auto& arg : call->args) {
auto new_e = this->Mutate(arg);
auto free_vars = FreeVars(extract);
// make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
- f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_));
+ f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function
Array<Expr> args;
PackedFunc check_;
};
-Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names,
+Expr MergeComposite(const Expr& expr, const Array<runtime::String>& pattern_names,
const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
- std::string pattern_name = pattern_names[i]->value;
- Expr pattern = patterns[i];
- PackedFunc check = checks[i];
- merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
+ merged_expr =
+ MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr);
}
return merged_expr;
}
namespace transform {
-Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
+Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return func_pass;
}
-TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
- tvm::Array<tir::StringImm> pattern_names = args[0];
+TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ tvm::Array<runtime::String> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
std::vector<PackedFunc> checks;
for (int i = 2; i < args.size(); i++) {
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
- tvm::tir::StringImmNode::make(target));
+ tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyInference(f));
};
- return CreateFunctionPass(pass_func, 0, "SimplifyInference",
- {tir::StringImmNode::make("InferType")});
+ return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference")
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
- if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
+ if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
using namespace vm;
-TVM_REGISTER_GLOBAL("runtime.container._GetADTTag")
+TVM_REGISTER_GLOBAL("runtime.GetADTTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
*rv = static_cast<int64_t>(adt.tag());
});
-TVM_REGISTER_GLOBAL("runtime.container._GetADTSize")
+TVM_REGISTER_GLOBAL("runtime.GetADTSize")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
});
-TVM_REGISTER_GLOBAL("runtime.container._GetADTFields")
+TVM_REGISTER_GLOBAL("runtime.GetADTFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
int idx = args[1];
*rv = adt[idx];
});
-TVM_REGISTER_GLOBAL("runtime.container._Tuple")
+TVM_REGISTER_GLOBAL("runtime.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
*rv = ADT::Tuple(fields);
});
-TVM_REGISTER_GLOBAL("runtime.container._ADT")
+TVM_REGISTER_GLOBAL("runtime.ADT")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
*rv = ADT(tag, fields);
});
-TVM_REGISTER_GLOBAL("runtime.container._String")
+TVM_REGISTER_GLOBAL("runtime.String")
.set_body_typed([](std::string str) {
return String(std::move(str));
});
+TVM_REGISTER_GLOBAL("runtime.GetStringSize")
+.set_body_typed([](String str) {
+ return static_cast<int64_t>(str.size());
+});
+
+TVM_REGISTER_GLOBAL("runtime.GetStdString")
+.set_body_typed([](String str) {
+ return std::string(str);
+});
+
+TVM_REGISTER_GLOBAL("runtime.CompareString")
+.set_body_typed([](String lhs, String rhs) {
+ return lhs.compare(rhs);
+});
+
+TVM_REGISTER_GLOBAL("runtime.StringHash")
+.set_body_typed([](String str) {
+ return static_cast<int64_t>(std::hash<String>()(str));
+});
+
TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
}
}
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol)] = info;
}
return fmap;
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
#include <tvm/node/node.h>
#include <tvm/node/repr_printer.h>
#include <tvm/target/target.h>
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
- Array<PrimExpr> tags = args[2];
+ Array<runtime::String> tags = args[2];
bool allow_override = args[3];
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
- tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value);
+ tags_vector.push_back(tag);
}
generic_func
void CodeGenCPU::AddFunction(const PrimFunc& f) {
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, param_types, false);
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
CHECK(module_->getFunction(static_cast<std::string>(global_symbol)) == nullptr)
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
entry_func = global_symbol;
}
// reserve keywords
ReserveKeywordsAsUnique();
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
GetUniqueName("_");
// add to alloc buffer type.
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
arg_kinds.push_back(kind);
}
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
std::string whole_code = cg.Finish();
// Generate source code for compilation.
- Array<Array<PrimExpr> > kernel_info;
+ Array<Array<runtime::String> > kernel_info;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
code = (*f)(code).operator std::string();
}
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
- std::string func_name = global_symbol;
- kernel_info.push_back(Array<PrimExpr>({func_name, code}));
+ kernel_info.push_back({global_symbol, code});
}
std::string xclbin;
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenStackVM: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
- auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
std::string device_flag = "-device=";
std::string keys_flag = "-keys=";
for (auto& item : options) {
- t->options_array.push_back(tir::StringImmNode::make(item));
+ t->options_array.push_back(item);
if (item.find(libs_flag) == 0) {
std::stringstream ss(item.substr(libs_flag.length()));
std::string lib_item;
while (std::getline(ss, lib_item, ',')) {
- t->libs_array.push_back(tir::StringImmNode::make(lib_item));
+ t->libs_array.push_back(lib_item);
}
} else if (item.find(device_flag) == 0) {
t->device_name = item.substr(device_flag.length());
- t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
+ t->keys_array.push_back(t->device_name);
} else if (item.find(keys_flag) == 0) {
std::stringstream ss(item.substr(keys_flag.length()));
std::string key_item;
while (std::getline(ss, key_item, ',')) {
- t->keys_array.push_back(tir::StringImmNode::make(key_item));
+ t->keys_array.push_back(key_item);
}
}
}
if (t->device_name.length() > 0) {
- t->keys_array.push_back(tir::StringImmNode::make(t->device_name));
+ t->keys_array.push_back(t->device_name);
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
if (target_name == "c" && t->device_name == "micro_dev") {
t->device_type = kDLMicroDev;
} else if (target_name == "c" || target_name == "llvm") {
- t->keys_array.push_back(tir::StringImmNode::make("cpu"));
+ t->keys_array.push_back("cpu");
} else if (target_name == "cuda" || target_name == "nvptx") {
t->device_type = kDLGPU;
- t->keys_array.push_back(tir::StringImmNode::make("cuda"));
- t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+ t->keys_array.push_back("cuda");
+ t->keys_array.push_back("gpu");
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
} else {
t->device_type = kDLROCM;
}
- t->keys_array.push_back(tir::StringImmNode::make(target_name));
- t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+ t->keys_array.push_back(target_name);
+ t->keys_array.push_back("gpu");
t->max_num_threads = 256;
if (t->device_name == "intel_graphics") {
t->thread_warp_size = 16;
} else {
t->device_type = kDLVulkan;
}
- t->keys_array.push_back(tir::StringImmNode::make(target_name));
- t->keys_array.push_back(tir::StringImmNode::make("gpu"));
+ t->keys_array.push_back(target_name);
+ t->keys_array.push_back("gpu");
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
- t->keys_array.push_back(tir::StringImmNode::make("sdaccel"));
- t->keys_array.push_back(tir::StringImmNode::make("hls"));
+ t->keys_array.push_back("sdaccel");
+ t->keys_array.push_back("hls");
} else if (target_name == "aocl" || target_name == "aocl_sw_emu") {
t->device_type = kDLAOCL;
- t->keys_array.push_back(tir::StringImmNode::make("aocl"));
- t->keys_array.push_back(tir::StringImmNode::make("hls"));
+ t->keys_array.push_back("aocl");
+ t->keys_array.push_back("hls");
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
- t->keys_array.push_back(tir::StringImmNode::make("opengl"));
+ t->keys_array.push_back("opengl");
} else if (target_name == "stackvm") {
t->device_type = kDLCPU;
} else if (target_name == "ext_dev") {
std::vector<std::string> TargetNode::keys() const {
std::vector<std::string> result;
for (auto& expr : keys_array) {
- result.push_back(expr.as<tir::StringImmNode>()->value);
+ result.push_back(expr);
}
return result;
}
std::vector<std::string> TargetNode::options() const {
std::vector<std::string> result;
for (auto& expr : options_array) {
- result.push_back(expr.as<tir::StringImmNode>()->value);
+ result.push_back(expr);
}
return result;
}
std::unordered_set<std::string> TargetNode::libs() const {
std::unordered_set<std::string> result;
for (auto& expr : libs_array) {
- result.insert(expr.as<tir::StringImmNode>()->value);
+ result.insert(expr);
}
return result;
}
data_ = std::move(n);
}
-
Var Var::copy_with_suffix(const std::string& suffix) const {
const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
}
});
-
-
TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](
DataType type, std::string name,
- Array<PrimExpr> args, int call_type,
+ Array<ObjectRef> args, int call_type,
FunctionRef func, int value_index
) {
+ Array<PrimExpr> prim_expr_args;
+ for (const auto& it : args) {
+ CHECK(it->IsInstance<runtime::StringObj>() ||
+ it->IsInstance<PrimExprNode>());
+ if (const auto* str = it.as<runtime::StringObj>()) {
+ prim_expr_args.push_back(StringImmNode::make(str->data));
+ } else {
+ prim_expr_args.push_back(Downcast<PrimExpr>(it));
+ }
+ }
return CallNode::make(type,
- name,
- args,
- static_cast<CallNode::CallType>(call_type),
- func,
- value_index);
+ name,
+ prim_expr_args,
+ static_cast<CallNode::CallType>(call_type),
+ func,
+ value_index);
});
} // namespace tir
Stmt IRTransform(Stmt ir_node,
const runtime::PackedFunc& f_preorder,
const runtime::PackedFunc& f_postorder,
- const Array<PrimExpr>& only_enable) {
+ const Array<runtime::String>& only_enable) {
std::unordered_set<uint32_t> only_type_index;
- for (PrimExpr s : only_enable) {
- only_type_index.insert(Object::TypeKey2Index(s.as<StringImmNode>()->value.c_str()));
+ for (auto s : only_enable) {
+ only_type_index.insert(Object::TypeKey2Index(s.c_str()));
}
IRTransformer transform(f_preorder, f_postorder, only_type_index);
return transform(std::move(ir_node));
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
- const tvm::Array<tvm::PrimExpr>& required) {
+ const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return PrimFuncPass(pass_func, pass_info);
}
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
- asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0)));
+ asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()),
+ EvaluateNode::make(0)));
}
}
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
- asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop));
+ auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str());
+ asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
// type checks
DataType dtype = buffer->dtype;
std::ostringstream type_err_msg;
if (!(dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1))) {
- asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
+ auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str());
+ asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop));
+ asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop));
}
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
stride_err_msg << arg_name << ".strides:"
<< " expected to be compact array";
if (conds.size() != 0) {
+ auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str());
Stmt check =
AssertStmtNode::make(arith::ComputeReduce<tir::AndNode>(conds, PrimExpr()),
- stride_err_msg.str(), EvaluateNode::make(0));
+ stride_msg, EvaluateNode::make(0));
check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)}));
}
} else {
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
- asserts_.emplace_back(
- AssertStmtNode::make(
- NotNode::make(is_null), stride_null_err_msg.str(), nop));
+ asserts_.emplace_back(AssertStmtNode::make(
+ NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop));
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
}
});
- return IRTransform(parent_for_stmt, nullptr, replace_target_for,
- {PrimExpr("For")});
+ return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"});
}
// Remove IfThenElse node from a For node.
}
});
- then_for = IRTransform(for_stmt, nullptr, replace_then_case,
- {PrimExpr("IfThenElse")});
+ then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
- else_for = IRTransform(for_stmt, nullptr, replace_else_case,
- {PrimExpr("IfThenElse")});
+ else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"});
}
return std::make_pair(then_for, else_for);
*ret = new_for;
}
});
- return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")});
+ return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")});
}
Stmt HoistIfThenElse(Stmt stmt) {
auto it = matrix_abc_.find(simplify_name(node->name));
CHECK(it != matrix_abc_.end())
<< "Cannot find matrix info for " << node->name;
- auto matrix_abc = "wmma." + it->second;
+ auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second);
Stmt body = this->VisitStmt(op->body);
return AttrStmtNode::make(op->node,
op->attr_key,
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
- return AssertStmtNode::make(op->value == value, os.str(), body);
+ return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()),
+ body);
}
}
return StmtExprMutator::VisitStmt_(op);
namespace tir {
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
- return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
+ return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg),
+ EvaluateNode::make(0));
}
PrimFunc MakePackedAPI(PrimFunc&& func,
int num_unpacked_args) {
- auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
std::string name_hint = global_symbol;
AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle ||
- tcode == kTVMNullptr, msg.str(), nop));
+ tcode == kTVMNullptr,
+ tvm::tir::StringImmNode::make(msg.str()), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
- seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
+ seq_check.emplace_back(
+ AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(
- AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
+ AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop));
}
} else {
args.push_back(v_arg);
};
-PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
+PrimFunc RemapThreadAxis(PrimFunc&& f, Map<runtime::String, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
- const StringImmNode* str = kv.first.as<StringImmNode>();
- CHECK(str != nullptr);
- tmap[str->value] = kv.second;
+ tmap[kv.first] = kv.second;
}
auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
namespace transform {
-Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
+Pass RemapThreadAxis(Map<runtime::String, IterVar> thread_map) {
auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
return RemapThreadAxis(std::move(f), thread_map);
};
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "SplitHostDevice: Require the target attribute";
- auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
using namespace std;
String s{"hello"};
CHECK_EQ(s.empty(), false);
- s = "";
+ s = std::string("");
CHECK_EQ(s.empty(), true);
}
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+ add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
r = relay.Call(add_relu, [a, b])
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))
+ add_relu = add_relu.with_attr("Composite", "test.add_relu")
# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
- func = func.with_attr("Compiler", tvm.tir.StringImm("a"))
+ func = func.with_attr("Compiler", "a")
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
def set_external_func_attr(func, compiler, ext_symbol):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
- func = func.with_attr("global_symbol",
- runtime.container.String(ext_symbol))
+ func = func.with_attr("Compiler", compiler)
+ func = func.with_attr("global_symbol", ext_symbol)
return func
body = relay.Tuple(tvm.runtime.convert([]))
type_params = tvm.runtime.convert([])
fn = relay.Function(params, body, ret_type, type_params)
- fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value"))
+ fn = fn.with_attr("test_attribute", "value")
+ fn = fn.with_attr("test_attribute1", "value1")
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.attrs["test_attribute"] == "value"
+ assert fn.attrs["test_attribute1"] == "value1"
str(fn)
check_json_roundtrip(fn)
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
- func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a"))
+ func0 = func0.with_attr("FuncName", "a")
x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
- func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b"))
+ func1 = func1.with_attr("FuncName", "b")
assert not consistent_equal(func0, func1)
d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1)
- add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test"))
+ add_1_fn = add_1_fn.with_attr("TestAttribute", "test")
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not consistent_equal(add_1_fn, add_fn)
g11 = relay.GlobalVar("g11")
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn11 = fn11.with_attr("Compiler", "a")
mod[g11] = fn11
x1 = relay.var("x1", shape=(3, 5))
x11 = relay.var("x11", shape=(3, 5))
fn11 = relay.Function([x11], x11)
fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn11 = fn11.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
mod["main"] = relay.Function([x, y], x + y + g1(x))
x1 = relay.var("x1", shape=(2, 2))
fn1 = relay.Function([x1], x1)
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
mod["main"] = relay.Function([x, y], x + y + fn1(x))
return mod
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
g1 = relay.GlobalVar("g1")
mod[g1] = fn1
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
g2 = relay.GlobalVar("g2")
mod[g2] = fn2
sb.ret(x1 + y1)
fn1 = relay.Function([x1, y1], sb.get())
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
x2 = relay.var("x2", shape=(3, 5))
y2 = relay.var("y2", shape=(3, 5))
sb1.ret(x2 - y2)
fn2 = relay.Function([x2, y2], sb1.get())
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
p0 = relay.var("p0", shape=(3, 5))
p1 = relay.var("p1", shape=(3, 5))
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
g1 = relay.GlobalVar('g1')
g2 = relay.GlobalVar('g2')
mod[g1] = fn1
mod = tvm.IRModule({})
fn1 = relay.Function([], relay.const(1))
fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a"))
+ fn1 = fn1.with_attr("Compiler", "a")
fn2 = relay.Function([], relay.const(2))
fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b"))
+ fn2 = fn2.with_attr("Compiler", "b")
p = relay.var('p', 'bool')
mod['main'] = relay.Function([p], relay.Call(
relay.If(p, fn1, fn2), []))
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
+ fn0 = fn0.with_attr("Compiler", "aa")
g0 = relay.GlobalVar("g0")
mod[g0] = fn0
y0 = relay.var("y0", shape=(3, 5))
fn0 = relay.Function([x0, y0], x0 * y0)
fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa"))
+ fn0 = fn0.with_attr("Compiler", "aa")
x1 = relay.var("x1", shape=(3, 5))
y1 = relay.var("y1", shape=(3, 5))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
- add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
+ add_relu = add_relu.with_attr("Composite", "add_relu")
# merged function
r = relay.Call(add_relu, [a, b])
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
- add_sub_mul = add_sub_mul.with_attr("Composite",
- tir.StringImm("add_sub_mul"))
+ add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul")
# add_sub_mul1 function
in_3 = relay.var('in_3', shape=(10, 10))
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
- add_sub_mul_1 = add_sub_mul_1.with_attr("Composite",
- tir.StringImm("add_sub_mul"))
+ add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul")
# merged function
m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
- add_add_add = add_add_add.with_attr("Composite",
- tir.StringImm("add_add_add"))
+ add_add_add = add_add_add.with_attr("Composite", "add_add_add")
# merged function
sub_node = relay.subtract(a, b)
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite",
- tir.StringImm("conv2d_bias_relu"))
+ "conv2d_bias_relu")
# add_relu function
in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
- add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
+ add_relu = add_relu.with_attr("Composite", "add_relu")
# merged function
conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
- merged_func = merged_func.with_attr('Composite',
- tir.StringImm(composite_name))
+ merged_func = merged_func.with_attr('Composite', composite_name)
ret = relay.Call(merged_func, [input_1, input_2])
return relay.Function([input_1, input_2], ret)
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
- func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul"))
+ func_1 = func_1.with_attr('Composite', "add_sub_mul")
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
- func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul"))
+ func_2 = func_2.with_attr('Composite', "add_sub_mul")
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
return relay.Function([input_1, input_2], out)
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
- add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu'))
+ add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu')
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
- add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu'))
+ add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu')
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
y2 = relay.var('y2')
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
- add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul'))
+ add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul')
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
- add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu'))
+ add_relu = add_relu.with_attr('Composite', 'add_relu')
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)
tuple_get_item_node = bn_node[0]
relu_node = relay.nn.relu(tuple_get_item_node)
bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node)
- bn_relu = bn_relu.with_attr("Composite", tir.StringImm("bn_relu"))
+ bn_relu = bn_relu.with_attr("Composite", "bn_relu")
# merged function
r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var])
import tvm.relay.testing
from tvm import relay
from tvm import runtime
-from tvm.runtime import container
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end
func = relay.Function([x0, y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
- func = func.with_attr("global_symbol", container.String("ccompiler_0"))
+ func = func.with_attr("Compiler", "ccompiler")
+ func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
func = relay.Function([data0, input0, input1], out)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
- func = func.with_attr("global_symbol", container.String("dnnl_0"))
+ func = func.with_attr("Compiler", "dnnl")
+ func = func.with_attr("global_symbol", "dnnl_0")
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func
bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_compiler"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_compiler_0"))
+ func0 = func0.with_attr("Compiler", "test_compiler")
+ func0 = func0.with_attr("global_symbol", "test_compiler_0")
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0
func1 = relay.Function([data1, weight1], conv)
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func1 = func1.with_attr("Compiler",
- tvm.tir.StringImm("test_compiler"))
- func1 = func1.with_attr("global_symbol",
- container.String("test_compiler_1"))
+ func1 = func1.with_attr("Compiler", "test_compiler")
+ func1 = func1.with_attr("global_symbol", "test_compiler_1")
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1
bn.astuple())
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_compiler"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_compiler_0"))
+ func0 = func0.with_attr("Compiler", "test_compiler")
+ func0 = func0.with_attr("global_symbol", "test_compiler_0")
# main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
func = relay.Function([y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
- func = func.with_attr("global_symbol", container.String("ccompiler_0"))
+ func = func.with_attr("Compiler", "ccompiler")
+ func = func.with_attr("global_symbol", "ccompiler_0")
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [y])
bn_mean, bn_var], tuple_o)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_target"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_target_2"))
+ func0 = func0.with_attr("Compiler", "test_target")
+ func0 = func0.with_attr("global_symbol", "test_target_2")
gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func1 = func1.with_attr("Compiler",
- tvm.tir.StringImm("test_target"))
- func1 = func1.with_attr("global_symbol",
- container.String("test_target_1"))
+ func1 = func1.with_attr("Compiler", "test_target")
+ func1 = func1.with_attr("global_symbol", "test_target_1")
gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func0 = func0.with_attr("Compiler",
- tvm.tir.StringImm("test_target"))
- func0 = func0.with_attr("global_symbol",
- container.String("test_target_0"))
+ func0 = func0.with_attr("Compiler", "test_target")
+ func0 = func0.with_attr("global_symbol", "test_target_0")
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
- assert dattr.name.value == "xyz"
+ assert dattr.name == "xyz"
assert isinstance(dattr, tvm.ir.DictAttrs)
assert "name" in dattr
assert dattr["x"].value == 1
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- PrimExpr("tvm.contrib.cublas.matmul"),
+ runtime::String("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- PrimExpr("tvm.contrib.cublas.batch_matmul"),
+ runtime::String("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
- PrimExpr("tvm.contrib.rocblas.matmul"),
+ runtime::String("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),