From c899b3c9493cdf35dcaaed318a9486a24131b4ee Mon Sep 17 00:00:00 2001 From: Max Willsey Date: Fri, 28 Aug 2020 12:13:47 -0700 Subject: [PATCH] Improve Rust bindings: Map, Array, String, various IR nodes (#6339) * Fix datatype * Add initialize macro * Add some TIR nodes * Better downcasting * Improve Array and add Map * Convert to new string API * Clean up some warnings * Add ConstIntBound type * Run cargo fmt * Remove debug prints * Add some more ops * Fix some string code Co-authored-by: Jared Roesch --- rust/tvm-macros/src/object.rs | 6 +- rust/tvm-rt/src/array.rs | 45 +++++- rust/tvm-rt/src/function.rs | 2 +- rust/tvm-rt/src/lib.rs | 1 + rust/tvm-rt/src/map.rs | 264 +++++++++++++++++++++++++++++++++++ rust/tvm-rt/src/object/mod.rs | 4 + rust/tvm-rt/src/object/object_ptr.rs | 4 +- rust/tvm-rt/src/string.rs | 118 ++++++++++++---- rust/tvm-sys/src/datatype.rs | 3 +- rust/tvm-sys/src/packed_func.rs | 31 ++++ rust/tvm/src/ir/arith.rs | 46 ++++++ rust/tvm/src/ir/mod.rs | 40 +++++- rust/tvm/src/ir/relay/mod.rs | 10 +- rust/tvm/src/ir/tir.rs | 93 ++++++++++++ rust/tvm/src/transform.rs | 32 ++++- 15 files changed, 644 insertions(+), 55 deletions(-) create mode 100644 rust/tvm-rt/src/map.rs create mode 100644 rust/tvm/src/ir/arith.rs create mode 100644 rust/tvm/src/ir/tir.rs diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index 0170e1d..342be6b 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -82,11 +82,11 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { impl #tvm_rt_crate::object::IsObjectRef for #ref_id { type Object = #payload_id; - fn as_object_ptr(&self) -> Option<&ObjectPtr> { + fn as_object_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr> { self.0.as_ref() } - fn from_object_ptr(object_ptr: Option>) -> Self { + fn from_object_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr>) -> Self { #ref_id(object_ptr) } } @@ -104,7 +104,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> { use std::convert::TryInto; - let oref: ObjectRef = ret_val.try_into()?; + let oref: #tvm_rt_crate::ObjectRef = ret_val.try_into()?; let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?; let ptr = ptr.downcast::<#payload_id>()?; Ok(#ref_id(Some(ptr))) diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 128bb87..6e0efc9 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -25,7 +25,7 @@ use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef}; use crate::{ external, function::{Function, Result}, - RetValue, + ArgValue, RetValue, }; #[repr(C)] @@ -40,6 +40,8 @@ pub struct Array { external! { #[name("node.ArrayGetItem")] fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; + #[name("node.ArraySize")] + fn array_size(array: ObjectRef) -> i64; } impl Array { @@ -76,4 +78,45 @@ impl Array { let oref: ObjectRef = array_get_item(self.object.clone(), index)?; oref.downcast() } + + pub fn len(&self) -> i64 { + array_size(self.object.clone()).expect("size should never fail") + } +} + +impl From> for ArgValue<'static> { + fn from(array: Array) -> ArgValue<'static> { + array.object.into() + } +} + +impl From> for RetValue { + fn from(array: Array) -> RetValue { + array.object.into() + } +} + +impl<'a, T: IsObjectRef> TryFrom> for Array { + type Error = Error; + + fn try_from(array: ArgValue<'a>) -> Result> { + let object_ref: ObjectRef = array.try_into()?; + // TODO: type check + Ok(Array { + object: object_ref, + _data: PhantomData, + }) + } +} + +impl<'a, T: IsObjectRef> TryFrom for Array { + type Error = Error; + + fn try_from(array: RetValue) -> Result> { + let object_ref = array.try_into()?; + Ok(Array { + object: object_ref, + _data: PhantomData, + }) + } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 94a20ac..bae06e9 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -133,7 +133,7 @@ impl Function { match rv { RetValue::ObjectHandle(object) => { let optr = crate::object::ObjectPtr::from_raw(object as _).unwrap(); - println!("after wrapped call: {}", optr.count()); + // println!("after wrapped call: {}", optr.count()); crate::object::ObjectPtr::leak(optr); } _ => {} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index ad4c1ca..84951f4 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -95,6 +95,7 @@ pub mod array; pub mod context; pub mod errors; pub mod function; +pub mod map; pub mod module; pub mod ndarray; mod to_function; diff --git a/rust/tvm-rt/src/map.rs b/rust/tvm-rt/src/map.rs new file mode 100644 index 0000000..e28dd7a --- /dev/null +++ b/rust/tvm-rt/src/map.rs @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::collections::HashMap; +use std::convert::{TryFrom, TryInto}; +use std::iter::FromIterator; +use std::marker::PhantomData; + +use crate::object::debug_print; + +use crate::array::Array; +use crate::errors::Error; +use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef}; +use crate::ArgValue; +use crate::{ + external, + function::{Function, Result}, + RetValue, +}; + +#[repr(C)] +#[derive(Clone)] +pub struct Map +where + K: IsObjectRef, + V: IsObjectRef, +{ + object: ObjectRef, + _data: PhantomData<(K, V)>, +} + +// TODO(@jroesch): convert to use generics instead of casting inside +// the implementation. +external! { + #[name("node.ArrayGetItem")] + fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; + #[name("node.MapSize")] + fn map_size(map: ObjectRef) -> i64; + #[name("node.MapGetItem")] + fn map_get_item(map_object: ObjectRef, key: ObjectRef) -> ObjectRef; + #[name("node.MapCount")] + fn map_count(map: ObjectRef, key: ObjectRef) -> ObjectRef; + #[name("node.MapItems")] + fn map_items(map: ObjectRef) -> Array; +} + +impl FromIterator<(K, V)> for Map +where + K: IsObjectRef, + V: IsObjectRef, +{ + fn from_iter>(iter: T) -> Self { + let iter = iter.into_iter(); + let (lower_bound, upper_bound) = iter.size_hint(); + let mut buffer: Vec = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2); + for (k, v) in iter { + buffer.push(k.to_object_ref().into()); + buffer.push(v.to_object_ref().into()) + } + Self::from_data(buffer).expect("failed to convert from data") + } +} + +impl Map +where + K: IsObjectRef, + V: IsObjectRef, +{ + pub fn from_data(data: Vec) -> Result> { + let func = Function::get("node.Map").expect( + "node.Map function is not registered, this is most likely a build or linking error", + ); + + let map_data: ObjectPtr = func.invoke(data)?.try_into()?; + + debug_assert!( + map_data.count() >= 1, + "map_data count is {}", + map_data.count() + ); + + Ok(Map { + object: ObjectRef(Some(map_data)), + _data: PhantomData, + }) + } + + pub fn get(&self, key: &K) -> Result + where + V: TryFrom, + { + let oref: ObjectRef = map_get_item(self.object.clone(), key.to_object_ref())?; + oref.downcast() + } +} + +pub struct IntoIter { + // NB: due to FFI this isn't as lazy as one might like + key_and_values: Array, + next_key: i64, + _data: PhantomData<(K, V)>, +} + +impl Iterator for IntoIter +where + K: IsObjectRef, + V: IsObjectRef, +{ + type Item = (K, V); + + #[inline] + fn next(&mut self) -> Option<(K, V)> { + if self.next_key < self.key_and_values.len() { + let key = self + .key_and_values + .get(self.next_key as isize) + .expect("this should always succeed"); + let value = self + .key_and_values + .get((self.next_key as isize) + 1) + .expect("this should always succeed"); + self.next_key += 2; + Some((key.downcast().unwrap(), value.downcast().unwrap())) + } else { + None + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + ((self.key_and_values.len() / 2) as usize, None) + } +} + +impl IntoIterator for Map +where + K: IsObjectRef, + V: IsObjectRef, +{ + type Item = (K, V); + type IntoIter = IntoIter; + + fn into_iter(self) -> IntoIter { + let items = map_items(self.object).expect("unable to get map items"); + IntoIter { + key_and_values: items, + next_key: 0, + _data: PhantomData, + } + } +} + +use std::fmt; + +impl fmt::Debug for Map +where + K: IsObjectRef, + V: IsObjectRef, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let ctr = debug_print(self.object.clone()).unwrap(); + fmt.write_fmt(format_args!("{:?}", ctr)) + } +} + +impl From> for HashMap +where + K: Eq + std::hash::Hash, + K: IsObjectRef, + V: IsObjectRef, + S: std::hash::BuildHasher + std::default::Default, +{ + fn from(map: Map) -> HashMap { + HashMap::from_iter(map.into_iter()) + } +} + +impl<'a, K, V> From> for ArgValue<'a> +where + K: IsObjectRef, + V: IsObjectRef, +{ + fn from(map: Map) -> ArgValue<'a> { + map.object.into() + } +} + +impl From> for RetValue +where + K: IsObjectRef, + V: IsObjectRef, +{ + fn from(map: Map) -> RetValue { + map.object.into() + } +} + +impl<'a, K, V> TryFrom> for Map +where + K: IsObjectRef, + V: IsObjectRef, +{ + type Error = Error; + + fn try_from(array: ArgValue<'a>) -> Result> { + let object_ref: ObjectRef = array.try_into()?; + // TODO: type check + Ok(Map { + object: object_ref, + _data: PhantomData, + }) + } +} + +impl TryFrom for Map +where + K: IsObjectRef, + V: IsObjectRef, +{ + type Error = Error; + + fn try_from(array: RetValue) -> Result> { + let object_ref = array.try_into()?; + // TODO: type check + Ok(Map { + object: object_ref, + _data: PhantomData, + }) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use super::*; + use crate::string::String as TString; + + #[test] + fn test_from_into_hash_map() { + let mut std_map: HashMap = HashMap::new(); + std_map.insert("key1".into(), "value1".into()); + std_map.insert("key2".into(), "value2".into()); + let tvm_map = Map::from_iter(std_map.clone().into_iter()); + let back_map = tvm_map.into(); + assert_eq!(std_map, back_map); + } +} diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 73b6c99..3858db7 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -115,6 +115,10 @@ impl<'a> From for ArgValue<'a> { external! { #[name("ir.DebugPrint")] fn debug_print(object: ObjectRef) -> CString; + #[name("node.StructuralHash")] + fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef; + #[name("node.StructuralEqual")] + fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef; } // external! { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 6880824..1923854 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -290,7 +290,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { RetValue::ObjectHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - println!("back to type {}", optr.count()); + // println!("back to type {}", optr.count()); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), @@ -315,7 +315,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { ArgValue::ObjectHandle(handle) => { let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; debug_assert!(optr.count() >= 1); - println!("count: {}", optr.count()); + // println!("count: {}", optr.count()); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 7727e4b..73d0439 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -17,11 +17,10 @@ * under the License. */ -use std::ffi::{CString, NulError}; -use std::os::raw::c_char; +use std::cmp::{Ordering, PartialEq}; +use std::hash::{Hash, Hasher}; -use super::errors::Error; -use super::{Object, ObjectPtr, ObjectRef}; +use super::{Object, ObjectPtr}; use tvm_macros::Object; @@ -31,41 +30,102 @@ use tvm_macros::Object; #[type_key = "runtime.String"] pub struct StringObj { base: Object, - data: *const c_char, + data: *const u8, size: u64, } -impl String { - pub fn new(string: std::string::String) -> Result { - let cstring = CString::new(string)?; - - // The string is being corrupted. - // why is this wrong - let length = cstring.as_bytes().len(); +impl From for String { + fn from(s: std::string::String) -> Self { + let size = s.len() as u64; + let obj = StringObj { + base: Object::base_object::(), + data: s.as_bytes().as_ptr(), + size, + }; + std::mem::forget(s); + let obj_ptr = ObjectPtr::new(obj); + String(Some(obj_ptr)) + } +} - let string_obj = StringObj { +impl From<&'static str> for String { + fn from(s: &'static str) -> Self { + let size = s.len() as u64; + let obj = StringObj { base: Object::base_object::(), - data: cstring.into_raw(), - size: length as u64, + data: s.as_bytes().as_ptr(), + size, }; + let obj_ptr = ObjectPtr::new(obj); + String(Some(obj_ptr)) + } +} + +impl AsRef<[u8]> for String { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} - let object_ptr = ObjectPtr::new(string_obj); - Ok(String(Some(object_ptr))) +impl std::fmt::Display for String { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.to_string_lossy().fmt(f) } +} - pub fn to_cstring(&self) -> Result { - use std::slice; - let ptr = self.0.as_ref().unwrap().data; - let size = self.0.as_ref().unwrap().size; - unsafe { - let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as usize); - CString::new(slice) - } +impl String { + pub fn is_empty(&self) -> bool { + self.len() == 0 } - pub fn to_string(&self) -> Result { - let string = self.to_cstring()?.into_string()?; - Ok(string) + pub fn len(&self) -> usize { + self.size as usize + } + + pub fn as_bytes(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.data, self.len()) } + } + + pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> { + std::str::from_utf8(self.as_bytes()) + } + + pub fn to_string_lossy(&self) -> std::borrow::Cow { + std::string::String::from_utf8_lossy(self.as_bytes()) + } +} + +impl> PartialEq for String { + fn eq(&self, other: &T) -> bool { + self.as_bytes() == other.as_ref() + } +} + +impl> PartialOrd for String { + fn partial_cmp(&self, other: &T) -> Option { + self.as_bytes().partial_cmp(other.as_ref()) + } +} + +impl Eq for String {} + +impl Ord for String { + fn cmp(&self, other: &Self) -> Ordering { + self.as_bytes().cmp(other.as_bytes()) + } +} + +impl Hash for String { + fn hash(&self, state: &mut H) { + self.as_bytes().hash(state); + } +} + +impl std::fmt::Debug for String { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // TODO(@mwillsey): remove this clone? + let string: String = self.clone().into(); + formatter.write_fmt(format_args!("{:?}", string)) } } @@ -78,7 +138,7 @@ mod tests { #[test] fn test_string_debug() -> Result<()> { - let s = String::new("foo".to_string()).unwrap(); + let s = String::from("foo"); let object_ref = s.to_object_ref(); println!("about to call"); let string = debug_print(object_ref)?; diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs index c98d374..8050d93 100644 --- a/rust/tvm-sys/src/datatype.rs +++ b/rust/tvm-sys/src/datatype.rs @@ -31,7 +31,8 @@ const DL_UINT_CODE: u8 = 1; const DL_FLOAT_CODE: u8 = 2; const DL_HANDLE: u8 = 3; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(C)] pub struct DataType { code: u8, bits: u8, diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 3121deb..3588539 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -378,3 +378,34 @@ impl TryFrom for std::ffi::CString { |RetValue::Str(val)| { val.into() }) } } + +// Implementations for bool. + +impl<'a> From for ArgValue<'a> { + fn from(s: bool) -> Self { + (s as i64).into() + } +} + +impl From for RetValue { + fn from(s: bool) -> Self { + (s as i64).into() + } +} + +impl TryFrom for bool { + type Error = ValueDowncastError; + + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> bool, + |RetValue::Int(val)| { !(val == 0) }) + } +} + +impl<'a> TryFrom> for bool { + type Error = ValueDowncastError; + + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) + } +} diff --git a/rust/tvm/src/ir/arith.rs b/rust/tvm/src/ir/arith.rs new file mode 100644 index 0000000..c2de24a --- /dev/null +++ b/rust/tvm/src/ir/arith.rs @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::{Object, ObjectPtr}; + +use super::*; + +macro_rules! define_node { + ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { + #[repr(C)] + #[derive(Object)] + #[ref_name = $ref] + #[type_key = $typekey] + pub struct $node { + base: Object, + $(pub $id : $t),* + } + + impl $name { + pub fn new($($id : $t,)*) -> $name { + let base = Object::base_object::<$node>(); + let node = $node { base, $($id),* }; + $name(Some(ObjectPtr::new(node))) + } + } + } +} + +define_node!(ConstIntBound, "ConstIntBound", "arith.ConstIntBound"; + ConstIntBoundNode { min_value: i64, max_value: i64 }); diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 4fe13a3..9222de5 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -18,10 +18,13 @@ */ use crate::runtime::String as TString; -use crate::runtime::{self, external, IsObjectRef, Object, ObjectRef}; +use crate::runtime::{self, external, IsObject, IsObjectRef, Object, ObjectRef}; use crate::DataType; +use tvm_macros::Object; +pub mod arith; pub mod relay; +pub mod tir; // TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) external! { @@ -33,18 +36,41 @@ pub fn as_text(object: T) -> String { let no_func = unsafe { runtime::Function::null() }; _as_text(object.to_object_ref(), 0, no_func) .unwrap() - .to_string() + .as_str() .unwrap() + .into() } #[repr(C)] -pub struct PrimExprNode { +#[derive(Object)] +#[ref_name = "BaseExpr"] +#[type_key = "Expr"] +pub struct BaseExprNode { pub base: Object, - pub dtype: DataType, +} + +impl BaseExprNode { + fn base() -> BaseExprNode { + BaseExprNode { + base: Object::base_object::(), + } + } } #[repr(C)] -pub struct IntImmNode { - pub base: PrimExprNode, - pub value: i64, +#[derive(Object)] +#[ref_name = "PrimExpr"] +#[type_key = "PrimExpr"] +pub struct PrimExprNode { + pub base: BaseExprNode, + pub datatype: DataType, +} + +impl PrimExprNode { + pub fn base(datatype: DataType) -> PrimExprNode { + PrimExprNode { + base: BaseExprNode::base::(), + datatype, + } + } } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index cad41ac..771882e 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -96,7 +96,7 @@ impl GlobalVar { pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { let node = GlobalVarNode { base: RelayExpr::base::(), - name_hint: TString::new(name_hint).unwrap(), + name_hint: name_hint.into(), }; GlobalVar(Some(ObjectPtr::new(node))) } @@ -135,7 +135,7 @@ impl Var { pub fn new(name_hint: String, _span: ObjectRef) -> Var { let node = VarNode { base: RelayExpr::base::(), - vid: Id::new(TString::new(name_hint.to_string()).unwrap()), + vid: Id::new(name_hint.into()), type_annotation: ObjectRef::null(), }; Var(Some(ObjectPtr::new(node))) @@ -241,7 +241,7 @@ mod tests { #[test] fn test_id() -> Result<()> { - let string = TString::new("foo".to_string()).expect("bar"); + let string = TString::from("foo"); let id = Id::new(string); let text = as_text(id.clone()); assert!(text.contains("relay.Id")); @@ -275,8 +275,8 @@ mod tests { Var::new("bar".into(), ObjectRef::null()), ]; let array = Array::from_vec(vec)?; - assert_eq!(array.get(0)?.name_hint().to_string()?, "foo"); - assert_eq!(array.get(1)?.name_hint().to_string()?, "bar"); + assert_eq!(array.get(0)?.name_hint().to_string(), "foo"); + assert_eq!(array.get(1)?.name_hint().to_string(), "bar"); Ok(()) } } diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs new file mode 100644 index 0000000..bef63b1 --- /dev/null +++ b/rust/tvm/src/ir/tir.rs @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::{Object, ObjectPtr, String as TVMString}; +use crate::DataType; + +use super::*; + +macro_rules! define_node { + ($name:ident, $ref:expr, $typekey:expr; $node:ident { $($id:ident : $t:ty),*}) => { + #[repr(C)] + #[derive(Object)] + #[ref_name = $ref] + #[type_key = $typekey] + pub struct $node { + base: PrimExprNode, + $(pub $id : $t),* + } + + impl $name { + pub fn new(datatype: DataType, $($id : $t,)*) -> $name { + let base = PrimExprNode::base::<$node>(datatype); + let node = $node { base, $($id),* }; + $name(Some(ObjectPtr::new(node))) + } + } + + impl From<$name> for PrimExpr { + // TODO(@jroesch): Remove we with subtyping traits. + fn from(x: $name) -> PrimExpr { + x.downcast().expect(concat!( + "Failed to downcast `", + stringify!($name), + "` to PrimExpr")) + } + } + } +} + +define_node!(IntImm, "IntImm", "IntImm"; + IntImmNode { value: i64 }); +define_node!(Var, "Var", "tir.Var"; + VarNode { name_hint: TVMString }); + +define_node!(Add, "Add", "tir.Add"; AddNode { a: PrimExpr, b: PrimExpr }); +define_node!(Sub, "Sub", "tir.Sub"; SubNode { a: PrimExpr, b: PrimExpr }); +define_node!(Mul, "Mul", "tir.Mul"; MulNode { a: PrimExpr, b: PrimExpr }); + +define_node!(Div, "Div", "tir.Div"; DivNode { a: PrimExpr, b: PrimExpr }); +define_node!(Mod, "Mod", "tir.Mod"; ModNode { a: PrimExpr, b: PrimExpr }); +define_node!(FloorDiv, "FloorDiv", "tir.FloorDiv"; FloorDivNode { a: PrimExpr, b: PrimExpr }); +define_node!(FloorMod, "FloorMod", "tir.FloorMod"; FloorModNode { a: PrimExpr, b: PrimExpr }); + +define_node!(Min, "Min", "tir.Min"; MinNode { a: PrimExpr, b: PrimExpr }); +define_node!(Max, "Max", "tir.Max"; MaxNode { a: PrimExpr, b: PrimExpr }); + +// the new datatype is in the base expr +define_node!(Cast, "Cast", "tir.Cast"; CastNode { value: PrimExpr }); + +// renamed base to start to avoid name clash +define_node!(Ramp, "Ramp", "tir.Ramp"; RampNode { start: PrimExpr, stride: PrimExpr, lanes: i32 }); + +define_node!(Select, "Select", "tir.Select"; + SelectNode { condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr }); + +define_node!(Eq, "Eq", "tir.EQ"; EqNode { a: PrimExpr, b: PrimExpr }); +define_node!(Ne, "Ne", "tir.NE"; NeNode { a: PrimExpr, b: PrimExpr }); +define_node!(Lt, "Lt", "tir.LT"; LtNode { a: PrimExpr, b: PrimExpr }); +define_node!(Le, "Le", "tir.LE"; LeNode { a: PrimExpr, b: PrimExpr }); +define_node!(Gt, "Gt", "tir.GT"; GtNode { a: PrimExpr, b: PrimExpr }); +define_node!(Ge, "Ge", "tir.GE"; GeNode { a: PrimExpr, b: PrimExpr }); + +define_node!(And, "And", "tir.And"; AndNode { a: PrimExpr, b: PrimExpr }); +define_node!(Or, "Or", "tir.Or"; OrNode { a: PrimExpr, b: PrimExpr }); +define_node!(Not, "Not", "tir.Not"; NotNode { value: PrimExpr }); + +define_node!(Let, "Let", "tir.Let"; LetNode { var: Var, value: PrimExpr, body: PrimExpr }); diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index ab84202..59fc604 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -45,17 +45,14 @@ pub struct PassInfoNode { impl PassInfo { pub fn new(opt_level: i32, name: String, required: Vec) -> Result { - let required: Result<_> = required - .into_iter() - .map(|name| TString::new(name)) - .collect(); + let required = required.into_iter().map(|name| name.into()).collect(); - let required = Array::from_vec(required?)?; + let required = Array::from_vec(required)?; let node = PassInfoNode { base: Object::base_object::(), opt_level, - name: TString::new(name).unwrap(), + name: name.into(), required, }; @@ -76,6 +73,29 @@ pub fn function_pass Function + 'stati create_func_pass(func, pass_info) } +/// A macro for generating the correct TVM symbols for plugin loading. +/// +/// The expression passed to the macro will be run when TVM loads the +/// shared library. +/// +/// This is useful for calling register to register packed functions +/// to consume via TVM's packed function APIs. +#[macro_export] +macro_rules! initialize { + ($body:expr) => { + #[no_mangle] + pub unsafe extern "C" fn initialize( + args: *mut tvm_sys::ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: tvm_sys::ffi::TVMRetValueHandle, + ) -> c_int { + $body + return 0; + } + }; +} + #[macro_export] macro_rules! export_pass { ($name:literal,$func:expr) => { -- 2.7.4