Add safe up/downcasting to the Rust object system (#6384)
authorMax Willsey <me@mwillsey.com>
Fri, 4 Sep 2020 23:29:49 +0000 (16:29 -0700)
committerGitHub <noreply@github.com>
Fri, 4 Sep 2020 23:29:49 +0000 (16:29 -0700)
* Revamp the rust object system with safe subtyping

* Small nits

rust/tvm-macros/src/object.rs
rust/tvm-macros/src/util.rs
rust/tvm-rt/src/array.rs
rust/tvm-rt/src/map.rs
rust/tvm-rt/src/object/mod.rs
rust/tvm-rt/src/object/object_ptr.rs
rust/tvm-rt/src/string.rs
rust/tvm/src/ir/mod.rs
rust/tvm/src/ir/relay/mod.rs
rust/tvm/src/ir/tir.rs

index 342be6b..ff72d6a 100644 (file)
@@ -23,56 +23,67 @@ use quote::quote;
 use syn::DeriveInput;
 use syn::Ident;
 
-use crate::util::get_tvm_rt_crate;
+use crate::util::*;
 
 pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
     let tvm_rt_crate = get_tvm_rt_crate();
     let result = quote! { #tvm_rt_crate::function::Result };
     let error = quote! { #tvm_rt_crate::errors::Error };
     let derive_input = syn::parse_macro_input!(input as DeriveInput);
-    let payload_id = derive_input.ident;
-
-    let mut type_key = None;
-    let mut ref_name = None;
-    let base = Some(Ident::new("base", Span::call_site()));
-
-    for attr in derive_input.attrs {
-        if attr.path.is_ident("type_key") {
-            type_key = Some(attr.parse_meta().expect("foo"))
-        }
-
-        if attr.path.is_ident("ref_name") {
-            ref_name = Some(attr.parse_meta().expect("foo"))
-        }
-    }
-
-    let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key {
-        match name_value.lit {
-            syn::Lit::Str(type_key) => type_key,
-            _ => panic!("foo"),
-        }
-    } else {
-        panic!("bar");
-    };
-
-    let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name {
-        match name_value.lit {
-            syn::Lit::Str(ref_name) => ref_name,
-            _ => panic!("foo"),
-        }
-    } else {
-        panic!("bar");
+    let payload_id = derive_input.ident.clone();
+
+    let type_key = get_attr(&derive_input, "type_key")
+        .map(attr_to_str)
+        .expect("Failed to get type_key");
+
+    let ref_id = get_attr(&derive_input, "ref_name")
+        .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site()))
+        .unwrap_or_else(|| {
+            let id = payload_id.to_string();
+            let suffixes = ["Node", "Obj"];
+            if let Some(suf) = suffixes
+                .iter()
+                .find(|&suf| id.len() > suf.len() && id.ends_with(suf))
+            {
+                Ident::new(&id[..id.len() - suf.len()], payload_id.span())
+            } else {
+                panic!(
+                    "Either 'ref_name' must be given, or the struct name must end one of {:?}",
+                    suffixes
+                )
+            }
+        });
+
+    let base_tokens = match &derive_input.data {
+        syn::Data::Struct(s) => s.fields.iter().next().and_then(|f| {
+            let (base_id, base_ty) = (f.ident.clone()?, f.ty.clone());
+            if base_id == "base" {
+                // The transitive case of subtyping
+                Some(quote! {
+                    impl<O> AsRef<O> for #payload_id
+                        where #base_ty: AsRef<O>
+                    {
+                        fn as_ref(&self) -> &O {
+                            self.#base_id.as_ref()
+                        }
+                    }
+                })
+            } else {
+                None
+            }
+        }),
+        _ => panic!("derive only works for structs"),
     };
 
-    let ref_id = Ident::new(&ref_name.value(), Span::call_site());
-    let base = base.expect("should be present");
-
-    let expanded = quote! {
+    let mut expanded = quote! {
         unsafe impl #tvm_rt_crate::object::IsObject for #payload_id {
             const TYPE_KEY: &'static str = #type_key;
+        }
 
-            fn as_object<'s>(&'s self) -> &'s Object {
-                &self.#base.as_object()
+        // a silly AsRef impl is necessary for subtyping to work
+        impl AsRef<#payload_id> for #payload_id {
+            fn as_ref(&self) -> &Self {
+                self
             }
         }
 
@@ -82,11 +93,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
         impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
             type Object = #payload_id;
 
-            fn as_object_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr<Self::Object>> {
+            fn as_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr<Self::Object>> {
                 self.0.as_ref()
             }
 
-            fn from_object_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr<Self::Object>>) -> Self {
+            fn into_ptr(self) -> Option<#tvm_rt_crate::object::ObjectPtr<Self::Object>> {
+                self.0
+            }
+
+            fn from_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr<Self::Object>>) -> Self {
                 #ref_id(object_ptr)
             }
         }
@@ -99,15 +114,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
             }
         }
 
+        impl std::convert::From<#payload_id> for #ref_id {
+            fn from(payload: #payload_id) -> Self {
+                let ptr = #tvm_rt_crate::object::ObjectPtr::new(payload);
+                #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr))
+            }
+        }
+
+        impl std::convert::From<#tvm_rt_crate::object::ObjectPtr<#payload_id>> for #ref_id {
+            fn from(ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id>) -> Self {
+                #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr))
+            }
+        }
+
         impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
             type Error = #error;
 
             fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> {
                 use std::convert::TryInto;
-                let oref: #tvm_rt_crate::ObjectRef = ret_val.try_into()?;
-                let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
-                let ptr = ptr.downcast::<#payload_id>()?;
-                Ok(#ref_id(Some(ptr)))
+                let ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id> = ret_val.try_into()?;
+                Ok(ptr.into())
             }
         }
 
@@ -155,8 +181,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
                 }
             }
         }
-
     };
 
+    expanded.extend(base_tokens);
+
     TokenStream::from(expanded)
 }
index 1e720f0..2a342bc 100644 (file)
@@ -28,3 +28,21 @@ pub fn get_tvm_rt_crate() -> TokenStream {
         quote!(tvm_rt)
     }
 }
+
+pub(crate) fn get_attr<'a>(
+    derive_input: &'a syn::DeriveInput,
+    name: &str,
+) -> Option<&'a syn::Attribute> {
+    derive_input.attrs.iter().find(|a| a.path.is_ident(name))
+}
+
+pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr {
+    match attr.parse_meta() {
+        Ok(syn::Meta::NameValue(syn::MetaNameValue {
+            lit: syn::Lit::Str(s),
+            ..
+        })) => s,
+        Ok(m) => panic!("Expected a string literal, got {:?}", m),
+        Err(e) => panic!(e),
+    }
+}
index 6e0efc9..d2c82fc 100644 (file)
@@ -46,10 +46,7 @@ external! {
 
 impl<T: IsObjectRef> Array<T> {
     pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
-        let iter = data
-            .iter()
-            .map(|element| element.to_object_ref().into())
-            .collect();
+        let iter = data.into_iter().map(T::into_arg_value).collect();
 
         let func = Function::get("node.Array").expect(
             "node.Array function is not registered, this is most likely a build or linking error",
@@ -66,7 +63,7 @@ impl<T: IsObjectRef> Array<T> {
         );
 
         Ok(Array {
-            object: ObjectRef(Some(array_data)),
+            object: array_data.into(),
             _data: PhantomData,
         })
     }
index e28dd7a..721fb1e 100644 (file)
@@ -70,8 +70,8 @@ where
         let (lower_bound, upper_bound) = iter.size_hint();
         let mut buffer: Vec<ArgValue> = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2);
         for (k, v) in iter {
-            buffer.push(k.to_object_ref().into());
-            buffer.push(v.to_object_ref().into())
+            buffer.push(k.into());
+            buffer.push(v.into())
         }
         Self::from_data(buffer).expect("failed to convert from data")
     }
@@ -96,7 +96,7 @@ where
         );
 
         Ok(Map {
-            object: ObjectRef(Some(map_data)),
+            object: map_data.into(),
             _data: PhantomData,
         })
     }
@@ -105,7 +105,8 @@ where
     where
         V: TryFrom<RetValue, Error = Error>,
     {
-        let oref: ObjectRef = map_get_item(self.object.clone(), key.to_object_ref())?;
+        let key = key.clone();
+        let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?;
         oref.downcast()
     }
 }
index 3858db7..46e0342 100644 (file)
@@ -18,7 +18,6 @@
  */
 
 use std::convert::TryFrom;
-use std::convert::TryInto;
 use std::ffi::CString;
 
 use crate::errors::Error;
@@ -28,87 +27,62 @@ use tvm_sys::{ArgValue, RetValue};
 
 mod object_ptr;
 
-pub use object_ptr::{IsObject, Object, ObjectPtr};
-
-#[derive(Clone)]
-pub struct ObjectRef(pub Option<ObjectPtr<Object>>);
-
-impl ObjectRef {
-    pub fn null() -> ObjectRef {
-        ObjectRef(None)
-    }
-}
-
-pub trait IsObjectRef: Sized {
+pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef};
+
+// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we
+// can't because of coherence rules. Instead, we generate them in the macro, and
+// add what we can (including Into instead of From) as subtraits.
+// We also add named conversions for clarity
+pub trait IsObjectRef:
+    Sized
+    + Clone
+    + Into<RetValue>
+    + TryFrom<RetValue, Error = Error>
+    + for<'a> Into<ArgValue<'a>>
+    + for<'a> TryFrom<ArgValue<'a>, Error = Error>
+{
     type Object: IsObject;
-    fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
-    fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self;
+    fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
+    fn into_ptr(self) -> Option<ObjectPtr<Self::Object>>;
+    fn from_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self;
 
-    fn to_object_ref(&self) -> ObjectRef {
-        let object_ptr = self.as_object_ptr().cloned();
-        ObjectRef(object_ptr.map(|ptr| ptr.upcast()))
+    fn null() -> Self {
+        Self::from_ptr(None)
     }
 
-    fn downcast<U: IsObjectRef>(&self) -> Result<U, Error> {
-        let ptr = self
-            .as_object_ptr()
-            .cloned()
-            .map(|ptr| ptr.downcast::<U::Object>());
-        let ptr = ptr.transpose()?;
-        Ok(U::from_object_ptr(ptr))
+    fn into_arg_value<'a>(self) -> ArgValue<'a> {
+        self.into()
     }
-}
 
-impl IsObjectRef for ObjectRef {
-    type Object = Object;
-
-    fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
-        self.0.as_ref()
+    fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result<Self, Error> {
+        Self::try_from(arg_value)
     }
 
-    fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
-        ObjectRef(object_ptr)
+    fn into_ret_value<'a>(self) -> RetValue {
+        self.into()
     }
-}
 
-impl TryFrom<RetValue> for ObjectRef {
-    type Error = Error;
-
-    fn try_from(ret_val: RetValue) -> Result<ObjectRef, Self::Error> {
-        let optr = ret_val.try_into()?;
-        Ok(ObjectRef(Some(optr)))
-    }
-}
-
-impl From<ObjectRef> for RetValue {
-    fn from(object_ref: ObjectRef) -> RetValue {
-        use std::ffi::c_void;
-        let object_ptr = object_ref.0;
-        match object_ptr {
-            None => RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void),
-            Some(value) => value.clone().into(),
-        }
+    fn from_ret_value<'a>(ret_value: RetValue) -> Result<Self, Error> {
+        Self::try_from(ret_value)
     }
-}
-
-impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef {
-    type Error = Error;
 
-    fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
-        let optr: ObjectPtr<Object> = arg_value.try_into()?;
-        debug_assert!(optr.count() >= 1);
-        Ok(ObjectRef(Some(optr)))
+    fn upcast<U>(self) -> U
+    where
+        U: IsObjectRef,
+        Self::Object: AsRef<U::Object>,
+    {
+        let ptr = self.into_ptr().map(ObjectPtr::upcast);
+        U::from_ptr(ptr)
     }
-}
 
-impl<'a> From<ObjectRef> for ArgValue<'a> {
-    fn from(object_ref: ObjectRef) -> ArgValue<'a> {
-        use std::ffi::c_void;
-        let object_ptr = object_ref.0;
-        match object_ptr {
-            None => ArgValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void),
-            Some(value) => value.into(),
-        }
+    fn downcast<U>(self) -> Result<U, Error>
+    where
+        U: IsObjectRef,
+        U::Object: AsRef<Self::Object>,
+    {
+        let ptr = self.into_ptr().map(ObjectPtr::downcast);
+        let ptr = ptr.transpose()?;
+        Ok(U::from_ptr(ptr))
     }
 }
 
index 1923854..1388d3c 100644 (file)
@@ -22,6 +22,7 @@ use std::ffi::CString;
 use std::ptr::NonNull;
 use std::sync::atomic::AtomicI32;
 
+use tvm_macros::Object;
 use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index};
 use tvm_sys::{ArgValue, RetValue};
 
@@ -35,7 +36,9 @@ type Deleter = unsafe extern "C" fn(object: *mut Object) -> ();
 /// table, an atomic reference count, and a customized deleter which
 /// will be invoked when the reference count is zero.
 ///
-#[derive(Debug)]
+#[derive(Debug, Object)]
+#[ref_name = "ObjectRef"]
+#[type_key = "runtime.Object"]
 #[repr(C)]
 pub struct Object {
     /// The index into TVM's runtime type information table.
@@ -151,25 +154,15 @@ impl Object {
 /// index, a method for accessing the base object given the
 /// subtype, and a typed delete method which is specialized
 /// to the subtype.
-pub unsafe trait IsObject {
+pub unsafe trait IsObject: AsRef<Object> {
     const TYPE_KEY: &'static str;
 
-    fn as_object<'s>(&'s self) -> &'s Object;
-
     unsafe extern "C" fn typed_delete(object: *mut Self) {
         let object = Box::from_raw(object);
         drop(object)
     }
 }
 
-unsafe impl IsObject for Object {
-    const TYPE_KEY: &'static str = "runtime.Object";
-
-    fn as_object<'s>(&'s self) -> &'s Object {
-        self
-    }
-}
-
 /// A smart pointer for types which implement IsObject.
 /// This type directly corresponds to TVM's C++ type ObjectPtr<T>.
 ///
@@ -179,14 +172,6 @@ pub struct ObjectPtr<T: IsObject> {
     pub ptr: NonNull<T>,
 }
 
-fn inc_ref<T: IsObject>(ptr: NonNull<T>) {
-    unsafe { ptr.as_ref().as_object().inc_ref() }
-}
-
-fn dec_ref<T: IsObject>(ptr: NonNull<T>) {
-    unsafe { ptr.as_ref().as_object().dec_ref() }
-}
-
 impl ObjectPtr<Object> {
     pub fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> {
         let non_null = NonNull::new(object_ptr);
@@ -199,14 +184,14 @@ impl ObjectPtr<Object> {
 
 impl<T: IsObject> Clone for ObjectPtr<T> {
     fn clone(&self) -> Self {
-        inc_ref(self.ptr);
+        unsafe { self.ptr.as_ref().as_ref().inc_ref() }
         ObjectPtr { ptr: self.ptr }
     }
 }
 
 impl<T: IsObject> Drop for ObjectPtr<T> {
     fn drop(&mut self) {
-        dec_ref(self.ptr);
+        unsafe { self.ptr.as_ref().as_ref().dec_ref() }
     }
 }
 
@@ -219,34 +204,42 @@ impl<T: IsObject> ObjectPtr<T> {
     }
 
     pub fn new(object: T) -> ObjectPtr<T> {
+        object.as_ref().inc_ref();
         let object_ptr = Box::new(object);
         let object_ptr = Box::leak(object_ptr);
         let ptr = NonNull::from(object_ptr);
-        inc_ref(ptr);
         ObjectPtr { ptr }
     }
 
     pub fn count(&self) -> i32 {
         // need to do atomic read in C++
         // ABI compatible atomics is funky/hard.
-        self.as_object()
+        self.as_ref()
             .ref_count
             .load(std::sync::atomic::Ordering::Relaxed)
     }
 
-    fn as_object<'s>(&'s self) -> &'s Object {
-        unsafe { self.ptr.as_ref().as_object() }
+    /// This method avoid running the destructor on self once it's dropped, so we don't accidentally release the memory
+    unsafe fn cast<U: IsObject>(self) -> ObjectPtr<U> {
+        let ptr = self.ptr.cast();
+        std::mem::forget(self);
+        ObjectPtr { ptr }
     }
 
-    pub fn upcast(self) -> ObjectPtr<Object> {
-        ObjectPtr {
-            ptr: self.ptr.cast(),
-        }
+    pub fn upcast<U>(self) -> ObjectPtr<U>
+    where
+        U: IsObject,
+        T: AsRef<U>,
+    {
+        unsafe { self.cast() }
     }
 
-    pub fn downcast<U: IsObject>(self) -> Result<ObjectPtr<U>, Error> {
+    pub fn downcast<U>(self) -> Result<ObjectPtr<U>, Error>
+    where
+        U: IsObject + AsRef<T>,
+    {
         let child_index = Object::get_type_index::<U>();
-        let object_index = self.as_object().type_index;
+        let object_index = self.as_ref().type_index;
 
         let is_derived = if child_index == object_index {
             true
@@ -256,10 +249,7 @@ impl<T: IsObject> ObjectPtr<T> {
         };
 
         if is_derived {
-            // NB: self gets dropped here causng a dec ref which we need to migtigate with an inc ref before it is dropped.
-            inc_ref(self.ptr);
-            let ptr = self.ptr.cast();
-            Ok(ObjectPtr { ptr })
+            Ok(unsafe { self.cast() })
         } else {
             Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY))
         }
index 73d0439..a5ee1f1 100644 (file)
@@ -20,7 +20,7 @@
 use std::cmp::{Ordering, PartialEq};
 use std::hash::{Hash, Hasher};
 
-use super::{Object, ObjectPtr};
+use super::Object;
 
 use tvm_macros::Object;
 
@@ -37,27 +37,18 @@ pub struct StringObj {
 impl From<std::string::String> for String {
     fn from(s: std::string::String) -> Self {
         let size = s.len() as u64;
-        let obj = StringObj {
-            base: Object::base_object::<StringObj>(),
-            data: s.as_bytes().as_ptr(),
-            size,
-        };
-        std::mem::forget(s);
-        let obj_ptr = ObjectPtr::new(obj);
-        String(Some(obj_ptr))
+        let data = Box::into_raw(s.into_boxed_str()).cast();
+        let base = Object::base_object::<StringObj>();
+        StringObj { base, data, size }.into()
     }
 }
 
 impl From<&'static str> for String {
     fn from(s: &'static str) -> Self {
         let size = s.len() as u64;
-        let obj = StringObj {
-            base: Object::base_object::<StringObj>(),
-            data: s.as_bytes().as_ptr(),
-            size,
-        };
-        let obj_ptr = ObjectPtr::new(obj);
-        String(Some(obj_ptr))
+        let data = s.as_bytes().as_ptr();
+        let base = Object::base_object::<StringObj>();
+        StringObj { base, data, size }.into()
     }
 }
 
@@ -139,7 +130,7 @@ mod tests {
     #[test]
     fn test_string_debug() -> Result<()> {
         let s = String::from("foo");
-        let object_ref = s.to_object_ref();
+        let object_ref = s.upcast();
         println!("about to call");
         let string = debug_print(object_ref)?;
         println!("after call");
index 9222de5..b615c1e 100644 (file)
@@ -34,7 +34,7 @@ external! {
 
 pub fn as_text<T: IsObjectRef>(object: T) -> String {
     let no_func = unsafe { runtime::Function::null() };
-    _as_text(object.to_object_ref(), 0, no_func)
+    _as_text(object.upcast(), 0, no_func)
         .unwrap()
         .as_str()
         .unwrap()
index 771882e..4f4497e 100644 (file)
@@ -18,7 +18,7 @@
  */
 
 use crate::runtime::array::Array;
-use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString};
+use crate::runtime::{object::*, String as TString};
 use crate::DataType;
 use tvm_macros::Object;
 
index bef63b1..ee30c51 100644 (file)
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-use crate::runtime::{Object, ObjectPtr, String as TVMString};
+use crate::runtime::String as TVMString;
 use crate::DataType;
 
 use super::*;
@@ -37,17 +37,7 @@ macro_rules! define_node {
             pub fn new(datatype: DataType, $($id : $t,)*) -> $name {
                 let base = PrimExprNode::base::<$node>(datatype);
                 let node = $node { base, $($id),* };
-                $name(Some(ObjectPtr::new(node)))
-            }
-        }
-
-        impl From<$name> for PrimExpr {
-            // TODO(@jroesch): Remove we with subtyping traits.
-            fn from(x: $name) -> PrimExpr {
-                x.downcast().expect(concat!(
-                    "Failed to downcast `",
-                    stringify!($name),
-                    "` to PrimExpr"))
+                node.into()
             }
         }
     }