// We use type traits to eliminate un-necessary checks.
template<typename T>
inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
- using TObjectRef = typename std::remove_reference<T>::type;
+ using ContainerType = typename std::remove_reference<T>::type::ContainerType;
if (value.defined()) {
Object* ptr = value.data_.data_;
- if (std::is_base_of<NDArray, TObjectRef>::value ||
- (std::is_base_of<TObjectRef, NDArray>::value &&
+ if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
+ (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
values_[i].v_handle = NDArray::FFIGetHandle(value);
type_codes_[i] = kTVMNDArrayHandle;
- } else if (std::is_base_of<Module, TObjectRef>::value ||
- (std::is_base_of<TObjectRef, Module>::value &&
+ } else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
+ (std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
inline bool TVMPODValue_::IsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType;
// NOTE: the following code can be optimized by constant folding.
- if (std::is_base_of<NDArray, TObjectRef>::value) {
+ if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
return type_code_ == kTVMNDArrayHandle &&
TVMArrayHandleToObjectHandle(
static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
}
- if (std::is_base_of<Module, TObjectRef>::value) {
+ if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
return type_code_ == kTVMModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
*static_cast<Object**>(value_.v_handle));
}
return
- (std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kTVMNDArrayHandle) ||
- (std::is_base_of<TObjectRef, Module>::value && type_code_ == kTVMModuleHandle) ||
+ (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
+ type_code_ == kTVMNDArrayHandle) ||
+ (std::is_base_of<ContainerType, Module::ContainerType>::value &&
+ type_code_ == kTVMModuleHandle) ||
(type_code_ == kTVMObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
}
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType;
+
if (type_code_ == kTVMNullptr) {
CHECK(TObjectRef::_type_is_nullable)
<< "Expect a not null value of " << ContainerType::_type_key;
return TObjectRef(ObjectPtr<Object>(nullptr));
}
// NOTE: the following code can be optimized by constant folding.
- if (std::is_base_of<NDArray, TObjectRef>::value) {
+ if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
// Casting to a sub-class of NDArray
TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
}
- if (std::is_base_of<Module, TObjectRef>::value) {
+ if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
// Casting to a sub-class of Module
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
- } else if (std::is_base_of<TObjectRef, NDArray>::value &&
+ } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
type_code_ == kTVMNDArrayHandle) {
// Casting to a base class that NDArray can sub-class
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle));
return TObjectRef(data);
- } else if (std::is_base_of<TObjectRef, Module>::value &&
+ } else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) {
// Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
template<typename TObjectRef, typename>
inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
+ using ContainerType = typename TObjectRef::ContainerType;
const Object* ptr = other.get();
if (ptr != nullptr) {
- if (std::is_base_of<NDArray, TObjectRef>::value ||
- (std::is_base_of<TObjectRef, NDArray>::value &&
+ if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
+ (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
return operator=(NDArray(std::move(other.data_)));
}
- if (std::is_base_of<Module, TObjectRef>::value ||
- (std::is_base_of<TObjectRef, Module>::value &&
+ if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
+ (std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) {
return operator=(Module(std::move(other.data_)));
}