[Rust] Add first stage of updating and rewriting Rust bindings. (#5526)
authorJared Roesch <jroesch@octoml.ai>
Fri, 8 May 2020 23:53:47 +0000 (16:53 -0700)
committerGitHub <noreply@github.com>
Fri, 8 May 2020 23:53:47 +0000 (16:53 -0700)
* Add tvm-sys

* Use as_mut_ptr

* Address CR feedback

* Update rust/tvm-sys/src/datatype.rs

Co-authored-by: Nick Hynes <nhynes@berkeley.edu>
* Final CR comments

* Fix find and replace error in frontend

Co-authored-by: Nick Hynes <nhynes@berkeley.edu>
12 files changed:
rust/.rustfmt.toml
rust/Cargo.toml
rust/tvm-sys/Cargo.toml [new file with mode: 0644]
rust/tvm-sys/build.rs [new file with mode: 0644]
rust/tvm-sys/src/array.rs [new file with mode: 0644]
rust/tvm-sys/src/byte_array.rs [new file with mode: 0644]
rust/tvm-sys/src/context.rs [new file with mode: 0644]
rust/tvm-sys/src/datatype.rs [new file with mode: 0644]
rust/tvm-sys/src/errors.rs [new file with mode: 0644]
rust/tvm-sys/src/lib.rs [new file with mode: 0644]
rust/tvm-sys/src/packed_func.rs [new file with mode: 0644]
rust/tvm-sys/src/value.rs [new file with mode: 0644]

index 3c51bb3..5a1f1d2 100644 (file)
@@ -29,3 +29,4 @@ merge_derives = true
 use_try_shorthand = false
 use_field_init_shorthand = false
 force_explicit_abi = true
+
index f08f861..b4a159c 100644 (file)
@@ -27,5 +27,6 @@ members = [
        "frontend",
        "frontend/tests/basics",
        "frontend/tests/callback",
-       "frontend/examples/resnet"
+       "frontend/examples/resnet",
+        "tvm-sys"
 ]
diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml
new file mode 100644 (file)
index 0000000..fe4d0bf
--- /dev/null
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "tvm-sys"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+license = "Apache-2.0"
+edition = "2018"
+
+[features]
+bindings = []
+
+[dependencies]
+thiserror = "^1.0"
+anyhow = "^1.0"
+ndarray = "0.12"
+enumn = "^0.1"
+
+[build-dependencies]
+bindgen = "0.51"
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
new file mode 100644 (file)
index 0000000..85e16be
--- /dev/null
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+extern crate bindgen;
+
+use std::path::PathBuf;
+
+use std::env;
+
+fn main() {
+    let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
+        let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
+            .canonicalize()
+            .unwrap();
+        crate_dir
+            .parent()
+            .unwrap()
+            .parent()
+            .unwrap()
+            .to_str()
+            .unwrap()
+            .to_string()
+    });
+
+    if cfg!(feature = "bindings") {
+        println!("cargo:rerun-if-env-changed=TVM_HOME");
+        println!("cargo:rustc-link-lib=dylib=tvm");
+        println!("cargo:rustc-link-search={}/build", tvm_home);
+    }
+
+    // @see rust-bindgen#550 for `blacklist_type`
+    bindgen::Builder::default()
+        .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
+        .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
+        .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
+        .clang_arg(format!("-I{}/include/", tvm_home))
+        .blacklist_type("max_align_t")
+        .layout_tests(false)
+        .derive_partialeq(true)
+        .derive_eq(true)
+        .generate()
+        .expect("unable to generate bindings")
+        .write_to_file(PathBuf::from("src/c_runtime_api.rs"))
+        .expect("can not write the bindings!");
+}
diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs
new file mode 100644 (file)
index 0000000..1627e9e
--- /dev/null
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    mem,
+    os::raw::{c_int, c_void},
+};
+
+use crate::ffi::{
+    DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
+    DLDeviceType_kDLCPU, DLTensor,
+};
+
+/// `From` conversions to `DLTensor` for `ndarray::Array`.
+/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
+macro_rules! impl_dltensor_from_ndarray {
+    ($type:ty, $typecode:expr) => {
+        impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
+            fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
+                DLTensor {
+                    data: arr.as_mut_ptr() as *mut c_void,
+                    ctx: DLContext {
+                        device_type: DLDeviceType_kDLCPU,
+                        device_id: 0,
+                    },
+                    ndim: arr.ndim() as c_int,
+                    dtype: DLDataType {
+                        code: $typecode as u8,
+                        bits: 8 * mem::size_of::<$type>() as u8,
+                        lanes: 1,
+                    },
+                    shape: arr.shape().as_ptr() as *const i64 as *mut i64,
+                    strides: arr.strides().as_ptr() as *const i64 as *mut i64,
+                    byte_offset: 0,
+                }
+            }
+        }
+    };
+}
+
+impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
+impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
+impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
+impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs
new file mode 100644 (file)
index 0000000..40f28f4
--- /dev/null
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use std::os::raw::c_char;
+
+use crate::ffi::TVMByteArray;
+
+/// A newtype wrapping a raw TVM byte-array.
+///
+/// ## Example
+///
+/// ```
+/// let v = b"hello";
+/// let barr = tvm_sys::ByteArray::from(&v);
+/// assert_eq!(barr.len(), v.len());
+/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
+/// ```
+pub struct ByteArray {
+    /// The raw FFI ByteArray.
+    array: TVMByteArray,
+}
+
+impl ByteArray {
+    /// Gets the underlying byte-array
+    pub fn data(&self) -> &'static [u8] {
+        unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size) }
+    }
+
+    /// Gets the length of the underlying byte-array
+    pub fn len(&self) -> usize {
+        self.array.size
+    }
+
+    /// Converts the underlying byte-array to `Vec<u8>`
+    pub fn to_vec(&self) -> Vec<u8> {
+        self.data().to_vec()
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.len() == 0
+    }
+}
+
+// Needs AsRef for Vec
+impl<T: AsRef<[u8]>> From<T> for ByteArray {
+    fn from(arg: T) -> Self {
+        let arg = arg.as_ref();
+        ByteArray {
+            array: TVMByteArray {
+                data: arg.as_ptr() as *const c_char,
+                size: arg.len(),
+            },
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn convert() {
+        let v = vec![1u8, 2, 3];
+        let barr = ByteArray::from(&v);
+        assert_eq!(barr.len(), v.len());
+        assert_eq!(barr.to_vec(), vec![1u8, 2, 3]);
+        let v = b"hello";
+        let barr = ByteArray::from(&v);
+        assert_eq!(barr.len(), v.len());
+        assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
+    }
+}
diff --git a/rust/tvm-sys/src/context.rs b/rust/tvm-sys/src/context.rs
new file mode 100644 (file)
index 0000000..64b58b9
--- /dev/null
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! Provides [`Context`] and related device queries.
+//!
+//! Create a new context for device type and device id.
+//!
+//! # Example
+//!
+//! ```
+//! # use tvm_sys::{DeviceType, Context};
+//! let cpu = DeviceType::from("cpu");
+//! let ctx = Context::new(cpu , 0);
+//! let cpu0 = Context::cpu(0);
+//! assert_eq!(ctx, cpu0);
+//! ```
+//!
+//! Or from a supported device name.
+//!
+//! ```
+//! use tvm_sys::Context;
+//! let cpu0 = Context::from("cpu");
+//! println!("{}", cpu0);
+//! ```
+
+use std::convert::TryFrom;
+use std::fmt::{self, Display, Formatter};
+use std::str::FromStr;
+
+use crate::ffi::{self, *};
+use crate::packed_func::{ArgValue, RetValue};
+
+use anyhow::Result;
+use enumn::N;
+use thiserror::Error;
+
+/// Device type represents the set of devices supported by
+/// [TVM](https://github.com/apache/incubator-tvm).
+///
+/// ## Example
+///
+/// ```
+/// use tvm_sys::DeviceType;
+/// let cpu = DeviceType::from("cpu");
+/// println!("device is: {}", cpu);
+///```
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, N)]
+#[repr(i64)]
+pub enum DeviceType {
+    CPU = 1,
+    GPU,
+    CPUPinned,
+    OpenCL,
+    Vulkan,
+    Metal,
+    VPI,
+    ROCM,
+    ExtDev,
+}
+
+impl Default for DeviceType {
+    /// default device is cpu.
+    fn default() -> Self {
+        DeviceType::CPU
+    }
+}
+
+impl From<DeviceType> for ffi::DLDeviceType {
+    fn from(device_type: DeviceType) -> Self {
+        device_type as Self
+    }
+}
+
+impl From<ffi::DLDeviceType> for DeviceType {
+    fn from(device_type: ffi::DLDeviceType) -> Self {
+        Self::n(device_type as _).expect("invalid enumeration value for ffi::DLDeviceType")
+    }
+}
+
+impl Display for DeviceType {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(
+            f,
+            "{}",
+            match self {
+                DeviceType::CPU => "cpu",
+                DeviceType::GPU => "gpu",
+                DeviceType::CPUPinned => "cpu_pinned",
+                DeviceType::OpenCL => "opencl",
+                DeviceType::Vulkan => "vulkan",
+                DeviceType::Metal => "metal",
+                DeviceType::VPI => "vpi",
+                DeviceType::ROCM => "rocm",
+                DeviceType::ExtDev => "ext_device",
+                // DeviceType(_) => "rpc",
+            }
+        )
+    }
+}
+
+impl<'a> From<&'a str> for DeviceType {
+    fn from(type_str: &'a str) -> Self {
+        match type_str {
+            "cpu" => DeviceType::CPU,
+            "llvm" => DeviceType::CPU,
+            "stackvm" => DeviceType::CPU,
+            "gpu" => DeviceType::GPU,
+            "cuda" => DeviceType::GPU,
+            "nvptx" => DeviceType::GPU,
+            "cl" => DeviceType::OpenCL,
+            "opencl" => DeviceType::OpenCL,
+            "metal" => DeviceType::Metal,
+            "vpi" => DeviceType::VPI,
+            "rocm" => DeviceType::ROCM,
+            _ => panic!("{:?} not supported!", type_str),
+        }
+    }
+}
+
+impl<'a> From<&DeviceType> for ArgValue<'a> {
+    fn from(dev: &DeviceType) -> Self {
+        Self::Int(*dev as _)
+    }
+}
+
+#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
+pub struct Context {
+    pub device_type: DeviceType,
+    pub device_id: usize,
+}
+
+impl Context {
+    pub fn new(device_type: DeviceType, device_id: usize) -> Context {
+        Context {
+            device_type,
+            device_id,
+        }
+    }
+}
+
+impl<'a> From<&'a Context> for DLContext {
+    fn from(ctx: &'a Context) -> Self {
+        Self {
+            device_type: ctx.device_type.into(),
+            device_id: ctx.device_id as i32,
+        }
+    }
+}
+
+impl Default for Context {
+    fn default() -> Self {
+        Self {
+            device_type: DLDeviceType_kDLCPU.into(),
+            device_id: 0,
+        }
+    }
+}
+
+#[derive(Debug, Error)]
+#[error("unsupported device: {0}")]
+pub struct UnsupportedDeviceError(String);
+
+macro_rules! impl_tvm_context {
+    ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
+        /// Creates a Context from a string (e.g., "cpu", "gpu", "ext_dev")
+        impl FromStr for Context {
+            type Err = UnsupportedDeviceError;
+            fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+                Ok(Self {
+                    device_type: match type_str {
+                         $( $(  stringify!($dev_name)  )|+ => $dev_type.into()),+,
+                        _ => return Err(UnsupportedDeviceError(type_str.to_string())),
+                    },
+                    device_id: 0,
+                })
+            }
+        }
+
+        impl Context {
+            $(
+                $(
+                    pub fn $dev_name(device_id: usize) -> Self {
+                        Self {
+                            device_type: $dev_type.into(),
+                            device_id: device_id,
+                        }
+                    }
+                )+
+            )+
+        }
+    };
+}
+
+impl_tvm_context!(
+    DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
+    DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+    DLDeviceType_kDLOpenCL: [cl],
+    DLDeviceType_kDLMetal: [metal],
+    DLDeviceType_kDLVPI: [vpi],
+    DLDeviceType_kDLROCM: [rocm],
+    DLDeviceType_kDLExtDev: [ext_dev]
+);
+
+impl<'a> From<&'a str> for Context {
+    fn from(target: &str) -> Self {
+        Context::new(DeviceType::from(target), 0)
+    }
+}
+
+impl From<ffi::DLContext> for Context {
+    fn from(ctx: ffi::DLContext) -> Self {
+        Context {
+            device_type: DeviceType::from(ctx.device_type),
+            device_id: ctx.device_id as usize,
+        }
+    }
+}
+
+impl From<Context> for ffi::DLContext {
+    fn from(ctx: Context) -> Self {
+        ffi::DLContext {
+            device_type: ctx.device_type.into(),
+            device_id: ctx.device_id as i32,
+        }
+    }
+}
+
+impl Display for Context {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        write!(f, "{}({})", self.device_type, self.device_id)
+    }
+}
+
+impl From<Context> for RetValue {
+    fn from(ret_value: Context) -> RetValue {
+        RetValue::Context(ret_value.into())
+    }
+}
+
+impl TryFrom<RetValue> for Context {
+    type Error = anyhow::Error;
+    fn try_from(ret_value: RetValue) -> anyhow::Result<Context> {
+        match ret_value {
+            RetValue::Context(dt) => Ok(dt.into()),
+            // TODO(@jroesch): improve
+            _ => Err(anyhow::anyhow!("unable to convert datatype from ...")),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn context() {
+        let ctx = Context::cpu(0);
+        println!("ctx: {}", ctx);
+        let default_ctx = Context::new(DeviceType::CPU, 0);
+        assert_eq!(ctx.clone(), default_ctx);
+        assert_ne!(ctx, Context::gpu(0));
+
+        let str_ctx = Context::new(DeviceType::GPU, 0);
+        assert_eq!(str_ctx.clone(), str_ctx);
+        assert_ne!(str_ctx, Context::new(DeviceType::CPU, 0));
+    }
+}
diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs
new file mode 100644 (file)
index 0000000..5dd414c
--- /dev/null
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::any::TypeId;
+use std::convert::TryFrom;
+use std::str::FromStr;
+
+use crate::ffi::DLDataType;
+use crate::packed_func::RetValue;
+
+use thiserror::Error;
+
+const DL_INT_CODE: u8 = 0;
+const DL_UINT_CODE: u8 = 1;
+const DL_FLOAT_CODE: u8 = 2;
+const DL_HANDLE: u8 = 3;
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct DataType {
+    code: u8,
+    bits: u8,
+    lanes: u16,
+}
+
+impl DataType {
+    pub fn new(code: u8, bits: u8, lanes: u16) -> DataType {
+        DataType { code, bits, lanes }
+    }
+
+    /// Returns the number of bytes occupied by an element of this `DataType`.
+    pub fn itemsize(&self) -> usize {
+        (self.bits as usize * self.lanes as usize) >> 3
+    }
+
+    /// Returns whether this `DataType` represents primitive type `T`.
+    pub fn is_type<T: 'static>(&self) -> bool {
+        if self.lanes != 1 {
+            return false;
+        }
+        let typ = TypeId::of::<T>();
+        (typ == TypeId::of::<i32>() && self.code == DL_INT_CODE && self.bits == 32)
+            || (typ == TypeId::of::<i64>() && self.code == DL_INT_CODE && self.bits == 64)
+            || (typ == TypeId::of::<u32>() && self.code == DL_UINT_CODE && self.bits == 32)
+            || (typ == TypeId::of::<u64>() && self.code == DL_UINT_CODE && self.bits == 64)
+            || (typ == TypeId::of::<f32>() && self.code == DL_FLOAT_CODE && self.bits == 32)
+            || (typ == TypeId::of::<f64>() && self.code == DL_FLOAT_CODE && self.bits == 64)
+    }
+
+    pub fn code(&self) -> usize {
+        self.code as usize
+    }
+
+    pub fn bits(&self) -> usize {
+        self.bits as usize
+    }
+
+    pub fn lanes(&self) -> usize {
+        self.lanes as usize
+    }
+}
+
+impl<'a> From<&'a DataType> for DLDataType {
+    fn from(dtype: &'a DataType) -> Self {
+        Self {
+            code: dtype.code as u8,
+            bits: dtype.bits as u8,
+            lanes: dtype.lanes as u16,
+        }
+    }
+}
+
+impl From<DLDataType> for DataType {
+    fn from(dtype: DLDataType) -> Self {
+        Self {
+            code: dtype.code,
+            bits: dtype.bits,
+            lanes: dtype.lanes,
+        }
+    }
+}
+
+#[derive(Debug, Error)]
+pub enum ParseDataTypeError {
+    #[error("invalid number: {0}")]
+    InvalidNumber(std::num::ParseIntError),
+    #[error("missing data type specifier (e.g., int32, float64)")]
+    MissingDataType,
+    #[error("unknown type: {0}")]
+    UnknownType(String),
+}
+
+/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
+/// such as "int32", "float32" or with lane "float32x1".
+impl FromStr for DataType {
+    type Err = ParseDataTypeError;
+
+    fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+        use ParseDataTypeError::*;
+
+        if type_str == "bool" {
+            return Ok(DataType::new(1, 1, 1));
+        }
+
+        let mut type_lanes = type_str.split('x');
+        let typ = type_lanes.next().ok_or(MissingDataType)?;
+        let lanes = type_lanes
+            .next()
+            .map(|l| <u16>::from_str_radix(l, 10))
+            .unwrap_or(Ok(1))
+            .map_err(InvalidNumber)?;
+        let (type_name, bits) = match typ.find(char::is_numeric) {
+            Some(idx) => {
+                let (name, bits_str) = typ.split_at(idx);
+                (
+                    name,
+                    u8::from_str_radix(bits_str, 10).map_err(InvalidNumber)?,
+                )
+            }
+            None => (typ, 32),
+        };
+
+        let type_code = match type_name {
+            "int" => DL_INT_CODE,
+            "uint" => DL_UINT_CODE,
+            "float" => DL_FLOAT_CODE,
+            "handle" => DL_HANDLE,
+            _ => return Err(UnknownType(type_name.to_string())),
+        };
+
+        Ok(DataType::new(type_code, bits, lanes))
+    }
+}
+
+impl std::fmt::Display for DataType {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        if self.bits == 1 && self.lanes == 1 {
+            return write!(f, "bool");
+        }
+        let mut type_str = match self.code {
+            DL_INT_CODE => "int",
+            DL_UINT_CODE => "uint",
+            DL_FLOAT_CODE => "float",
+            DL_HANDLE => "handle",
+            _ => "unknown",
+        }
+        .to_string();
+
+        type_str += &self.bits.to_string();
+        if self.lanes > 1 {
+            type_str += &format!("x{}", self.lanes);
+        }
+        f.write_str(&type_str)
+    }
+}
+
+impl From<DataType> for RetValue {
+    fn from(dt: DataType) -> RetValue {
+        RetValue::DataType((&dt).into())
+    }
+}
+
+impl TryFrom<RetValue> for DataType {
+    type Error = anyhow::Error;
+    fn try_from(ret_value: RetValue) -> anyhow::Result<DataType> {
+        match ret_value {
+            RetValue::DataType(dt) => Ok(dt.into()),
+            // TODO(@jroesch): improve
+            _ => Err(anyhow::anyhow!("unable to convert datatype from ...")),
+        }
+    }
+}
diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-sys/src/errors.rs
new file mode 100644 (file)
index 0000000..8479ec6
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use thiserror::Error;
+
+#[derive(Error, Debug)]
+#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")]
+pub struct ValueDowncastError {
+    pub actual_type: String,
+    pub expected_type: &'static str,
+}
+
+#[derive(Error, Debug)]
+#[error("Function call `{context:?}` returned error: {message:?}")]
+pub struct FuncCallError {
+    context: String,
+    message: String,
+}
+
+impl FuncCallError {
+    pub fn get_with_context(context: String) -> Self {
+        Self {
+            context,
+            message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
+                .to_str()
+                .expect("double fault")
+                .to_owned(),
+        }
+    }
+}
diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs
new file mode 100644 (file)
index 0000000..dd28e36
--- /dev/null
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This crate contains the minimal interface over TVM's
+//! C runtime API.
+//!
+//! These common bindings are useful to both runtimes
+//! written in Rust, as well as higher level API bindings.
+//!
+//! See the `tvm-rt` or `tvm` crates for full bindings to
+//! the TVM API.
+
+/// The low-level C runtime FFI API for TVM.
+pub mod ffi {
+    #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)]
+
+    use std::os::raw::{c_char, c_int, c_void};
+
+    include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
+
+    pub type BackendPackedCFunc =
+        extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
+}
+
+pub mod array;
+pub mod byte_array;
+pub mod context;
+pub mod datatype;
+pub mod errors;
+#[macro_use]
+pub mod packed_func;
+pub mod value;
+
+pub use byte_array::ByteArray;
+pub use context::{Context, DeviceType};
+pub use datatype::DataType;
+pub use errors::*;
+pub use packed_func::{ArgValue, RetValue};
diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs
new file mode 100644 (file)
index 0000000..e4b2739
--- /dev/null
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+    convert::TryFrom,
+    ffi::{CStr, CString},
+    os::raw::c_void,
+};
+
+use crate::{errors::ValueDowncastError, ffi::*};
+
+pub use crate::ffi::TVMValue;
+
+pub trait PackedFunc:
+    Fn(&[ArgValue]) -> Result<RetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
+
+impl<T> PackedFunc for T where
+    T: Fn(&[ArgValue]) -> Result<RetValue, crate::errors::FuncCallError> + Send + Sync
+{
+}
+
+/// Calls a packed function and returns a `RetValue`.
+///
+/// # Example
+///
+/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)`
+#[macro_export]
+macro_rules! call_packed {
+    ($fn:expr, $($args:expr),+) => {
+        $fn(&[$($args.into(),)+])
+    };
+    ($fn:expr) => {
+        $fn(&Vec::new())
+    };
+}
+
+/// Constructs a derivative of a TVMPodValue.
+macro_rules! TVMPODValue {
+    {
+        $(#[$m:meta])+
+        $name:ident $(<$a:lifetime>)? {
+            $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)?
+        },
+        match $value:ident {
+            $($tvm_type:ident => { $from_tvm_type:expr })+
+        },
+        match &self {
+            $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+
+        }
+        $(,)?
+    } => {
+        $(#[$m])+
+        #[derive(Clone, Debug)]
+        pub enum $name $(<$a>)? {
+            Int(i64),
+            UInt(i64),
+            Float(f64),
+            Null,
+            DataType(DLDataType),
+            String(CString),
+            Context(TVMContext),
+            Handle(*mut c_void),
+            ArrayHandle(TVMArrayHandle),
+            ObjectHandle(*mut c_void),
+            ModuleHandle(TVMModuleHandle),
+            FuncHandle(TVMFunctionHandle),
+            NDArrayHandle(*mut c_void),
+            $($extra_variant($variant_type)),+
+        }
+
+        impl $(<$a>)? $name $(<$a>)? {
+            pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self {
+                use $name::*;
+                #[allow(non_upper_case_globals)]
+                unsafe {
+                    match type_code as _ {
+                        DLDataTypeCode_kDLInt => Int($value.v_int64),
+                        DLDataTypeCode_kDLUInt => UInt($value.v_int64),
+                        DLDataTypeCode_kDLFloat => Float($value.v_float64),
+                        TVMTypeCode_kTVMNullptr => Null,
+                        TVMTypeCode_kTVMDataType => DataType($value.v_type),
+                        TVMTypeCode_kTVMContext => Context($value.v_ctx),
+                        TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
+                        TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
+                        TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
+                        TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
+                        TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
+                        TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
+                        $( $tvm_type => { $from_tvm_type } ),+
+                        _ => unimplemented!("{}", type_code),
+                    }
+                }
+            }
+
+            pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) {
+                use $name::*;
+                match self {
+                    Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
+                    UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
+                    Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
+                    Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr),
+                    DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType),
+                    Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
+                    String(val) => {
+                        (
+                            TVMValue { v_handle: val.as_ptr() as *mut c_void },
+                            TVMTypeCode_kTVMStr,
+                        )
+                    }
+                    Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle),
+                    ArrayHandle(val) => {
+                        (
+                            TVMValue { v_handle: *val as *const _ as *mut c_void },
+                            TVMTypeCode_kTVMNDArrayHandle,
+                        )
+                    },
+                    ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle),
+                    ModuleHandle(val) =>
+                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle),
+                    FuncHandle(val) => (
+                        TVMValue { v_handle: *val },
+                        TVMTypeCode_kTVMPackedFuncHandle
+                    ),
+                    NDArrayHandle(val) =>
+                        (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
+                    $( $self_type($val) => { $from_self_type } ),+
+                }
+            }
+        }
+    }
+}
+
+TVMPODValue! {
+    /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way
+    /// to obtain a `ArgValue` is automatically via `call_packed!`.
+    ArgValue<'a> {
+        Bytes(&'a TVMByteArray),
+        Str(&'a CStr),
+    },
+    match value {
+        TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
+        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
+    },
+    match &self {
+        Bytes(val) => {
+            (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes)
+        }
+        Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) }
+    }
+}
+
+TVMPODValue! {
+    /// An owned TVMPODValue. Can be converted from a variety of primitive and object types.
+    /// Can be downcasted using `try_from` if it contains the desired type.
+    ///
+    /// # Example
+    ///
+    /// ```
+    /// use std::convert::{TryFrom, TryInto};
+    /// use tvm_sys::RetValue;
+    ///
+    /// let a = 42u32;
+    /// let b: u32 = tvm_sys::RetValue::from(a).try_into().unwrap();
+    ///
+    /// let s = "hello, world!";
+    /// let t: RetValue = s.to_string().into();
+    /// assert_eq!(String::try_from(t).unwrap(), s);
+    /// ```
+    RetValue {
+        Bytes(TVMByteArray),
+        Str(&'static CStr),
+    },
+    match value {
+        TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
+        TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
+    },
+    match &self {
+        Bytes(val) =>
+            { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) }
+        Str(val) =>
+            { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) }
+    }
+}
+
+#[macro_export]
+macro_rules! try_downcast {
+    ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => {
+        match $val {
+            $( $pat => { Ok($converter) } )+
+            _ => Err($crate::errors::ValueDowncastError {
+                actual_type: format!("{:?}", $val),
+                expected_type: stringify!($into),
+            }),
+        }
+    };
+}
+
+/// Creates a conversion to a `ArgValue` for a primitive type and DLDataTypeCode.
+macro_rules! impl_pod_value {
+    ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => {
+        $(
+            impl<'a> From<$type> for ArgValue<'a> {
+                fn from(val: $type) -> Self {
+                    Self::$variant(val as $inner_ty)
+                }
+            }
+
+            impl<'a, 'v> From<&'a $type> for ArgValue<'v> {
+                fn from(val: &'a $type) -> Self {
+                    Self::$variant(*val as $inner_ty)
+                }
+            }
+
+            impl<'a> TryFrom<ArgValue<'a>> for $type {
+                type Error = $crate::errors::ValueDowncastError;
+                fn try_from(val: ArgValue<'a>) -> Result<Self, Self::Error> {
+                    try_downcast!(val -> $type, |ArgValue::$variant(val)| { val as $type })
+                }
+            }
+
+            impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type {
+                type Error = $crate::errors::ValueDowncastError;
+                fn try_from(val: &'a ArgValue<'v>) -> Result<Self, Self::Error> {
+                    try_downcast!(val -> $type, |ArgValue::$variant(val)| { *val as $type })
+                }
+            }
+
+            impl From<$type> for RetValue {
+                fn from(val: $type) -> Self {
+                    Self::$variant(val as $inner_ty)
+                }
+            }
+
+            impl TryFrom<RetValue> for $type {
+              type Error = $crate::errors::ValueDowncastError;
+                fn try_from(val: RetValue) -> Result<Self, Self::Error> {
+                    try_downcast!(val -> $type, |RetValue::$variant(val)| { val as $type })
+                }
+            }
+        )+
+    };
+}
+
+impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]);
+impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]);
+impl_pod_value!(Float, f64, [f32, f64]);
+impl_pod_value!(DataType, DLDataType, [DLDataType]);
+impl_pod_value!(Context, TVMContext, [TVMContext]);
+
+impl<'a> From<&'a str> for ArgValue<'a> {
+    fn from(s: &'a str) -> Self {
+        Self::String(CString::new(s).unwrap())
+    }
+}
+
+impl<'a> From<String> for ArgValue<'a> {
+    fn from(s: String) -> Self {
+        Self::String(CString::new(s).unwrap())
+    }
+}
+
+impl<'a> From<&'a CStr> for ArgValue<'a> {
+    fn from(s: &'a CStr) -> Self {
+        Self::Str(s)
+    }
+}
+
+impl<'a> From<CString> for ArgValue<'a> {
+    fn from(s: CString) -> Self {
+        Self::String(s)
+    }
+}
+
+impl<'a> From<&'a TVMByteArray> for ArgValue<'a> {
+    fn from(s: &'a TVMByteArray) -> Self {
+        Self::Bytes(s)
+    }
+}
+
+impl<'a> TryFrom<ArgValue<'a>> for &'a str {
+    type Error = ValueDowncastError;
+    fn try_from(val: ArgValue<'a>) -> Result<Self, Self::Error> {
+        try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() })
+    }
+}
+
+impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str {
+    type Error = ValueDowncastError;
+    fn try_from(val: &'a ArgValue<'v>) -> Result<Self, Self::Error> {
+        try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() })
+    }
+}
+
+/// Converts an unspecialized handle to a ArgValue.
+impl<T> From<*const T> for ArgValue<'static> {
+    fn from(ptr: *const T) -> Self {
+        Self::Handle(ptr as *mut c_void)
+    }
+}
+
+/// Converts an unspecialized mutable handle to a ArgValue.
+impl<T> From<*mut T> for ArgValue<'static> {
+    fn from(ptr: *mut T) -> Self {
+        Self::Handle(ptr as *mut c_void)
+    }
+}
+
+impl<'a> From<&'a mut DLTensor> for ArgValue<'a> {
+    fn from(arr: &'a mut DLTensor) -> Self {
+        Self::ArrayHandle(arr as *mut DLTensor)
+    }
+}
+
+impl<'a> From<&'a DLTensor> for ArgValue<'a> {
+    fn from(arr: &'a DLTensor) -> Self {
+        Self::ArrayHandle(arr as *const _ as *mut DLTensor)
+    }
+}
+
+impl TryFrom<RetValue> for String {
+    type Error = ValueDowncastError;
+    fn try_from(val: RetValue) -> Result<String, Self::Error> {
+        try_downcast!(
+            val -> String,
+            |RetValue::String(s)| { s.into_string().unwrap() },
+            |RetValue::Str(s)| { s.to_str().unwrap().to_string() }
+        )
+    }
+}
+
+impl From<String> for RetValue {
+    fn from(s: String) -> Self {
+        Self::String(std::ffi::CString::new(s).unwrap())
+    }
+}
+
+impl From<TVMByteArray> for RetValue {
+    fn from(arr: TVMByteArray) -> Self {
+        Self::Bytes(arr)
+    }
+}
+
+impl TryFrom<RetValue> for TVMByteArray {
+    type Error = ValueDowncastError;
+    fn try_from(val: RetValue) -> Result<Self, Self::Error> {
+        try_downcast!(val -> TVMByteArray, |RetValue::Bytes(val)| { val })
+    }
+}
+
+impl Default for RetValue {
+    fn default() -> Self {
+        Self::Int(0)
+    }
+}
+
+impl TryFrom<RetValue> for std::ffi::CString {
+    type Error = ValueDowncastError;
+    fn try_from(val: RetValue) -> Result<CString, Self::Error> {
+        try_downcast!(val -> std::ffi::CString,
+            |RetValue::Str(val)| { val.into() })
+    }
+}
diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs
new file mode 100644 (file)
index 0000000..a9ad5f5
--- /dev/null
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::str::FromStr;
+
+use crate::ffi::*;
+
+use thiserror::Error;
+
+macro_rules! impl_pod_tvm_value {
+    ($field:ident, $field_ty:ty, $( $ty:ty ),+) => {
+        $(
+            impl From<$ty> for TVMValue {
+                fn from(val: $ty) -> Self {
+                    TVMValue { $field: val as $field_ty }
+                }
+            }
+
+            impl From<TVMValue> for $ty {
+                fn from(val: TVMValue) -> Self {
+                    unsafe { val.$field as $ty }
+                }
+            }
+        )+
+    };
+    ($field:ident, $ty:ty) => {
+        impl_pod_tvm_value!($field, $ty, $ty);
+    }
+}
+
+impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize);
+impl_pod_tvm_value!(v_float64, f64, f32, f64);
+impl_pod_tvm_value!(v_type, DLDataType);
+impl_pod_tvm_value!(v_ctx, TVMContext);
+
+#[derive(Debug, Error)]
+#[error("unsupported device: {0}")]
+pub struct UnsupportedDeviceError(String);
+
+macro_rules! impl_tvm_context {
+    ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => {
+        /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev")
+        impl FromStr for TVMContext {
+            type Err = UnsupportedDeviceError;
+            fn from_str(type_str: &str) -> Result<Self, Self::Err> {
+                Ok(Self {
+                    device_type: match type_str {
+                         $( $(  stringify!($dev_name)  )|+ => $dev_type ),+,
+                        _ => return Err(UnsupportedDeviceError(type_str.to_string())),
+                    },
+                    device_id: 0,
+                })
+            }
+        }
+
+        impl TVMContext {
+            $(
+                $(
+                    pub fn $dev_name(device_id: usize) -> Self {
+                        Self {
+                            device_type: $dev_type,
+                            device_id: device_id as i32,
+                        }
+                    }
+                )+
+            )+
+        }
+    };
+}
+
+impl_tvm_context!(
+    DLDeviceType_kDLCPU: [cpu, llvm, stackvm],
+    DLDeviceType_kDLGPU: [gpu, cuda, nvptx],
+    DLDeviceType_kDLOpenCL: [cl],
+    DLDeviceType_kDLMetal: [metal],
+    DLDeviceType_kDLVPI: [vpi],
+    DLDeviceType_kDLROCM: [rocm],
+    DLDeviceType_kDLExtDev: [ext_dev]
+);