use_try_shorthand = false
use_field_init_shorthand = false
force_explicit_abi = true
+
"frontend",
"frontend/tests/basics",
"frontend/tests/callback",
- "frontend/examples/resnet"
+ "frontend/examples/resnet",
+ "tvm-sys"
]
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "tvm-sys"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+edition = "2018"
+
+[features]
+bindings = []
+
+[dependencies]
+thiserror = "^1.0"
+anyhow = "^1.0"
+ndarray = "0.12"
+enumn = "^0.1"
+
+[build-dependencies]
+bindgen = "0.51"
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate bindgen;
+
+use std::path::PathBuf;
+
+use std::env;
+
+fn main() {
+ let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
+ let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+ .canonicalize()
+ .unwrap();
+ crate_dir
+ .parent()
+ .unwrap()
+ .parent()
+ .unwrap()
+ .to_str()
+ .unwrap()
+ .to_string()
+ });
+
+ if cfg!(feature = "bindings") {
+ println!("cargo:rerun-if-env-changed=TVM_HOME");
+ println!("cargo:rustc-link-lib=dylib=tvm");
+ println!("cargo:rustc-link-search={}/build", tvm_home);
+ }
+
+ // @see rust-bindgen#550 for `blacklist_type`
+ bindgen::Builder::default()
+ .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
+ .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
+ .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
+ .clang_arg(format!("-I{}/include/", tvm_home))
+ .blacklist_type("max_align_t")
+ .layout_tests(false)
+ .derive_partialeq(true)
+ .derive_eq(true)
+ .generate()
+ .expect("unable to generate bindings")
+ .write_to_file(PathBuf::from("src/c_runtime_api.rs"))
+ .expect("can not write the bindings!");
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+ mem,
+ os::raw::{c_int, c_void},
+};
+
+use crate::ffi::{
+ DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
+ DLDeviceType_kDLCPU, DLTensor,
+};
+
+/// `From` conversions to `DLTensor` for `ndarray::Array`.
+/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
+macro_rules! impl_dltensor_from_ndarray {
+ ($type:ty, $typecode:expr) => {
+ impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
+ fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
+ DLTensor {
+ data: arr.as_mut_ptr() as *mut c_void,
+ ctx: DLContext {
+ device_type: DLDeviceType_kDLCPU,
+ device_id: 0,
+ },
+ ndim: arr.ndim() as c_int,
+ dtype: DLDataType {
+ code: $typecode as u8,
+ bits: 8 * mem::size_of::<$type>() as u8,
+ lanes: 1,
+ },
+ shape: arr.shape().as_ptr() as *const i64 as *mut i64,
+ strides: arr.strides().as_ptr() as *const i64 as *mut i64,
+ byte_offset: 0,
+ }
+ }
+ }
+ };
+}
+
+impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use std::os::raw::c_char;
+
+use crate::ffi::TVMByteArray;
+
+/// A newtype wrapping a raw TVM byte-array.
+///
+/// ## Example
+///
+/// ```
+/// let v = b"hello";
+/// let barr = tvm_sys::ByteArray::from(&v);
+/// assert_eq!(barr.len(), v.len());
+/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
+/// ```
+pub struct ByteArray {
+ /// The raw FFI ByteArray.
+ array: TVMByteArray,
+}
+
+impl ByteArray {
+ /// Gets the underlying byte-array
+ pub fn data(&self) -> &'static [u8] {
+ unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size) }
+ }
+
+ /// Gets the length of the underlying byte-array
+ pub fn len(&self) -> usize {
+ self.array.size
+ }
+
+ /// Converts the underlying byte-array to `Vec<u8>`
+ pub fn to_vec(&self) -> Vec<u8> {
+ self.data().to_vec()
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+}
+
+// Needs AsRef for Vec
+impl<T: AsRef<[u8]>> From<T> for ByteArray {
+ fn from(arg: T) -> Self {
+ let arg = arg.as_ref();
+ ByteArray {
+ array: TVMByteArray {
+ data: arg.as_ptr() as *const c_char,
+ size: arg.len(),
+ },
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn convert() {
+ let v = vec![1u8, 2, 3];
+ let barr = ByteArray::from(&v);
+ assert_eq!(barr.len(), v.len());
+ assert_eq!(barr.to_vec(), vec![1u8, 2, 3]);
+ let v = b"hello";
+ let barr = ByteArray::from(&v);
+ assert_eq!(barr.len(), v.len());
+ assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! Provides [`Context`] and related device queries.
+//!
+//! Create a new context for device type and device id.
+//!
+//! # Example
+//!
+//! ```
+//! # use tvm_sys::{DeviceType, Context};
+//! let cpu = DeviceType::from("cpu");
+//! let ctx = Context::new(cpu , 0);
+//! let cpu0 = Context::cpu(0);
+//! assert_eq!(ctx, cpu0);
+//! ```
+//!
+//! Or from a supported device name.
+//!
+//! ```
+//! use tvm_sys::Context;
+//! let cpu0 = Context::from("cpu");
+//! println!("{}", cpu0);
+//! ```
+
+use std::convert::TryFrom;
+use std::fmt::{self, Display, Formatter};
+use std::str::FromStr;
+
+use crate::ffi::{self, *};
+use crate::packed_func::{ArgValue, RetValue};
+
+use anyhow::Result;
+use enumn::N;
+use thiserror::Error;
+
+/// Device type represents the set of devices supported by
+/// [TVM](https://github.com/apache/incubator-tvm).
+///
+/// ## Example
+///
+/// ```
+/// use tvm_sys::DeviceType;
+/// let cpu = DeviceType::from("cpu");
+/// println!("device is: {}", cpu);
+///```
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, N)]
+#[repr(i64)]
+pub enum DeviceType {
+ CPU = 1,
+ GPU,
+ CPUPinned,
+ OpenCL,
+ Vulkan,
+ Metal,
+ VPI,
+ ROCM,
+ ExtDev,
+}
+
+impl Default for DeviceType {
+ /// default device is cpu.
+ fn default() -> Self {
+ DeviceType::CPU
+ }
+}
+
+impl From<DeviceType> for ffi::DLDeviceType {
+ fn from(device_type: DeviceType) -> Self {
+ device_type as Self
+ }
+}
+
+impl From<ffi::DLDeviceType> for DeviceType {
+ fn from(device_type: ffi::DLDeviceType) -> Self {
+ Self::n(device_type as _).expect("invalid enumeration value for ffi::DLDeviceType")
+ }
+}
+
+impl Display for DeviceType {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{}",
+ match self {
+ DeviceType::CPU => "cpu",
+ DeviceType::GPU => "gpu",
+ DeviceType::CPUPinned => "cpu_pinned",
+ DeviceType::OpenCL => "opencl",
+ DeviceType::Vulkan => "vulkan",
+ DeviceType::Metal => "metal",
+ DeviceType::VPI => "vpi",
+ DeviceType::ROCM => "rocm",
+ DeviceType::ExtDev => "ext_device",
+ // DeviceType(_) => "rpc",
+ }
+ )
+ }
+}
+
+impl<'a> From<&'a str> for DeviceType {
+ fn from(type_str: &'a str) -> Self {
+ match type_str {
+ "cpu" => DeviceType::CPU,
+ "llvm" => DeviceType::CPU,
+ "stackvm" => DeviceType::CPU,
+ "gpu" => DeviceType::GPU,
+ "cuda" => DeviceType::GPU,
+ "nvptx" => DeviceType::GPU,
+ "cl" => DeviceType::OpenCL,
+ "opencl" => DeviceType::OpenCL,
+ "metal" => DeviceType::Metal,
+ "vpi" => DeviceType::VPI,
+ "rocm" => DeviceType::ROCM,
+ _ => panic!("{:?} not supported!", type_str),
+ }
+ }
+}
+
+impl<'a> From<&DeviceType> for ArgValue<'a> {
+ fn from(dev: &DeviceType) -> Self {
+ Self::Int(*dev as _)
+ }
+}
+
+#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
+pub struct Context {
+ pub device_type: DeviceType,
+ pub device_id: usize,
+}
+
+impl Context {
+ pub fn new(device_type: DeviceType, device_id: usize) -> Context {
+ Context {
+ device_type,
+ device_id,
+ }
+ }
+}
+
+impl<'a> From<&'a Context> for DLContext {
+ fn from(ctx: &'a Context) -> Self {
+ Self {
+ device_type: ctx.device_type.into(),
+ device_id: ctx.device_id as i32,
+ }
+ }
+}
+
+impl Default for Context {
+ fn default() -> Self {
+ Self {
+ device_type: DLDeviceType_kDLCPU.into(),
+ device_id: 0,
+ }
+ }
+}
+
+#[derive(Debug, Error)]
+#[error("unsupported device: {0}")]
+pub struct UnsupportedDeviceError(String);
+
+macro_rules! impl_tvm_context {
+ ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
+ /// Creates a Context from a string (e.g., "cpu", "gpu", "ext_dev")
+ impl FromStr for Context {
+ type Err = UnsupportedDeviceError;
+ fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+ Ok(Self {
+ device_type: match type_str {
+ $( $( stringify!($dev_name) )|+ => $dev_type.into()),+,
+ _ => return Err(UnsupportedDeviceError(type_str.to_string())),
+ },
+ device_id: 0,
+ })
+ }
+ }
+
+ impl Context {
+ $(
+ $(
+ pub fn $dev_name(device_id: usize) -> Self {
+ Self {
+ device_type: $dev_type.into(),
+ device_id: device_id,
+ }
+ }
+ )+
+ )+
+ }
+ };
+}
+
+impl_tvm_context!(
+ DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
+ DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+ DLDeviceType_kDLOpenCL: [cl],
+ DLDeviceType_kDLMetal: [metal],
+ DLDeviceType_kDLVPI: [vpi],
+ DLDeviceType_kDLROCM: [rocm],
+ DLDeviceType_kDLExtDev: [ext_dev]
+);
+
+impl<'a> From<&'a str> for Context {
+ fn from(target: &str) -> Self {
+ Context::new(DeviceType::from(target), 0)
+ }
+}
+
+impl From<ffi::DLContext> for Context {
+ fn from(ctx: ffi::DLContext) -> Self {
+ Context {
+ device_type: DeviceType::from(ctx.device_type),
+ device_id: ctx.device_id as usize,
+ }
+ }
+}
+
+impl From<Context> for ffi::DLContext {
+ fn from(ctx: Context) -> Self {
+ ffi::DLContext {
+ device_type: ctx.device_type.into(),
+ device_id: ctx.device_id as i32,
+ }
+ }
+}
+
+impl Display for Context {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(f, "{}({})", self.device_type, self.device_id)
+ }
+}
+
+impl From<Context> for RetValue {
+ fn from(ret_value: Context) -> RetValue {
+ RetValue::Context(ret_value.into())
+ }
+}
+
+impl TryFrom<RetValue> for Context {
+ type Error = anyhow::Error;
+ fn try_from(ret_value: RetValue) -> anyhow::Result<Context> {
+ match ret_value {
+ RetValue::Context(dt) => Ok(dt.into()),
+ // TODO(@jroesch): improve
+ _ => Err(anyhow::anyhow!("unable to convert datatype from ...")),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn context() {
+ let ctx = Context::cpu(0);
+ println!("ctx: {}", ctx);
+ let default_ctx = Context::new(DeviceType::CPU, 0);
+ assert_eq!(ctx.clone(), default_ctx);
+ assert_ne!(ctx, Context::gpu(0));
+
+ let str_ctx = Context::new(DeviceType::GPU, 0);
+ assert_eq!(str_ctx.clone(), str_ctx);
+ assert_ne!(str_ctx, Context::new(DeviceType::CPU, 0));
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::any::TypeId;
+use std::convert::TryFrom;
+use std::str::FromStr;
+
+use crate::ffi::DLDataType;
+use crate::packed_func::RetValue;
+
+use thiserror::Error;
+
+const DL_INT_CODE: u8 = 0;
+const DL_UINT_CODE: u8 = 1;
+const DL_FLOAT_CODE: u8 = 2;
+const DL_HANDLE: u8 = 3;
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct DataType {
+ code: u8,
+ bits: u8,
+ lanes: u16,
+}
+
+impl DataType {
+ pub fn new(code: u8, bits: u8, lanes: u16) -> DataType {
+ DataType { code, bits, lanes }
+ }
+
+ /// Returns the number of bytes occupied by an element of this `DataType`.
+ pub fn itemsize(&self) -> usize {
+ (self.bits as usize * self.lanes as usize) >> 3
+ }
+
+ /// Returns whether this `DataType` represents primitive type `T`.
+ pub fn is_type<T: 'static>(&self) -> bool {
+ if self.lanes != 1 {
+ return false;
+ }
+ let typ = TypeId::of::<T>();
+ (typ == TypeId::of::<i32>() && self.code == DL_INT_CODE && self.bits == 32)
+ || (typ == TypeId::of::<i64>() && self.code == DL_INT_CODE && self.bits == 64)
+ || (typ == TypeId::of::<u32>() && self.code == DL_UINT_CODE && self.bits == 32)
+ || (typ == TypeId::of::<u64>() && self.code == DL_UINT_CODE && self.bits == 64)
+ || (typ == TypeId::of::<f32>() && self.code == DL_FLOAT_CODE && self.bits == 32)
+ || (typ == TypeId::of::<f64>() && self.code == DL_FLOAT_CODE && self.bits == 64)
+ }
+
+ pub fn code(&self) -> usize {
+ self.code as usize
+ }
+
+ pub fn bits(&self) -> usize {
+ self.bits as usize
+ }
+
+ pub fn lanes(&self) -> usize {
+ self.lanes as usize
+ }
+}
+
+impl<'a> From<&'a DataType> for DLDataType {
+ fn from(dtype: &'a DataType) -> Self {
+ Self {
+ code: dtype.code as u8,
+ bits: dtype.bits as u8,
+ lanes: dtype.lanes as u16,
+ }
+ }
+}
+
+impl From<DLDataType> for DataType {
+ fn from(dtype: DLDataType) -> Self {
+ Self {
+ code: dtype.code,
+ bits: dtype.bits,
+ lanes: dtype.lanes,
+ }
+ }
+}
+
+#[derive(Debug, Error)]
+pub enum ParseDataTypeError {
+ #[error("invalid number: {0}")]
+ InvalidNumber(std::num::ParseIntError),
+ #[error("missing data type specifier (e.g., int32, float64)")]
+ MissingDataType,
+ #[error("unknown type: {0}")]
+ UnknownType(String),
+}
+
+/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
+/// such as "int32", "float32" or with lane "float32x1".
+impl FromStr for DataType {
+ type Err = ParseDataTypeError;
+
+ fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+ use ParseDataTypeError::*;
+
+ if type_str == "bool" {
+ return Ok(DataType::new(1, 1, 1));
+ }
+
+ let mut type_lanes = type_str.split('x');
+ let typ = type_lanes.next().ok_or(MissingDataType)?;
+ let lanes = type_lanes
+ .next()
+ .map(|l| <u16>::from_str_radix(l, 10))
+ .unwrap_or(Ok(1))
+ .map_err(InvalidNumber)?;
+ let (type_name, bits) = match typ.find(char::is_numeric) {
+ Some(idx) => {
+ let (name, bits_str) = typ.split_at(idx);
+ (
+ name,
+ u8::from_str_radix(bits_str, 10).map_err(InvalidNumber)?,
+ )
+ }
+ None => (typ, 32),
+ };
+
+ let type_code = match type_name {
+ "int" => DL_INT_CODE,
+ "uint" => DL_UINT_CODE,
+ "float" => DL_FLOAT_CODE,
+ "handle" => DL_HANDLE,
+ _ => return Err(UnknownType(type_name.to_string())),
+ };
+
+ Ok(DataType::new(type_code, bits, lanes))
+ }
+}
+
+impl std::fmt::Display for DataType {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ if self.bits == 1 && self.lanes == 1 {
+ return write!(f, "bool");
+ }
+ let mut type_str = match self.code {
+ DL_INT_CODE => "int",
+ DL_UINT_CODE => "uint",
+ DL_FLOAT_CODE => "float",
+ DL_HANDLE => "handle",
+ _ => "unknown",
+ }
+ .to_string();
+
+ type_str += &self.bits.to_string();
+ if self.lanes > 1 {
+ type_str += &format!("x{}", self.lanes);
+ }
+ f.write_str(&type_str)
+ }
+}
+
+impl From<DataType> for RetValue {
+ fn from(dt: DataType) -> RetValue {
+ RetValue::DataType((&dt).into())
+ }
+}
+
+impl TryFrom<RetValue> for DataType {
+ type Error = anyhow::Error;
+ fn try_from(ret_value: RetValue) -> anyhow::Result<DataType> {
+ match ret_value {
+ RetValue::DataType(dt) => Ok(dt.into()),
+ // TODO(@jroesch): improve
+ _ => Err(anyhow::anyhow!("unable to convert datatype from ...")),
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use thiserror::Error;
+
+#[derive(Error, Debug)]
+#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")]
+pub struct ValueDowncastError {
+ pub actual_type: String,
+ pub expected_type: &'static str,
+}
+
+#[derive(Error, Debug)]
+#[error("Function call `{context:?}` returned error: {message:?}")]
+pub struct FuncCallError {
+ context: String,
+ message: String,
+}
+
+impl FuncCallError {
+ pub fn get_with_context(context: String) -> Self {
+ Self {
+ context,
+ message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
+ .to_str()
+ .expect("double fault")
+ .to_owned(),
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This crate contains the minimal interface over TVM's
+//! C runtime API.
+//!
+//! These common bindings are useful to both runtimes
+//! written in Rust, as well as higher level API bindings.
+//!
+//! See the `tvm-rt` or `tvm` crates for full bindings to
+//! the TVM API.
+
+/// The low-level C runtime FFI API for TVM.
+pub mod ffi {
+ #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
+
+ use std::os::raw::{c_char, c_int, c_void};
+
+ include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
+
+ pub type BackendPackedCFunc =
+ extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
+}
+
+pub mod array;
+pub mod byte_array;
+pub mod context;
+pub mod datatype;
+pub mod errors;
+#[macro_use]
+pub mod packed_func;
+pub mod value;
+
+pub use byte_array::ByteArray;
+pub use context::{Context, DeviceType};
+pub use datatype::DataType;
+pub use errors::*;
+pub use packed_func::{ArgValue, RetValue};
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+ convert::TryFrom,
+ ffi::{CStr, CString},
+ os::raw::c_void,
+};
+
+use crate::{errors::ValueDowncastError, ffi::*};
+
+pub use crate::ffi::TVMValue;
+
+pub trait PackedFunc:
+ Fn(&[ArgValue]) -> Result<RetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
+
+impl<T> PackedFunc for T where
+ T: Fn(&[ArgValue]) -> Result<RetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
+
+/// Calls a packed function and returns a `RetValue`.
+///
+/// # Example
+///
+/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
+#[macro_export]
+macro_rules! call_packed {
+ ($fn:expr, $($args:expr),+) => {
+ $fn(&[$($args.into(),)+])
+ };
+ ($fn:expr) => {
+ $fn(&Vec::new())
+ };
+}
+
+/// Constructs a derivative of a TVMPodValue.
+macro_rules! TVMPODValue {
+ {
+ $(#[$m:meta])+
+ $name:ident $(<$a:lifetime>)? {
+ $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)?
+ },
+ match $value:ident {
+ $($tvm_type:ident => { $from_tvm_type:expr })+
+ },
+ match &self {
+ $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+
+ }
+ $(,)?
+ } => {
+ $(#[$m])+
+ #[derive(Clone, Debug)]
+ pub enum $name $(<$a>)? {
+ Int(i64),
+ UInt(i64),
+ Float(f64),
+ Null,
+ DataType(DLDataType),
+ String(CString),
+ Context(TVMContext),
+ Handle(*mut c_void),
+ ArrayHandle(TVMArrayHandle),
+ ObjectHandle(*mut c_void),
+ ModuleHandle(TVMModuleHandle),
+ FuncHandle(TVMFunctionHandle),
+ NDArrayHandle(*mut c_void),
+ $($extra_variant($variant_type)),+
+ }
+
+ impl $(<$a>)? $name $(<$a>)? {
+ pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self {
+ use $name::*;
+ #[allow(non_upper_case_globals)]
+ unsafe {
+ match type_code as _ {
+ DLDataTypeCode_kDLInt => Int($value.v_int64),
+ DLDataTypeCode_kDLUInt => UInt($value.v_int64),
+ DLDataTypeCode_kDLFloat => Float($value.v_float64),
+ TVMTypeCode_kTVMNullptr => Null,
+ TVMTypeCode_kTVMDataType => DataType($value.v_type),
+ TVMTypeCode_kTVMContext => Context($value.v_ctx),
+ TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
+ TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
+ TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
+ TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
+ TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
+ TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
+ $( $tvm_type => { $from_tvm_type } ),+
+ _ => unimplemented!("{}", type_code),
+ }
+ }
+ }
+
+ pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) {
+ use $name::*;
+ match self {
+ Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
+ UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
+ Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
+ Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr),
+ DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType),
+ Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
+ String(val) => {
+ (
+ TVMValue { v_handle: val.as_ptr() as *mut c_void },
+ TVMTypeCode_kTVMStr,
+ )
+ }
+ Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle),
+ ArrayHandle(val) => {
+ (
+ TVMValue { v_handle: *val as *const _ as *mut c_void },
+ TVMTypeCode_kTVMNDArrayHandle,
+ )
+ },
+ ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle),
+ ModuleHandle(val) =>
+ (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle),
+ FuncHandle(val) => (
+ TVMValue { v_handle: *val },
+ TVMTypeCode_kTVMPackedFuncHandle
+ ),
+ NDArrayHandle(val) =>
+ (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
+ $( $self_type($val) => { $from_self_type } ),+
+ }
+ }
+ }
+ }
+}
+
+TVMPODValue! {
+ /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
+ /// to obtain a `ArgValue` is automatically via `call_packed!`.
+ ArgValue<'a> {
+ Bytes(&'a TVMByteArray),
+ Str(&'a CStr),
+ },
+ match value {
+ TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
+ TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
+ },
+ match &self {
+ Bytes(val) => {
+ (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes)
+ }
+ Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) }
+ }
+}
+
+TVMPODValue! {
+ /// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
+ /// Can be downcasted using `try_from` if it contains the desired type.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use std::convert::{TryFrom, TryInto};
+ /// use tvm_sys::RetValue;
+ ///
+ /// let a = 42u32;
+ /// let b: u32 = tvm_sys::RetValue::from(a).try_into().unwrap();
+ ///
+ /// let s = "hello, world!";
+ /// let t: RetValue = s.to_string().into();
+ /// assert_eq!(String::try_from(t).unwrap(), s);
+ /// ```
+ RetValue {
+ Bytes(TVMByteArray),
+ Str(&'static CStr),
+ },
+ match value {
+ TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
+ TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
+ },
+ match &self {
+ Bytes(val) =>
+ { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) }
+ Str(val) =>
+ { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) }
+ }
+}
+
+#[macro_export]
+macro_rules! try_downcast {
+ ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => {
+ match $val {
+ $( $pat => { Ok($converter) } )+
+ _ => Err($crate::errors::ValueDowncastError {
+ actual_type: format!("{:?}", $val),
+ expected_type: stringify!($into),
+ }),
+ }
+ };
+}
+
+/// Creates a conversion to a `ArgValue` for a primitive type and DLDataTypeCode.
+macro_rules! impl_pod_value {
+ ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => {
+ $(
+ impl<'a> From<$type> for ArgValue<'a> {
+ fn from(val: $type) -> Self {
+ Self::$variant(val as $inner_ty)
+ }
+ }
+
+ impl<'a, 'v> From<&'a $type> for ArgValue<'v> {
+ fn from(val: &'a $type) -> Self {
+ Self::$variant(*val as $inner_ty)
+ }
+ }
+
+ impl<'a> TryFrom<ArgValue<'a>> for $type {
+ type Error = $crate::errors::ValueDowncastError;
+ fn try_from(val: ArgValue<'a>) -> Result<Self, Self::Error> {
+ try_downcast!(val -> $type, |ArgValue::$variant(val)| { val as $type })
+ }
+ }
+
+ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type {
+ type Error = $crate::errors::ValueDowncastError;
+ fn try_from(val: &'a ArgValue<'v>) -> Result<Self, Self::Error> {
+ try_downcast!(val -> $type, |ArgValue::$variant(val)| { *val as $type })
+ }
+ }
+
+ impl From<$type> for RetValue {
+ fn from(val: $type) -> Self {
+ Self::$variant(val as $inner_ty)
+ }
+ }
+
+ impl TryFrom<RetValue> for $type {
+ type Error = $crate::errors::ValueDowncastError;
+ fn try_from(val: RetValue) -> Result<Self, Self::Error> {
+ try_downcast!(val -> $type, |RetValue::$variant(val)| { val as $type })
+ }
+ }
+ )+
+ };
+}
+
+impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]);
+impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]);
+impl_pod_value!(Float, f64, [f32, f64]);
+impl_pod_value!(DataType, DLDataType, [DLDataType]);
+impl_pod_value!(Context, TVMContext, [TVMContext]);
+
+impl<'a> From<&'a str> for ArgValue<'a> {
+ fn from(s: &'a str) -> Self {
+ Self::String(CString::new(s).unwrap())
+ }
+}
+
+impl<'a> From<String> for ArgValue<'a> {
+ fn from(s: String) -> Self {
+ Self::String(CString::new(s).unwrap())
+ }
+}
+
+impl<'a> From<&'a CStr> for ArgValue<'a> {
+ fn from(s: &'a CStr) -> Self {
+ Self::Str(s)
+ }
+}
+
+impl<'a> From<CString> for ArgValue<'a> {
+ fn from(s: CString) -> Self {
+ Self::String(s)
+ }
+}
+
+impl<'a> From<&'a TVMByteArray> for ArgValue<'a> {
+ fn from(s: &'a TVMByteArray) -> Self {
+ Self::Bytes(s)
+ }
+}
+
+impl<'a> TryFrom<ArgValue<'a>> for &'a str {
+ type Error = ValueDowncastError;
+ fn try_from(val: ArgValue<'a>) -> Result<Self, Self::Error> {
+ try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() })
+ }
+}
+
+impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str {
+ type Error = ValueDowncastError;
+ fn try_from(val: &'a ArgValue<'v>) -> Result<Self, Self::Error> {
+ try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() })
+ }
+}
+
+/// Converts an unspecialized handle to a ArgValue.
+impl<T> From<*const T> for ArgValue<'static> {
+ fn from(ptr: *const T) -> Self {
+ Self::Handle(ptr as *mut c_void)
+ }
+}
+
+/// Converts an unspecialized mutable handle to a ArgValue.
+impl<T> From<*mut T> for ArgValue<'static> {
+ fn from(ptr: *mut T) -> Self {
+ Self::Handle(ptr as *mut c_void)
+ }
+}
+
+impl<'a> From<&'a mut DLTensor> for ArgValue<'a> {
+ fn from(arr: &'a mut DLTensor) -> Self {
+ Self::ArrayHandle(arr as *mut DLTensor)
+ }
+}
+
+impl<'a> From<&'a DLTensor> for ArgValue<'a> {
+ fn from(arr: &'a DLTensor) -> Self {
+ Self::ArrayHandle(arr as *const _ as *mut DLTensor)
+ }
+}
+
+impl TryFrom<RetValue> for String {
+ type Error = ValueDowncastError;
+ fn try_from(val: RetValue) -> Result<String, Self::Error> {
+ try_downcast!(
+ val -> String,
+ |RetValue::String(s)| { s.into_string().unwrap() },
+ |RetValue::Str(s)| { s.to_str().unwrap().to_string() }
+ )
+ }
+}
+
+impl From<String> for RetValue {
+ fn from(s: String) -> Self {
+ Self::String(std::ffi::CString::new(s).unwrap())
+ }
+}
+
+impl From<TVMByteArray> for RetValue {
+ fn from(arr: TVMByteArray) -> Self {
+ Self::Bytes(arr)
+ }
+}
+
+impl TryFrom<RetValue> for TVMByteArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: RetValue) -> Result<Self, Self::Error> {
+ try_downcast!(val -> TVMByteArray, |RetValue::Bytes(val)| { val })
+ }
+}
+
+impl Default for RetValue {
+ fn default() -> Self {
+ Self::Int(0)
+ }
+}
+
+impl TryFrom<RetValue> for std::ffi::CString {
+ type Error = ValueDowncastError;
+ fn try_from(val: RetValue) -> Result<CString, Self::Error> {
+ try_downcast!(val -> std::ffi::CString,
+ |RetValue::Str(val)| { val.into() })
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::str::FromStr;
+
+use crate::ffi::*;
+
+use thiserror::Error;
+
+macro_rules! impl_pod_tvm_value {
+ ($field:ident, $field_ty:ty, $( $ty:ty ),+) => {
+ $(
+ impl From<$ty> for TVMValue {
+ fn from(val: $ty) -> Self {
+ TVMValue { $field: val as $field_ty }
+ }
+ }
+
+ impl From<TVMValue> for $ty {
+ fn from(val: TVMValue) -> Self {
+ unsafe { val.$field as $ty }
+ }
+ }
+ )+
+ };
+ ($field:ident, $ty:ty) => {
+ impl_pod_tvm_value!($field, $ty, $ty);
+ }
+}
+
+impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
+impl_pod_tvm_value!(v_float64, f64, f32, f64);
+impl_pod_tvm_value!(v_type, DLDataType);
+impl_pod_tvm_value!(v_ctx, TVMContext);
+
+#[derive(Debug, Error)]
+#[error("unsupported device: {0}")]
+pub struct UnsupportedDeviceError(String);
+
+macro_rules! impl_tvm_context {
+ ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
+ /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
+ impl FromStr for TVMContext {
+ type Err = UnsupportedDeviceError;
+ fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+ Ok(Self {
+ device_type: match type_str {
+ $( $( stringify!($dev_name) )|+ => $dev_type ),+,
+ _ => return Err(UnsupportedDeviceError(type_str.to_string())),
+ },
+ device_id: 0,
+ })
+ }
+ }
+
+ impl TVMContext {
+ $(
+ $(
+ pub fn $dev_name(device_id: usize) -> Self {
+ Self {
+ device_type: $dev_type,
+ device_id: device_id as i32,
+ }
+ }
+ )+
+ )+
+ }
+ };
+}
+
+impl_tvm_context!(
+ DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
+ DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+ DLDeviceType_kDLOpenCL: [cl],
+ DLDeviceType_kDLMetal: [metal],
+ DLDeviceType_kDLVPI: [vpi],
+ DLDeviceType_kDLROCM: [rocm],
+ DLDeviceType_kDLExtDev: [ext_dev]
+);