use syn::DeriveInput;
use syn::Ident;
-use crate::util::get_tvm_rt_crate;
+use crate::util::*;
pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
let tvm_rt_crate = get_tvm_rt_crate();
let result = quote! { #tvm_rt_crate::function::Result };
let error = quote! { #tvm_rt_crate::errors::Error };
let derive_input = syn::parse_macro_input!(input as DeriveInput);
- let payload_id = derive_input.ident;
-
- let mut type_key = None;
- let mut ref_name = None;
- let base = Some(Ident::new("base", Span::call_site()));
-
- for attr in derive_input.attrs {
- if attr.path.is_ident("type_key") {
- type_key = Some(attr.parse_meta().expect("foo"))
- }
-
- if attr.path.is_ident("ref_name") {
- ref_name = Some(attr.parse_meta().expect("foo"))
- }
- }
-
- let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key {
- match name_value.lit {
- syn::Lit::Str(type_key) => type_key,
- _ => panic!("foo"),
- }
- } else {
- panic!("bar");
- };
-
- let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name {
- match name_value.lit {
- syn::Lit::Str(ref_name) => ref_name,
- _ => panic!("foo"),
- }
- } else {
- panic!("bar");
+ let payload_id = derive_input.ident.clone();
+
+ let type_key = get_attr(&derive_input, "type_key")
+ .map(attr_to_str)
+ .expect("Failed to get type_key");
+
+ let ref_id = get_attr(&derive_input, "ref_name")
+ .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site()))
+ .unwrap_or_else(|| {
+ let id = payload_id.to_string();
+ let suffixes = ["Node", "Obj"];
+ if let Some(suf) = suffixes
+ .iter()
+ .find(|&suf| id.len() > suf.len() && id.ends_with(suf))
+ {
+ Ident::new(&id[..id.len() - suf.len()], payload_id.span())
+ } else {
+ panic!(
+ "Either 'ref_name' must be given, or the struct name must end one of {:?}",
+ suffixes
+ )
+ }
+ });
+
+ let base_tokens = match &derive_input.data {
+ syn::Data::Struct(s) => s.fields.iter().next().and_then(|f| {
+ let (base_id, base_ty) = (f.ident.clone()?, f.ty.clone());
+ if base_id == "base" {
+ // The transitive case of subtyping
+ Some(quote! {
+ impl<O> AsRef<O> for #payload_id
+ where #base_ty: AsRef<O>
+ {
+ fn as_ref(&self) -> &O {
+ self.#base_id.as_ref()
+ }
+ }
+ })
+ } else {
+ None
+ }
+ }),
+ _ => panic!("derive only works for structs"),
};
- let ref_id = Ident::new(&ref_name.value(), Span::call_site());
- let base = base.expect("should be present");
-
- let expanded = quote! {
+ let mut expanded = quote! {
unsafe impl #tvm_rt_crate::object::IsObject for #payload_id {
const TYPE_KEY: &'static str = #type_key;
+ }
- fn as_object<'s>(&'s self) -> &'s Object {
- &self.#base.as_object()
+ // a silly AsRef impl is necessary for subtyping to work
+ impl AsRef<#payload_id> for #payload_id {
+ fn as_ref(&self) -> &Self {
+ self
}
}
impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
type Object = #payload_id;
- fn as_object_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr<Self::Object>> {
+ fn as_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr<Self::Object>> {
self.0.as_ref()
}
- fn from_object_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr<Self::Object>>) -> Self {
+ fn into_ptr(self) -> Option<#tvm_rt_crate::object::ObjectPtr<Self::Object>> {
+ self.0
+ }
+
+ fn from_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr<Self::Object>>) -> Self {
#ref_id(object_ptr)
}
}
}
}
+ impl std::convert::From<#payload_id> for #ref_id {
+ fn from(payload: #payload_id) -> Self {
+ let ptr = #tvm_rt_crate::object::ObjectPtr::new(payload);
+ #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr))
+ }
+ }
+
+ impl std::convert::From<#tvm_rt_crate::object::ObjectPtr<#payload_id>> for #ref_id {
+ fn from(ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id>) -> Self {
+ #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr))
+ }
+ }
+
impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
type Error = #error;
fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> {
use std::convert::TryInto;
- let oref: #tvm_rt_crate::ObjectRef = ret_val.try_into()?;
- let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
- let ptr = ptr.downcast::<#payload_id>()?;
- Ok(#ref_id(Some(ptr)))
+ let ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id> = ret_val.try_into()?;
+ Ok(ptr.into())
}
}
}
}
}
-
};
+ expanded.extend(base_tokens);
+
TokenStream::from(expanded)
}
quote!(tvm_rt)
}
}
+
+pub(crate) fn get_attr<'a>(
+ derive_input: &'a syn::DeriveInput,
+ name: &str,
+) -> Option<&'a syn::Attribute> {
+ derive_input.attrs.iter().find(|a| a.path.is_ident(name))
+}
+
+pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr {
+ match attr.parse_meta() {
+ Ok(syn::Meta::NameValue(syn::MetaNameValue {
+ lit: syn::Lit::Str(s),
+ ..
+ })) => s,
+ Ok(m) => panic!("Expected a string literal, got {:?}", m),
+ Err(e) => panic!(e),
+ }
+}
impl<T: IsObjectRef> Array<T> {
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
- let iter = data
- .iter()
- .map(|element| element.to_object_ref().into())
- .collect();
+ let iter = data.into_iter().map(T::into_arg_value).collect();
let func = Function::get("node.Array").expect(
"node.Array function is not registered, this is most likely a build or linking error",
);
Ok(Array {
- object: ObjectRef(Some(array_data)),
+ object: array_data.into(),
_data: PhantomData,
})
}
let (lower_bound, upper_bound) = iter.size_hint();
let mut buffer: Vec<ArgValue> = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2);
for (k, v) in iter {
- buffer.push(k.to_object_ref().into());
- buffer.push(v.to_object_ref().into())
+ buffer.push(k.into());
+ buffer.push(v.into())
}
Self::from_data(buffer).expect("failed to convert from data")
}
);
Ok(Map {
- object: ObjectRef(Some(map_data)),
+ object: map_data.into(),
_data: PhantomData,
})
}
where
V: TryFrom<RetValue, Error = Error>,
{
- let oref: ObjectRef = map_get_item(self.object.clone(), key.to_object_ref())?;
+ let key = key.clone();
+ let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?;
oref.downcast()
}
}
*/
use std::convert::TryFrom;
-use std::convert::TryInto;
use std::ffi::CString;
use crate::errors::Error;
mod object_ptr;
-pub use object_ptr::{IsObject, Object, ObjectPtr};
-
-#[derive(Clone)]
-pub struct ObjectRef(pub Option<ObjectPtr<Object>>);
-
-impl ObjectRef {
- pub fn null() -> ObjectRef {
- ObjectRef(None)
- }
-}
-
-pub trait IsObjectRef: Sized {
+pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef};
+
+// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we
+// can't because of coherence rules. Instead, we generate them in the macro, and
+// add what we can (including Into instead of From) as subtraits.
+// We also add named conversions for clarity
+pub trait IsObjectRef:
+ Sized
+ + Clone
+ + Into<RetValue>
+ + TryFrom<RetValue, Error = Error>
+ + for<'a> Into<ArgValue<'a>>
+ + for<'a> TryFrom<ArgValue<'a>, Error = Error>
+{
type Object: IsObject;
- fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
- fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self;
+ fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
+ fn into_ptr(self) -> Option<ObjectPtr<Self::Object>>;
+ fn from_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self;
- fn to_object_ref(&self) -> ObjectRef {
- let object_ptr = self.as_object_ptr().cloned();
- ObjectRef(object_ptr.map(|ptr| ptr.upcast()))
+ fn null() -> Self {
+ Self::from_ptr(None)
}
- fn downcast<U: IsObjectRef>(&self) -> Result<U, Error> {
- let ptr = self
- .as_object_ptr()
- .cloned()
- .map(|ptr| ptr.downcast::<U::Object>());
- let ptr = ptr.transpose()?;
- Ok(U::from_object_ptr(ptr))
+ fn into_arg_value<'a>(self) -> ArgValue<'a> {
+ self.into()
}
-}
-impl IsObjectRef for ObjectRef {
- type Object = Object;
-
- fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
- self.0.as_ref()
+ fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result<Self, Error> {
+ Self::try_from(arg_value)
}
- fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
- ObjectRef(object_ptr)
+ fn into_ret_value<'a>(self) -> RetValue {
+ self.into()
}
-}
-impl TryFrom<RetValue> for ObjectRef {
- type Error = Error;
-
- fn try_from(ret_val: RetValue) -> Result<ObjectRef, Self::Error> {
- let optr = ret_val.try_into()?;
- Ok(ObjectRef(Some(optr)))
- }
-}
-
-impl From<ObjectRef> for RetValue {
- fn from(object_ref: ObjectRef) -> RetValue {
- use std::ffi::c_void;
- let object_ptr = object_ref.0;
- match object_ptr {
- None => RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void),
- Some(value) => value.clone().into(),
- }
+ fn from_ret_value<'a>(ret_value: RetValue) -> Result<Self, Error> {
+ Self::try_from(ret_value)
}
-}
-
-impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef {
- type Error = Error;
- fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
- let optr: ObjectPtr<Object> = arg_value.try_into()?;
- debug_assert!(optr.count() >= 1);
- Ok(ObjectRef(Some(optr)))
+ fn upcast<U>(self) -> U
+ where
+ U: IsObjectRef,
+ Self::Object: AsRef<U::Object>,
+ {
+ let ptr = self.into_ptr().map(ObjectPtr::upcast);
+ U::from_ptr(ptr)
}
-}
-impl<'a> From<ObjectRef> for ArgValue<'a> {
- fn from(object_ref: ObjectRef) -> ArgValue<'a> {
- use std::ffi::c_void;
- let object_ptr = object_ref.0;
- match object_ptr {
- None => ArgValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void),
- Some(value) => value.into(),
- }
+ fn downcast<U>(self) -> Result<U, Error>
+ where
+ U: IsObjectRef,
+ U::Object: AsRef<Self::Object>,
+ {
+ let ptr = self.into_ptr().map(ObjectPtr::downcast);
+ let ptr = ptr.transpose()?;
+ Ok(U::from_ptr(ptr))
}
}
use std::ptr::NonNull;
use std::sync::atomic::AtomicI32;
+use tvm_macros::Object;
use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index};
use tvm_sys::{ArgValue, RetValue};
/// table, an atomic reference count, and a customized deleter which
/// will be invoked when the reference count is zero.
///
-#[derive(Debug)]
+#[derive(Debug, Object)]
+#[ref_name = "ObjectRef"]
+#[type_key = "runtime.Object"]
#[repr(C)]
pub struct Object {
/// The index into TVM's runtime type information table.
/// index, a method for accessing the base object given the
/// subtype, and a typed delete method which is specialized
/// to the subtype.
-pub unsafe trait IsObject {
+pub unsafe trait IsObject: AsRef<Object> {
const TYPE_KEY: &'static str;
- fn as_object<'s>(&'s self) -> &'s Object;
-
unsafe extern "C" fn typed_delete(object: *mut Self) {
let object = Box::from_raw(object);
drop(object)
}
}
-unsafe impl IsObject for Object {
- const TYPE_KEY: &'static str = "runtime.Object";
-
- fn as_object<'s>(&'s self) -> &'s Object {
- self
- }
-}
-
/// A smart pointer for types which implement IsObject.
/// This type directly corresponds to TVM's C++ type ObjectPtr<T>.
///
pub ptr: NonNull<T>,
}
-fn inc_ref<T: IsObject>(ptr: NonNull<T>) {
- unsafe { ptr.as_ref().as_object().inc_ref() }
-}
-
-fn dec_ref<T: IsObject>(ptr: NonNull<T>) {
- unsafe { ptr.as_ref().as_object().dec_ref() }
-}
-
impl ObjectPtr<Object> {
pub fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> {
let non_null = NonNull::new(object_ptr);
impl<T: IsObject> Clone for ObjectPtr<T> {
fn clone(&self) -> Self {
- inc_ref(self.ptr);
+ unsafe { self.ptr.as_ref().as_ref().inc_ref() }
ObjectPtr { ptr: self.ptr }
}
}
impl<T: IsObject> Drop for ObjectPtr<T> {
fn drop(&mut self) {
- dec_ref(self.ptr);
+ unsafe { self.ptr.as_ref().as_ref().dec_ref() }
}
}
}
pub fn new(object: T) -> ObjectPtr<T> {
+ object.as_ref().inc_ref();
let object_ptr = Box::new(object);
let object_ptr = Box::leak(object_ptr);
let ptr = NonNull::from(object_ptr);
- inc_ref(ptr);
ObjectPtr { ptr }
}
pub fn count(&self) -> i32 {
// need to do atomic read in C++
// ABI compatible atomics is funky/hard.
- self.as_object()
+ self.as_ref()
.ref_count
.load(std::sync::atomic::Ordering::Relaxed)
}
- fn as_object<'s>(&'s self) -> &'s Object {
- unsafe { self.ptr.as_ref().as_object() }
+ /// This method avoid running the destructor on self once it's dropped, so we don't accidentally release the memory
+ unsafe fn cast<U: IsObject>(self) -> ObjectPtr<U> {
+ let ptr = self.ptr.cast();
+ std::mem::forget(self);
+ ObjectPtr { ptr }
}
- pub fn upcast(self) -> ObjectPtr<Object> {
- ObjectPtr {
- ptr: self.ptr.cast(),
- }
+ pub fn upcast<U>(self) -> ObjectPtr<U>
+ where
+ U: IsObject,
+ T: AsRef<U>,
+ {
+ unsafe { self.cast() }
}
- pub fn downcast<U: IsObject>(self) -> Result<ObjectPtr<U>, Error> {
+ pub fn downcast<U>(self) -> Result<ObjectPtr<U>, Error>
+ where
+ U: IsObject + AsRef<T>,
+ {
let child_index = Object::get_type_index::<U>();
- let object_index = self.as_object().type_index;
+ let object_index = self.as_ref().type_index;
let is_derived = if child_index == object_index {
true
};
if is_derived {
- // NB: self gets dropped here causng a dec ref which we need to migtigate with an inc ref before it is dropped.
- inc_ref(self.ptr);
- let ptr = self.ptr.cast();
- Ok(ObjectPtr { ptr })
+ Ok(unsafe { self.cast() })
} else {
Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY))
}
use std::cmp::{Ordering, PartialEq};
use std::hash::{Hash, Hasher};
-use super::{Object, ObjectPtr};
+use super::Object;
use tvm_macros::Object;
impl From<std::string::String> for String {
fn from(s: std::string::String) -> Self {
let size = s.len() as u64;
- let obj = StringObj {
- base: Object::base_object::<StringObj>(),
- data: s.as_bytes().as_ptr(),
- size,
- };
- std::mem::forget(s);
- let obj_ptr = ObjectPtr::new(obj);
- String(Some(obj_ptr))
+ let data = Box::into_raw(s.into_boxed_str()).cast();
+ let base = Object::base_object::<StringObj>();
+ StringObj { base, data, size }.into()
}
}
impl From<&'static str> for String {
fn from(s: &'static str) -> Self {
let size = s.len() as u64;
- let obj = StringObj {
- base: Object::base_object::<StringObj>(),
- data: s.as_bytes().as_ptr(),
- size,
- };
- let obj_ptr = ObjectPtr::new(obj);
- String(Some(obj_ptr))
+ let data = s.as_bytes().as_ptr();
+ let base = Object::base_object::<StringObj>();
+ StringObj { base, data, size }.into()
}
}
#[test]
fn test_string_debug() -> Result<()> {
let s = String::from("foo");
- let object_ref = s.to_object_ref();
+ let object_ref = s.upcast();
println!("about to call");
let string = debug_print(object_ref)?;
println!("after call");
pub fn as_text<T: IsObjectRef>(object: T) -> String {
let no_func = unsafe { runtime::Function::null() };
- _as_text(object.to_object_ref(), 0, no_func)
+ _as_text(object.upcast(), 0, no_func)
.unwrap()
.as_str()
.unwrap()
*/
use crate::runtime::array::Array;
-use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString};
+use crate::runtime::{object::*, String as TString};
use crate::DataType;
use tvm_macros::Object;
* under the License.
*/
-use crate::runtime::{Object, ObjectPtr, String as TVMString};
+use crate::runtime::String as TVMString;
use crate::DataType;
use super::*;
pub fn new(datatype: DataType, $($id : $t,)*) -> $name {
let base = PrimExprNode::base::<$node>(datatype);
let node = $node { base, $($id),* };
- $name(Some(ObjectPtr::new(node)))
- }
- }
-
- impl From<$name> for PrimExpr {
- // TODO(@jroesch): Remove we with subtyping traits.
- fn from(x: $name) -> PrimExpr {
- x.downcast().expect(concat!(
- "Failed to downcast `",
- stringify!($name),
- "` to PrimExpr"))
+ node.into()
}
}
}