From 73d9e01c8f8149c8a99449afcfe077b5d579be92 Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Fri, 4 Sep 2020 16:29:49 -0700 Subject: [PATCH] Add safe up/downcasting to the Rust object system (#6384) * Revamp the rust object system with safe subtyping * Small nits --- rust/tvm-macros/src/object.rs | 119 +++++++++++++++++++++-------------- rust/tvm-macros/src/util.rs | 18 ++++++ rust/tvm-rt/src/array.rs | 7 +-- rust/tvm-rt/src/map.rs | 9 +-- rust/tvm-rt/src/object/mod.rs | 110 +++++++++++++------------------- rust/tvm-rt/src/object/object_ptr.rs | 62 ++++++++---------- rust/tvm-rt/src/string.rs | 25 +++----- rust/tvm/src/ir/mod.rs | 2 +- rust/tvm/src/ir/relay/mod.rs | 2 +- rust/tvm/src/ir/tir.rs | 14 +---- 10 files changed, 178 insertions(+), 190 deletions(-) diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index 342be6b..ff72d6a 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -23,56 +23,67 @@ use quote::quote; use syn::DeriveInput; use syn::Ident; -use crate::util::get_tvm_rt_crate; +use crate::util::*; pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { let tvm_rt_crate = get_tvm_rt_crate(); let result = quote! { #tvm_rt_crate::function::Result }; let error = quote! { #tvm_rt_crate::errors::Error }; let derive_input = syn::parse_macro_input!(input as DeriveInput); - let payload_id = derive_input.ident; - - let mut type_key = None; - let mut ref_name = None; - let base = Some(Ident::new("base", Span::call_site())); - - for attr in derive_input.attrs { - if attr.path.is_ident("type_key") { - type_key = Some(attr.parse_meta().expect("foo")) - } - - if attr.path.is_ident("ref_name") { - ref_name = Some(attr.parse_meta().expect("foo")) - } - } - - let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key { - match name_value.lit { - syn::Lit::Str(type_key) => type_key, - _ => panic!("foo"), - } - } else { - panic!("bar"); - }; - - let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name { - match name_value.lit { - syn::Lit::Str(ref_name) => ref_name, - _ => panic!("foo"), - } - } else { - panic!("bar"); + let payload_id = derive_input.ident.clone(); + + let type_key = get_attr(&derive_input, "type_key") + .map(attr_to_str) + .expect("Failed to get type_key"); + + let ref_id = get_attr(&derive_input, "ref_name") + .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site())) + .unwrap_or_else(|| { + let id = payload_id.to_string(); + let suffixes = ["Node", "Obj"]; + if let Some(suf) = suffixes + .iter() + .find(|&suf| id.len() > suf.len() && id.ends_with(suf)) + { + Ident::new(&id[..id.len() - suf.len()], payload_id.span()) + } else { + panic!( + "Either 'ref_name' must be given, or the struct name must end one of {:?}", + suffixes + ) + } + }); + + let base_tokens = match &derive_input.data { + syn::Data::Struct(s) => s.fields.iter().next().and_then(|f| { + let (base_id, base_ty) = (f.ident.clone()?, f.ty.clone()); + if base_id == "base" { + // The transitive case of subtyping + Some(quote! { + impl AsRef for #payload_id + where #base_ty: AsRef + { + fn as_ref(&self) -> &O { + self.#base_id.as_ref() + } + } + }) + } else { + None + } + }), + _ => panic!("derive only works for structs"), }; - let ref_id = Ident::new(&ref_name.value(), Span::call_site()); - let base = base.expect("should be present"); - - let expanded = quote! { + let mut expanded = quote! { unsafe impl #tvm_rt_crate::object::IsObject for #payload_id { const TYPE_KEY: &'static str = #type_key; + } - fn as_object<'s>(&'s self) -> &'s Object { - &self.#base.as_object() + // a silly AsRef impl is necessary for subtyping to work + impl AsRef<#payload_id> for #payload_id { + fn as_ref(&self) -> &Self { + self } } @@ -82,11 +93,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { impl #tvm_rt_crate::object::IsObjectRef for #ref_id { type Object = #payload_id; - fn as_object_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr> { + fn as_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr> { self.0.as_ref() } - fn from_object_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr>) -> Self { + fn into_ptr(self) -> Option<#tvm_rt_crate::object::ObjectPtr> { + self.0 + } + + fn from_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr>) -> Self { #ref_id(object_ptr) } } @@ -99,15 +114,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } + impl std::convert::From<#payload_id> for #ref_id { + fn from(payload: #payload_id) -> Self { + let ptr = #tvm_rt_crate::object::ObjectPtr::new(payload); + #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr)) + } + } + + impl std::convert::From<#tvm_rt_crate::object::ObjectPtr<#payload_id>> for #ref_id { + fn from(ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id>) -> Self { + #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr)) + } + } + impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id { type Error = #error; fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> { use std::convert::TryInto; - let oref: #tvm_rt_crate::ObjectRef = ret_val.try_into()?; - let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?; - let ptr = ptr.downcast::<#payload_id>()?; - Ok(#ref_id(Some(ptr))) + let ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id> = ret_val.try_into()?; + Ok(ptr.into()) } } @@ -155,8 +181,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } } - }; + expanded.extend(base_tokens); + TokenStream::from(expanded) } diff --git a/rust/tvm-macros/src/util.rs b/rust/tvm-macros/src/util.rs index 1e720f0..2a342bc 100644 --- a/rust/tvm-macros/src/util.rs +++ b/rust/tvm-macros/src/util.rs @@ -28,3 +28,21 @@ pub fn get_tvm_rt_crate() -> TokenStream { quote!(tvm_rt) } } + +pub(crate) fn get_attr<'a>( + derive_input: &'a syn::DeriveInput, + name: &str, +) -> Option<&'a syn::Attribute> { + derive_input.attrs.iter().find(|a| a.path.is_ident(name)) +} + +pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr { + match attr.parse_meta() { + Ok(syn::Meta::NameValue(syn::MetaNameValue { + lit: syn::Lit::Str(s), + .. + })) => s, + Ok(m) => panic!("Expected a string literal, got {:?}", m), + Err(e) => panic!(e), + } +} diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 6e0efc9..d2c82fc 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -46,10 +46,7 @@ external! { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data - .iter() - .map(|element| element.to_object_ref().into()) - .collect(); + let iter = data.into_iter().map(T::into_arg_value).collect(); let func = Function::get("node.Array").expect( "node.Array function is not registered, this is most likely a build or linking error", @@ -66,7 +63,7 @@ impl Array { ); Ok(Array { - object: ObjectRef(Some(array_data)), + object: array_data.into(), _data: PhantomData, }) } diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs index e28dd7a..721fb1e 100644 --- a/rust/tvm-rt/src/map.rs +++ b/rust/tvm-rt/src/map.rs @@ -70,8 +70,8 @@ where let (lower_bound, upper_bound) = iter.size_hint(); let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); for (k, v) in iter { - buffer.push(k.to_object_ref().into()); - buffer.push(v.to_object_ref().into()) + buffer.push(k.into()); + buffer.push(v.into()) } Self::from_data(buffer).expect("failed to convert from data") } @@ -96,7 +96,7 @@ where ); Ok(Map { - object: ObjectRef(Some(map_data)), + object: map_data.into(), _data: PhantomData, }) } @@ -105,7 +105,8 @@ where where V: TryFrom, { - let oref: ObjectRef = map_get_item(self.object.clone(), key.to_object_ref())?; + let key = key.clone(); + let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?; oref.downcast() } } diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 3858db7..46e0342 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -18,7 +18,6 @@ */ use std::convert::TryFrom; -use std::convert::TryInto; use std::ffi::CString; use crate::errors::Error; @@ -28,87 +27,62 @@ use tvm_sys::{ArgValue, RetValue}; mod object_ptr; -pub use object_ptr::{IsObject, Object, ObjectPtr}; - -#[derive(Clone)] -pub struct ObjectRef(pub Option>); - -impl ObjectRef { - pub fn null() -> ObjectRef { - ObjectRef(None) - } -} - -pub trait IsObjectRef: Sized { +pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; + +// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we +// can't because of coherence rules. Instead, we generate them in the macro, and +// add what we can (including Into instead of From) as subtraits. +// We also add named conversions for clarity +pub trait IsObjectRef: + Sized + + Clone + + Into + + TryFrom + + for<'a> Into> + + for<'a> TryFrom, Error = Error> +{ type Object: IsObject; - fn as_object_ptr(&self) -> Option<&ObjectPtr>; - fn from_object_ptr(object_ptr: Option>) -> Self; + fn as_ptr(&self) -> Option<&ObjectPtr>; + fn into_ptr(self) -> Option>; + fn from_ptr(object_ptr: Option>) -> Self; - fn to_object_ref(&self) -> ObjectRef { - let object_ptr = self.as_object_ptr().cloned(); - ObjectRef(object_ptr.map(|ptr| ptr.upcast())) + fn null() -> Self { + Self::from_ptr(None) } - fn downcast(&self) -> Result { - let ptr = self - .as_object_ptr() - .cloned() - .map(|ptr| ptr.downcast::()); - let ptr = ptr.transpose()?; - Ok(U::from_object_ptr(ptr)) + fn into_arg_value<'a>(self) -> ArgValue<'a> { + self.into() } -} -impl IsObjectRef for ObjectRef { - type Object = Object; - - fn as_object_ptr(&self) -> Option<&ObjectPtr> { - self.0.as_ref() + fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { + Self::try_from(arg_value) } - fn from_object_ptr(object_ptr: Option>) -> Self { - ObjectRef(object_ptr) + fn into_ret_value<'a>(self) -> RetValue { + self.into() } -} -impl TryFrom for ObjectRef { - type Error = Error; - - fn try_from(ret_val: RetValue) -> Result { - let optr = ret_val.try_into()?; - Ok(ObjectRef(Some(optr))) - } -} - -impl From for RetValue { - fn from(object_ref: ObjectRef) -> RetValue { - use std::ffi::c_void; - let object_ptr = object_ref.0; - match object_ptr { - None => RetValue::ObjectHandle(std::ptr::null::() as *mut c_void), - Some(value) => value.clone().into(), - } + fn from_ret_value<'a>(ret_value: RetValue) -> Result { + Self::try_from(ret_value) } -} - -impl<'a> std::convert::TryFrom> for ObjectRef { - type Error = Error; - fn try_from(arg_value: ArgValue<'a>) -> Result { - let optr: ObjectPtr = arg_value.try_into()?; - debug_assert!(optr.count() >= 1); - Ok(ObjectRef(Some(optr))) + fn upcast(self) -> U + where + U: IsObjectRef, + Self::Object: AsRef, + { + let ptr = self.into_ptr().map(ObjectPtr::upcast); + U::from_ptr(ptr) } -} -impl<'a> From for ArgValue<'a> { - fn from(object_ref: ObjectRef) -> ArgValue<'a> { - use std::ffi::c_void; - let object_ptr = object_ref.0; - match object_ptr { - None => ArgValue::ObjectHandle(std::ptr::null::() as *mut c_void), - Some(value) => value.into(), - } + fn downcast(self) -> Result + where + U: IsObjectRef, + U::Object: AsRef, + { + let ptr = self.into_ptr().map(ObjectPtr::downcast); + let ptr = ptr.transpose()?; + Ok(U::from_ptr(ptr)) } } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 1923854..1388d3c 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -22,6 +22,7 @@ use std::ffi::CString; use std::ptr::NonNull; use std::sync::atomic::AtomicI32; +use tvm_macros::Object; use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; use tvm_sys::{ArgValue, RetValue}; @@ -35,7 +36,9 @@ type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); /// table, an atomic reference count, and a customized deleter which /// will be invoked when the reference count is zero. /// -#[derive(Debug)] +#[derive(Debug, Object)] +#[ref_name = "ObjectRef"] +#[type_key = "runtime.Object"] #[repr(C)] pub struct Object { /// The index into TVM's runtime type information table. @@ -151,25 +154,15 @@ impl Object { /// index, a method for accessing the base object given the /// subtype, and a typed delete method which is specialized /// to the subtype. -pub unsafe trait IsObject { +pub unsafe trait IsObject: AsRef { const TYPE_KEY: &'static str; - fn as_object<'s>(&'s self) -> &'s Object; - unsafe extern "C" fn typed_delete(object: *mut Self) { let object = Box::from_raw(object); drop(object) } } -unsafe impl IsObject for Object { - const TYPE_KEY: &'static str = "runtime.Object"; - - fn as_object<'s>(&'s self) -> &'s Object { - self - } -} - /// A smart pointer for types which implement IsObject. /// This type directly corresponds to TVM's C++ type ObjectPtr. /// @@ -179,14 +172,6 @@ pub struct ObjectPtr { pub ptr: NonNull, } -fn inc_ref(ptr: NonNull) { - unsafe { ptr.as_ref().as_object().inc_ref() } -} - -fn dec_ref(ptr: NonNull) { - unsafe { ptr.as_ref().as_object().dec_ref() } -} - impl ObjectPtr { pub fn from_raw(object_ptr: *mut Object) -> Option> { let non_null = NonNull::new(object_ptr); @@ -199,14 +184,14 @@ impl ObjectPtr { impl Clone for ObjectPtr { fn clone(&self) -> Self { - inc_ref(self.ptr); + unsafe { self.ptr.as_ref().as_ref().inc_ref() } ObjectPtr { ptr: self.ptr } } } impl Drop for ObjectPtr { fn drop(&mut self) { - dec_ref(self.ptr); + unsafe { self.ptr.as_ref().as_ref().dec_ref() } } } @@ -219,34 +204,42 @@ impl ObjectPtr { } pub fn new(object: T) -> ObjectPtr { + object.as_ref().inc_ref(); let object_ptr = Box::new(object); let object_ptr = Box::leak(object_ptr); let ptr = NonNull::from(object_ptr); - inc_ref(ptr); ObjectPtr { ptr } } pub fn count(&self) -> i32 { // need to do atomic read in C++ // ABI compatible atomics is funky/hard. - self.as_object() + self.as_ref() .ref_count .load(std::sync::atomic::Ordering::Relaxed) } - fn as_object<'s>(&'s self) -> &'s Object { - unsafe { self.ptr.as_ref().as_object() } + /// This method avoid running the destructor on self once it's dropped, so we don't accidentally release the memory + unsafe fn cast(self) -> ObjectPtr { + let ptr = self.ptr.cast(); + std::mem::forget(self); + ObjectPtr { ptr } } - pub fn upcast(self) -> ObjectPtr { - ObjectPtr { - ptr: self.ptr.cast(), - } + pub fn upcast(self) -> ObjectPtr + where + U: IsObject, + T: AsRef, + { + unsafe { self.cast() } } - pub fn downcast(self) -> Result, Error> { + pub fn downcast(self) -> Result, Error> + where + U: IsObject + AsRef, + { let child_index = Object::get_type_index::(); - let object_index = self.as_object().type_index; + let object_index = self.as_ref().type_index; let is_derived = if child_index == object_index { true @@ -256,10 +249,7 @@ impl ObjectPtr { }; if is_derived { - // NB: self gets dropped here causng a dec ref which we need to migtigate with an inc ref before it is dropped. - inc_ref(self.ptr); - let ptr = self.ptr.cast(); - Ok(ObjectPtr { ptr }) + Ok(unsafe { self.cast() }) } else { Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) } diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 73d0439..a5ee1f1 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -20,7 +20,7 @@ use std::cmp::{Ordering, PartialEq}; use std::hash::{Hash, Hasher}; -use super::{Object, ObjectPtr}; +use super::Object; use tvm_macros::Object; @@ -37,27 +37,18 @@ pub struct StringObj { impl From for String { fn from(s: std::string::String) -> Self { let size = s.len() as u64; - let obj = StringObj { - base: Object::base_object::(), - data: s.as_bytes().as_ptr(), - size, - }; - std::mem::forget(s); - let obj_ptr = ObjectPtr::new(obj); - String(Some(obj_ptr)) + let data = Box::into_raw(s.into_boxed_str()).cast(); + let base = Object::base_object::(); + StringObj { base, data, size }.into() } } impl From<&'static str> for String { fn from(s: &'static str) -> Self { let size = s.len() as u64; - let obj = StringObj { - base: Object::base_object::(), - data: s.as_bytes().as_ptr(), - size, - }; - let obj_ptr = ObjectPtr::new(obj); - String(Some(obj_ptr)) + let data = s.as_bytes().as_ptr(); + let base = Object::base_object::(); + StringObj { base, data, size }.into() } } @@ -139,7 +130,7 @@ mod tests { #[test] fn test_string_debug() -> Result<()> { let s = String::from("foo"); - let object_ref = s.to_object_ref(); + let object_ref = s.upcast(); println!("about to call"); let string = debug_print(object_ref)?; println!("after call"); diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 9222de5..b615c1e 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -34,7 +34,7 @@ external! { pub fn as_text(object: T) -> String { let no_func = unsafe { runtime::Function::null() }; - _as_text(object.to_object_ref(), 0, no_func) + _as_text(object.upcast(), 0, no_func) .unwrap() .as_str() .unwrap() diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 771882e..4f4497e 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -18,7 +18,7 @@ */ use crate::runtime::array::Array; -use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString}; +use crate::runtime::{object::*, String as TString}; use crate::DataType; use tvm_macros::Object; diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs index bef63b1..ee30c51 100644 --- a/rust/tvm/src/ir/tir.rs +++ b/rust/tvm/src/ir/tir.rs @@ -17,7 +17,7 @@ * under the License. */ -use crate::runtime::{Object, ObjectPtr, String as TVMString}; +use crate::runtime::String as TVMString; use crate::DataType; use super::*; @@ -37,17 +37,7 @@ macro_rules! define_node { pub fn new(datatype: DataType, $($id : $t,)*) -> $name { let base = PrimExprNode::base::<$node>(datatype); let node = $node { base, $($id),* }; - $name(Some(ObjectPtr::new(node))) - } - } - - impl From<$name> for PrimExpr { - // TODO(@jroesch): Remove we with subtyping traits. - fn from(x: $name) -> PrimExpr { - x.downcast().expect(concat!( - "Failed to downcast `", - stringify!($name), - "` to PrimExpr")) + node.into() } } } -- 2.7.4