--- /dev/null
+/*
+ * Copyright (c) 2013 Nordic Semiconductor ASA
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification,
+ * are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this list
+ * of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form, except as embedded into a Nordic Semiconductor ASA
+ * integrated circuit in a product or a software update for such product, must reproduce
+ * the above copyright notice, this list of conditions and the following disclaimer in
+ * the documentation and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of Nordic Semiconductor ASA nor the names of its contributors may be
+ * used to endorse or promote products derived from this software without specific prior
+ * written permission.
+ *
+ * 4. This software, with or without modification, must only be used with a
+ * Nordic Semiconductor ASA integrated circuit.
+ *
+ * 5. Any software provided in binary or object form under this license must not be reverse
+ * engineered, decompiled, modified and/or disassembled.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+#include "crc16.h"
+
+#include <stdlib.h>
+
+uint16_t crc16_compute(uint8_t const* p_data, uint32_t size, uint16_t const* p_crc) {
+ uint16_t crc = (p_crc == NULL) ? 0xFFFF : *p_crc;
+
+ for (uint32_t i = 0; i < size; i++) {
+ crc = (uint8_t)(crc >> 8) | (crc << 8);
+ crc ^= p_data[i];
+ crc ^= (uint8_t)(crc & 0xFF) >> 4;
+ crc ^= (crc << 8) << 4;
+ crc ^= ((crc & 0xFF) << 4) << 1;
+ }
+
+ return crc;
+}
--- /dev/null
+/*
+ * Copyright (c) 2013 Nordic Semiconductor ASA
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification,
+ * are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this list
+ * of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form, except as embedded into a Nordic Semiconductor ASA
+ * integrated circuit in a product or a software update for such product, must reproduce
+ * the above copyright notice, this list of conditions and the following disclaimer in
+ * the documentation and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of Nordic Semiconductor ASA nor the names of its contributors may be
+ * used to endorse or promote products derived from this software without specific prior
+ * written permission.
+ *
+ * 4. This software, with or without modification, must only be used with a
+ * Nordic Semiconductor ASA integrated circuit.
+ *
+ * 5. Any software provided in binary or object form under this license must not be reverse
+ * engineered, decompiled, modified and/or disassembled.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ */
+
+/** @file
+ *
+ * @defgroup crc_compute CRC compute
+ * @{
+ * @ingroup hci_transport
+ *
+ * @brief This module implements CRC-16-CCITT (polynomial 0x1021) with 0xFFFF initial value.
+ * The data can be passed in multiple blocks.
+ */
+
+#ifndef CRC16_H__
+#define CRC16_H__
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#include <stdint.h>
+
+/**@brief Function for calculating CRC-16 in blocks.
+ *
+ * Feed each consecutive data block into this function, along with the current value of p_crc as
+ * returned by the previous call of this function. The first call of this function should pass NULL
+ * as the initial value of the crc in p_crc.
+ *
+ * @param[in] p_data The input data block for computation.
+ * @param[in] size The size of the input data block in bytes.
+ * @param[in] p_crc The previous calculated CRC-16 value or NULL if first call.
+ *
+ * @return The updated CRC-16 value, based on the input supplied.
+ */
+uint16_t crc16_compute(uint8_t const* p_data, uint32_t size, uint16_t const* p_crc);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // CRC16_H__
+
+/** @} */
tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF)
tvm_option(USE_RTTI "Build with RTTI" ON)
tvm_option(USE_MSVC_MT "Build with MT" OFF)
-tvm_option(USE_MICRO "Build with Micro" OFF)
+tvm_option(USE_MICRO "Build with Micro TVM support" OFF)
tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF)
tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF)
-
# include directories
include_directories(${CMAKE_INCLUDE_PATH})
include_directories("include")
set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
set(THREADS_PREFER_PTHREAD_FLAG TRUE)
find_package(Threads REQUIRED)
- target_link_libraries(tvm Threads::Threads)
- target_link_libraries(tvm_runtime Threads::Threads)
+ target_link_libraries(tvm PUBLIC Threads::Threads)
+ target_link_libraries(tvm_runtime PUBLIC Threads::Threads)
endif()
-target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS})
-target_link_libraries(tvm_runtime ${TVM_RUNTIME_LINKER_LIBS})
+target_link_libraries(tvm PRIVATE ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS})
+target_link_libraries(tvm_runtime PRIVATE ${TVM_RUNTIME_LINKER_LIBS})
# Related headers
target_include_directories(
target_include_directories(
tvm_objs
PUBLIC "topi/include")
+set(CRC16_INCLUDE_PATH "3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16")
+target_include_directorieS(
+ tvm_objs
+ PRIVATE "${CRC16_INCLUDE_PATH}")
+target_include_directorieS(
+ tvm_runtime_objs
+ PRIVATE "${CRC16_INCLUDE_PATH}")
set(TVM_TEST_LIBRARY_NAME tvm)
if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
add_library(tvm_allvisible SHARED $<TARGET_OBJECTS:tvm_objs>)
target_include_directories(tvm_allvisible PUBLIC "$<TARGET_PROPERTY:tvm,INCLUDE_DIRECTORIES>")
- target_link_libraries(tvm_allvisible PUBLIC "$<TARGET_PROPERTY:tvm,LINK_LIBRARIES>")
+ target_link_libraries(tvm_allvisible PRIVATE "$<TARGET_PROPERTY:tvm,LINK_LIBRARIES>")
set(TVM_TEST_LIBRARY_NAME tvm_allvisible)
-set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL")
+ set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL")
# Note: 'target_link_options' with 'PRIVATE' keyword would be cleaner
# but it's not available until CMake 3.13. Switch to 'target_link_options'
# once minimum CMake version is bumped up to 3.13 or above.
- target_link_libraries(tvm ${HIDE_SYMBOLS_LINKER_FLAGS})
- target_link_libraries(tvm_runtime ${HIDE_SYMBOLS_LINKER_FLAGS})
+ target_link_libraries(tvm PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS})
+ target_link_libraries(tvm_runtime PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS})
endif()
# Tests
3rdparty/bfloat16/bfloat16.cc
3rdparty/dlpack
3rdparty/dmlc-core
+3rdparty/mbed-os
BSD 2-clause License
# Setup build environment
TVM_ROOT=$(shell cd ../..; pwd)
-CRT_ROOT ?= ../../src/runtime/crt
+CRT_ROOT ?= ../../build/standalone_crt
+ifeq ($(shell ls -lhd $(CRT_ROOT)),)
+$(error "CRT not found. Ensure you have built the standalone_crt target and try again")
+endif
ENABLE_TVM_PLATFORM_ABORT_BACKTRACE ?= 1
QUIET ?= @
$(endif)
+CRT_SRCS = $(shell find $(CRT_ROOT))
demo_dynamic: $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/graph_c.json $(build_dir)/params_cpp.bin $(build_dir)/params_c.bin $(build_dir)/cat.bin
$(QUIET)TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/params_cpp.bin $(build_dir)/cat.bin
test_static: $(build_dir)/test_static $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin
$(QUIET)TVM_NUM_THREADS=1 $(build_dir)/test_static $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_graph_c.json $(build_dir)/test_params_c.bin
-$(build_dir)/crt/graph_runtime/libgraph_runtime.a:
+$(build_dir)/crt/libgraph_runtime.a: $(CRT_SRCS)
$(QUIET)cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) "EXTRA_CFLAGS=$(PKG_COMPILE_OPTS)" graph_runtime
-$(build_dir)/crt/common/libcommon.a:
+$(build_dir)/crt/libcommon.a: $(CRT_SRCS)
$(QUIET)cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) "EXTRA_CFLAGS=$(PKG_COMPILE_OPTS)" common
$(build_dir)/demo_dynamic: demo.cc
$(QUIET)mkdir -p $(@D)
$(QUIET)g++ $(PKG_CXXFLAGS) -o $@ test.cc $(BACKTRACE_OBJS) $(BACKTRACE_LDFLAGS)
-$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o ${build_dir}/model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS)
+$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o ${build_dir}/model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a ${build_dir}/graph_c.json.c ${build_dir}/params_c.bin.c $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS)
-$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o ${build_dir}/test_model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS)
+$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o ${build_dir}/test_model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_LDFLAGS)
$(QUIET)mkdir -p $(@D)
$(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)
-$(build_dir)/bundle_c.so: bundle.c $(build_dir)/model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS)
+$(build_dir)/bundle_c.so: bundle.c $(build_dir)/model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS)
$(QUIET)mkdir -p $(@D)
$(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS)
-$(build_dir)/test_bundle_c.so: bundle.c $(build_dir)/test_model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS)
+$(build_dir)/test_bundle_c.so: bundle.c $(build_dir)/test_model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS)
$(QUIET)mkdir -p $(@D)
$(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS)
* under the License.
*/
+#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <tvm/runtime/c_runtime_api.h>
#include "backtrace.h"
#endif
+#define CRT_MEMORY_NUM_PAGES 16384
+#define CRT_MEMORY_PAGE_SIZE_LOG2 10
+
+static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * (1 << CRT_MEMORY_PAGE_SIZE_LOG2)];
+
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
do { \
ctx.device_id = device_id;
// declare pointers
- TVM_CCALL(TVMInitializeRuntime());
+ TVM_CCALL(TVMInitializeRuntime(g_crt_memory, sizeof(g_crt_memory), CRT_MEMORY_PAGE_SIZE_LOG2));
TVMPackedFunc pf;
TVMArgs args = TVMArgs_Create(NULL, NULL, 0);
TVM_CCALL(TVMPackedFunc_InitGlobalFunc(&pf, "runtime.SystemLib", &args));
TVMGraphRuntime_GetOutput(graph_runtime, index, tensor);
}
-void __attribute__((noreturn)) TVMPlatformAbort(int error_code) {
+void TVMLogf(const char* msg, ...) {
+ va_list args;
+ va_start(args, msg);
+ vfprintf(stderr, msg, args);
+ va_end(args);
+}
+
+void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t error_code) {
fprintf(stderr, "TVMPlatformAbort: %d\n", error_code);
#ifdef ENABLE_TVM_ABORT_BACKTRACE
tvm_platform_abort_backtrace();
* under the License.
*/
+#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <tvm/runtime/crt/crt.h>
#endif
#include "bundle.h"
+#define CRT_MEMORY_NUM_PAGES 16384
+#define CRT_MEMORY_PAGE_SIZE_LOG2 10
+
+static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * (1 << CRT_MEMORY_PAGE_SIZE_LOG2)];
+
/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
do { \
ctx.device_id = device_id;
// get pointers
- TVM_CCALL(TVMInitializeRuntime());
+ TVM_CCALL(TVMInitializeRuntime(g_crt_memory, sizeof(g_crt_memory), CRT_MEMORY_PAGE_SIZE_LOG2));
TVMPackedFunc pf;
TVMArgs args = TVMArgs_Create(NULL, NULL, 0);
TVM_CCALL(TVMPackedFunc_InitGlobalFunc(&pf, "runtime.SystemLib", &args));
TVMGraphRuntime_GetOutput(graph_runtime, index, tensor);
}
-void __attribute__((noreturn)) TVMPlatformAbort(int error_code) {
+void TVMLogf(const char* msg, ...) {
+ va_list args;
+ va_start(args, msg);
+ vfprintf(stderr, msg, args);
+ va_end(args);
+}
+
+void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t error_code) {
fprintf(stderr, "TVMPlatformAbort: %d\n", error_code);
#ifdef ENABLE_TVM_PLATFORM_ABORT_BACKTRACE
tvm_platform_abort_backtrace();
#ifndef TVM_RUNTIME_CRT_CONFIG_H_
#define TVM_RUNTIME_CRT_CONFIG_H_
+/*! Log level of the CRT runtime */
+#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG
+
/*! Support low-level debugging in MISRA-C runtime */
#define TVM_CRT_DEBUG 0
/*! Maximum supported string length in function names */
#define TVM_CRT_STRLEN_NAME 80
-/*!
- * \brief Log memory pool size for virtual memory allocation
- *
- * Here is a list of possible choices:
- * * use 16 for 64 KiB memory space
- * * use 17 for 128 KiB memory space
- * * use 18 for 256 KiB memory space
- * * use 19 for 512 KiB memory space
- * * use 20 for 1 MiB memory space
- * * use 21 for 2 MiB memory space
- * * use 22 for 4 MiB memory space
- * * use 23 for 8 MiB memory space
- * * use 24 for 16 MiB memory space
- * * use 25 for 32 MiB memory space
- * * use 26 for 64 MiB memory space
- * * use 27 for 128 MiB memory space
- * * use 28 for 256 MiB memory space
- */
-#define TVM_CRT_LOG_VIRT_MEM_SIZE 24
-
-/*! \brief Page size for virtual memory allocation */
-#define TVM_CRT_PAGE_BYTES_LOG 12
-
/*! Maximum number of registered modules. */
#define TVM_CRT_MAX_REGISTERED_MODULES 2
#include <sys/time.h>
#include <tvm/runtime/c_runtime_api.h>
-#include "build/graph_c.json.c"
-#include "build/params_c.bin.c"
#include "bundle.h"
+extern const char build_graph_c_json[];
+extern unsigned int build_graph_c_json_len;
+
+extern const char build_params_c_bin[];
+extern unsigned int build_params_c_bin_len;
+
#define OUTPUT_LEN 1000
int main(int argc, char** argv) {
char* data = (char*)malloc(st.st_size);
FILE* fp = fopen(file_path, "rb");
+ size_t bytes_to_read = st.st_size;
size_t bytes_read = 0;
- while (bytes_read < st.st_size) {
+ while (bytes_read < bytes_to_read) {
size_t this_round = fread(data, 1, st.st_size, fp);
if (this_round == 0) {
if (ferror(fp)) {
struct timeval t0, t1, t2, t3, t4, t5;
gettimeofday(&t0, 0);
- auto* handle = tvm_runtime_create(json_data, params_data, params_size, argv[0]);
+ void* handle = tvm_runtime_create(json_data, params_data, params_size, argv[0]);
gettimeofday(&t1, 0);
float input_storage[10 * 5];
# specific language governing permissions and limitations
# under the License.
-if(USE_STANDALONE_CRT)
- include(ExternalProject)
-
- message(STATUS "Build with standalone CRT")
+if(USE_MICRO)
+ message(STATUS "Build standalone CRT for micro TVM")
file(GLOB crt_srcs src/runtime/crt/**)
function(tvm_crt_add_copy_file var src dest)
set("${var}" "${${var}}" PARENT_SCOPE)
endfunction(tvm_crt_add_copy_file)
- # Build an isolated build directory, separate from the TVM tree.
- file(GLOB_RECURSE crt_srcs
- RELATIVE "${CMAKE_SOURCE_DIR}/src/runtime/crt"
- "${CMAKE_SOURCE_DIR}/src/runtime/crt/common/*.c"
- "${CMAKE_SOURCE_DIR}/src/runtime/crt/graph_runtime/*.c"
- "${CMAKE_SOURCE_DIR}/src/runtime/crt/include/*.h")
-
- foreach(src IN LISTS crt_srcs)
- tvm_crt_add_copy_file(host_isolated_build_deps ${CMAKE_SOURCE_DIR}/src/runtime/crt/${src} standalone_crt/${src})
- endforeach()
-
- file(GLOB_RECURSE crt_headers RELATIVE "${CMAKE_SOURCE_DIR}/include" include/tvm/runtime/crt/*.h)
- foreach(hdr IN LISTS crt_headers)
- tvm_crt_add_copy_file(host_isolated_build_deps ${CMAKE_SOURCE_DIR}/include/${hdr} standalone_crt/include/${hdr})
- endforeach()
-
- tvm_crt_add_copy_file(host_isolated_build_deps
- ${CMAKE_SOURCE_DIR}/include/tvm/runtime/c_runtime_api.h standalone_crt/include/tvm/runtime/c_runtime_api.h)
- tvm_crt_add_copy_file(host_isolated_build_deps
- ${CMAKE_SOURCE_DIR}/include/tvm/runtime/c_backend_api.h standalone_crt/include/tvm/runtime/c_backend_api.h)
- tvm_crt_add_copy_file(host_isolated_build_deps
- ${CMAKE_SOURCE_DIR}/src/runtime/crt/Makefile standalone_crt/Makefile)
-
- get_filename_component(crt_config_abspath src/runtime/crt/host/crt_config.h ABSOLUTE)
- list(APPEND host_isolated_build_deps src/runtime/crt/host/crt_config.h)
- add_custom_target(standalone_crt DEPENDS ${host_isolated_build_deps})
-
- get_filename_component(host_build_dir_abspath "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt" ABSOLUTE)
-
- if(${VERBOSE})
- set(make_quiet QUIET=)
- else(${VERBOSE})
- set(make_quiet )
- endif(${VERBOSE})
-
- ExternalProject_Add(host_standalone_crt
- DOWNLOAD_COMMAND ""
- SOURCE_DIR standalone_crt
- CONFIGURE_COMMAND ""
- BUILD_COMMAND make
- DLPACK_INCLUDE_DIR=${CMAKE_SOURCE_DIR}/3rdparty/dlpack/include
- TVM_INCLUDE_DIR=${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include
- CRT_CONFIG=${crt_config_abspath}
- BUILD_DIR=${host_build_dir_abspath} all ${make_quiet}
- BUILD_IN_SOURCE ON
- WORKING_DIRECTORY standalone_crt
- COMMENT "Building host CRT runtime"
- BUILD_BYPRODUCTS host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a
- DEPENDS standalone_crt
- INSTALL_COMMAND ""
- )
- ExternalProject_Add_StepDependencies(host_standalone_crt build ${host_isolated_build_deps})
-# add_custom_command(
-# OUTPUT host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a
-# COMMAND make
-# DLPACK_INCLUDE_DIR=${CMAKE_SOURCE_DIR}/3rdparty/dlpack/include
-# TVM_INCLUDE_DIR=${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include
-# CRT_CONFIG=${crt_config_abspath}
-# BUILD_DIR=${host_build_dir_abspath} all ${make_quiet}
-# WORKING_DIRECTORY standalone_crt
-# DEPENDS ${host_isolated_build_deps})
-# add_custom_target(host_standalone_crt DEPENDS host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a)
-
-# # add_custom_target(host_standalone_crt ALL
-# # DEPENDS host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a)
- add_library(host_standalone_crt_common STATIC IMPORTED GLOBAL)
- add_dependencies(host_standalone_crt_common host_standalone_crt)
- set_target_properties(host_standalone_crt_common PROPERTIES
- IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common/libcommon.a"
- IMPORTED_OBJECTS "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common/libcommon.a"
- PUBLIC_HEADER "${crt_headers}")
-# add_dependencies(host_standalone_crt_common host_standalone_crt)
-# # ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common/libcommon.a)
-
- add_library(host_standalone_crt_graph_runtime STATIC IMPORTED GLOBAL)
- add_dependencies(host_standalone_crt_graph_runtime host_standalone_crt)
- set_target_properties(host_standalone_crt_graph_runtime PROPERTIES
- IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime/libgraph_runtime.a"
- IMPORTED_OBJECTS "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime/libgraph_runtime.a"
- PUBLIC_HEADER "${crt_headers}")
-# add_dependencies(host_standalone_crt_graph_runtime host_standalone_crt)
-# # ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime/libgraph_runtime.a)
-
- # Standalone CRT tests
- file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*.cc)
- find_path(GTEST_INCLUDE_DIR gtest/gtest.h)
- find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}")
-
- # Create the `crttest` target if we can find GTest. If not, we create dummy
- # targets that give the user an informative error message.
- if(GTEST_INCLUDE_DIR AND GTEST_LIB)
- foreach(__srcpath ${TEST_SRCS})
- get_filename_component(__srcname ${__srcpath} NAME)
- string(REPLACE ".cc" "" __execname ${__srcname})
- add_executable(${__execname} ${__srcpath})
- list(APPEND TEST_EXECS ${__execname})
- target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host)
- target_compile_options(${__execname} PRIVATE -pthread)
-# target_link_directories(${__execname} PRIVATE
-# ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common
-# ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime)
- target_link_libraries(${__execname} host_standalone_crt_graph_runtime host_standalone_crt_common ${GTEST_LIB} pthread)
- set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1)
- set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
+ function(tvm_crt_define_targets)
+ # Build an isolated build directory, separate from the TVM tree.
+ set(CRC16_PATH "3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16")
+ list(APPEND CRT_FILE_COPY_JOBS
+ "${CRC16_PATH} *.h -> include *.c -> src/runtime/crt/utvm_rpc_common"
+ "3rdparty/dlpack/include *.h -> include"
+ "3rdparty/dmlc-core/include *.h -> include"
+ "include/tvm/runtime c_*_api.h -> include/tvm/runtime"
+ "include/tvm/runtime/crt *.h -> include/tvm/runtime/crt"
+ "src/runtime/crt Makefile -> ."
+ "src/runtime/crt/include *.h -> include"
+ "src/runtime/crt/common *.c -> src/runtime/crt/common"
+ "src/runtime/crt/graph_runtime *.c -> src/runtime/crt/graph_runtime"
+ "src/runtime/crt/host crt_config.h -> src/runtime/crt/host"
+ "src/runtime/crt/utvm_rpc_common *.cc -> src/runtime/crt/utvm_rpc_common"
+ "src/runtime/crt/utvm_rpc_server *.cc -> src/runtime/crt/utvm_rpc_server"
+ "src/runtime/minrpc *.h -> src/runtime/minrpc"
+ "src/support generic_arena.h -> src/support"
+ )
+
+ set(standalone_crt_base "${CMAKE_CURRENT_BINARY_DIR}/standalone_crt")
+
+ foreach(job_spec IN LISTS CRT_FILE_COPY_JOBS)
+ string(REPLACE " " ";" job_spec "${job_spec}")
+ list(LENGTH job_spec job_spec_length)
+ math(EXPR job_spec_length_mod "${job_spec_length} % 3")
+ if(NOT "${job_spec_length_mod}" EQUAL 1)
+ message(FATAL_ERROR "CRT copy job spec list length is ${job_spec_length}; parsed job spec is ${job_spec}")
+ endif()
+ math(EXPR job_spec_stop "${job_spec_length} - 3")
+
+ list(GET job_spec 0 job_src_base)
+ set(job_src_base "${CMAKE_SOURCE_DIR}/${job_src_base}")
+ foreach(copy_pattern_index RANGE 1 "${job_spec_stop}" 3)
+ list(GET job_spec ${copy_pattern_index} copy_pattern)
+ math(EXPR copy_dest_index "${copy_pattern_index} + 2")
+ list(GET job_spec ${copy_dest_index} copy_dest)
+
+ file(GLOB_RECURSE copy_files
+ RELATIVE "${job_src_base}"
+ "${job_src_base}/${copy_pattern}")
+ list(LENGTH copy_files copy_files_length)
+ if("${copy_files_length}" EQUAL 0)
+ message(FATAL_ERROR "CRT copy job matched 0 files: ${job_src_base}/${copy_pattern} -> ${copy_dest}")
+ endif()
+ foreach(copy_src IN LISTS copy_files)
+ get_filename_component(dest_path "${standalone_crt_base}/${copy_dest}/${copy_src}" ABSOLUTE)
+ tvm_crt_add_copy_file(host_isolated_build_deps ${job_src_base}/${copy_src} ${dest_path})
+ endforeach()
+ endforeach()
+ endforeach()
+
+ add_custom_target(standalone_crt DEPENDS ${host_isolated_build_deps})
+
+ get_filename_component(host_build_dir_abspath "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt" ABSOLUTE)
+
+ if(${VERBOSE})
+ set(make_quiet QUIET=)
+ else(${VERBOSE})
+ set(make_quiet )
+ endif(${VERBOSE})
+
+ list(APPEND crt_libraries graph_runtime utvm_rpc_server utvm_rpc_common common) # NOTE: listed in link order.
+ foreach(crt_lib_name IN LISTS crt_libraries)
+ list(APPEND crt_library_paths "host_standalone_crt/lib${crt_lib_name}.a")
endforeach()
- add_custom_target(crttest DEPENDS ${TEST_EXECS})
- elseif(NOT GTEST_INCLUDE_DIR)
- add_custom_target(crttest
- COMMAND echo "Missing Google Test headers in include path"
- COMMAND exit 1)
- elseif(NOT GTEST_LIB)
- add_custom_target(crttest
- COMMAND echo "Missing Google Test library"
- COMMAND exit 1)
+
+ set(make_common_args
+ "DLPACK_INCLUDE_DIR=${CMAKE_SOURCE_DIR}/3rdparty/dlpack/include"
+ "TVM_INCLUDE_DIR=${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include"
+ "CRT_CONFIG=src/runtime/crt/host/crt_config.h"
+ "BUILD_DIR=${host_build_dir_abspath}"
+ "EXTRA_CFLAGS=-fPIC"
+ "EXTRA_CXXFLAGS=-fPIC"
+ "EXTRA_LDFLAGS=-fPIC"
+ "${make_quiet}")
+
+ add_custom_command(
+ OUTPUT ${crt_library_paths}
+ COMMAND make ARGS ${make_common_args} clean
+ COMMAND make ARGS ${make_common_args} all
+ WORKING_DIRECTORY "${standalone_crt_base}"
+ DEPENDS standalone_crt ${host_isolated_build_deps})
+
+ add_custom_target(host_standalone_crt DEPENDS ${crt_library_paths})
+
+ foreach(crt_lib IN LISTS crt_libraries)
+ set(cmake_crt_lib_name host_standalone_crt_${crt_lib})
+ list(APPEND cmake_crt_libraries ${cmake_crt_lib_name})
+ add_library(${cmake_crt_lib_name} STATIC IMPORTED GLOBAL)
+ set(cmake_crt_lib_path "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/lib${crt_lib}.a")
+ add_dependencies(${cmake_crt_lib_name} host_standalone_crt "${cmake_crt_lib_path}")
+ set_target_properties(${cmake_crt_lib_name} PROPERTIES
+ IMPORTED_LOCATION "${cmake_crt_lib_path}"
+ IMPORTED_OBJECTS "${cmake_crt_lib_path}"
+ PUBLIC_HEADER "${crt_headers}")
+ endforeach()
+
+ # Standalone CRT tests
+ file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*_test.cc)
+ find_path(GTEST_INCLUDE_DIR gtest/gtest.h)
+ find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}")
+
+ # Create the `crttest` target if we can find GTest. If not, we create dummy
+ # targets that give the user an informative error message.
+ if(GTEST_INCLUDE_DIR AND GTEST_LIB)
+ foreach(__srcpath ${TEST_SRCS})
+ get_filename_component(__srcname ${__srcpath} NAME)
+ string(REPLACE ".cc" "" __execname ${__srcname})
+ add_executable(${__execname} ${__srcpath})
+ list(APPEND TEST_EXECS ${__execname})
+ target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host)
+ target_compile_options(${__execname} PRIVATE -pthread)
+ target_link_libraries(${__execname} ${cmake_crt_libraries} ${GTEST_LIB} pthread)
+ set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1)
+ set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
+ endforeach()
+ add_custom_target(crttest DEPENDS ${TEST_EXECS})
+ elseif(NOT GTEST_INCLUDE_DIR)
+ add_custom_target(crttest
+ COMMAND echo "Missing Google Test headers in include path"
+ COMMAND exit 1)
+ elseif(NOT GTEST_LIB)
+ add_custom_target(crttest
+ COMMAND echo "Missing Google Test library"
+ COMMAND exit 1)
+ endif()
+
+ endfunction()
+
+ tvm_crt_define_targets()
+
+ set(TVM_CRT_LINKER_LIB host_standalone_crt_utvm_rpc_common)
+ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
+ list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${TVM_CRT_LINKER_LIB} -Wl,--no-whole-archive)
+ elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES ".*Clang")
+ list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,-force_load $<TARGET_PROPERTY:${TVM_CRT_LINKER_LIB},IMPORTED_LOCATION>)
+ else()
+ list(APPEND TVM_RUNTIME_LINKER_LIBS ${TVM_CRT_LINKER_LIB})
endif()
-endif(USE_STANDALONE_CRT)
+endif(USE_MICRO)
#ifndef TVM_RUNTIME_CRT_CRT_H_
#define TVM_RUNTIME_CRT_CRT_H_
+#include <inttypes.h>
#include <tvm/runtime/crt/error_codes.h>
#ifdef __cplusplus
/*!
* \brief Initialize various data structures used by the rutnime.
+ * \param memory_pool Pointer to the global memory pool used by the CRT.
+ * \param memory_pool_size_bytes Size of `memory_pool`, in bytes.
+ * \param page_size_bytes_log2 log2 of the page size, in bytes.
* \return An error code describing the outcome of intialization. Generally, initialization
* is only expected to fail due to a misconfiguration.
*/
-tvm_crt_error_t TVMInitializeRuntime(void);
+tvm_crt_error_t TVMInitializeRuntime(uint8_t* memory_pool, size_t memory_pool_size_bytes,
+ size_t page_size_bytes_log2);
#ifdef __cplusplus
} // extern "C"
#define DEFINE_TVM_CRT_ERROR(category, code) \
(((category) << TVM_CRT_ERROR_CATEGORY_Pos) | ((code) << TVM_CRT_ERROR_CODE_Pos))
-typedef enum { kTvmErrorCategoryFunctionRegistry = 1 } tvm_crt_error_category_t;
+typedef enum {
+ kTvmErrorCategoryFunctionRegistry = 1,
+ kTvmErrorCategoryFraming = 2,
+ kTvmErrorCategoryWriteStream = 3,
+ kTvmErrorCategorySession = 4,
+ kTvmErrorCategoryPlatform = 5,
+} tvm_crt_error_category_t;
typedef enum {
kTvmErrorNoError = 0,
kTvmErrorFunctionRegistryFull = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionRegistry, 2),
kTvmErrorFunctionAlreadyDefined = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionRegistry, 3),
kTvmErrorBufferTooSmall = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionRegistry, 4),
+
+ // Framing
+ kTvmErrorFramingInvalidState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 0),
+ kTvmErrorFramingShortPacket = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 1),
+ kTvmErrorFramingInvalidEscape = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 2),
+ kTvmErrorFramingPayloadOverflow = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 3),
+ kTvmErrorFramingPayloadIncomplete = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 4),
+
+ // Write stream
+ kTvmErrorWriteStreamShortWrite = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryWriteStream, 0),
+ kTvmErrorWriteStreamLongWrite = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryWriteStream, 1),
+
+ // Session
+ kTvmErrorSessionInvalidState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategorySession, 0),
+ kTvmErrorSessionReceiveBufferBusy = DEFINE_TVM_CRT_ERROR(kTvmErrorCategorySession, 1),
+ kTvmErrorSessionReceiveBufferShortWrite = DEFINE_TVM_CRT_ERROR(kTvmErrorCategorySession, 2),
+
+ // Platform
+ kTvmErrorPlatformCheckFailure = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 0),
+ kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1),
+
+ // System errors are always negative integers; this mask indicates presence of a system error.
+ // Cast tvm_crt_error_t to a signed integer to interpret the negative error code.
+ kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)),
} tvm_crt_error_t;
#ifdef __cplusplus
*/
/*!
- * \file runtime/crt/include/tvm/runtime/crt/internal/common/logging.h
+ * \file runtime/crt/logging.h
* \brief A replacement of the dmlc logging system that avoids
* the usage of GLOG and C++ headers
*/
-#ifndef TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_LOGGING_H_
-#define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_LOGGING_H_
+#ifndef TVM_RUNTIME_CRT_LOGGING_H_
+#define TVM_RUNTIME_CRT_LOGGING_H_
+
+#include <tvm/runtime/crt/platform.h>
+
+#define TVM_CRT_LOG_LEVEL_DEBUG 3
+#define TVM_CRT_LOG_LEVEL_INFO 2
+#define TVM_CRT_LOG_LEVEL_WARN 1
+#define TVM_CRT_LOG_LEVEL_ERROR 0
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void __attribute__((format(printf, 1, 2))) TVMLogf(const char* fmt, ...);
+
+#define LOG(level, x, ...) \
+ if (TVM_CRT_LOG_LEVEL >= level) { \
+ TVMLogf(x, ##__VA_ARGS__); \
+ }
+
+#define LOG_ERROR(x, ...) LOG(TVM_CRT_LOG_LEVEL_ERROR, x, ##__VA_ARGS__)
+#define LOG_WARN(x, ...) LOG(TVM_CRT_LOG_LEVEL_WARN, x, ##__VA_ARGS__)
+#define LOG_INFO(x, ...) LOG(TVM_CRT_LOG_LEVEL_INFO, x, ##__VA_ARGS__)
+#define LOG_DEBUG(x, ...) LOG(TVM_CRT_LOG_LEVEL_DEBUG, x, ##__VA_ARGS__)
#ifndef CHECK
-#define CHECK(x) \
- do { \
- if (!(x)) { \
- fprintf(stderr, "Check failed: %s\n", #x); \
- exit(-1); \
- } \
+#define CHECK(x) \
+ do { \
+ if (!(x)) { \
+ LOG_ERROR(__FILE__ ":%d: Check failed: %s\n", __LINE__, #x); \
+ TVMPlatformAbort(kTvmErrorPlatformCheckFailure); \
+ } \
} while (0)
#endif
#ifndef CHECK_BINARY_OP
-#define CHECK_BINARY_OP(op, x, y, fmt, ...) \
- do { \
- if (!(x op y)) { \
- fprintf(stderr, "Check failed: %s %s %s: " fmt "\n", #x, #op, #y, ##__VA_ARGS__); \
- exit(-1); \
- } \
+#define CHECK_BINARY_OP(op, x, y, fmt, ...) \
+ do { \
+ if (!(x op y)) { \
+ LOG_ERROR(__FILE__ ":%d: Check failed: %s %s %s: " fmt "\n", __LINE__, #x, #op, #y, \
+ ##__VA_ARGS__); \
+ TVMPlatformAbort(kTvmErrorPlatformCheckFailure); \
+ } \
} while (0)
#endif
#define CHECK_NE(x, y, fmt, ...) CHECK_BINARY_OP(!=, x, y, fmt, ##__VA_ARGS__)
#endif
-#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_LOGGING_H_
+#ifdef __cplusplus
+} // extern "C"
+#endif
+
+#endif // TVM_RUNTIME_CRT_LOGGING_H_
#ifndef TVM_RUNTIME_CRT_PLATFORM_H_
#define TVM_RUNTIME_CRT_PLATFORM_H_
+#include <tvm/runtime/crt/error_codes.h>
+
#ifdef __cplusplus
extern "C" {
#endif
*
* \param code An error code.
*/
-void __attribute__((noreturn)) TVMPlatformAbort(int code);
+void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t code);
#ifdef __cplusplus
} // extern "C"
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/crt/rpc_common/frame_buffer.h
+ * \brief Defines a buffer for use by the RPC framing layer.
+ */
+
+#ifndef TVM_RUNTIME_CRT_RPC_COMMON_FRAME_BUFFER_H_
+#define TVM_RUNTIME_CRT_RPC_COMMON_FRAME_BUFFER_H_
+
+#include <inttypes.h>
+#include <stdlib.h>
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+class FrameBuffer {
+ public:
+ FrameBuffer(uint8_t* data, size_t data_size_bytes)
+ : data_{data}, capacity_{data_size_bytes}, num_valid_bytes_{0}, read_cursor_{0} {}
+
+ size_t Write(const uint8_t* data, size_t data_size_bytes);
+
+ size_t Read(uint8_t* data, size_t data_size_bytes);
+
+ size_t Peek(uint8_t* data, size_t data_size_bytes);
+
+ void Clear();
+
+ size_t ReadAvailable() const { return num_valid_bytes_ - read_cursor_; }
+
+ size_t Size() const { return num_valid_bytes_; }
+
+ private:
+ /*! \brief pointer to data buffer. */
+ uint8_t* data_;
+
+ /*! \brief The total number of bytes available in data_. Always a power of 2. */
+ size_t capacity_;
+
+ /*! \brief index into data_ of the next potentially-available byte in the buffer.
+ * The byte is available when tail_ != data_ + capacity_.
+ */
+ size_t num_valid_bytes_;
+
+ /*! \brief Read cursor position. */
+ size_t read_cursor_;
+};
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CRT_RPC_COMMON_FRAME_BUFFER_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file framing.h
+ * \brief Framing for RPC.
+ */
+
+#ifndef TVM_RUNTIME_CRT_RPC_COMMON_FRAMING_H_
+#define TVM_RUNTIME_CRT_RPC_COMMON_FRAMING_H_
+
+#include <crc16.h>
+#include <inttypes.h>
+#include <stddef.h>
+#include <tvm/runtime/crt/error_codes.h>
+#include <tvm/runtime/crt/rpc_common/write_stream.h>
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+enum class Escape : uint8_t { kEscapeStart = 0xff, kEscapeNop = 0xfe, kPacketStart = 0xfd };
+
+class PacketFieldSizeBytes {
+ public:
+ static constexpr const size_t kPayloadLength = sizeof(uint32_t);
+ static constexpr const size_t kCrc = sizeof(uint16_t);
+};
+
+class Unframer {
+ public:
+ explicit Unframer(WriteStream* stream)
+ : stream_{stream},
+ state_{State::kFindPacketStart},
+ saw_escape_start_{false},
+ num_buffer_bytes_valid_{0} {}
+
+ /*!
+ * \brief Push data into unframer and try to decode one packet.
+ *
+ * This function will return when exactly one packet has been decoded. It may not consume all of
+ * `data` in this case, and valid bytes may remain at the end of data.
+ *
+ * \param data The new data to unframe and send downstream.
+ * \param data_size_bytes The number of valid bytes in data.
+ * \param bytes_consumed Pointer written with the number of bytes consumed from data.
+ * \return
+ * - kTvmErrorNoError when successful -- continue writing data.
+ * - kTvmErrorFramingInvalidState when the Unframer was in or enters an invalid state
+ * (probably indicates memory corruption).
+ * - kTvmErrorFramingShortPacket when a new packet started before the current one ended.
+ * - kTvmErrorFramingInvalidEscape when an invalid escape sequence was seen
+ */
+ tvm_crt_error_t Write(const uint8_t* data, size_t data_size_bytes, size_t* bytes_consumed);
+
+ /*! \brief Reset unframer to initial state. */
+ void Reset();
+
+ /*! \brief Return an underestimate of the number of bytes needed from the wire. */
+ size_t BytesNeeded();
+
+ private:
+ tvm_crt_error_t FindPacketStart();
+ tvm_crt_error_t FindPacketLength();
+ tvm_crt_error_t FindPacketCrc();
+ tvm_crt_error_t FindCrcEnd();
+
+ bool IsBufferFull(size_t buffer_full_bytes) {
+ return num_buffer_bytes_valid_ >= buffer_full_bytes;
+ }
+
+ /*! \brief Consume input into buffer_ until buffer_ has buffer_full_bytes. */
+ tvm_crt_error_t AddToBuffer(size_t buffer_full_bytes, bool update_crc);
+
+ void ClearBuffer();
+
+ /*! \brief Unescape and consume input bytes, storing into buffer.
+ *
+ * \param buffer A buffer to fill with consumed, unescaped bytes.
+ * \param buffer_size_bytes Size of buffer, in bytes.
+ * \param bytes_filled A pointer to an accumulator to which is added the number of bytes written
+ * to `buffer`.
+ * \param update_crc true when the CRC should be updated with the escaped bytes.
+ * \return
+ * - kTvmErrorNoError if successful
+ * - kTvmErrorFramingShortPacket if a start-of-packet escape code was encountered. If so,
+ * *bytes_filled indicates the number of bytes before the Escape::kEscapeStart byte.
+ * - kTvmErrorFramingInvalidEscape if an invalid escape sequence was seen.
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns 0.
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns an invalid positive number.
+ * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the
+ * WriteStream's Write() function.
+ */
+ tvm_crt_error_t ConsumeInput(uint8_t* buffer, size_t buffer_size_bytes, size_t* bytes_filled,
+ bool update_crc);
+
+ WriteStream* stream_;
+
+ enum class State : uint8_t {
+ kFindPacketStart = 0,
+ kFindPacketLength = 1,
+ kFindPacketCrc = 2,
+ kFindCrcEnd = 3,
+ };
+ State state_;
+
+ const uint8_t* input_;
+ size_t input_size_bytes_;
+
+ bool saw_escape_start_;
+
+ /*! \brief unframe buffer, sized to the longest framing field. */
+ uint8_t buffer_[128];
+
+ /*! \brief number of bytes in buffer that are currently valid. */
+ size_t num_buffer_bytes_valid_;
+
+ /*! \brief number of payload bytes left to write before the CRC begins. */
+ size_t num_payload_bytes_remaining_;
+
+ /*! \brief Running CRC value. */
+ uint16_t crc_;
+};
+
+class Framer {
+ public:
+ typedef ssize_t (*WriteFunc)(const uint8_t* data, size_t data_size_bytes);
+
+ explicit Framer(WriteStream* stream)
+ : stream_{stream}, state_{State::kReset}, num_payload_bytes_remaining_{0} {}
+
+ /*! \brief Frame and write a full packet.
+ * \param payload The entire packet payload.
+ * \param payload_size_bytes Number of bytes in the packet.
+ * \return
+ * - kTvmErrorNoError when no error occurs
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns 0.
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns an invalid positive number.
+ * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the
+ * WriteStream's Write() function.
+ */
+ tvm_crt_error_t Write(const uint8_t* payload, size_t payload_size_bytes);
+
+ /*! \brief Start framing and writing a new packet to the wire.
+ *
+ * When transmitting payloads that are too large to be buffered, call this function first to send
+ * the packet header and length fields.
+ *
+ * \param payload_size_bytes Number of payload bytes included as part of this packet.
+ * \return
+ * - kTvmErrorNoError when no error occurs
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns 0.
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns an invalid positive number.
+ * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the
+ * WriteStream's Write() function.
+ */
+ tvm_crt_error_t StartPacket(size_t payload_size_bytes);
+
+ /*! \brief Write payload data to the wire.
+ *
+ * When transmitting payloads that are too large to be buffered, call this function after calling
+ * StartPacket to escape and transmit framed payloads. This function can be called multiple times
+ * for a single packet.
+ *
+ * \param payload_chunk A piece of the packet payload.
+ * \param payload_chunk_size_bytes Number of valid bytes in payload_chunk.
+ * \return
+ * - kTvmErrorNoError when no error occurs
+ * - kTvmErrorFramingInvalidState when StartPacket() has not been called.
+ * - kTvmErrorFramingPayloadOverflow when more bytes were requested to be written than were
+ * declared in the payload_size_bytes parameter given to StartPacket().
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns 0.
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns an invalid positive number.
+ * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the
+ * WriteStream's Write() function.
+ */
+ tvm_crt_error_t WritePayloadChunk(const uint8_t* payload_chunk, size_t payload_chunk_size_bytes);
+
+ /* \brief Finish writing one packet by sending the CRC.
+ *
+ * When transmitting paylaods that are too large to be buffered, call this function after sending
+ * the entire payload using WritePayloadChunk.
+ *
+ * \return
+ * - kTvmErrorNoError when no error occurs
+ * - kTvmErrorFramingInvalidState when StartPacket() has not been called.
+ * - kTvmErrorFramingPayloadIncomplete when less bytes were written using WritePayloadChunk()
+ * than were declared in the payload_size_bytes parameter given to StartPacket().
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns 0.
+ * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write()
+ * function returns an invalid positive number.
+ * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the
+ * WriteStream's Write() function.
+ */
+ tvm_crt_error_t FinishPacket();
+
+ /* \brief Reset state of the Framer. */
+ void Reset();
+
+ private:
+ /*! \brief Maximum size of stack-based buffer. */
+ static constexpr const size_t kMaxStackBufferSizeBytes = 128;
+
+ enum class State : uint8_t {
+ /*! \brief State entered at construction time or after write error, before first packet sent. */
+ kReset = 0,
+
+ /*! \brief State entered after a packet has successfully finished transmitting. */
+ kIdle = 1,
+
+ /*! \brief State entered when a packet payload or CRC needs to be transmitted. */
+ kTransmitPacketPayload = 2,
+ };
+
+ /*!
+ * \brief Escape data and write the result to wire, and update crc_.
+ *
+ * \param data Unescaped data to write.
+ * \param data_size_bytes Number of valid bytes in data.
+ * \param escape true if escaping should be applied.
+ * \param update_crc true if escaping should be applied.
+ * \return kTvmErrorNoError on success, negative value on error.
+ */
+ tvm_crt_error_t WriteAndCrc(const uint8_t* data, size_t data_size_bytes, bool escape,
+ bool update_crc);
+
+ /*! \brief Called to write framed data to the transport. */
+ WriteStream* stream_;
+
+ /*! \brief State fo the Framer. */
+ State state_;
+
+ /*! \brief When state_ == kTransmitPacketPayload, number of payload bytes left to transmit. */
+ size_t num_payload_bytes_remaining_;
+
+ /*! \brief Running CRC value. */
+ uint16_t crc_;
+};
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CRT_RPC_COMMON_FRAMING_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file session.h
+ * \brief RPC Session
+ */
+
+#ifndef TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
+#define TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
+
+#include <inttypes.h>
+#include <tvm/runtime/crt/error_codes.h>
+#include <tvm/runtime/crt/rpc_common/frame_buffer.h>
+#include <tvm/runtime/crt/rpc_common/framing.h>
+#include <tvm/runtime/crt/rpc_common/write_stream.h>
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+enum class MessageType : uint8_t {
+ kStartSessionInit = 0x00,
+ kStartSessionReply = 0x01,
+ kTerminateSession = 0x02,
+ kLog = 0x03,
+ kNormal = 0x10,
+};
+
+typedef struct SessionHeader {
+ uint16_t session_id;
+ MessageType message_type;
+} __attribute__((packed)) SessionHeader;
+
+/*!
+ * \brief CRT communication session management class.
+ * Assumes the following properties provided by the underlying transport:
+ * - in-order delivery.
+ * - reliable delivery.
+ *
+ * Specifically, designed for use with UARTs. Will probably work over semihosting, USB, and TCP;
+ * will probably not work reliably enough over UDP.
+ */
+class Session {
+ public:
+ /*! \brief Callback invoked when a full message is received.
+ *
+ * This function is called in the following situations:
+ * - When a new session is established (this typically indicates the remote end reset).
+ * In this case, buf is NULL.
+ * - When a log message or normal traffic is received. In this case, buf points to a
+ * valid buffer containing the message content.
+ *
+ * \param context The value of `message_received_func_context` passed to the constructor.
+ * \param message_type The type of session message received. Currently, this is always
+ * either kNormal or kLog.
+ * \param buf When message_type is not kStartSessionMessage, a FrameBuffer whose read cursor is
+ * at the first byte of the message payload. Otherwise, NULL.
+ */
+ typedef void (*MessageReceivedFunc)(void* context, MessageType message_type, FrameBuffer* buf);
+
+ /*! \brief An invalid nonce value that typically indicates an unknown nonce. */
+ static constexpr const uint8_t kInvalidNonce = 0;
+
+ Session(uint8_t initial_session_nonce, Framer* framer, FrameBuffer* receive_buffer,
+ MessageReceivedFunc message_received_func, void* message_received_func_context)
+ : local_nonce_{initial_session_nonce},
+ session_id_{0},
+ state_{State::kReset},
+ receiver_{this},
+ framer_{framer},
+ receive_buffer_{receive_buffer},
+ receive_buffer_has_complete_message_{false},
+ message_received_func_{message_received_func},
+ message_received_func_context_{message_received_func_context} {
+ // Session can be used for system startup logging, before the RPC server is instantiated. In
+ // this case, allow receive_buffer_ to be nullptr. The instantiator agrees not to use
+ // Receiver().
+ if (receive_buffer_ != nullptr) {
+ receive_buffer_->Clear();
+ }
+ }
+
+ /*!
+ * \brief Send a session terminate message, usually done at startup to interrupt a hanging remote.
+ * \return kTvmErrorNoError on success, or an error code otherwise.
+ */
+ tvm_crt_error_t Initialize();
+
+ /*!
+ * \brief Terminate any previously-established session.
+ * \return kTvmErrorNoError on success, or an error code otherwise.
+ */
+ tvm_crt_error_t TerminateSession();
+
+ /*!
+ * \brief Start a new session regardless of state. Sends kStartSessionMessage.
+ *
+ * Generally speaking, this function should be called once per device reset by exactly one side
+ * in the system. No traffic can flow until this function is called.
+ *
+ * \return kTvmErrorNoError on success, or an error code otherwise.
+ */
+ tvm_crt_error_t StartSession();
+
+ /*!
+ * \brief Obtain a WriteStream implementation for use by the framing layer.
+ * \return A WriteStream to which received data should be written. Owned by this class.
+ */
+ WriteStream* Receiver() { return &receiver_; }
+
+ /*!
+ * \brief Send a full message including header, payload, and CRC footer.
+ * \param message_type One of MessageType; distinguishes the type of traffic at the session layer.
+ * \param message_data The data contained in the message.
+ * \param message_size_bytes The number of valid bytes in message_data.
+ * \return kTvmErrorNoError on success, or an error code otherwise.
+ */
+ tvm_crt_error_t SendMessage(MessageType message_type, const uint8_t* message_data,
+ size_t message_size_bytes);
+
+ /*!
+ * \brief Send the framing and session layer headers.
+ *
+ * This function allows messages to be sent in pieces.
+ *
+ * \param message_type One of MessageType; distinguishes the type of traffic at the session layer.
+ * \param message_size_bytes The size of the message body, in bytes. Excludes the framing and
+ * session layer headers. \return 0 on success, negative error code on failure.
+ * \return kTvmErrorNoError on success, or an error code otherwise.
+ */
+ tvm_crt_error_t StartMessage(MessageType message_type, size_t message_size_bytes);
+
+ /*!
+ * \brief Send a part of the message body.
+ *
+ * This function allows messages to be sent in pieces.
+ *
+ * \param chunk_data The data contained in this message body chunk.
+ * \param chunk_size_bytes The number of valid bytes in chunk_data.
+ * \return kTvmErrorNoError on success, or an error code otherwise.
+ */
+ tvm_crt_error_t SendBodyChunk(const uint8_t* chunk_data, size_t chunk_size_bytes);
+
+ /*!
+ * \brief Finish sending the message by sending the framing layer footer.
+ * \return kTvmErrorNoError on success, or an error code otherwise.
+ */
+ tvm_crt_error_t FinishMessage();
+
+ /*! \brief Returns true if the session is in the established state. */
+ bool IsEstablished() const { return state_ == State::kSessionEstablished; }
+
+ /*!
+ * \brief Clear the receive buffer and prepare to receive next message.
+ *
+ * Call this function after MessageReceivedFunc is invoked. Any SessionReceiver::Write() calls
+ * made will return errors until this function is called to prevent them from corrupting the
+ * valid message in the receive buffer.
+ */
+ void ClearReceiveBuffer();
+
+ /*! \brief A version number used to check compatibility of the remote session implementation. */
+ static const constexpr uint8_t kVersion = 0x01;
+
+ private:
+ class SessionReceiver : public WriteStream {
+ public:
+ explicit SessionReceiver(Session* session) : session_{session} {}
+ virtual ~SessionReceiver() {}
+
+ ssize_t Write(const uint8_t* data, size_t data_size_bytes) override;
+ void PacketDone(bool is_valid) override;
+
+ private:
+ void operator delete(void*) noexcept {} // NOLINT(readability/casting)
+ Session* session_;
+ };
+
+ enum class State : uint8_t {
+ kReset = 0,
+ kNoSessionEstablished = 1,
+ kStartSessionSent = 2,
+ kSessionEstablished = 3,
+ };
+
+ void RegenerateNonce();
+
+ tvm_crt_error_t SendInternal(MessageType message_type, const uint8_t* message_data,
+ size_t message_size_bytes);
+
+ void SendSessionStartReply(const SessionHeader& header);
+
+ void ProcessStartSessionInit(const SessionHeader& header);
+
+ void ProcessStartSessionReply(const SessionHeader& header);
+
+ void OnSessionEstablishedMessage();
+
+ void OnSessionTerminatedMessage();
+
+ void SetSessionId(uint8_t initiator_nonce, uint8_t responder_nonce) {
+ session_id_ = initiator_nonce | (((uint16_t)responder_nonce) << 8);
+ }
+
+ uint8_t InitiatorNonce(uint16_t session_id) { return session_id & 0xff; }
+
+ uint8_t ResponderNonce(uint16_t session_id) { return (session_id >> 8) & 0xff; }
+
+ uint8_t local_nonce_;
+ uint16_t session_id_;
+ State state_;
+ SessionReceiver receiver_;
+ Framer* framer_;
+ FrameBuffer* receive_buffer_;
+ bool receive_buffer_has_complete_message_;
+ MessageReceivedFunc message_received_func_;
+ void* message_received_func_context_;
+};
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_
*/
/*!
- * \file utvm_timer.c
- * \brief uTVM timer API stubs for Spike
+ * \file framing.h
+ * \brief Framing for RPC.
*/
-#ifdef __cplusplus
-extern "C" {
-#endif
+#ifndef TVM_RUNTIME_CRT_RPC_COMMON_WRITE_STREAM_H_
+#define TVM_RUNTIME_CRT_RPC_COMMON_WRITE_STREAM_H_
-#include "utvm_runtime.h"
+#include <inttypes.h>
+#include <stddef.h>
+#include <sys/types.h>
+#include <tvm/runtime/crt/error_codes.h>
-int32_t UTVMTimerStart() { return UTVM_ERR_OK; }
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
-uint32_t UTVMTimerStop(int32_t* err) {
- *err = UTVM_ERR_OK;
- return 0;
-}
+class WriteStream {
+ public:
+ virtual ~WriteStream();
+ virtual ssize_t Write(const uint8_t* data, size_t data_size_bytes) = 0;
+ virtual void PacketDone(bool is_valid) = 0;
-#ifdef __cplusplus
-} // TVM_EXTERN_C
-#endif
+ tvm_crt_error_t WriteAll(uint8_t* data, size_t data_size_bytes, size_t* bytes_consumed);
+};
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CRT_RPC_COMMON_WRITE_STREAM_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file utvm_rpc_server.h
+ * \brief MicroTVM RPC Server
+ */
+
+#ifndef TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_
+#define TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_
+
+#include <stdlib.h>
+#include <sys/types.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*! \brief TVM RPC channel write function.
+ *
+ * Tries to write `num_bytes` from `data` to the underlying channel.
+ * \param data Pointer to data to write.
+ * \param num_bytes Number of bytes avaiable in data.
+ * \return The number of bytes written.
+ */
+typedef ssize_t (*utvm_rpc_channel_write_t)(void* context, const uint8_t* data, size_t num_bytes);
+
+/*! \brief Opaque pointer type to TVM RPC Server. */
+typedef void* utvm_rpc_server_t;
+
+/*! \brief Initialize the TVM RPC Server.
+ *
+ * Call this on device startup before calling anyother utvm_rpc_server_ functions.
+ *
+ * \param memory A memory block used by the runtime as dynamic memory, primarily to allocate
+ * tensors.
+ * \param memory_size_bytes Size of the memory block, in bytes. Should be a multiple of
+ * (1 << page_size_bytes_log2)
+ * \param page_size_bytes_log2 Log2 of the size of each memory page. The internal allocator
+ * allocates one page at a time; more pages reduces waste but
+ * increases overhead.
+ * \param write_func A callback function invoked by the TVM RPC Server to write data back to the
+ * host. Internally, the TVM RPC Server will block until all data in a reply
+ * packet has been written.
+ * \param write_func_ctx An opaque pointer passed to write_func when it is called.
+ * \return A pointer to the TVM RPC Server. The pointer is allocated in the same memory space as
+ * the TVM workspace.
+ */
+utvm_rpc_server_t UTvmRpcServerInit(uint8_t* memory, size_t memory_size_bytes,
+ size_t page_size_bytes_log2,
+ utvm_rpc_channel_write_t write_func, void* write_func_ctx);
+
+/*! \brief Copy received data into an internal buffer for processing.
+ *
+ * Currently only handles 1 byte of data. In the future, the goal of this function is to be safe to
+ * invoke from an ISR. At that time, this function will just append to an internal buffer.
+ *
+ * \param server The TVM RPC Server pointer.
+ * \param byte The received byte of data.
+ * \return The number of bytes copied to the internal buffer. May be less than data_size_bytes when
+ * the internal buffer fills.
+ */
+size_t UTvmRpcServerReceiveByte(utvm_rpc_server_t server, uint8_t byte);
+
+/*! \brief Perform normal processing of received data.
+ *
+ * \param server The TVM RPC Server pointer.
+ * \return true while the server is still running. false when it shuts down gracefully.
+ */
+bool UTvmRpcServerLoop(utvm_rpc_server_t server);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_
# under the License.
"""MicroTVM module for bare-metal backends"""
-from ..contrib import binutil
-from .base import DEVICE_SECTIONS
-from .base import Session, create_micro_mod, cross_compiler, LibType
-from .base import get_micro_host_driven_dir, get_micro_device_dir
-from . import device
+from .artifact import Artifact
+from .build import build_static_runtime, default_options, TVM_ROOT_DIR
+from .build import CRT_ROOT_DIR, Workspace
+from .compiler import Compiler, DefaultCompiler, Flasher
+from .debugger import GdbRemoteDebugger
+from .micro_library import MicroLibrary
+from .micro_binary import MicroBinary
+from .session import create_local_graph_runtime, Session
+from .transport import TransportLogger, DebugWrapperTransport, SubprocessTransport
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""""Defines abstractions around compiler artifacts produced in compiling micro TVM binaries."""
+
+import io
+import os
+import json
+import shutil
+import tarfile
+
+
+class ArtifactFileNotFoundError(Exception):
+ """Raised when an artifact file cannot be found on disk."""
+
+
+class ArtifactBadSymlinkError(Exception):
+ """Raised when an artifact symlink points outside the base directory."""
+
+
+class ArtifactBadArchiveError(Exception):
+ """Raised when an artifact archive is malformed."""
+
+
+class Artifact:
+ """Describes a compiler artifact and defines common logic to archive it for transport."""
+
+ # A version number written to the archive.
+ ENCODING_VERSION = 1
+
+ # A unique string identifying the type of artifact in an archive. Subclasses must redefine this
+ # variable.
+ ARTIFACT_TYPE = None
+
+ @classmethod
+ def unarchive(cls, archive_path, base_dir):
+ """Unarchive an artifact into base_dir.
+
+ Parameters
+ ----------
+ archive_path : str
+ Path to the archive file.
+ base_dir : str
+ Path to a non-existent, empty directory under which the artifact will live.
+
+ Returns
+ -------
+ Artifact :
+ The unarchived artifact.
+ """
+ if os.path.exists(base_dir):
+ raise ValueError(f'base_dir exists: {base_dir}')
+
+ base_dir_parent, base_dir_name = os.path.split(base_dir)
+ temp_dir = os.path.join(base_dir_parent, f'__tvm__{base_dir_name}')
+ os.mkdir(temp_dir)
+ try:
+ with tarfile.open(archive_path) as tar_f:
+ tar_f.extractall(temp_dir)
+
+ temp_dir_contents = os.listdir(temp_dir)
+ if len(temp_dir_contents) != 1:
+ raise ArtifactBadArchiveError(
+ 'Expected exactly 1 subdirectory at root of archive, got '
+ f'{temp_dir_contents!r}')
+
+ metadata_path = os.path.join(temp_dir, temp_dir_contents[0], 'metadata.json')
+ if not metadata_path:
+ raise ArtifactBadArchiveError('No metadata.json found in archive')
+
+ with open(metadata_path) as metadata_f:
+ metadata = json.load(metadata_f)
+
+ version = metadata.get('version')
+ if version != cls.ENCODING_VERSION:
+ raise ArtifactBadArchiveError(
+ f'archive version: expect {cls.EXPECTED_VERSION}, found {version}')
+
+ os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir)
+
+ artifact_cls = cls
+ for sub_cls in cls.__subclasses__():
+ if (sub_cls.ARTIFACT_TYPE is not None and
+ sub_cls.ARTIFACT_TYPE == metadata.get('artifact_type')):
+ artifact_cls = sub_cls
+ break
+
+ return artifact_cls.from_unarchived(
+ base_dir, metadata['labelled_files'], metadata['metadata'])
+ finally:
+ shutil.rmtree(temp_dir)
+
+ @classmethod
+ def from_unarchived(cls, base_dir, labelled_files, metadata):
+ return cls(base_dir, labelled_files, metadata)
+
+ def __init__(self, base_dir, labelled_files, metadata):
+ """Create a new artifact.
+
+ Parameters
+ ----------
+ base_dir : str
+ The path to a directory on disk which contains all the files in this artifact.
+ labelled_files : Dict[str, str]
+ A dict mapping a file label to the relative paths of the files that carry that label.
+ metadata : Dict
+ A dict containing artitrary JSON-serializable key-value data describing the artifact.
+ """
+ self.base_dir = os.path.realpath(base_dir)
+ self.labelled_files = labelled_files
+ self.metadata = metadata
+
+ for label, files in labelled_files.items():
+ for f in files:
+ f_path = os.path.join(self.base_dir, f)
+ if not os.path.lexists(f_path):
+ raise ArtifactFileNotFoundError(f'{f} (label {label}): not found at {f_path}')
+
+ if os.path.islink(f_path):
+ link_path = os.path.readlink(f_path)
+ if os.path.isabs(link_path):
+ link_fullpath = link_path
+ else:
+ link_fullpath = os.path.join(os.path.dirname(f_path), link_path)
+
+ link_fullpath = os.path.realpath(link_fullpath)
+ if not link_fullpath.startswith(self.base_dir):
+ raise ArtifactBadSymlinkError(
+ f'{f} (label {label}): symlink points outside artifact tree')
+
+ def abspath(self, rel_path):
+ """Return absolute path to the member with the given relative path."""
+ return os.path.join(self.base_dir, rel_path)
+
+ def label(self, label):
+ """Return a list of relative paths to files with the given label."""
+ return self.labelled_files[label]
+
+ def label_abspath(self, label):
+ return [self.abspath(p) for p in self.labelled_files[label]]
+
+ def archive(self, archive_path):
+ """Create a relocatable tar archive of the artifacts.
+
+ Parameters
+ ----------
+ archive_path : str
+ Path to the tar file to create. Or, path to a directory, under which a tar file will be
+ created named {base_dir}.tar.
+
+ Returns
+ -------
+ str :
+ The value of archive_path, after potentially making the computation describe above.
+ """
+ if os.path.isdir(archive_path):
+ archive_path = os.path.join(archive_path, f'{os.path.basename(self.base_dir)}.tar')
+
+ archive_name = os.path.splitext(os.path.basename(archive_path))[0]
+ with tarfile.open(archive_path, 'w') as tar_f:
+ def _add_file(name, data, f_type):
+ tar_info = tarfile.TarInfo(name=name)
+ tar_info.type = f_type
+ data_bytes = bytes(data, 'utf-8')
+ tar_info.size = len(data)
+ tar_f.addfile(tar_info, io.BytesIO(data_bytes))
+
+ _add_file(f'{archive_name}/metadata.json',
+ json.dumps({'version': self.ENCODING_VERSION,
+ 'labelled_files': self.labelled_files,
+ 'metadata': self.metadata},
+ indent=2,
+ sort_keys=True),
+ tarfile.REGTYPE)
+ for dir_path, _, files in os.walk(self.base_dir):
+ for f in files:
+ file_path = os.path.join(dir_path, f)
+ archive_file_path = os.path.join(
+ archive_name, os.path.relpath(file_path, self.base_dir))
+ if not os.path.islink(file_path):
+ tar_f.add(file_path, archive_file_path, recursive=False)
+ continue
+
+ link_path = os.readlink(file_path)
+ if not os.path.isabs(link_path):
+ tar_f.add(file_path, archive_file_path, recursive=False)
+ continue
+
+ relpath = os.path.relpath(link_path, os.path.dirname(file_path))
+ _add_file(archive_file_path, relpath, tarfile.LNKTYPE)
+
+ return archive_path
# under the License.
"""Base definitions for MicroTVM"""
-from __future__ import absolute_import
-
-import os
-import re
-import sys
-from enum import Enum
-
import tvm
import tvm._ffi
-from tvm.contrib import util as _util
-from tvm.contrib import cc as _cc
-
-# all sections that comprise a device's memory layout, in order from lowest
-# starting address to highest
-DEVICE_SECTIONS = [
- "text",
- "rodata",
- "data",
- "bss",
- "args",
- "heap",
- "workspace",
- "stack",
-]
-
-
-class LibType(Enum):
- """Enumeration of library types that can be compiled and loaded onto a device"""
-
- # library to be used as a MicroTVM runtime
- RUNTIME = 0
- # library to be used as an operator
- OPERATOR = 1
-
-
-class Session:
- """MicroTVM Device Session
-
- Parameters
- ----------
- config : dict
- configuration for this session (as generated by
- `tvm.micro.device.host.default_config()`, for example)
-
- Example
- --------
- .. code-block:: python
-
- c_mod = ... # some module generated with "c" as the target
- dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666)
- with tvm.micro.Session(dev_config) as sess:
- micro_mod = sess.create_micro_mod(c_mod)
- """
-
- def __init__(self, config):
- self._check_system()
- # TODO(weberlo): add config validation
-
- # grab a binutil instance from the ID in the config
- dev_funcs = tvm.micro.device.get_device_funcs(config["device_id"])
- self.toolchain_prefix = config["toolchain_prefix"]
- self.mem_layout = config["mem_layout"]
- self.word_size_bits = config["word_size_bits"]
- self.thumb_mode = config["thumb_mode"]
- self.use_device_timer = config["use_device_timer"]
- self.comms_method = config["comms_method"]
-
- # First, find and compile runtime library.
- runtime_src_path = os.path.join(get_micro_host_driven_dir(), "utvm_runtime.c")
- tmp_dir = _util.tempdir()
- runtime_obj_path = tmp_dir.relpath("utvm_runtime.obj")
- options = ["-I{}".format(get_micro_host_driven_dir())]
- dev_funcs["create_micro_lib"](
- runtime_obj_path, runtime_src_path, LibType.RUNTIME, options=options
- )
-
- comms_method = config["comms_method"]
- if comms_method == "openocd":
- server_addr = config["server_addr"]
- server_port = config["server_port"]
- elif comms_method == "host":
- server_addr = ""
- server_port = 0
- else:
- raise RuntimeError(f"unknown communication method: f{self.comms_method}")
-
- assert all(
- map(lambda sec: sec in self.mem_layout, DEVICE_SECTIONS)
- ), "not all sections have an assigned memory layout"
- self.module = _CreateSession(
- comms_method,
- runtime_obj_path,
- self.toolchain_prefix,
- self.mem_layout["text"].get("start", 0),
- self.mem_layout["text"]["size"],
- self.mem_layout["rodata"].get("start", 0),
- self.mem_layout["rodata"]["size"],
- self.mem_layout["data"].get("start", 0),
- self.mem_layout["data"]["size"],
- self.mem_layout["bss"].get("start", 0),
- self.mem_layout["bss"]["size"],
- self.mem_layout["args"].get("start", 0),
- self.mem_layout["args"]["size"],
- self.mem_layout["heap"].get("start", 0),
- self.mem_layout["heap"]["size"],
- self.mem_layout["workspace"].get("start", 0),
- self.mem_layout["workspace"]["size"],
- self.mem_layout["stack"].get("start", 0),
- self.mem_layout["stack"]["size"],
- self.word_size_bits,
- self.thumb_mode,
- self.use_device_timer,
- server_addr,
- server_port,
- config.get("debug_func"),
- )
- self._enter = self.module["enter"]
- self._exit = self.module["exit"]
- self.get_last_batch_time = self.module["get_last_batch_time"]
- self.get_last_batch_cycles = self.module["get_last_batch_cycles"]
-
- def _check_system(self):
- """Check if the user's system is supported by MicroTVM.
-
- Raises error if not supported.
- """
- if not sys.platform.startswith("linux"):
- raise RuntimeError("MicroTVM is currently only supported on Linux")
- # TODO(weberlo): Add 32-bit support.
- # It's primarily the compilation pipeline that isn't compatible.
- if sys.maxsize <= 2 ** 32:
- raise RuntimeError("MicroTVM is currently only supported on 64-bit host platforms")
-
- def __enter__(self):
- self._enter()
- return self
-
- def __exit__(self, exc_type, exc_value, exc_traceback):
- self._exit()
-
-
-def _calc_max_workspace_usage(src):
- # TODO factor in alignment to the calculation (alloc sizes will be aligned up to the word size)
- alloc_re = re.compile(
- r".*\* ?(.+) = (\(.+\))? TVMBackendAllocWorkspace\(.+, .+, \(uint64_t\)(.+), .+, .+\).*"
- )
- free_re = re.compile(r".*if \(TVMBackendFreeWorkspace\(.+, .+, (\(void\*\))? (.+)\) != 0\) {.*")
- max_usage = 0
- alloc_map = {}
- for line in src.split("\n"):
- if line.strip().startswith("//"):
- continue
- match = alloc_re.match(line)
- if match is not None:
- alloc_map[match.group(1)] = int(match.group(3))
- max_usage = max(max_usage, sum(alloc_map.values()))
- else:
- match = free_re.match(line)
- if match is not None:
- print(alloc_map)
- del alloc_map[match.group(2)]
- return max_usage
-
-
-def create_micro_mod(
- c_mod, dev_config, lib_src_paths=None, lib_headers=None, lib_include_paths=None
-):
- """Produces a micro module from a given module.
-
- Parameters
- ----------
- c_mod : tvm.module.Module
- module with "c" as its target backend
-
- lib_src_paths: TODO
- TODO
-
- lib_headers: TODO
- TODO
-
- lib_include_paths: TODO
- TODO
-
- Return
- ------
- micro_mod : tvm.module.Module
- micro module for the target device
- """
- temp_dir = _util.tempdir()
- lib_obj_path = temp_dir.relpath("dev_lib.obj")
- # TODO use dev config to dispatch on the type of C codegen to run through
- # (e.g., CodeGenCArm, CodeGenCHost, CodeGenCRiscV)
- c_mod.export_library(
- lib_obj_path,
- fcompile=cross_compiler(
- dev_config,
- LibType.OPERATOR,
- lib_src_paths=lib_src_paths,
- lib_headers=lib_headers,
- lib_include_paths=lib_include_paths,
- ),
- )
- micro_mod = tvm.runtime.load_module(lib_obj_path)
- return micro_mod
-
-
-def cross_compiler(
- dev_config, lib_type, lib_src_paths=None, lib_headers=None, lib_include_paths=None
-):
- """Create a cross compile function that wraps `create_lib` for a `Binutil` instance.
-
- For use in `tvm.runtime.Module.export_library`.
-
- Parameters
- ----------
- create_micro_lib : func
- function for creating MicroTVM libraries for a specific device (e.g.,
- `tvm.micro.device.get_device_funcs('arm.stm32f746xx')['create_micro_lib']`)
-
- lib_type : micro.LibType
- whether to compile a MicroTVM runtime or operator library
-
- lib_src_paths: TODO
- TODO
-
- lib_headers: TODO
- e.g., `['cmsis_gcc.h', 'arm_math.h']`
-
- lib_include_paths: TODO
- TODO
-
- Return
- ------
- func : Callable[[str, str, Optional[str]], None]
- cross compile function taking a destination path for the object file
- and a path for the input source file.
-
- Example
- --------
- .. code-block:: python
-
- c_mod = ... # some module generated with "c" as the target
- fcompile = tvm.micro.cross_compiler(dev_config, LibType.OPERATOR)
- c_mod.export_library('dev_lib.obj', fcompile=fcompile)
- """
- assert (lib_headers is None) == (
- lib_include_paths is None
- ), "must specify both `lib_headers` and `lib_include_paths` or neither"
-
- if lib_src_paths is None:
- lib_src_paths = []
- if lib_include_paths is None:
- lib_include_paths = []
- include_options = []
- for include_path in lib_include_paths:
- include_options.append("-I")
- include_options.append(include_path)
- create_micro_lib = tvm.micro.device.get_device_funcs(dev_config["device_id"])[
- "create_micro_lib"
- ]
- mem_layout = dev_config["mem_layout"]
-
- def compile_func(obj_path, src_path, **kwargs):
- if isinstance(obj_path, list):
- obj_path = obj_path[0]
- if isinstance(src_path, list):
- src_path = src_path[0]
- options = kwargs.get("options", [])
- options += include_options
-
- # check that workspace allocations don't exceed available workspace memory
- with open(src_path) as f:
- src_contents = f.read()
- max_ws_usage = _calc_max_workspace_usage(src_contents)
- available_mem = mem_layout["workspace"]["size"]
- if max_ws_usage > available_mem:
- raise RuntimeError(
- f"workspace allocations in library ({max_ws_usage}) "
- f"exceed available memory ({available_mem})"
- )
- # inject headers into new source path, if requested
- if lib_headers:
- headers_to_inject = "\n".join(map(lambda s: f"#include <{s}>", lib_headers)) + "\n"
- new_src_contents = headers_to_inject + src_contents
- tmp_dir = _util.tempdir()
- src_path = tmp_dir.relpath(os.path.basename(src_path))
- with open(src_path, "w") as f:
- f.write(new_src_contents)
-
- create_micro_lib(obj_path, src_path, lib_type, options, lib_src_paths=lib_src_paths)
-
- return _cc.cross_compiler(compile_func, output_format="obj")
-
-
-def get_micro_host_driven_dir():
- """Get directory path for uTVM host-driven runtime source files.
-
- Return
- ------
- micro_device_dir : str
- directory path
- """
- micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
- micro_host_driven_dir = os.path.join(
- micro_dir, "..", "..", "..", "src", "runtime", "micro", "host_driven"
- )
- return micro_host_driven_dir
-
-
-def get_micro_device_dir():
- """Get directory path for parent directory of device-specific source files
-
- Return
- ------
- micro_device_dir : str
- directory path
- """
- micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
- micro_device_dir = os.path.join(
- micro_dir, "..", "..", "..", "src", "runtime", "micro", "device"
- )
- return micro_device_dir
-
-
tvm._ffi._init_api("tvm.micro", "tvm.micro.base")
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines top-level glue functions for building microTVM artifacts."""
+
+import copy
+import logging
+import os
+import re
+from tvm.contrib import util
+
+
+_LOG = logging.getLogger(__name__)
+
+
+class Workspace:
+ """Defines helper functions for manipulating temporary compilation workspaces."""
+
+ def __init__(self, root=None, debug=False):
+ if debug or root is not None:
+ with util.TempDirectory.set_keep_for_debug():
+ self.tempdir = util.tempdir(custom_path=root)
+ _LOG.info('Created debug mode workspace at: %s', self.tempdir.temp_dir)
+ else:
+ self.tempdir = util.tempdir()
+
+ def relpath(self, path):
+ return self.tempdir.relpath(path)
+
+ def listdir(self):
+ return self.tempdir.listdir()
+
+ @property
+ def path(self):
+ return self.tempdir.temp_dir
+
+
+# Required C runtime libraries, in link order.
+CRT_RUNTIME_LIB_NAMES = ['utvm_rpc_server', 'utvm_rpc_common', 'common']
+
+
+TVM_ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
+
+
+CRT_ROOT_DIR = os.path.join(TVM_ROOT_DIR, 'src', 'runtime', 'crt')
+
+
+RUNTIME_LIB_SRC_DIRS = (
+ [os.path.join(CRT_ROOT_DIR, n) for n in CRT_RUNTIME_LIB_NAMES] +
+ [os.path.join(TVM_ROOT_DIR,
+ '3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/'
+ 'libraries/crc16')])
+
+
+RUNTIME_SRC_REGEX = re.compile(r'^.*\.cc?$', re.IGNORECASE)
+
+
+_CRT_DEFAULT_OPTIONS = {
+ 'ccflags': ['-std=c++11'],
+ 'ldflags': ['-std=gnu++14'],
+ 'include_dirs': [
+ f'{TVM_ROOT_DIR}/include',
+ f'{TVM_ROOT_DIR}/3rdparty/dlpack/include',
+ f'{TVM_ROOT_DIR}/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/'
+ 'TARGET_SDK_11/libraries/crc16/',
+ f'{TVM_ROOT_DIR}/3rdparty/dmlc-core/include',
+ f'{CRT_ROOT_DIR}/include'
+ ],
+ 'profile': {
+ 'common': ['-Wno-unused-variable']
+ }
+}
+
+
+def default_options(target_include_dir):
+ """Return default opts passed to Compile commands."""
+ bin_opts = copy.deepcopy(_CRT_DEFAULT_OPTIONS)
+ bin_opts['include_dirs'].append(target_include_dir)
+ lib_opts = copy.deepcopy(bin_opts)
+ lib_opts['profile']['common'].append('-Werror')
+ lib_opts['cflags'] = ['-Wno-error=incompatible-pointer-types']
+ return {'bin_opts': bin_opts, 'lib_opts': lib_opts}
+
+
+def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=None):
+ """Build the on-device runtime, statically linking the given modules.
+
+ Parameters
+ ----------
+ compiler : tvm.micro.Compiler
+ Compiler instance used to build the runtime.
+
+ module : IRModule
+ Module to statically link.
+
+ lib_opts : dict
+ Extra kwargs passed to library(),
+
+ bin_opts : dict
+ Extra kwargs passed to binary(),
+
+ Returns
+ -------
+ MicroBinary :
+ The compiled runtime.
+ """
+ lib_opts = _CRT_DEFAULT_OPTIONS if lib_opts is None else lib_opts
+ bin_opts = _CRT_DEFAULT_OPTIONS if bin_opts is None else bin_opts
+
+ mod_build_dir = workspace.relpath(os.path.join('build', 'module'))
+ os.makedirs(mod_build_dir)
+ mod_src_dir = workspace.relpath(os.path.join('src', 'module'))
+ os.makedirs(mod_src_dir)
+ mod_src_path = os.path.join(mod_src_dir, 'module.c')
+ module.save(mod_src_path, 'cc')
+
+ libs = []
+ for lib_src_dir in RUNTIME_LIB_SRC_DIRS:
+ lib_name = os.path.basename(lib_src_dir)
+ lib_build_dir = workspace.relpath(f'build/{lib_name}')
+ os.makedirs(lib_build_dir)
+
+ lib_srcs = []
+ for p in os.listdir(lib_src_dir):
+ if RUNTIME_SRC_REGEX.match(p):
+ lib_srcs.append(os.path.join(lib_src_dir, p))
+
+ libs.append(compiler.library(lib_build_dir, lib_srcs, lib_opts))
+
+ libs.append(compiler.library(mod_build_dir, [mod_src_path], lib_opts))
+
+ runtime_build_dir = workspace.relpath(f'build/runtime')
+ os.makedirs(runtime_build_dir)
+ return compiler.binary(runtime_build_dir, libs, bin_opts)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines a utility for representing deferred class instatiations as JSON."""
+
+import importlib
+import json
+import typing
+
+
+JsonSerializable = typing.Union[int, float, str, None, bool]
+
+
+class SerializedFactoryError(Exception):
+ """Raised when ClassFactory.from_json is invoked with an invalid JSON blob."""
+
+
+class ClassFactory:
+ """Describes a JSON-serializable class instantiation, for use with the RPC server."""
+
+ # When not None, the superclass from which all cls must derive.
+ SUPERCLASS = None
+
+ def __init__(self, cls: typing.Callable, init_args: typing.List[JsonSerializable],
+ init_kw: typing.Dict[str, JsonSerializable]):
+ self.cls = cls
+ self.init_args = init_args
+ self.init_kw = init_kw
+
+ def override_kw(self, **kw_overrides):
+ kwargs = self.init_kw
+ if kw_overrides:
+ kwargs = dict(kwargs)
+ for k, v in kw_overrides.items():
+ kwargs[k] = v
+
+ return self.__class__(self.cls, self.init_args, kwargs)
+
+ def instantiate(self):
+ return self.cls(*self.init_args, **self.init_kw)
+
+ @property
+ def to_json(self):
+ return json.dumps({
+ 'cls': '.'.join([self.cls.__module__, self.cls.__name__]),
+ 'init_args': self.init_args,
+ 'init_kw': self.init_kw,
+ })
+
+ EXPECTED_KEYS = ('cls', 'init_args', 'init_kw')
+
+ @classmethod
+ def from_json(cls, data):
+ """Reconstruct a ClassFactory instance from its JSON representation.
+
+ Parameters
+ ----------
+ data : str
+ The JSON representation of the ClassFactory.
+
+ Returns
+ -------
+ ClassFactory :
+ The reconstructed ClassFactory instance.
+
+ Raises
+ ------
+ SerializedFactoryError :
+ If the JSON object represented by `data` is malformed.
+ """
+ obj = json.loads(data)
+ if not isinstance(obj, dict):
+ raise SerializedFactoryError(f'deserialized json payload: want dict, got: {obj!r}')
+
+ for key in cls.EXPECTED_KEYS:
+ if key not in obj:
+ raise SerializedFactoryError(
+ f'deserialized json payload: expect key {key}, got: {obj!r}')
+
+ cls_package_name, cls_name = obj['cls'].rsplit('.', 1)
+ cls_package = importlib.import_module(cls_package_name)
+ cls_obj = getattr(cls_package, cls_name)
+ return cls(cls_obj, obj['init_args'], obj['init_kw'])
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines interfaces and default implementations for compiling and flashing code."""
+
+import abc
+import glob
+import os
+import re
+
+from tvm.contrib import binutil
+import tvm.target
+from . import build
+from . import class_factory
+from . import debugger
+from . import transport
+
+
+class DetectTargetError(Exception):
+ """Raised when no target comment was detected in the sources given."""
+
+
+class NoDefaultToolchainMatchedError(Exception):
+ """Raised when no default toolchain matches the target string."""
+
+
+class Compiler(metaclass=abc.ABCMeta):
+ """The compiler abstraction used with micro TVM."""
+
+ TVM_TARGET_RE = re.compile(r'^// tvm target: (.*)$')
+
+ @classmethod
+ def _target_from_sources(cls, sources):
+ """Determine the target used to generate the given source files.
+
+ Parameters
+ ----------
+ sources : List[str]
+ The paths to source files to analyze.
+
+ Returns
+ -------
+ tvm.target.Target :
+ A Target instance reconstructed from the target string listed in the source files.
+ """
+ target_strs = set()
+
+ for obj in sources:
+ with open(obj) as obj_f:
+ for line in obj_f:
+ m = cls.TVM_TARGET_RE.match(line)
+ if m:
+ target_strs.add(m.group(1))
+
+ if len(target_strs) != 1:
+ raise DetectTargetError(
+ 'autodetecting cross-compiler: could not extract TVM target from C source; regex '
+ f'{cls.TVM_TARGET_RE.pattern} does not match any line in sources: '
+ f'{", ".join(sources)}')
+
+ target_str = next(iter(target_strs))
+ return tvm.target.create(target_str)
+
+ # Maps regexes identifying CPUs to the default toolchain prefix for that CPU.
+ TOOLCHAIN_PREFIX_BY_CPU_REGEX = {
+ r'cortex-[am].*': 'arm-none-eabi-',
+ 'x86[_-]64': '',
+ 'native': '',
+ }
+
+ def _autodetect_toolchain_prefix(self, target):
+ matches = []
+ for regex, prefix in self.TOOLCHAIN_PREFIX_BY_CPU_REGEX.items():
+ if re.match(regex, target.attrs['mcpu']):
+ matches.append(prefix)
+
+ if matches:
+ if len(matches) != 1:
+ raise NoDefaultToolchainMatchedError(
+ f'{opt} matched more than 1 default toolchain prefix: {", ".join(matches)}. '
+ 'Specify cc.cross_compiler to create_micro_library()')
+
+ return matches[0]
+
+ raise NoDefaultToolchainMatchedError(
+ f'target {str(target)} did not match any default toolchains')
+
+ def _defaults_from_target(self, target):
+ """Determine the default compiler options from the target specified.
+
+ Parameters
+ ----------
+ target : tvm.target.Target
+
+ Returns
+ -------
+ List[str] :
+ Default options used the configure the compiler for that target.
+ """
+ opts = []
+ # TODO use march for arm(https://gcc.gnu.org/onlinedocs/gcc/ARM-Options.html)?
+ if target.attrs.get('mcpu'):
+ opts.append(f'-march={target.attrs["mcpu"]}')
+ if target.attrs.get('mfpu'):
+ opts.append(f'-mfpu={target.attrs["mfpu"]}')
+
+ return opts
+
+ @abc.abstractmethod
+ def library(self, output, sources, options=None):
+ """Build a library from the given source files.
+
+ Parameters
+ ----------
+ output : str
+ The path to the library that should be created. The containing directory
+ is guaranteed to be empty and should be the base_dir for the returned
+ Artifact.
+ sources : List[str]
+ A list of paths to source files that should be compiled.
+ options : Optional[List[str]]
+ If given, additional command-line flags to pass to the compiler.
+
+ Returns
+ -------
+ MicroLibrary :
+ The compiled library, as a MicroLibrary instance.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def binary(self, output, objects, options=None, link_main=True, main_options=None):
+ """Link a binary from the given object and/or source files.
+
+ Parameters
+ ----------
+ output : str
+ The path to the binary that should be created. The containing directory
+ is guaranteed to be empty and should be the base_dir for the returned
+ Artifact.
+ objects : List[MicroLibrary]
+ A list of paths to source files or libraries that should be compiled. The final binary
+ should be statically-linked.
+ options: Optional[List[str]]
+ If given, additional command-line flags to pass to the compiler.
+ link_main: Optional[bool]
+ True if the standard main entry point for this Compiler should be included in the
+ binary. False if a main entry point is provided in one of `objects`.
+ main_options: Optional[List[str]]
+ If given, additional command-line flags to pass to the compiler when compiling the
+ main() library. In some cases, the main() may be compiled directly into the final binary
+ along with `objects` for logistical reasons. In those cases, specifying main_options is
+ an error and ValueError will be raised.
+
+ Returns
+ -------
+ MicroBinary :
+ The compiled binary, as a MicroBinary instance.
+ """
+ raise NotImplementedError()
+
+ @property
+ def flasher_factory(self):
+ """Produce a FlasherFactory for a Flasher instance suitable for this Compiler."""
+ raise NotImplementedError("The Compiler base class doesn't define a flasher.")
+
+ def flasher(self, **kw):
+ """Return a Flasher that can be used to program a produced MicroBinary onto the target."""
+ return self.flasher_factory.override_kw(**kw).instantiate()
+
+
+class IncompatibleTargetError(Exception):
+ """Raised when source files specify a target that differs from the compiler target."""
+
+
+class DefaultCompiler(Compiler):
+ """A Compiler implementation that attempts to use the system-installed GCC."""
+
+ def __init__(self, target=None):
+ super(DefaultCompiler, self).__init__()
+ self.target = target
+ if isinstance(target, str):
+ self.target = tvm.target.create(target)
+
+ def library(self, output, sources, options=None):
+ options = options if options is not None else {}
+ try:
+ target = self._target_from_sources(sources)
+ except DetectTargetError:
+ assert self.target is not None, (
+ "Must specify target= to constructor when compiling sources which don't specify a "
+ "target")
+
+ target = self.target
+
+ if self.target is not None and str(self.target) != str(target):
+ raise IncompatibleTargetError(
+ f'auto-detected target {target} differs from configured {self.target}')
+
+ prefix = self._autodetect_toolchain_prefix(target)
+ outputs = []
+ for src in sources:
+ src_base, src_ext = os.path.splitext(os.path.basename(src))
+
+ compiler_name = {'.c': 'gcc', '.cc': 'g++', '.cpp': 'g++'}[src_ext]
+ args = [prefix + compiler_name, '-g']
+ args.extend(self._defaults_from_target(target))
+
+ args.extend(options.get(f'{src_ext[1:]}flags', []))
+
+ for include_dir in options.get('include_dirs', []):
+ args.extend(['-I', include_dir])
+
+ output_filename = f'{src_base}.o'
+ output_abspath = os.path.join(output, output_filename)
+ binutil.run_cmd(args + ['-c', '-o', output_abspath, src])
+ outputs.append(output_abspath)
+
+ output_filename = f'{os.path.basename(output)}.a'
+ output_abspath = os.path.join(output, output_filename)
+ binutil.run_cmd([prefix + 'ar', '-r', output_abspath] + outputs)
+ binutil.run_cmd([prefix + 'ranlib', output_abspath])
+
+ return tvm.micro.MicroLibrary(output, [output_filename])
+
+ def binary(self, output, objects, options=None, link_main=True, main_options=None):
+ assert self.target is not None, (
+ 'must specify target= to constructor, or compile sources which specify the target '
+ 'first')
+
+ args = [self._autodetect_toolchain_prefix(self.target) + 'g++']
+ args.extend(self._defaults_from_target(self.target))
+ if options is not None:
+ args.extend(options.get('ldflags', []))
+
+ for include_dir in options.get('include_dirs', []):
+ args.extend(['-I', include_dir])
+
+ output_filename = os.path.basename(output)
+ output_abspath = os.path.join(output, output_filename)
+ args.extend(['-g', '-o', output_abspath])
+
+ if link_main:
+ host_main_srcs = glob.glob(os.path.join(build.CRT_ROOT_DIR, 'host', '*.cc'))
+ if main_options:
+ main_lib = self.library(os.path.join(output, 'host'), host_main_srcs, main_options)
+ for lib_name in main_lib.library_files:
+ args.append(main_lib.abspath(lib_name))
+ else:
+ args.extend(host_main_srcs)
+
+ for obj in objects:
+ for lib_name in obj.library_files:
+ args.append(obj.abspath(lib_name))
+
+ binutil.run_cmd(args)
+ return tvm.micro.MicroBinary(output, output_filename, [])
+
+ @property
+ def flasher_factory(self):
+ return FlasherFactory(HostFlasher, [], {})
+
+
+class Flasher(metaclass=abc.ABCMeta):
+ """An interface for flashing binaries and returning a transport factory."""
+
+ @abc.abstractmethod
+ def flash(self, micro_binary):
+ """Flash a binary onto the device.
+
+ Parameters
+ ----------
+ micro_binary : MicroBinary
+ A MicroBinary instance.
+
+ Returns
+ -------
+ transport.TransportContextManager :
+ A ContextManager that can be used to create and tear down an RPC transport layer between
+ this TVM instance and the newly-flashed binary.
+ """
+ raise NotImplementedError()
+
+
+class FlasherFactory(class_factory.ClassFactory):
+ """A ClassFactory for Flasher instances."""
+
+ SUPERCLASS = Flasher
+
+
+class HostFlasher(Flasher):
+ """A Flasher implementation that spawns a subprocess on the host."""
+
+ def __init__(self, debug=False):
+ self.debug = debug
+
+ def flash(self, micro_binary):
+ if self.debug:
+ gdb_wrapper = debugger.GdbTransportDebugger(
+ [micro_binary.abspath(micro_binary.binary_file)])
+ return transport.DebugWrapperTransport(
+ debugger=gdb_wrapper, transport=gdb_wrapper.Transport())
+
+ return transport.SubprocessTransport([micro_binary.abspath(micro_binary.binary_file)])
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines functions for controlling debuggers for micro TVM binaries."""
+
+import abc
+import os
+import signal
+import subprocess
+import threading
+
+from . import transport as _transport
+
+
+class Debugger(metaclass=abc.ABCMeta):
+ """An interface for controlling micro TVM debuggers."""
+
+ def __init__(self):
+ self.on_terminate_callbacks = []
+
+ @abc.abstractmethod
+ def start(self):
+ """Start the debugger, but do not block on it.
+
+ The runtime will continue to be driven in the background.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def stop(self):
+ """Terminate the debugger."""
+ raise NotImplementedError()
+
+
+class GdbDebugger(Debugger):
+ """Handles launching, suspending signals, and potentially dealing with terminal issues."""
+
+ @abc.abstractmethod
+ def popen_kwargs(self):
+ raise NotImplementedError()
+
+ def _wait_restore_signal(self):
+ self.popen.wait()
+ if not self.did_terminate.is_set():
+ for callback in self.on_terminate_callbacks:
+ try:
+ callback()
+ except Exception: # pylint: disable=broad-except
+ logging.warn('on_terminate_callback raised exception', exc_info=True)
+
+ def start(self):
+ kwargs = self.popen_kwargs()
+ self.did_terminate = threading.Event()
+ self.old_signal = signal.signal(signal.SIGINT, signal.SIG_IGN)
+ self.popen = subprocess.Popen(**kwargs)
+ threading.Thread(target=self._WaitRestoreSignal).start()
+
+ def stop(self):
+ self.did_terminate.set()
+ self.popen.terminate()
+ signal.signal(signal.SIGINT, self.old_signal)
+
+
+class GdbTransportDebugger(GdbDebugger):
+ """A debugger that uses a single GDB subprocess as both the transport and the debugger.
+
+ Opens pipes for the target's stdin and stdout, launches GDB and configures GDB's target
+ arguments to read and write from the pipes using /dev/fd.
+ """
+
+ def __init__(self, args, **popen_kw):
+ super(GdbTransportDebugger, self).__init__()
+ self.args = args
+ self.popen_kw = popen_kw
+
+ def popen_kwargs(self):
+ stdin_read, stdin_write = os.pipe()
+ stdout_read, stdout_write = os.pipe()
+
+ os.set_inheritable(stdin_read, True)
+ os.set_inheritable(stdout_write, True)
+
+ sysname = os.uname()[0]
+ if sysname == 'Darwin':
+ args = ['lldb',
+ '-O', f'target create {self.args[0]}',
+ '-O', f'settings set target.input-path /dev/fd/{stdin_read}',
+ '-O', f'settings set target.output-path /dev/fd/{stdout_write}']
+ if len(self.args) > 1:
+ args.extend(
+ ['-O', 'settings set target.run-args {}'.format(' '.join(self.args[1:]))])
+ elif sysname == 'Linux':
+ args = (['gdb', '--args'] +
+ self.args +
+ ['</dev/fd/{stdin_read}', '>/dev/fd/{stdout_write}'])
+ else:
+ raise NotImplementedError(f'System {sysname} is not yet supported')
+
+ self.stdin = os.fdopen(stdin_write, 'wb', buffering=0)
+ self.stdout = os.fdopen(stdout_read, 'rb', buffering=0)
+
+ return {
+ 'args': args,
+ 'pass_fds': [stdin_read, stdout_write],
+ }
+
+ def _wait_for_process_death(self):
+ self.popen.wait()
+ self.stdin.close()
+ self.stdout.close()
+
+ def start(self):
+ to_return = super(GdbTransportDebugger, self).Start()
+ threading.Thread(target=self._wait_for_process_death, daemon=True).start()
+ return to_return
+
+ def stop(self):
+ self.stdin.close()
+ self.stdout.close()
+ super(GdbTransportDebugger, self).Stop()
+
+ class _Transport(_transport.Transport):
+ def __init__(self, gdb_transport_debugger):
+ self.gdb_transport_debugger = gdb_transport_debugger
+
+ def open(self):
+ pass # Pipes opened by parent class.
+
+ def write(self, data):
+ return self.gdb_transport_debugger.stdin.write(data)
+
+ def read(self, n):
+ return self.gdb_transport_debugger.stdout.read(n)
+
+ def close(self):
+ pass # Pipes closed by parent class.
+
+ def transport(self):
+ return self._Transport(self)
+
+
+class GdbRemoteDebugger(GdbDebugger):
+ """A Debugger that invokes GDB and attaches to a remote GDBserver-based target."""
+
+ def __init__(self, gdb_binary, remote_hostport, debug_binary, wrapping_context_manager=None,
+ **popen_kw):
+ super(GdbRemoteDebugger, self).__init__()
+ self.gdb_binary = gdb_binary
+ self.remote_hostport = remote_hostport
+ self.debug_binary = debug_binary
+ self.wrapping_context_manager = wrapping_context_manager
+ self.popen_kw = popen_kw
+
+ def popen_kwargs(self):
+ kwargs = {
+ 'args': [self.gdb_binary,
+ '-iex', f'file {self.debug_binary}',
+ '-iex', f'target remote {self.remote_hostport}'],
+ }
+ kwargs.update(self.popen_kw)
+
+ return kwargs
+
+ def start(self):
+ if self.wrapping_context_manager is not None:
+ self.wrapping_context_manager.__enter__()
+ super(GdbRemoteDebugger, self).Start()
+
+ def stop(self):
+ try:
+ super(GdbRemoteDebugger, self).Stop()
+ finally:
+ if self.wrapping_context_manager is not None:
+ self.wrapping_context_manager.__exit__(None, None, None)
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Device-specific configuration for MicroTVM"""
-
-from .base import create_micro_lib_base, gen_mem_layout
-from .base import MemConstraint, register_device, get_device_funcs
-from . import host
-from . import arm
-from . import riscv_spike
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Base module for ARM device configurations"""
-
-from . import stm32f746xx
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Compilation and config definitions for Arm STM32F746XX devices"""
-import os
-from .. import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint
-
-DEVICE_ID = "arm.stm32f746xx"
-TOOLCHAIN_PREFIX = "arm-none-eabi-"
-WORD_SIZE_BITS = 32
-#
-# [Device Memory Layout]
-# RAM (rwx) : START = 0x20000000, LENGTH = 320K
-# Flash (rx) : START = 0x8000000, LENGTH = 1024K
-#
-BASE_ADDR = 0x20000000
-AVAILABLE_MEM = 320000
-DEFAULT_SECTION_CONSTRAINTS = {
- "text": (18000, MemConstraint.ABSOLUTE_BYTES),
- "rodata": (512, MemConstraint.ABSOLUTE_BYTES),
- "data": (100, MemConstraint.ABSOLUTE_BYTES),
- "bss": (640, MemConstraint.ABSOLUTE_BYTES),
- "args": (4096, MemConstraint.ABSOLUTE_BYTES),
- "heap": (100.0, MemConstraint.WEIGHT),
- "workspace": (64000, MemConstraint.ABSOLUTE_BYTES),
- "stack": (32, MemConstraint.ABSOLUTE_BYTES),
-}
-
-
-def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None):
- """Wrapper over `create_micro_lib_base` to add device-specific options
-
- Parameters
- ----------
- obj_path : str
- path to generated object file
-
- src_path : str
- path to source file
-
- lib_type : micro.LibType
- whether to compile a MicroTVM runtime or operator library
-
- options : Optional[List[str]]
- additional options to pass to GCC
-
- lib_src_paths : Optional[List[str]]
- TODO
- """
- if options is None:
- options = []
- else:
- options = list(options)
-
- options += [
- # TODO(weberlo): make a debug flag
- "-O2",
- "-mcpu=cortex-m7",
- "-mlittle-endian",
- "-mfloat-abi=hard",
- "-mfpu=fpv5-sp-d16",
- "-mthumb",
- "-ffast-math",
- "-gdwarf-5",
- "-DARM_MATH_CM7",
- "-D__FPU_PRESENT=1U",
- "-DARM_MATH_DSP",
- "-Wno-unused-variable",
- "-Wno-unused-parameter",
- "-I{}".format(os.environ["CMSIS_ST_PATH"]),
- "-I{}/Core/Include".format(os.environ["CMSIS_ST_PATH"]),
- ]
- create_micro_lib_base(
- obj_path,
- src_path,
- TOOLCHAIN_PREFIX,
- DEVICE_ID,
- lib_type,
- options=options,
- lib_src_paths=lib_src_paths,
- )
-
-
-def generate_config(server_addr, server_port, section_constraints=None):
- """Generates a configuration for Arm STM32F746XX devices
-
- Parameters
- ----------
- server_addr : str
- address of OpenOCD server to connect to
-
- server_port : int
- port of OpenOCD server to connect to
-
- section_constraints: Optional[Dict[str, [Number, MemConstraint]]]
- maps section name to the quantity of available memory
-
- Return
- ------
- config : Dict[str, Any]
- MicroTVM config dict for this device
- """
- if section_constraints is None:
- section_constraints = DEFAULT_SECTION_CONSTRAINTS
- return {
- "device_id": DEVICE_ID,
- "toolchain_prefix": TOOLCHAIN_PREFIX,
- "mem_layout": gen_mem_layout(BASE_ADDR, AVAILABLE_MEM, WORD_SIZE_BITS, section_constraints),
- "word_size_bits": WORD_SIZE_BITS,
- "thumb_mode": True,
- "use_device_timer": True,
- "comms_method": "openocd",
- "server_addr": server_addr,
- "server_port": server_port,
- }
-
-
-register_device(
- DEVICE_ID,
- {
- "create_micro_lib": create_micro_lib,
- "generate_config": generate_config,
- },
-)
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Base definitions for MicroTVM config"""
-import glob
-import os
-import enum
-import pathlib
-
-from tvm.contrib import util as _util
-from tvm.contrib.binutil import run_cmd
-from tvm._ffi.libinfo import find_include_path
-from tvm.micro import DEVICE_SECTIONS, LibType, get_micro_host_driven_dir, get_micro_device_dir
-
-_DEVICE_REGISTRY = {}
-
-
-def register_device(device_id, device_funcs):
- """Register a device and associated compilation/config functions
-
- Parameters
- ----------
- device_id : str
- unique identifier for the device
-
- device_funcs : Dict[str, func]
- dictionary with compilation and config generation functions as values
- """
- if device_id in _DEVICE_REGISTRY:
- raise RuntimeError(f'"{device_id}" already exists in the device registry')
- _DEVICE_REGISTRY[device_id] = device_funcs
-
-
-def get_device_funcs(device_id):
- """Get compilation and config generation functions for device
-
- Parameters
- ----------
- device_id : str
- unique identifier for the device
-
- Return
- ------
- device_funcs : Dict[str, func]
- dictionary with compilation and config generation functions as values
- """
- if device_id not in _DEVICE_REGISTRY:
- raise RuntimeError(f'"{device_id}" does not exist in the binutil registry')
- device_funcs = _DEVICE_REGISTRY[device_id]
- return device_funcs
-
-
-def create_micro_lib_base(
- out_obj_path,
- in_src_path,
- toolchain_prefix,
- device_id,
- lib_type,
- options=None,
- lib_src_paths=None,
-):
- """Compiles code into a binary for the target micro device.
-
- Parameters
- ----------
- out_obj_path : str
- path to generated object file
-
- in_src_path : str
- path to source file
-
- toolchain_prefix : str
- toolchain prefix to be used. For example, a prefix of
- "riscv64-unknown-elf-" means "riscv64-unknown-elf-gcc" is used as
- the compiler and "riscv64-unknown-elf-ld" is used as the linker,
- etc.
-
- device_id : str
- unique identifier for the target device
-
- lib_type : micro.LibType
- whether to compile a MicroTVM runtime or operator library
-
- options : List[str]
- additional options to pass to GCC
-
- lib_src_paths : Optional[List[str]]
- paths to additional source files to be compiled into the library
- """
- # look at these (specifically `strip`):
- # https://stackoverflow.com/questions/15314581/g-compiler-flag-to-minimize-binary-size
- base_compile_cmd = [
- f"{toolchain_prefix}gcc",
- "-std=c11",
- "-Wall",
- "-Wextra",
- "--pedantic",
- "-c",
- "-g",
- "-nostartfiles",
- "-nodefaultlibs",
- "-nostdlib",
- "-fdata-sections",
- "-ffunction-sections",
- ]
- if options is not None:
- base_compile_cmd += options
-
- src_paths = []
- include_paths = find_include_path() + [get_micro_host_driven_dir()]
- tmp_dir = _util.tempdir()
- # we need to create a new src file in the operator branch
- new_in_src_path = in_src_path
- if lib_type == LibType.RUNTIME:
- dev_dir = _get_device_source_dir(device_id)
-
- dev_src_paths = glob.glob(f"{dev_dir}/*.[csS]")
- # there needs to at least be a utvm_timer.c file
- assert dev_src_paths
- assert "utvm_timer.c" in map(os.path.basename, dev_src_paths)
-
- src_paths += dev_src_paths
- elif lib_type == LibType.OPERATOR:
- # create a temporary copy of the operator source, so we can inject the dev lib
- # header without modifying the original.
- temp_src_path = tmp_dir.relpath("temp.c")
- with open(in_src_path, "r") as f:
- src_lines = f.read().splitlines()
- src_lines.insert(0, '#include "utvm_device_dylib_redirect.c"')
- with open(temp_src_path, "w") as f:
- f.write("\n".join(src_lines))
- new_in_src_path = temp_src_path
- else:
- raise RuntimeError("unknown lib type")
-
- src_paths += [new_in_src_path]
-
- # add any src paths required by the operator
- if lib_src_paths is not None:
- src_paths += lib_src_paths
-
- # print(f"include paths: {include_paths}")
- for path in include_paths:
- base_compile_cmd += ["-I", path]
-
- prereq_obj_paths = []
- # print(src_paths)
- for src_path in src_paths:
- curr_obj_path = tmp_dir.relpath(pathlib.Path(src_path).with_suffix(".o").name)
- assert curr_obj_path not in prereq_obj_paths
- prereq_obj_paths.append(curr_obj_path)
- curr_compile_cmd = base_compile_cmd + [src_path, "-o", curr_obj_path]
- # TODO(weberlo): make compilation fail if there are any warnings
- run_cmd(curr_compile_cmd)
-
- ld_cmd = [f"{toolchain_prefix}ld", "-relocatable"]
- ld_cmd += prereq_obj_paths
- ld_cmd += ["-o", out_obj_path]
- run_cmd(ld_cmd)
-
-
-# TODO we shouldn't need an enum for this. too much bureaucracy.
-class MemConstraint(enum.Enum):
- """Represents a constraint on the device's memory layout"""
-
- ABSOLUTE_BYTES = 0
- WEIGHT = 1
-
-
-def gen_mem_layout(base_addr, available_mem, word_size_bits, section_constraints):
- """Template function to generate memory layout for devices.
-
- Parameters
- ----------
- base_addr: Number
- The address where usable memory begins on this device.
-
- available_mem: Number
- Available memory at base_addr, given in bytes.
-
- word_size_bits: Number
- Number of bits in one word on this device.
-
- section_constraints: Optional[Dict[str, [Number, MemConstraint]]]
- maps section name to the quantity of available memory
- """
- assert word_size_bits in (32, 64), "only 32- or 64-bit devices are supported now"
- word_size_bytes = word_size_bits // 8
- byte_sum = sum(
- x[0] for x in section_constraints.values() if x[1] == MemConstraint.ABSOLUTE_BYTES
- )
- weight_sum = sum(x[0] for x in section_constraints.values() if x[1] == MemConstraint.WEIGHT)
- assert byte_sum <= available_mem
- available_weight_mem = available_mem - byte_sum
-
- res = {}
- curr_addr = base_addr
- for section in DEVICE_SECTIONS:
- (val, cons_type) = section_constraints[section]
- if cons_type == MemConstraint.ABSOLUTE_BYTES:
- assert (
- val % word_size_bytes == 0
- ), f"constraint {val} for {section} section is not word-aligned"
- size = val
- res[section] = {
- "start": curr_addr,
- "size": size,
- }
- else:
- size = int((val / weight_sum) * available_weight_mem)
- size = (size // word_size_bytes) * word_size_bytes
- res[section] = {
- "start": curr_addr,
- "size": size,
- }
- curr_addr += size
-
- return res
-
-
-def _get_device_source_dir(device_id):
- """Grabs the source directory for device-specific uTVM files"""
- dev_subdir = "/".join(device_id.split("."))
- return get_micro_device_dir() + "/" + dev_subdir
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Compilation and config definitions for the host emulated device"""
-import sys
-
-from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint
-
-DEVICE_ID = "host"
-TOOLCHAIN_PREFIX = ""
-WORD_SIZE_BITS = 64 if sys.maxsize > 2 ** 32 else 32
-
-# we pretend we only have 320kb in the default case, so we can use `gen_mem_layout`
-DEFAULT_AVAILABLE_MEM = 3200000
-DEFAULT_SECTION_CONSTRAINTS = {
- "text": (20480, MemConstraint.ABSOLUTE_BYTES),
- "rodata": (20480, MemConstraint.ABSOLUTE_BYTES),
- "data": (768, MemConstraint.ABSOLUTE_BYTES),
- "bss": (4096, MemConstraint.ABSOLUTE_BYTES),
- "args": (4096, MemConstraint.ABSOLUTE_BYTES),
- "heap": (262144, MemConstraint.ABSOLUTE_BYTES),
- "workspace": (64000, MemConstraint.ABSOLUTE_BYTES),
- "stack": (80, MemConstraint.ABSOLUTE_BYTES),
-}
-
-
-def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None):
- """Wrapper over `create_micro_lib_base` to add device-specific options
-
- Parameters
- ----------
- obj_path : str
- path to generated object file
-
- src_path : str
- path to source file
-
- lib_type : micro.LibType
- whether to compile a MicroTVM runtime or operator library
-
- options : Optional[List[str]]
- additional options to pass to GCC
-
- lib_src_paths : Optional[List[str]]
- paths to additional source files to be compiled into the library
- """
- if options is None:
- options = []
- else:
- options = list(options)
- # Cannot increase optimization level on host due to code loading method.
- options.append("-O0")
- if sys.maxsize > 2 ** 32 and sys.platform.startswith("linux"):
- options += ["-mcmodel=large"]
- options.append("-DUTVM_TARGET_HOST")
- create_micro_lib_base(
- obj_path,
- src_path,
- TOOLCHAIN_PREFIX,
- DEVICE_ID,
- lib_type,
- options=options,
- lib_src_paths=lib_src_paths,
- )
-
-
-def generate_config(available_mem=None, section_constraints=None):
- """Generates a configuration for the host emulated device
-
- Parameters
- ----------
- available_mem: int
- number of RW bytes available for use on device
-
- section_constraints: Optional[Dict[str, Dict[Number, MemConstraint]]]
- maps section name to the quantity of available memory
-
- Return
- ------
- config : Dict[str, Any]
- MicroTVM config dict for this device
- """
- if available_mem is None:
- available_mem = DEFAULT_AVAILABLE_MEM
- if section_constraints is None:
- section_constraints = DEFAULT_SECTION_CONSTRAINTS
- mem_layout = gen_mem_layout(0, available_mem, WORD_SIZE_BITS, section_constraints)
- # TODO the host emulated device is an outlier, since we don't know how what
- # its base address will be until we've created it in the C++. is there any
- # way to change the infrastructure around this so it's not so much of an
- # outlier?
-
- # need to zero out all start addresses, because they don't make sense for a
- # host device (the memory region is allocated in the backend)
- for section in mem_layout:
- mem_layout[section]["start"] = 0
- return {
- "device_id": DEVICE_ID,
- "toolchain_prefix": TOOLCHAIN_PREFIX,
- "mem_layout": mem_layout,
- "word_size_bits": WORD_SIZE_BITS,
- "thumb_mode": False,
- "use_device_timer": False,
- "comms_method": "host",
- }
-
-
-register_device(
- DEVICE_ID,
- {
- "create_micro_lib": create_micro_lib,
- "generate_config": generate_config,
- },
-)
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Compilation and config definitions for Spike, a RISC-V functional ISA simulator"""
-
-from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint
-
-DEVICE_ID = "riscv_spike"
-TOOLCHAIN_PREFIX = "riscv64-unknown-elf-"
-WORD_SIZE_BITS = 64
-
-DEFAULT_SECTION_CONSTRAINTS = {
- "text": (18000, MemConstraint.ABSOLUTE_BYTES),
- "rodata": (128, MemConstraint.ABSOLUTE_BYTES),
- "data": (128, MemConstraint.ABSOLUTE_BYTES),
- "bss": (2048, MemConstraint.ABSOLUTE_BYTES),
- "args": (4096, MemConstraint.ABSOLUTE_BYTES),
- "heap": (100.0, MemConstraint.WEIGHT),
- "workspace": (64000, MemConstraint.ABSOLUTE_BYTES),
- "stack": (32, MemConstraint.ABSOLUTE_BYTES),
-}
-
-
-def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None):
- """Wrapper over `create_micro_lib_base` to add device-specific options
-
- Parameters
- ----------
- obj_path : str
- path to generated object file
-
- src_path : str
- path to source file
-
- lib_type : micro.LibType
- whether to compile a MicroTVM runtime or operator library
-
- options : Optional[List[str]]
- additional options to pass to GCC
-
- lib_src_paths : Optional[List[str]]
- TODO
- """
- create_micro_lib_base(
- obj_path,
- src_path,
- TOOLCHAIN_PREFIX,
- DEVICE_ID,
- lib_type,
- options=options,
- lib_src_paths=lib_src_paths,
- )
-
-
-def generate_config(base_addr, available_mem, server_addr, server_port, section_constraints=None):
- """Generates a configuration for Spike
-
- Parameters
- ----------
- base_addr : int
- base address of the simulator (for calculating the memory layout)
-
- server_addr : str
- address of OpenOCD server to connect to
-
- server_port : int
- port of OpenOCD server to connect to
-
- TODO correct type annotation?
- section_constraints: Optional[Dict[str, Tuple[Number, MemConstraint]]]
- TODO
-
- Return
- ------
- config : Dict[str, Any]
- MicroTVM config dict for this device
- """
- if section_constraints is None:
- section_constraints = DEFAULT_SECTION_CONSTRAINTS
- return {
- "device_id": DEVICE_ID,
- "toolchain_prefix": TOOLCHAIN_PREFIX,
- "mem_layout": gen_mem_layout(base_addr, available_mem, WORD_SIZE_BITS, section_constraints),
- "word_size_bits": WORD_SIZE_BITS,
- "thumb_mode": False,
- "use_device_timer": False,
- "comms_method": "openocd",
- "server_addr": server_addr,
- "server_port": server_port,
- }
-
-
-register_device(
- DEVICE_ID,
- {
- "create_micro_lib": create_micro_lib,
- "generate_config": generate_config,
- },
-)
lines.append("static TVMBackendPackedCFunc funcs[] = {")
for f in funcs:
- lines.append(f" &{f},")
+ lines.append(f" (TVMBackendPackedCFunc) &{f},")
lines += [
"};",
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines an Artifact implementation for representing compiled micro TVM binaries."""
+
+from . import artifact
+
+
+class MicroBinary(artifact.Artifact):
+ """An Artifact that describes a compiled binary."""
+
+ ARTIFACT_TYPE = 'micro_binary'
+
+ @classmethod
+ def from_unarchived(cls, base_dir, labelled_files, metadata):
+ binary_file = labelled_files['binary_file'][0]
+ del labelled_files['binary_file']
+
+ debug_files = None
+ if 'debug_files' in labelled_files:
+ debug_files = labelled_files['debug_files']
+ del labelled_files['debug_files']
+
+ return cls(base_dir, binary_file, debug_files=debug_files, labelled_files=labelled_files,
+ metadata=metadata)
+
+ def __init__(self, base_dir, binary_file, debug_files=None, labelled_files=None, metadata=None):
+ labelled_files = {} if labelled_files is None else dict(labelled_files)
+ metadata = {} if metadata is None else dict(metadata)
+ labelled_files['binary_file'] = [binary_file]
+ if debug_files is not None:
+ labelled_files['debug_files'] = debug_files
+
+ super(MicroBinary, self).__init__(base_dir, labelled_files, metadata)
+
+ self.binary_file = binary_file
+ self.debug_files = debug_files
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines an Artifact subclass that describes a compiled static library."""
+
+from tvm.contrib import util
+from . import artifact
+from . import compiler
+
+
+class MicroLibrary(artifact.Artifact):
+ """An Artifact that describes a compiled static library."""
+
+ ARTIFACT_TYPE = 'micro_library'
+
+ @classmethod
+ def from_unarchived(cls, base_dir, labelled_files, metadata):
+ library_files = labelled_files['library_files']
+ del labelled_files['library_files']
+
+ debug_files = None
+ if 'debug_files' in labelled_files:
+ debug_files = labelled_files['debug_files']
+ del labelled_files['debug_files']
+
+ return cls(base_dir, library_files, debug_files=debug_files, labelled_files=labelled_files,
+ metadata=metadata)
+
+ def __init__(self, base_dir, library_files, debug_files=None, labelled_files=None,
+ metadata=None):
+ labelled_files = {} if labelled_files is None else dict(labelled_files)
+ metadata = {} if metadata is None else dict(metadata)
+ labelled_files['library_files'] = library_files
+ if debug_files is not None:
+ labelled_files['debug_files'] = debug_files
+
+ super(MicroLibrary, self).__init__(base_dir, labelled_files, metadata)
+
+ self.library_files = library_files
+ self.debug_file = debug_files
+
+
+def create_micro_library(output, objects, options=None):
+ """Create a MicroLibrary using the default compiler options.
+
+ Parameters
+ ----------
+ output : str
+ Path to the output file, expected to end in .tar.
+ objects : List[str]
+ Paths to the source files to include in the library.
+ options : Optional[List[str]]
+ If given, additional command-line flags for the compiler.
+ """
+ temp_dir = util.tempdir()
+ comp = compiler.DefaultCompiler()
+ output = temp_dir.relpath('micro-library.o')
+ comp.library(output, objects, options=options)
+
+ with open(output, 'rb') as output_f:
+ elf_data = output_f.read()
+
+ # TODO(areusch): Define a mechanism to determine compiler and linker flags for each lib
+ # enabled by the target str, and embed here.
+ micro_lib = MicroLibrary('', elf_data, {'target': comp.target.str()})
+ micro_lib.save(output)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines a top-level glue class that operates the Transport and Flasher classes."""
+
+import logging
+import time
+
+from .._ffi import get_global_func
+from ..contrib import graph_runtime
+from .base import _rpc_connect
+from ..rpc import RPCSession
+from .transport import TransportLogger
+
+
+class Session:
+ """MicroTVM Device Session
+
+ Parameters
+ ----------
+ config : dict
+ configuration for this session (as generated by
+ `tvm.micro.device.host.default_config()`, for example)
+
+ Example
+ --------
+ .. code-block:: python
+
+ c_mod = ... # some module generated with "c" as the target
+ dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666)
+ with tvm.micro.Session(dev_config) as sess:
+ micro_mod = sess.create_micro_mod(c_mod)
+ """
+
+ def __init__(self, binary=None, flasher=None, transport_context_manager=None,
+ session_name='micro-rpc'):
+ """Configure a new session.
+
+ Parameters
+ ----------
+ binary : MicroBinary
+ If given, `flasher` must also be given. During session initialization, this binary will
+ be flashed to the device before the transport is created.
+ flasher : Flasher
+ If given, `binary` must also be given. Used to flash `binary` during session
+ initialization.
+ transport_context_manager : ContextManager[transport.Transport]
+ If given, `flasher` and `binary` should not be given. On entry, this context manager
+ should establish a tarnsport between this TVM instance and the device.
+ session_name : str
+ Name of the session, used for debugging.
+ """
+ self.binary = binary
+ self.flasher = flasher
+ self.transport_context_manager = transport_context_manager
+ self.session_name = session_name
+
+ self._rpc = None
+ self._graph_runtime = None
+
+ def get_system_lib(self):
+ return self._rpc.get_function('runtime.SystemLib')()
+
+ def __enter__(self):
+ """Initialize this session and establish an RPC session with the on-device RPC server.
+
+ Returns
+ -------
+ Session :
+ Returns self.
+ """
+ if self.flasher is not None:
+ self.transport_context_manager = self.flasher.flash(self.binary)
+ time.sleep(3.0)
+
+ self.transport = TransportLogger(
+ self.session_name, self.transport_context_manager, level=logging.INFO).__enter__()
+ self._rpc = RPCSession(_rpc_connect(
+ self.session_name, self.transport.write, self.transport.read))
+ self.context = self._rpc.cpu(0)
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ """Tear down this session and associated RPC session resources."""
+ self.transport.__exit__(exc_type, exc_value, exc_traceback)
+
+
+def create_local_graph_runtime(graph_json_str, mod, ctx):
+ """Create a local graph runtime driving execution on the remote CPU context given.
+
+ Parameters
+ ----------
+ graph_json_str : str
+ A string containing the graph representation.
+
+ mod : tvm.runtime.Module
+ The remote module containing functions in graph_json_str.
+
+ ctx : tvm.Context
+ The remote CPU execution context.
+
+ Returns
+ -------
+ tvm.contrib.GraphRuntime :
+ A local graph runtime instance that executes on the remote device.
+ """
+ device_type_id = [ctx.device_type, ctx.device_id]
+ fcreate = get_global_func("tvm.graph_runtime.create")
+ return graph_runtime.GraphModule(fcreate(
+ graph_json_str, mod, *device_type_id))
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines abstractions and implementations of the RPC transport used with micro TVM."""
+
+import abc
+import logging
+import string
+import subprocess
+import typing
+
+import tvm
+
+_LOG = logging.getLogger(__name__)
+
+
+@tvm.error.register_error
+class SessionTerminatedError(Exception):
+ """Raised when a transport read operationd discovers that the remote session is terminated."""
+
+
+class Transport(metaclass=abc.ABCMeta):
+ """The abstract Transport class used for micro TVM."""
+
+ def __enter__(self):
+ self.open()
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ self.close()
+
+ @abc.abstractmethod
+ def open(self):
+ """Open any resources needed to send and receive RPC protocol data for a single session."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def close(self):
+ """Release resources associated with this transport."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def read(self, n):
+ """Read up to n bytes from the transport.
+
+ Parameters
+ ----------
+ n : int
+ Maximum number of bytes to read from the transport.
+
+ Returns
+ -------
+ bytes :
+ Data read from the channel. Less than `n` bytes may be returned, but 0 bytes should
+ never be returned except in error. Note that if a transport error occurs, an Exception
+ should be raised rather than simply returning empty bytes.
+
+
+ Raises
+ ------
+ SessionTerminatedError :
+ When the transport layer determines that the active session was terminated by the
+ remote side. Typically this indicates that the remote device has reset.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def write(self, data):
+ """Write data to the transport channel.
+
+ Parameters
+ ----------
+ data : bytes
+ The data to write over the channel.
+
+ Returns
+ -------
+ int :
+ The number of bytes written to the underlying channel. This can be less than the length
+ of `data`, but cannot be 0.
+ """
+ raise NotImplementedError()
+
+
+class TransportLogger(Transport):
+ """Wraps a Transport implementation and logs traffic to the Python logging infrastructure."""
+
+ def __init__(self, name, child, logger=None, level=logging.INFO):
+ self.name = name
+ self.child = child
+ self.logger = logger or _LOG
+ self.level = level
+
+ # Construct PRINTABLE to exclude whitespace from string.printable.
+ PRINTABLE = (string.digits + string.ascii_letters + string.punctuation)
+
+ @classmethod
+ def _to_hex(cls, data):
+ lines = []
+ if not data:
+ lines.append('')
+ return lines
+
+ for i in range(0, (len(data) + 15) // 16):
+ chunk = data[i * 16:(i + 1) * 16]
+ hex_chunk = ' '.join(f'{c:02x}' for c in chunk)
+ ascii_chunk = ''.join((chr(c) if chr(c) in cls.PRINTABLE else '.') for c in chunk)
+ lines.append(f'{i * 16:04x} {hex_chunk:47} {ascii_chunk}')
+
+ if len(lines) == 1:
+ lines[0] = lines[0][6:]
+
+ return lines
+
+ def open(self):
+ self.logger.log(self.level, 'opening transport')
+ self.child.open()
+
+ def close(self):
+ self.logger.log(self.level, 'closing transport')
+ return self.child.close()
+
+ def read(self, n):
+ data = self.child.read(n)
+ hex_lines = self._to_hex(data)
+ if len(hex_lines) > 1:
+ self.logger.log(self.level, '%s read %4d B -> [%d B]:\n%s',
+ self.name, n, len(data), '\n'.join(hex_lines))
+ else:
+ self.logger.log(self.level, '%s read %4d B -> [%d B]: %s',
+ self.name, n, len(data), hex_lines[0])
+
+ return data
+
+ def write(self, data):
+ bytes_written = self.child.write(data)
+ hex_lines = self._to_hex(data[:bytes_written])
+ if len(hex_lines) > 1:
+ self.logger.log(self.level, '%s write <- [%d B]:\n%s',
+ self.name, bytes_written, '\n'.join(hex_lines))
+ else:
+ self.logger.log(self.level, '%s write <- [%d B]: %s',
+ self.name, bytes_written, hex_lines[0])
+
+ return bytes_written
+
+
+class SubprocessTransport(Transport):
+ """A Transport implementation that uses a subprocess's stdin/stdout as the channel."""
+
+ def __init__(self, args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+ self.popen = None
+
+ def open(self):
+ self.kwargs['stdout'] = subprocess.PIPE
+ self.kwargs['stdin'] = subprocess.PIPE
+ self.kwargs['bufsize'] = 0
+ self.popen = subprocess.Popen(self.args, **self.kwargs)
+ self.stdin = self.popen.stdin
+ self.stdout = self.popen.stdout
+
+ def write(self, data):
+ to_return = self.stdin.write(data)
+ self.stdin.flush()
+
+ return to_return
+
+ def read(self, n):
+ return self.stdout.read(n)
+
+ def close(self):
+ self.stdin.close()
+ self.stdout.close()
+ self.popen.terminate()
+
+
+class DebugWrapperTransport(Transport):
+ """A Transport wrapper class that launches a debugger before opening the transport.
+
+ This is primiarly useful when debugging the other end of a SubprocessTransport. It allows you
+ to pipe data through the GDB process to drive the subprocess with a debugger attached.
+ """
+
+ def __init__(self, debugger, transport):
+ self.debugger = debugger
+ self.transport = transport
+ self.debugger.on_terminate_callbacks.append(self.transport.close)
+
+ def open(self):
+ self.debugger.Start()
+
+ try:
+ self.transport.open()
+ except Exception:
+ self.debugger.Stop()
+ raise
+
+ def write(self, data):
+ return self.transport.write(data)
+
+ def read(self, n):
+ return self.transport.read(n)
+
+ def close(self):
+ self.transport.close()
+ self.debugger.Stop()
+
+
+TransportContextManager = typing.ContextManager[Transport]
"""
curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", ".."))
-
- path = os.path.join(source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server))
+ minrpc_dir = os.path.join(source_dir, "src", "runtime", "minrpc")
+ path = os.path.join(minrpc_dir, server, ("%s.cc" % server))
candidates = [path]
if not os.path.isfile(path):
raise RuntimeError("Cannot find minserver %s, in candidates %s" % (server, candidates))
- return path
+ return minrpc_dir, path
def with_minrpc(compile_func, server="posix_popen_server", runtime="libtvm"):
fcompile : function
The return compilation.
"""
- server_path = find_minrpc_server_libpath(server)
+ minrpc_dir, server_path = find_minrpc_server_libpath(server)
runtime_path = libinfo.find_lib_path([runtime, runtime + ".so", runtime + ".dylib"])[0]
runtime_dir = os.path.abspath(os.path.dirname(runtime_path))
# Always recommend to to link statically.
options += ["-Wl,-rpath=" + runtime_dir]
options += ["-I" + path for path in libinfo.find_include_path()]
+ options += ["-I" + minrpc_dir]
fcompile = cc.cross_compiler(
compile_func, options=options, add_files=[server_path, runtime_path]
)
opts = _merge_opts(opts, options)
return Target(" ".join(["opencl"] + opts))
+def micro(hardware="unknown", options=None):
+ """Returns a microTVM target.
+
+ Parameters
+ ----------
+ hardware : str
+ Canonically identifies the target device; typicaly one of cortex-mX, or a specific SoC model
+ when that model has been tested to work with microTVM.
+ options : str or list of str
+ Additional options
+ """
+ trans_table = {
+ "host": ["-mcpu=native"],
+ }
+ opts = _merge_opts(trans_table[hardware] + ["-runtime=c", "--system-lib"], options)
+
+ # NOTE: in the future, the default micro target will be LLVM except when
+ # external dependencies are present.
+ return Target(" ".join(["c"] + opts))
def arm_cpu(model="unknown", options=None):
"""Returns a ARM CPU target.
# specific language governing permissions and limitations
# under the License.
+# NOTE: Although this Makefile contains build commands for the C runtime, it isn't intended to be
+# used directly in the TVM source tree. Instead, build the "standalone_crt" target, which produces a
+# directory tree suitable for this Makefile. If this Makefile looks like it's the top-level of a
+# source tree, you can probably ignore this message.
+
+# NOTE: If files appear to be missing in the generated standalone_crt target, consult the copy job
+# specs listed in the TVM repo in cmake/modules/StandaloneCrt.cmake.
+
ifeq ($(CRT_CONFIG),)
$(error "Must supply path to crt_config.h: CRT_CONFIG=...")
endif
-DLPACK_INCLUDE_DIR ?= ../../../3rdparty/dlpack/include
-TVM_INCLUDE_DIR ?= ../../../include
+
+ifneq ($(wildcard .gitignore),)
+$(error "detected building inside tvm source tree.")
+$(error "build the standalone_crt target, and re-invoke makefile in build/standalone_crt")
+endif
BUILD_DIR ?= build
PREFIX ?=
AR ?= ${PREFIX}ar
CC ?= ${PREFIX}gcc
+CXX ?= ${PREFIX}g++
RANLIB ?= ${PREFIX}ranlib
QUIET ?= @
-CFLAGS += -isystem "${TVM_INCLUDE_DIR}" -isystem "${DLPACK_INCLUDE_DIR}" -I include -I $(dir ${CRT_CONFIG})
-CFLAGS += -Werror -g $(EXTRA_CFLAGS)
+CRT_PREFIX = $(wildcard src/crt)
+
+INCLUDES ?= -isystem include -iquote $(dir ${CRT_CONFIG})
+CFLAGS += ${INCLUDES} -Werror -g $(EXTRA_CFLAGS)
+CXXFLAGS += ${INCLUDES} -std=c++11 -Werror -g $(EXTRA_CXXFLAGS)
LDFLAGS += -Werror -g $(EXTRA_LDFLAGS)
-${BUILD_DIR}/%.o: %.c
+${BUILD_DIR}/%.o: src/%.c $(CRT_CONFIG)
${QUIET}mkdir -p $(dir $@)
${QUIET}${CC} ${CFLAGS} -c -o "$@" "$<"
-${BUILD_DIR}/common/libcommon.a: $(patsubst %.c,${BUILD_DIR}/%.o,$(wildcard common/*.c))
- ${QUIET}${AR} -cr "$@" $^
- ${QUIET}${RANLIB} ${RANLIBFLAGS} "$@"
+${BUILD_DIR}/%.o: src/%.cc $(CRT_CONFIG)
+ ${QUIET}mkdir -p $(dir $@)
+ ${QUIET}${CXX} ${CXXFLAGS} -c -o "$@" "$<"
+
+define LIB_template
+$${BUILD_DIR}/lib$(notdir $(1)).a: $$(patsubst src/%.c,$${BUILD_DIR}/%.o,$$(wildcard src/$(1:src/%=%)/*.c)) $$(patsubst src/%.cc,${BUILD_DIR}/%.o,$$(wildcard src/$(1:src/%=%)/*.cc))
+ $${QUIET}$${AR} -cr "$$@" $$^
+ $${QUIET}$${RANLIB} $${RANLIBFLAGS} "$$@"
+$(notdir $(1)): $${BUILD_DIR}/lib$(notdir $(1)).a
+
+endef
-${BUILD_DIR}/graph_runtime/libgraph_runtime.a: $(patsubst %.c,${BUILD_DIR}/%.o,$(wildcard graph_runtime/*.c))
- ${QUIET}${AR} -cr "$@" $^
- ${QUIET}${RANLIB} ${RANLIBFLAGS} "$@"
+LIBS = src/runtime/crt/common src/runtime/crt/graph_runtime src/runtime/crt/utvm_rpc_common src/runtime/crt/utvm_rpc_server
-common: ${BUILD_DIR}/common/libcommon.a
-graph_runtime: ${BUILD_DIR}/graph_runtime/libgraph_runtime.a
+$(foreach lib,$(LIBS),$(eval $(call LIB_template,$(lib))))
-all: common graph_runtime
+all: $(notdir $(LIBS))
clean:
rm -rf "${BUILD_DIR}"
-.PHONY: all common graph_runtime
+.PHONY: all $(notdir $(LIBS))
.DEFAULT_GOAL: all
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/crt/crt.h>
#include <tvm/runtime/crt/func_registry.h>
+#include <tvm/runtime/crt/internal/common/memory.h>
#include <tvm/runtime/crt/internal/common/ndarray.h>
#include <tvm/runtime/crt/internal/graph_runtime/graph_runtime.h>
#include <tvm/runtime/crt/memory.h>
return 0;
}
-tvm_crt_error_t TVMInitializeRuntime() {
+tvm_crt_error_t TVMInitializeRuntime(uint8_t* memory_pool, size_t memory_pool_size_bytes,
+ size_t page_size_bytes_log2) {
int idx;
- int error;
+ tvm_crt_error_t error;
+
+ error =
+ TVMInitializeGlobalMemoryManager(memory_pool, memory_pool_size_bytes, page_size_bytes_log2);
+ if (error != kTvmErrorNoError) {
+ return error;
+ }
system_lib_handle = kTVMModuleHandleUninitialized;
}
error = TVMFuncRegisterGlobal("runtime.SystemLib", &SystemLibraryCreate, 0);
- if (error != 0) {
+ if (error != kTvmErrorNoError) {
return error;
}
error = TVMFuncRegisterGlobal("tvm.rpc.server.ModuleGetFunction", &ModuleGetFunction, 0);
- if (error != 0) {
+ if (error != kTvmErrorNoError) {
return error;
}
- return 0;
+ return kTvmErrorNoError;
}
#include <stdlib.h>
#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/crt/internal/common/logging.h>
+#include <tvm/runtime/crt/error_codes.h>
#include <tvm/runtime/crt/internal/common/memory.h>
+#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/memory.h>
#include <tvm/runtime/crt/platform.h>
-/**
- * \brief Memory pool for virtual dynamic memory allocation
- */
-static uint8_t g_memory_pool[TVM_CRT_VIRT_MEM_SIZE];
-
// construct a new page
Page PageCreate(uint8_t* memory_pool, size_t page_size_bytes, tvm_index_t ptable_begin,
tvm_index_t num_pages) {
} else {
start = ptable->num_pages;
CHECK_LE((unsigned)(start + npage), ptable->max_pages,
- "insufficient memory, start=%" PRId64 ", npage=%" PRId64 ", total=%" PRId64 "", start,
- npage, start + npage);
+ "insufficient memory, start=%" PRId32 ", npage=%" PRId32 ", total=%" PRId32 " / %zu",
+ (int32_t)start, (int32_t)npage, (int32_t)(start + npage), mgr->pmap.max_pages);
/* insert page entry */
Page p = PageCreate(ptable->memory_pool, ptable->page_size_bytes, start, npage);
ptable->resize(ptable, start + npage, &p);
#define ROUND_UP(qty, modulo) (((qty) + ((modulo)-1)) / (modulo) * (modulo))
+static bool g_memory_manager_initialized = 0;
+static MemoryManager g_memory_manager;
+
void MemoryManagerCreate(MemoryManager* manager, uint8_t* memory_pool,
size_t memory_pool_size_bytes, size_t page_size_bytes_log2) {
memset(manager, 0, sizeof(MemoryManager));
manager->free_map.insert = MultiMap_Insert;
}
-MemoryManager* TVMGetGlobalMemoryManager() {
- /* initialize once */
- static uint32_t initialized = 0;
- static MemoryManager mgr;
- if (!initialized) {
- memset(g_memory_pool, 0, sizeof(g_memory_pool));
- MemoryManagerCreate(&mgr, g_memory_pool, TVM_CRT_VIRT_MEM_SIZE, TVM_CRT_PAGE_BYTES_LOG);
- initialized = 1;
+tvm_crt_error_t TVMInitializeGlobalMemoryManager(uint8_t* memory_pool,
+ size_t memory_pool_size_bytes,
+ size_t page_size_bytes_log2) {
+ if (g_memory_manager_initialized) {
+ return kTvmErrorPlatformMemoryManagerInitialized;
}
- return &mgr;
+
+ MemoryManagerCreate(&g_memory_manager, memory_pool, memory_pool_size_bytes, page_size_bytes_log2);
+
+ g_memory_manager_initialized = true;
+ return kTvmErrorNoError;
+}
+
+MemoryManager* TVMGetGlobalMemoryManager() {
+ CHECK(g_memory_manager_initialized);
+ return &g_memory_manager;
}
/** \brief Allocate memory from manager */
*/
#include <stdio.h>
#include <string.h>
-#include <tvm/runtime/crt/internal/common/logging.h>
+#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/packed_func.h>
DLDataType String2DLDataType(const char* s) {
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/crt_config.h.template
+ * \brief Template for CRT configuration, to be modified on each target.
+ */
+#ifndef TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_
+#define TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_
+
+/*! Maximum supported dimension in NDArray */
+#define TVM_CRT_MAX_NDIM 6
+
+/*! Maximum supported arguments in generated functions */
+#define TVM_CRT_MAX_ARGS 10
+
+/*! Size of the global function registry, in bytes. */
+#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
+
+/*! Maximum number of registered modules. */
+#define TVM_CRT_MAX_REGISTERED_MODULES 2
+
+/*! Maximum packet size, in bytes, including the length header. */
+#define TVM_CRT_MAX_PACKET_SIZE_BYTES 2048
+
+/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */
+#define TVM_CRT_MAX_STRLEN_DLTYPE 10
+
+/*! Maximum supported string length in function names */
+#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80
+
+/*! \brief Maximum length of a PackedFunc function name. */
+#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30
+
+/*! \brief DLDataType for the return value from strlen */
+#define TVM_CRT_STRLEN_DLTYPE 10
+
+#endif // TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_
*/
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/crt/internal/common/logging.h>
#include <tvm/runtime/crt/internal/graph_runtime/graph_runtime.h>
+#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/memory.h>
#include <tvm/runtime/crt/module.h>
#include <tvm/runtime/crt/packed_func.h>
#ifndef TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_
#define TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_
+/*! Log level of the CRT runtime */
+#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG
+
/*! Support low-level debugging in MISRA-C runtime */
#define TVM_CRT_DEBUG 0
/*! Maximum supported string length in function names */
#define TVM_CRT_STRLEN_NAME 80
-/*!
- * \brief Log memory pool size for virtual memory allocation
- *
- * Here is a list of possible choices:
- * * use 16 for 64 KiB memory space
- * * use 17 for 128 KiB memory space
- * * use 18 for 256 KiB memory space
- * * use 19 for 512 KiB memory space
- * * use 20 for 1 MiB memory space
- * * use 21 for 2 MiB memory space
- * * use 22 for 4 MiB memory space
- * * use 23 for 8 MiB memory space
- * * use 24 for 16 MiB memory space
- * * use 25 for 32 MiB memory space
- * * use 26 for 64 MiB memory space
- * * use 27 for 128 MiB memory space
- * * use 28 for 256 MiB memory space
- */
-#define TVM_CRT_LOG_VIRT_MEM_SIZE 24
-
-/*! \brief Log2 of page size for virtual memory allocation */
-#define TVM_CRT_PAGE_BYTES_LOG 12
-
/*! Maximum number of registered modules. */
#define TVM_CRT_MAX_REGISTERED_MODULES 2
/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
+/*! Maximum packet size, in bytes, including the length header. */
+#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
+
+/*! \brief Maximum length of a PackedFunc function name. */
+#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30
+
+// #define TVM_CRT_FRAMER_ENABLE_LOGS
+
#endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file main.cc
+ * \brief main entry point for host subprocess-based CRT
+ */
+#include <inttypes.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/crt/logging.h>
+#include <tvm/runtime/crt/utvm_rpc_server.h>
+#include <unistd.h>
+
+#include <chrono>
+#include <iostream>
+
+#include "crt_config.h"
+
+using namespace std::chrono;
+
+extern "C" {
+
+ssize_t UTvmWriteFunc(void* context, const uint8_t* data, size_t num_bytes) {
+ ssize_t to_return = write(STDOUT_FILENO, data, num_bytes);
+ fflush(stdout);
+ fsync(STDOUT_FILENO);
+ return to_return;
+}
+
+void TVMPlatformAbort(tvm_crt_error_t error_code) {
+ std::cerr << "TVMPlatformAbort: " << error_code << std::endl;
+ throw "Aborted";
+}
+
+high_resolution_clock::time_point g_utvm_start_time;
+int g_utvm_timer_running = 0;
+
+int TVMPlatformTimerStart() {
+ if (g_utvm_timer_running) {
+ std::cerr << "timer already running" << std::endl;
+ return -1;
+ }
+ g_utvm_start_time = high_resolution_clock::now();
+ g_utvm_timer_running = 1;
+ return 0;
+}
+
+int TVMPlatformTimerStop(double* res_us) {
+ if (!g_utvm_timer_running) {
+ std::cerr << "timer not running" << std::endl;
+ return -1;
+ }
+ auto utvm_stop_time = high_resolution_clock::now();
+ duration<double, std::micro> time_span(utvm_stop_time - g_utvm_start_time);
+ *res_us = time_span.count();
+ g_utvm_timer_running = 0;
+ return 0;
+}
+}
+
+uint8_t memory[512 * 1024];
+
+static char** g_argv = NULL;
+
+int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_value,
+ int* out_ret_tcode, void* resource_handle) {
+ execvp(g_argv[0], g_argv);
+ perror("utvm runtime: error restarting");
+ return -1;
+}
+
+int main(int argc, char** argv) {
+ g_argv = argv;
+ utvm_rpc_server_t rpc_server =
+ UTvmRpcServerInit(memory, sizeof(memory), 8, &UTvmWriteFunc, nullptr);
+
+ if (TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server,
+ 0)) {
+ fprintf(stderr, "utvm runtime: internal error registering global packedfunc; exiting\n");
+ return 2;
+ }
+
+ setbuf(stdin, NULL);
+ setbuf(stdout, NULL);
+
+ for (;;) {
+ uint8_t c;
+ int ret_code = read(STDIN_FILENO, &c, 1);
+ if (ret_code < 0) {
+ perror("utvm runtime: read failed");
+ return 2;
+ } else if (ret_code == 0) {
+ fprintf(stderr, "utvm runtime: 0-length read, exiting!\n");
+ return 2;
+ }
+ if (UTvmRpcServerReceiveByte(rpc_server, c) != 1) {
+ abort();
+ }
+ if (!UTvmRpcServerLoop(rpc_server)) {
+ execvp(argv[0], argv);
+ perror("utvm runtime: error restarting");
+ return 2;
+ }
+ }
+ return 0;
+}
#define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_MEMORY_H_
#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/crt/error_codes.h>
#include "crt_config.h"
extern "C" {
#endif
-/*! Number of bits in a page */
-#define TVM_CRT_PAGE_BITS ((1 << TVM_CRT_PAGE_BYTES_LOG) << 3)
-
-/*! \brief Translate log memory size into bytes */
-#define TVM_CRT_VIRT_MEM_SIZE (1 << TVM_CRT_LOG_VIRT_MEM_SIZE)
-
-/*! \brief Number of possible page entries in total */
-#define TVM_CRT_MAX_PAGES (TVM_CRT_VIRT_MEM_SIZE / TVM_CRT_PAGE_BYTES)
-
/*! \brief A page in the DRAM */
typedef struct Page {
/*! \brief Start location in page table */
MultiMap free_map;
} MemoryManager;
-// Exposed for testing
+/*!
+ * Exposed for testing.
+ *
+ * \param manager The memory manager to initialize.
+ * \param memory_pool Pointer to the global memory pool used by the CRT.
+ * \param memory_pool_size_bytes Size of `memory_pool`, in bytes.
+ * \param page_size_bytes_log2 log2 of the page size, in bytes.
+ */
void MemoryManagerCreate(MemoryManager* manager, uint8_t* memory_pool,
size_t memory_pool_size_bytes, size_t page_size_bytes_log2);
+/*!
+ * Initialize the global memory manager.
+ *
+ * Call this function once before invoking any other CRT functions beginning with `TVM`.
+ * Repeated calls will cause TVMPlatformAbort to be invoked.
+ * \param memory_pool Pointer to the global memory pool used by the CRT.
+ * \param memory_pool_size_bytes Size of `memory_pool`, in bytes.
+ * \param page_size_bytes_log2 log2 of the page size, in bytes.
+ * \return An error code indicating the status of the operation.
+ */
+tvm_crt_error_t TVMInitializeGlobalMemoryManager(uint8_t* memory_pool,
+ size_t memory_pool_size_bytes,
+ size_t page_size_bytes_log2);
+
#ifdef __cplusplus
} // extern "C"
#endif
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file frame_buffer.cc
+ * \brief Defines a buffer for use by the RPC framing layer.
+ */
+
+#include <stdio.h>
+#include <string.h>
+#include <tvm/runtime/crt/rpc_common/frame_buffer.h>
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+size_t FrameBuffer::Write(const uint8_t* data, size_t data_size_bytes) {
+ size_t num_bytes_available = capacity_ - num_valid_bytes_;
+ size_t num_bytes_to_copy = data_size_bytes;
+ if (num_bytes_available < num_bytes_to_copy) {
+ num_bytes_to_copy = num_bytes_available;
+ }
+
+ memcpy(&data_[num_valid_bytes_], data, num_bytes_to_copy);
+ num_valid_bytes_ += num_bytes_to_copy;
+ return num_bytes_to_copy;
+}
+
+size_t FrameBuffer::Read(uint8_t* data, size_t data_size_bytes) {
+ size_t num_bytes_to_copy = data_size_bytes;
+ size_t num_bytes_available = num_valid_bytes_ - read_cursor_;
+ if (num_bytes_available < num_bytes_to_copy) {
+ num_bytes_to_copy = num_bytes_available;
+ }
+
+ memcpy(data, &data_[read_cursor_], num_bytes_to_copy);
+ read_cursor_ += num_bytes_to_copy;
+ return num_bytes_to_copy;
+}
+
+void FrameBuffer::Clear() {
+ num_valid_bytes_ = 0;
+ read_cursor_ = 0;
+}
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file framing.cc
+ * \brief Framing for RPC.
+ */
+
+#include <string.h>
+#include <tvm/runtime/crt/logging.h>
+#include <tvm/runtime/crt/rpc_common/framing.h>
+
+#include "crt_config.h"
+
+// For debugging purposes, Framer logs can be enabled, but this should only be done when
+// running from the host. This is done differently from TVMLogf() because TVMLogf() uses the
+// framer in its implementation.
+#ifdef TVM_CRT_FRAMER_ENABLE_LOGS
+#include <cstdio>
+#define TVM_FRAMER_DEBUG_LOG(msg, ...) fprintf(stderr, "utvm framer: " msg " \n", ##__VA_ARGS__)
+#define TVM_UNFRAMER_DEBUG_LOG(msg, ...) fprintf(stderr, "utvm unframer: " msg " \n", ##__VA_ARGS__)
+#else
+#define TVM_FRAMER_DEBUG_LOG(msg, ...)
+#define TVM_UNFRAMER_DEBUG_LOG(msg, ...)
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+template <typename E>
+static constexpr uint8_t to_integral(E e) {
+ return static_cast<uint8_t>(e);
+}
+
+void Unframer::Reset() {
+ state_ = State::kFindPacketStart;
+ saw_escape_start_ = false;
+ num_buffer_bytes_valid_ = 0;
+}
+
+tvm_crt_error_t Unframer::Write(const uint8_t* data, size_t data_size_bytes,
+ size_t* bytes_consumed) {
+ tvm_crt_error_t return_code = kTvmErrorNoError;
+ input_ = data;
+ input_size_bytes_ = data_size_bytes;
+
+ while (return_code == kTvmErrorNoError && input_size_bytes_ > 0) {
+ TVM_UNFRAMER_DEBUG_LOG("state: %02x size 0x%02zx", to_integral(state_), input_size_bytes_);
+ switch (state_) {
+ case State::kFindPacketStart:
+ return_code = FindPacketStart();
+ break;
+ case State::kFindPacketLength:
+ return_code = FindPacketLength();
+ break;
+ case State::kFindPacketCrc:
+ return_code = FindPacketCrc();
+ break;
+ case State::kFindCrcEnd:
+ return_code = FindCrcEnd();
+ break;
+ default:
+ return_code = kTvmErrorFramingInvalidState;
+ break;
+ }
+ }
+
+ *bytes_consumed = data_size_bytes - input_size_bytes_;
+ input_ = nullptr;
+ input_size_bytes_ = 0;
+
+ if (return_code != kTvmErrorNoError) {
+ state_ = State::kFindPacketStart;
+ ClearBuffer();
+ }
+
+ return return_code;
+}
+
+tvm_crt_error_t Unframer::FindPacketStart() {
+ size_t i;
+ for (i = 0; i < input_size_bytes_; ++i) {
+ if (input_[i] == to_integral(Escape::kEscapeStart)) {
+ saw_escape_start_ = true;
+ } else if (input_[i] == to_integral(Escape::kPacketStart) && saw_escape_start_) {
+ uint8_t packet_start_sequence[2]{to_integral(Escape::kEscapeStart),
+ to_integral(Escape::kPacketStart)};
+ crc_ = crc16_compute(packet_start_sequence, sizeof(packet_start_sequence), nullptr);
+ saw_escape_start_ = false;
+ state_ = State::kFindPacketLength;
+ i++;
+ break;
+ } else {
+ saw_escape_start_ = false;
+ }
+ }
+
+ input_ += i;
+ input_size_bytes_ -= i;
+ return kTvmErrorNoError;
+}
+
+tvm_crt_error_t Unframer::ConsumeInput(uint8_t* buffer, size_t buffer_size_bytes,
+ size_t* bytes_filled, bool update_crc) {
+ CHECK(*bytes_filled < buffer_size_bytes);
+ tvm_crt_error_t to_return = kTvmErrorNoError;
+ size_t i;
+ for (i = 0; i < input_size_bytes_; ++i) {
+ uint8_t c = input_[i];
+ if (saw_escape_start_) {
+ saw_escape_start_ = false;
+ if (c == to_integral(Escape::kPacketStart)) {
+ // When the start packet sequence is seen, abort unframing the current packet. Since the
+ // escape byte has already been parsed, update the CRC include only the escape byte. This
+ // readies the unframer to consume the kPacketStart byte on the next Write() call.
+ uint8_t escape_start = to_integral(Escape::kEscapeStart);
+ crc_ = crc16_compute(&escape_start, 1, NULL);
+ to_return = kTvmErrorFramingShortPacket;
+ saw_escape_start_ = true;
+
+ break;
+ } else if (c == to_integral(Escape::kEscapeNop)) {
+ continue;
+ } else if (c == to_integral(Escape::kEscapeStart)) {
+ // do nothing (allow character to be printed)
+ } else {
+ // Invalid escape sequence.
+ to_return = kTvmErrorFramingInvalidEscape;
+ i++;
+ break;
+ }
+ } else if (c == to_integral(Escape::kEscapeStart)) {
+ saw_escape_start_ = true;
+ continue;
+ } else {
+ saw_escape_start_ = false;
+ }
+
+ buffer[*bytes_filled] = c;
+ (*bytes_filled)++;
+ if (*bytes_filled == buffer_size_bytes) {
+ i++;
+ break;
+ }
+ }
+
+ if (update_crc) {
+ crc_ = crc16_compute(input_, i, &crc_);
+ }
+
+ input_ += i;
+ input_size_bytes_ -= i;
+ return to_return;
+}
+
+tvm_crt_error_t Unframer::AddToBuffer(size_t buffer_full_bytes, bool update_crc) {
+ CHECK(!IsBufferFull(buffer_full_bytes));
+ return ConsumeInput(buffer_, buffer_full_bytes, &num_buffer_bytes_valid_, update_crc);
+}
+
+void Unframer::ClearBuffer() { num_buffer_bytes_valid_ = 0; }
+
+tvm_crt_error_t Unframer::FindPacketLength() {
+ tvm_crt_error_t to_return = AddToBuffer(PacketFieldSizeBytes::kPayloadLength, true);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+
+ if (!IsBufferFull(PacketFieldSizeBytes::kPayloadLength)) {
+ return to_return;
+ }
+
+ num_payload_bytes_remaining_ = *reinterpret_cast<uint32_t*>(buffer_);
+ TVM_UNFRAMER_DEBUG_LOG("payload length: 0x%zx", num_payload_bytes_remaining_);
+ ClearBuffer();
+ state_ = State::kFindPacketCrc;
+ return to_return;
+}
+
+tvm_crt_error_t Unframer::FindPacketCrc() {
+ // CHECK(num_buffer_bytes_valid_ == 0);
+ while (num_payload_bytes_remaining_ > 0) {
+ size_t num_bytes_to_buffer = num_payload_bytes_remaining_;
+ if (num_bytes_to_buffer > sizeof(buffer_)) {
+ num_bytes_to_buffer = sizeof(buffer_);
+ }
+
+ // remember in case we need to rewind due to WriteAll() error.
+ size_t prev_input_size_bytes = input_size_bytes_;
+ size_t prev_num_buffer_bytes_valid = num_buffer_bytes_valid_;
+ {
+ tvm_crt_error_t to_return = AddToBuffer(num_bytes_to_buffer, true);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+ }
+
+ if (prev_num_buffer_bytes_valid == num_buffer_bytes_valid_) {
+ // Return if no bytes were consumed from the input.
+ return kTvmErrorNoError;
+ }
+
+ {
+ size_t bytes_consumed;
+ tvm_crt_error_t to_return =
+ stream_->WriteAll(buffer_, num_buffer_bytes_valid_, &bytes_consumed);
+ num_payload_bytes_remaining_ -= bytes_consumed;
+ if (to_return != kTvmErrorNoError) {
+ // rewind input, skipping escape bytes.
+ size_t buffer_bytes_consumed;
+ const uint8_t* input = input_ - (prev_input_size_bytes - input_size_bytes_);
+ for (buffer_bytes_consumed = 0; bytes_consumed > 0; ++buffer_bytes_consumed) {
+ if (input[buffer_bytes_consumed] != uint8_t(Escape::kEscapeStart)) {
+ bytes_consumed--;
+ }
+ }
+
+ size_t bytes_to_rewind = prev_input_size_bytes - buffer_bytes_consumed;
+ input_ -= bytes_to_rewind;
+ input_size_bytes_ += bytes_to_rewind;
+
+ // must not have seen escape, since AddToBuffer won't stop in the middle.
+ saw_escape_start_ = false;
+
+ return to_return;
+ }
+ }
+
+ ClearBuffer();
+ }
+
+ if (num_payload_bytes_remaining_ == 0) {
+ state_ = State::kFindCrcEnd;
+ }
+
+ return kTvmErrorNoError;
+}
+
+tvm_crt_error_t Unframer::FindCrcEnd() {
+ tvm_crt_error_t to_return = AddToBuffer(PacketFieldSizeBytes::kCrc, false);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+
+ if (!IsBufferFull(PacketFieldSizeBytes::kCrc)) {
+ return kTvmErrorNoError;
+ }
+
+ // TODO(areusch): Handle endianness.
+ stream_->PacketDone(crc_ == *reinterpret_cast<uint16_t*>(buffer_));
+ ClearBuffer();
+ state_ = State::kFindPacketStart;
+ return kTvmErrorNoError;
+}
+
+void Framer::Reset() { state_ = State::kReset; }
+
+tvm_crt_error_t Framer::Write(const uint8_t* payload, size_t payload_size_bytes) {
+ tvm_crt_error_t to_return;
+ to_return = StartPacket(payload_size_bytes);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+
+ to_return = WritePayloadChunk(payload, payload_size_bytes);
+ if (to_return != 0) {
+ return to_return;
+ }
+
+ to_return = FinishPacket();
+ return to_return;
+}
+
+tvm_crt_error_t Framer::StartPacket(size_t payload_size_bytes) {
+ uint8_t packet_header[sizeof(uint32_t)];
+ size_t ptr = 0;
+ if (state_ == State::kReset) {
+ packet_header[ptr] = to_integral(Escape::kEscapeNop);
+ ptr++;
+ tvm_crt_error_t to_return =
+ WriteAndCrc(packet_header, ptr, false /* escape */, false /* update_crc */);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+
+ ptr = 0;
+ }
+
+ packet_header[ptr] = to_integral(Escape::kEscapeStart);
+ ptr++;
+ packet_header[ptr] = to_integral(Escape::kPacketStart);
+ ptr++;
+
+ crc_ = 0xffff;
+ tvm_crt_error_t to_return =
+ WriteAndCrc(packet_header, ptr, false /* escape */, true /* update_crc */);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+
+ uint32_t payload_size_wire = payload_size_bytes;
+ to_return = WriteAndCrc(reinterpret_cast<uint8_t*>(&payload_size_wire), sizeof(payload_size_wire),
+ true /* escape */, true /* update_crc */);
+ if (to_return == kTvmErrorNoError) {
+ state_ = State::kTransmitPacketPayload;
+ num_payload_bytes_remaining_ = payload_size_bytes;
+ }
+
+ return to_return;
+}
+
+tvm_crt_error_t Framer::WriteAndCrc(const uint8_t* data, size_t data_size_bytes, bool escape,
+ bool update_crc) {
+ while (data_size_bytes > 0) {
+ uint8_t buffer[kMaxStackBufferSizeBytes];
+ size_t buffer_ptr = 0;
+ size_t i;
+ for (i = 0; i < data_size_bytes && buffer_ptr != kMaxStackBufferSizeBytes; ++i) {
+ uint8_t c = data[i];
+ if (!escape || c != to_integral(Escape::kEscapeStart)) {
+ buffer[buffer_ptr] = c;
+ buffer_ptr++;
+ continue;
+ }
+
+ if (buffer_ptr == kMaxStackBufferSizeBytes - 1) {
+ break;
+ }
+
+ buffer[buffer_ptr] = to_integral(Escape::kEscapeStart);
+ buffer_ptr++;
+
+ buffer[buffer_ptr] = to_integral(Escape::kEscapeStart);
+ buffer_ptr++;
+ }
+
+ size_t bytes_consumed;
+ tvm_crt_error_t to_return = stream_->WriteAll(buffer, buffer_ptr, &bytes_consumed);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+
+ if (update_crc) {
+ crc_ = crc16_compute(buffer, buffer_ptr, &crc_);
+ }
+
+ data_size_bytes -= i;
+ data += i;
+ }
+
+ return kTvmErrorNoError;
+}
+
+tvm_crt_error_t Framer::WritePayloadChunk(const uint8_t* payload_chunk,
+ size_t payload_chunk_size_bytes) {
+ if (state_ != State::kTransmitPacketPayload) {
+ return kTvmErrorFramingInvalidState;
+ } else if (payload_chunk_size_bytes > num_payload_bytes_remaining_) {
+ return kTvmErrorFramingPayloadOverflow;
+ }
+
+ TVM_FRAMER_DEBUG_LOG("write payload chunk: %" PRIuMAX " bytes", payload_chunk_size_bytes);
+ tvm_crt_error_t to_return = WriteAndCrc(payload_chunk, payload_chunk_size_bytes,
+ true /* escape */, true /* update_crc */);
+ if (to_return != kTvmErrorNoError) {
+ state_ = State::kReset;
+ return to_return;
+ }
+
+ num_payload_bytes_remaining_ -= payload_chunk_size_bytes;
+ return kTvmErrorNoError;
+}
+
+tvm_crt_error_t Framer::FinishPacket() {
+ if (state_ != State::kTransmitPacketPayload) {
+ return kTvmErrorFramingInvalidState;
+ } else if (num_payload_bytes_remaining_ != 0) {
+ return kTvmErrorFramingPayloadIncomplete;
+ }
+
+ tvm_crt_error_t to_return = WriteAndCrc(reinterpret_cast<uint8_t*>(&crc_), sizeof(crc_),
+ true /* escape */, false /* update_crc */);
+ if (to_return != kTvmErrorNoError) {
+ TVM_FRAMER_DEBUG_LOG("write and crc returned: %02x", to_return);
+ state_ = State::kReset;
+ } else {
+ state_ = State::kIdle;
+ }
+ return to_return;
+}
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file session.h
+ * \brief RPC Session
+ */
+
+#include <tvm/runtime/crt/logging.h>
+#include <tvm/runtime/crt/rpc_common/session.h>
+
+#include "crt_config.h"
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+struct utvm_session_start_payload_t {
+ uint8_t version;
+};
+
+void Session::RegenerateNonce() {
+ local_nonce_ = (((local_nonce_ << 5) | (local_nonce_ >> 5)) + 1);
+
+ if (local_nonce_ == kInvalidNonce) {
+ local_nonce_++;
+ }
+}
+
+tvm_crt_error_t Session::SendInternal(MessageType message_type, const uint8_t* message_data,
+ size_t message_size_bytes) {
+ tvm_crt_error_t to_return = StartMessage(message_type, message_size_bytes);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+
+ if (message_size_bytes > 0) {
+ to_return = SendBodyChunk(message_data, message_size_bytes);
+ if (to_return != kTvmErrorNoError) {
+ return to_return;
+ }
+ }
+
+ return framer_->FinishPacket();
+}
+
+tvm_crt_error_t Session::StartMessage(MessageType message_type, size_t message_size_bytes) {
+ SessionHeader header{session_id_, message_type};
+ if (message_type == MessageType::kLog) {
+ header.session_id = 0;
+ }
+
+ tvm_crt_error_t to_return = framer_->StartPacket(message_size_bytes + sizeof(SessionHeader));
+ if (to_return != 0) {
+ return to_return;
+ }
+
+ return framer_->WritePayloadChunk(reinterpret_cast<uint8_t*>(&header), sizeof(SessionHeader));
+}
+
+tvm_crt_error_t Session::SendBodyChunk(const uint8_t* chunk, size_t chunk_size_bytes) {
+ return framer_->WritePayloadChunk(chunk, chunk_size_bytes);
+}
+
+tvm_crt_error_t Session::FinishMessage() { return framer_->FinishPacket(); }
+
+tvm_crt_error_t Session::StartSession() {
+ CHECK_NE(state_, State::kReset, "must call Initialize");
+
+ RegenerateNonce();
+ SetSessionId(local_nonce_, 0);
+ utvm_session_start_payload_t payload = {Session::kVersion};
+ tvm_crt_error_t to_return = SendInternal(MessageType::kStartSessionInit,
+ reinterpret_cast<uint8_t*>(&payload), sizeof(payload));
+ if (to_return == 0) {
+ state_ = State::kStartSessionSent;
+ }
+
+ return to_return;
+}
+
+tvm_crt_error_t Session::Initialize() { return TerminateSession(); }
+
+tvm_crt_error_t Session::TerminateSession() {
+ SetSessionId(0, 0);
+ state_ = State::kNoSessionEstablished;
+ return SendInternal(MessageType::kTerminateSession, nullptr, 0);
+}
+
+tvm_crt_error_t Session::SendMessage(MessageType message_type, const uint8_t* message_data,
+ size_t message_size_bytes) {
+ if (state_ != State::kSessionEstablished && message_type != MessageType::kLog) {
+ return kTvmErrorSessionInvalidState;
+ }
+
+ return SendInternal(message_type, message_data, message_size_bytes);
+}
+
+ssize_t Session::SessionReceiver::Write(const uint8_t* data, size_t data_size_bytes) {
+ if (session_->receive_buffer_has_complete_message_) {
+ return kTvmErrorSessionReceiveBufferBusy;
+ }
+
+ size_t bytes_written = session_->receive_buffer_->Write(data, data_size_bytes);
+ if (bytes_written != data_size_bytes) {
+ return kTvmErrorSessionReceiveBufferShortWrite;
+ }
+
+ return bytes_written;
+}
+
+void Session::SessionReceiver::PacketDone(bool is_valid) {
+ if (!is_valid) {
+ return;
+ }
+
+ SessionHeader header;
+ int bytes_read =
+ session_->receive_buffer_->Read(reinterpret_cast<uint8_t*>(&header), sizeof(header));
+ if (bytes_read != sizeof(header)) {
+ return;
+ }
+ session_->receive_buffer_has_complete_message_ = true;
+
+ switch (header.message_type) {
+ case MessageType::kStartSessionInit:
+ session_->ProcessStartSessionInit(header);
+ session_->receive_buffer_has_complete_message_ = false;
+ break;
+ case MessageType::kStartSessionReply:
+ session_->ProcessStartSessionReply(header);
+ session_->receive_buffer_has_complete_message_ = false;
+ break;
+ case MessageType::kTerminateSession:
+ if (session_->state_ == State::kSessionEstablished) {
+ session_->state_ = State::kNoSessionEstablished;
+ session_->OnSessionTerminatedMessage();
+ }
+ session_->receive_buffer_has_complete_message_ = false;
+ break;
+ case MessageType::kLog:
+ if (header.session_id == 0 || header.session_id == session_->session_id_) {
+ // Special case for log messages: session id can be 0.
+ session_->message_received_func_(session_->message_received_func_context_,
+ header.message_type, session_->receive_buffer_);
+ }
+ break;
+ default:
+ if (session_->state_ == State::kSessionEstablished &&
+ header.session_id == session_->session_id_) {
+ session_->message_received_func_(session_->message_received_func_context_,
+ header.message_type, session_->receive_buffer_);
+ }
+ break;
+ }
+}
+
+void Session::ClearReceiveBuffer() {
+ receive_buffer_has_complete_message_ = false;
+ receive_buffer_->Clear();
+}
+
+void Session::SendSessionStartReply(const SessionHeader& header) {
+ RegenerateNonce();
+ SetSessionId(InitiatorNonce(header.session_id), local_nonce_);
+ utvm_session_start_payload_t payload = {Session::kVersion};
+ tvm_crt_error_t to_return = SendInternal(MessageType::kStartSessionReply,
+ reinterpret_cast<uint8_t*>(&payload), sizeof(payload));
+ state_ = State::kSessionEstablished;
+ CHECK_EQ(to_return, kTvmErrorNoError, "SendSessionStartReply");
+ OnSessionEstablishedMessage();
+}
+
+void Session::ProcessStartSessionInit(const SessionHeader& header) {
+ if (InitiatorNonce(header.session_id) == kInvalidNonce) {
+ return;
+ }
+
+ utvm_session_start_payload_t payload;
+ int bytes_read = receive_buffer_->Read(reinterpret_cast<uint8_t*>(&payload), sizeof(payload));
+ if (bytes_read != sizeof(payload)) {
+ return;
+ }
+
+ switch (state_) {
+ case State::kReset:
+ case State::kNoSessionEstablished:
+ // Normal case: received a StartSession packet from reset.
+ SendSessionStartReply(header);
+ break;
+
+ case State::kStartSessionSent:
+ // When two StartSessionInit packets sent simultaneously: lowest nonce wins; ties retry.
+ if (InitiatorNonce(header.session_id) < local_nonce_) {
+ if (payload.version == Session::kVersion) {
+ SendSessionStartReply(header);
+ }
+ } else if (InitiatorNonce(header.session_id) == local_nonce_) {
+ StartSession();
+ }
+
+ break;
+
+ case State::kSessionEstablished:
+ SendSessionStartReply(header);
+ OnSessionEstablishedMessage();
+ break;
+
+ default:
+ state_ = State::kReset;
+ }
+}
+
+void Session::ProcessStartSessionReply(const SessionHeader& header) {
+ if (ResponderNonce(header.session_id) == kInvalidNonce) {
+ return;
+ }
+
+ utvm_session_start_payload_t payload;
+ int bytes_read = receive_buffer_->Read(reinterpret_cast<uint8_t*>(&payload), sizeof(payload));
+ if (bytes_read != sizeof(payload)) {
+ return;
+ }
+
+ switch (state_) {
+ case State::kReset:
+ case State::kNoSessionEstablished:
+ break;
+ case State::kStartSessionSent:
+ if (InitiatorNonce(header.session_id) == local_nonce_ &&
+ payload.version == Session::kVersion) {
+ SetSessionId(local_nonce_, ResponderNonce(header.session_id));
+ state_ = State::kSessionEstablished;
+ OnSessionEstablishedMessage();
+ }
+ break;
+ case State::kSessionEstablished:
+ if (InitiatorNonce(header.session_id) != kInvalidNonce &&
+ ResponderNonce(header.session_id) == kInvalidNonce) {
+ if (payload.version == Session::kVersion) {
+ SendSessionStartReply(header);
+ } else {
+ SetSessionId(local_nonce_, 0);
+ state_ = State::kReset;
+ }
+ } else {
+ state_ = State::kReset;
+ }
+ break;
+ }
+}
+
+void Session::OnSessionEstablishedMessage() {
+ message_received_func_(message_received_func_context_, MessageType::kStartSessionReply, NULL);
+}
+
+void Session::OnSessionTerminatedMessage() {
+ message_received_func_(message_received_func_context_, MessageType::kTerminateSession, NULL);
+}
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file framing.h
+ * \brief Framing for RPC.
+ */
+#include <tvm/runtime/crt/rpc_common/write_stream.h>
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+WriteStream::~WriteStream() {}
+
+tvm_crt_error_t WriteStream::WriteAll(uint8_t* data, size_t data_size_bytes,
+ size_t* bytes_consumed) {
+ *bytes_consumed = 0;
+ while (data_size_bytes > 0) {
+ ssize_t to_return = Write(data, data_size_bytes);
+ if (to_return == 0) {
+ return kTvmErrorWriteStreamShortWrite;
+ } else if (to_return < 0) {
+ return (tvm_crt_error_t)to_return;
+ } else if (to_return > 0 && ((size_t)to_return) > data_size_bytes) {
+ return kTvmErrorWriteStreamLongWrite;
+ }
+
+ data += to_return;
+ data_size_bytes -= to_return;
+ *bytes_consumed += to_return;
+ }
+
+ return kTvmErrorNoError;
+}
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file utvm_rpc_server.cc
+ * \brief MicroTVM RPC Server
+ */
+
+#include <inttypes.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/types.h>
+
+// NOTE: dmlc/base.h contains some declarations that are incompatible with some C embedded
+// toolchains. Just pull the bits we need for this file.
+#define DMLC_CMAKE_LITTLE_ENDIAN DMLC_IO_USE_LITTLE_ENDIAN
+#define DMLC_LITTLE_ENDIAN true
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/crt/crt.h>
+#include <tvm/runtime/crt/internal/common/memory.h>
+#include <tvm/runtime/crt/logging.h>
+#include <tvm/runtime/crt/memory.h>
+#include <tvm/runtime/crt/module.h>
+#include <tvm/runtime/crt/platform.h>
+#include <tvm/runtime/crt/rpc_common/frame_buffer.h>
+#include <tvm/runtime/crt/rpc_common/framing.h>
+#include <tvm/runtime/crt/rpc_common/session.h>
+#include <tvm/runtime/crt/utvm_rpc_server.h>
+
+#include "../../minrpc/minrpc_server.h"
+#include "crt_config.h"
+
+namespace tvm {
+namespace runtime {
+namespace micro_rpc {
+
+class MicroIOHandler {
+ public:
+ MicroIOHandler(Session* session, FrameBuffer* receive_buffer)
+ : session_{session}, receive_buffer_{receive_buffer} {}
+
+ void MessageStart(size_t message_size_bytes) {
+ session_->StartMessage(MessageType::kNormal, message_size_bytes + 8);
+ }
+
+ ssize_t PosixWrite(const uint8_t* buf, size_t buf_size_bytes) {
+ int to_return = session_->SendBodyChunk(buf, buf_size_bytes);
+ if (to_return < 0) {
+ return to_return;
+ }
+ return buf_size_bytes;
+ }
+
+ void MessageDone() { CHECK_EQ(session_->FinishMessage(), kTvmErrorNoError, "FinishMessage"); }
+
+ ssize_t PosixRead(uint8_t* buf, size_t buf_size_bytes) {
+ return receive_buffer_->Read(buf, buf_size_bytes);
+ }
+
+ void Close() {}
+
+ void Exit(int code) {
+ for (;;) {
+ }
+ }
+
+ private:
+ Session* session_;
+ FrameBuffer* receive_buffer_;
+};
+
+namespace {
+// Stored as globals so that they can be used to report initialization errors.
+utvm_rpc_channel_write_t g_write_func = nullptr;
+void* g_write_func_ctx = nullptr;
+} // namespace
+
+class SerialWriteStream : public WriteStream {
+ public:
+ SerialWriteStream() {}
+ virtual ~SerialWriteStream() {}
+
+ ssize_t Write(const uint8_t* data, size_t data_size_bytes) override {
+ return g_write_func(g_write_func_ctx, data, data_size_bytes);
+ }
+
+ void PacketDone(bool is_valid) override {}
+
+ private:
+ void operator delete(void*) noexcept {} // NOLINT(readability/casting)
+};
+
+class MicroRPCServer {
+ public:
+ MicroRPCServer(uint8_t* receive_storage, size_t receive_storage_size_bytes,
+ utvm_rpc_channel_write_t write_func, void* write_func_ctx)
+ : receive_buffer_{receive_storage, receive_storage_size_bytes},
+ framer_{&send_stream_},
+ session_{0xa5, &framer_, &receive_buffer_, &HandleCompleteMessageCb, this},
+ io_{&session_, &receive_buffer_},
+ unframer_{session_.Receiver()},
+ rpc_server_{&io_},
+ has_pending_byte_{false},
+ is_running_{true} {}
+
+ void* operator new(size_t count, void* ptr) { return ptr; }
+
+ void Initialize() { CHECK_EQ(kTvmErrorNoError, session_.Initialize(), "rpc server init"); }
+
+ /*! \brief Process one message from the receive buffer, if possible.
+ *
+ * \return true if additional messages could be processed. false if the server shutdown request
+ * has been received.
+ */
+ bool Loop() {
+ if (has_pending_byte_) {
+ size_t bytes_consumed;
+ CHECK_EQ(unframer_.Write(&pending_byte_, 1, &bytes_consumed), kTvmErrorNoError,
+ "unframer_.Write");
+ CHECK_EQ(bytes_consumed, 1, "bytes_consumed");
+ has_pending_byte_ = false;
+ }
+
+ return is_running_;
+ }
+
+ void HandleReceivedByte(uint8_t byte) {
+ CHECK(!has_pending_byte_);
+ has_pending_byte_ = true;
+ pending_byte_ = byte;
+ }
+
+ void Log(const uint8_t* message, size_t message_size_bytes) {
+ tvm_crt_error_t to_return =
+ session_.SendMessage(MessageType::kLog, message, message_size_bytes);
+ if (to_return != 0) {
+ TVMPlatformAbort(to_return);
+ }
+ }
+
+ private:
+ FrameBuffer receive_buffer_;
+ SerialWriteStream send_stream_;
+ Framer framer_;
+ Session session_;
+ MicroIOHandler io_;
+ Unframer unframer_;
+ MinRPCServer<MicroIOHandler> rpc_server_;
+
+ bool has_pending_byte_;
+ uint8_t pending_byte_;
+ bool is_running_;
+
+ void HandleCompleteMessage(MessageType message_type, FrameBuffer* buf) {
+ if (message_type != MessageType::kNormal) {
+ return;
+ }
+
+ is_running_ = rpc_server_.ProcessOnePacket();
+ session_.ClearReceiveBuffer();
+ }
+
+ static void HandleCompleteMessageCb(void* context, MessageType message_type, FrameBuffer* buf) {
+ static_cast<MicroRPCServer*>(context)->HandleCompleteMessage(message_type, buf);
+ }
+};
+
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
+
+void* operator new[](size_t count, void* ptr) noexcept { return ptr; }
+
+extern "C" {
+
+static utvm_rpc_server_t g_rpc_server = nullptr;
+
+utvm_rpc_server_t UTvmRpcServerInit(uint8_t* memory, size_t memory_size_bytes,
+ size_t page_size_bytes_log2,
+ utvm_rpc_channel_write_t write_func, void* write_func_ctx) {
+ tvm::runtime::micro_rpc::g_write_func = write_func;
+ tvm::runtime::micro_rpc::g_write_func_ctx = write_func_ctx;
+
+ tvm_crt_error_t err = TVMInitializeRuntime(memory, memory_size_bytes, page_size_bytes_log2);
+ if (err != kTvmErrorNoError) {
+ TVMPlatformAbort(err);
+ }
+
+ auto receive_buffer =
+ new (vmalloc(TVM_CRT_MAX_PACKET_SIZE_BYTES)) uint8_t[TVM_CRT_MAX_PACKET_SIZE_BYTES];
+ auto rpc_server = new (vmalloc(sizeof(tvm::runtime::micro_rpc::MicroRPCServer)))
+ tvm::runtime::micro_rpc::MicroRPCServer(receive_buffer, TVM_CRT_MAX_PACKET_SIZE_BYTES,
+ write_func, write_func_ctx);
+ g_rpc_server = static_cast<utvm_rpc_server_t>(rpc_server);
+ rpc_server->Initialize();
+ return g_rpc_server;
+}
+
+void TVMLogf(const char* format, ...) {
+ va_list args;
+ char log_buffer[256];
+ va_start(args, format);
+ size_t num_bytes_logged = vsnprintf(log_buffer, sizeof(log_buffer), format, args);
+ va_end(args);
+
+ // Most header-based logging frameworks tend to insert '\n' at the end of the log message.
+ // Remove that for remote logging, since the remote logger will do the same.
+ if (num_bytes_logged > 0 && log_buffer[num_bytes_logged - 1] == '\n') {
+ log_buffer[num_bytes_logged - 1] = 0;
+ num_bytes_logged--;
+ }
+
+ if (g_rpc_server != nullptr) {
+ static_cast<tvm::runtime::micro_rpc::MicroRPCServer*>(g_rpc_server)
+ ->Log(reinterpret_cast<uint8_t*>(log_buffer), num_bytes_logged);
+ } else {
+ tvm::runtime::micro_rpc::SerialWriteStream write_stream;
+ tvm::runtime::micro_rpc::Framer framer{&write_stream};
+ tvm::runtime::micro_rpc::Session session{0xa5, &framer, nullptr, nullptr, nullptr};
+ tvm_crt_error_t err =
+ session.SendMessage(tvm::runtime::micro_rpc::MessageType::kLog,
+ reinterpret_cast<uint8_t*>(log_buffer), num_bytes_logged);
+ if (err != kTvmErrorNoError) {
+ TVMPlatformAbort(err);
+ }
+ }
+}
+
+size_t UTvmRpcServerReceiveByte(utvm_rpc_server_t server_ptr, uint8_t byte) {
+ // NOTE(areusch): In the future, this function is intended to work from an IRQ context. That's not
+ // needed at present.
+ tvm::runtime::micro_rpc::MicroRPCServer* server =
+ static_cast<tvm::runtime::micro_rpc::MicroRPCServer*>(server_ptr);
+ server->HandleReceivedByte(byte);
+ return 1;
+}
+
+bool UTvmRpcServerLoop(utvm_rpc_server_t server_ptr) {
+ tvm::runtime::micro_rpc::MicroRPCServer* server =
+ static_cast<tvm::runtime::micro_rpc::MicroRPCServer*>(server_ptr);
+ return server->Loop();
+}
+
+} // extern "C"
+++ /dev/null
-/*\r
- * Licensed to the Apache Software Foundation (ASF) under one\r
- * or more contributor license agreements. See the NOTICE file\r
- * distributed with this work for additional information\r
- * regarding copyright ownership. The ASF licenses this file\r
- * to you under the Apache License, Version 2.0 (the\r
- * "License"); you may not use this file except in compliance\r
- * with the License. You may obtain a copy of the License at\r
- *\r
- * http://www.apache.org/licenses/LICENSE-2.0\r
- *\r
- * Unless required by applicable law or agreed to in writing,\r
- * software distributed under the License is distributed on an\r
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\r
- * KIND, either express or implied. See the License for the\r
- * specific language governing permissions and limitations\r
- * under the License.\r
- */\r
-\r
-.syntax unified\r
-.cpu cortex-m7\r
-.fpu softvfp\r
-.thumb\r
-\r
-.section .text.UTVMInit\r
-.type UTVMInit, %function\r
-UTVMInit:\r
- /* enable fpu */\r
- ldr r0, =0xE000ED88\r
- ldr r1, [r0]\r
- ldr r2, =0xF00000\r
- orr r1, r2\r
- str r1, [r0]\r
- dsb\r
- isb\r
- /* set stack pointer */\r
- ldr sp, =_utvm_stack_pointer_init\r
- bl UTVMMain\r
-.size UTVMInit, .-UTVMInit\r
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file utvm_timer.c
- * \brief uTVM timer API definitions for STM32F746XX-series boards
- */
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#include <stdint.h>
-
-#include "utvm_runtime.h"
-// NOTE: This expects ST CMSIS to be in your include path.
-// Download STM32CubeF7 here:
-// https://www.st.com/content/st_com/en/products/embedded-software/mcu-mpu-embedded-software/stm32-embedded-software/stm32cube-mcu-mpu-packages/stm32cubef7.html
-// and add Drivers/CMSIS to your C include path.
-#include "Device/ST/STM32F7xx/Include/stm32f746xx.h"
-
-#define utvm_SystemCoreClock 216000000UL
-
-int32_t UTVMTimerStart() {
- UTVMTimerReset();
- TIM2->CR1 = TIM_CR1_CEN; // Start counter
- return UTVM_ERR_OK;
-}
-
-uint32_t UTVMTimerStop(int32_t* err) {
- TIM2->CR1 &= TIM_CR1_CEN;
- if (TIM2->SR & TIM_SR_UIF_Msk) {
- *err = UTVM_ERR_TIMER_OVERFLOW;
- return 0;
- }
- *err = UTVM_ERR_OK;
- uint32_t tim_cnt = TIM2->CNT;
- uint32_t millis = tim_cnt / (utvm_SystemCoreClock / 1000);
- uint32_t micros =
- (tim_cnt - (millis * (utvm_SystemCoreClock / 1000))) / (utvm_SystemCoreClock / 1000000);
- return millis * 1000 + micros;
-}
-
-void UTVMTimerReset() {
- RCC->APB1RSTR |= RCC_APB1RSTR_TIM2RST; // Hold TIM2 in reset
- RCC->DCKCFGR1 = (RCC->DCKCFGR1 & ~RCC_DCKCFGR1_TIMPRE_Msk); // disable 2x clock boost to TIM2
- RCC->CFGR = (RCC->CFGR & ~RCC_CFGR_PPRE1_Msk); // No AHB clock division to APB1 (1:1).
- RCC->APB1ENR |= RCC_APB1ENR_TIM2EN; // Enable TIM2 clock.
- RCC->APB1RSTR &= ~RCC_APB1RSTR_TIM2RST; // Exit TIM2 reset.
-
- DBGMCU->APB1FZ |= DBGMCU_APB1_FZ_DBG_TIM2_STOP; // stop TIM2 clock during debug halt.
- TIM2->ARR = 0xffffffff;
- if (TIM2->SR & TIM_SR_UIF_Msk) {
- for (;;) {
- }
- }
-}
-
-#ifdef __cplusplus
-} // TVM_EXTERN_C
-#endif
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file utvm_timer.c
- * \brief uTVM timer API stubs for the host emulated device
- */
-
-#include <stdint.h>
-
-#include "utvm_runtime.h"
-
-// TODO(weberlo): use this? https://stackoverflow.com/questions/5141960/get-the-current-time-in-c
-
-int32_t UTVMTimerStart() { return UTVM_ERR_OK; }
-
-uint32_t UTVMTimerStop(int32_t* err) {
- *err = UTVM_ERR_OK;
- return 0;
-}
+++ /dev/null
-/*\r
- * Licensed to the Apache Software Foundation (ASF) under one\r
- * or more contributor license agreements. See the NOTICE file\r
- * distributed with this work for additional information\r
- * regarding copyright ownership. The ASF licenses this file\r
- * to you under the Apache License, Version 2.0 (the\r
- * "License"); you may not use this file except in compliance\r
- * with the License. You may obtain a copy of the License at\r
- *\r
- * http://www.apache.org/licenses/LICENSE-2.0\r
- *\r
- * Unless required by applicable law or agreed to in writing,\r
- * software distributed under the License is distributed on an\r
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\r
- * KIND, either express or implied. See the License for the\r
- * specific language governing permissions and limitations\r
- * under the License.\r
- */\r
-\r
-UTVMInit:\r
- /* set stack pointer */\r
- la sp, _utvm_stack_pointer_init\r
- call UTVMMain\r
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file utvm_device_dylib_redirect.cc
- * \brief uTVM dynamic linking stubs
- *
- * This is a library that gets included in each uTVM library. We redirect
- * each library call into a pre-defined global function pointer, and we patch
- * the correct addresses of each function into the pointers when we load the
- * library.
- */
-#ifdef __cplusplus
-extern "C" {
-#endif
-#include <stddef.h>
-#include <stdint.h>
-
-// TODO(weberlo, areusch): compiler errors say volatile qualifier is discarded.
-// should we just get rid of em?
-void* (*volatile TVMBackendAllocWorkspace_)(int, int, uint64_t, int, int) = NULL;
-int (*volatile TVMBackendFreeWorkspace_)(int, int, void*) = NULL;
-void (*volatile TVMAPISetLastError_)(const char*) = NULL;
-
-void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint,
- int dtype_bits_hint) {
- return (*TVMBackendAllocWorkspace_)(device_type, device_id, size, dtype_code_hint,
- dtype_bits_hint);
-}
-
-int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
- return (*TVMBackendFreeWorkspace_)(device_type, device_id, ptr);
-}
-
-void TVMAPISetLastError(const char* msg) { (*TVMAPISetLastError_)(msg); }
-
-void* memset(void* s, int c, size_t n) {
- char* p = (char*)s; // NOLINT(readability/casting): linter is configured for c++
- while (n > 0) {
- *p = (char)c; // NOLINT(readability/casting): linter is configured for c++
- p++;
- n--;
- }
- return s;
-}
-
-void* memmove(void* to, const void* from, size_t n) {
- // TODO(weberlo, areusch): will need to factor memmove calls into workspace size calculation
- // NOLINTNEXTLINE(readability/casting): linter is configured for c++
- char* temp = (char*)TVMBackendAllocWorkspace(1, 1, (uint64_t)n, 2, 8);
- if (temp == NULL) {
- return NULL;
- }
-
- const char* from_pp = (char*)from; // NOLINT(readability/casting): linter is configured for c++
- for (size_t i = 0; i < n; i++) {
- temp[i] = from_pp[i];
- }
- char* to_pp = (char*)to; // NOLINT(readability/casting): linter is configured for c++
- for (size_t i = 0; i < n; i++) {
- to_pp[i] = temp[i];
- }
-
- // NOLINTNEXTLINE(readability/casting): linter is configured for c++
- if (TVMBackendFreeWorkspace(1, (uint64_t)1, (void*)temp) != 0) {
- return NULL;
- }
-
- return to;
-}
-
-#ifdef __cplusplus
-} // TVM_EXTERN_C
-#endif
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file utvm_runtime.cc
- * \brief uTVM runtime
- *
- * All function calls go through the externally defined `UTVMInit`, which
- * performs device-specific setup, then calls `UTVMMain`. `UTVMMain` then
- * calls the function in `utvm_task` with the arguments from the task.
- *
- * Additionally included in this file are definitions for some of the most
- * common functions used in the C runtime API.
- */
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#include "utvm_runtime.h"
-
-// TODO(weberlo, areusch): move defines into header
-// TODO(weberlo, areusch): unify TASK_QUEUE_SIZE and MicroSession::kTaskQueueCapacity.
-#define TASK_QUEUE_SIZE 20
-volatile UTVMTask utvm_tasks[TASK_QUEUE_SIZE] = {};
-volatile uint32_t utvm_num_tasks = 0;
-volatile uint32_t utvm_task_times[TASK_QUEUE_SIZE] = {};
-
-// These pointers are patched at load time to point to the workspace section.
-volatile char* utvm_workspace_start = NULL; // NOLINT(*)
-volatile char* utvm_workspace_end = NULL; // NOLINT(*)
-volatile char* utvm_workspace_curr = NULL; // NOLINT(*)
-#define MAX_WS_ALLOCS 10
-volatile char* utvm_alloc_ends[MAX_WS_ALLOCS] = {}; // NOLINT(*)
-volatile uint32_t utvm_alloc_idx = 0;
-// Keep track of how many active allocations there are on the workspace.
-volatile uint32_t utvm_num_active_allocs = 0;
-
-volatile uint32_t utvm_word_size = 0;
-
-volatile int32_t utvm_last_error = 0; // NOLINT(*)
-
-volatile uint32_t utvm_done = 0;
-
-// Gets called by UTVMInit, after device-specific initialization is finished.
-void UTVMMain() {
- utvm_done = 0;
- // loss of precision should be fine here, since we only care about the lower bits
- if (((uint32_t)utvm_workspace_start) % utvm_word_size) {
- utvm_last_error = UTVM_ERR_WS_UNALIGNED_START;
- UTVMDone();
- return;
- }
- utvm_workspace_curr = utvm_workspace_start;
- utvm_num_active_allocs = 0;
- utvm_alloc_idx = 0;
- utvm_last_error = UTVM_ERR_NOT_FINISHED;
- for (uint32_t i = 0; i < utvm_num_tasks; i++) {
- int32_t err = UTVM_ERR_OK;
- utvm_task_times[i] = 0;
- err = UTVMTimerStart();
- if (err < 0) {
- utvm_last_error = err;
- UTVMDone();
- return;
- }
- err = utvm_tasks[i].func((void*)utvm_tasks[i].arg_values, // NOLINT(*)
- (void*)utvm_tasks[i].arg_type_codes, // NOLINT(*)
- utvm_tasks[i].num_args);
- if (err < 0) {
- UTVMDone();
- return;
- }
- utvm_task_times[i] = UTVMTimerStop(&err);
- if (err < 0) {
- utvm_last_error = err;
- UTVMDone();
- return;
- }
- }
- if (utvm_last_error == UTVM_ERR_NOT_FINISHED) {
- utvm_last_error = UTVM_ERR_OK;
- }
- UTVMDone();
-}
-
-// We use a dummy function to signal execution is finished for device
-// backends which require breakpoints.
-void __attribute__((noinline)) UTVMDone() {
- utvm_done = 1;
-#ifndef UTVM_TARGET_HOST
- for (;;) {
- }
-#endif
-}
-
-#define ALIGNED_UP(x, word_size) \
- ((((word_size) - (((uintptr_t)(x)) % (word_size))) % (word_size)) + (x))
-
-void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint,
- int dtype_bits_hint) {
- if (size == 0) {
- utvm_last_error = UTVM_ERR_WS_ZERO_SIZE_ALLOC;
- return NULL;
- }
- size_t alloc_requested_bytes = size;
- size_t alloc_size_words = (alloc_requested_bytes + utvm_word_size - 1) / utvm_word_size;
- size_t alloc_size_bytes = alloc_size_words * utvm_word_size;
-
- // Align up to the target word size.
- if (utvm_workspace_curr + alloc_size_bytes > utvm_workspace_end) {
- // Out of space in workspace.
- utvm_last_error = UTVM_ERR_WS_OUT_OF_SPACE;
- return NULL;
- }
- if (utvm_alloc_idx == MAX_WS_ALLOCS - 1) {
- // Exceeded number of allocs we can keep track of.
- utvm_last_error = UTVM_ERR_WS_TOO_MANY_ALLOCS;
- return NULL;
- }
- void* ret_ptr = (void*)utvm_workspace_curr; // NOLINT(*)
- utvm_workspace_curr = utvm_workspace_curr + alloc_size_bytes;
- // store the *end* of the alloc, so we can restore the WS pointer when freeing
- utvm_alloc_ends[utvm_alloc_idx] = utvm_workspace_curr;
- utvm_alloc_idx++;
- utvm_num_active_allocs++;
- return ret_ptr;
-}
-
-int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
- // TODO(weberlo, areusch): add dev type check
- if (utvm_num_active_allocs == 0) {
- TVMAPISetLastError("free called with no active workspace allocations");
- // Reset allocations and workspace (for future task executions).
- utvm_num_active_allocs = 0;
- utvm_workspace_curr = utvm_workspace_start;
- utvm_last_error = UTVM_ERR_WS_DOUBLE_FREE;
- return -1;
- } else {
- utvm_num_active_allocs--;
- if (ptr == utvm_workspace_start) {
- // it's the first allocation
- utvm_alloc_ends[0] = NULL;
- } else {
- for (uint32_t i = utvm_alloc_idx - 1; i >= 0; i--) {
- if (utvm_alloc_ends[i] == ptr) {
- utvm_alloc_ends[i + 1] = NULL;
- break;
- }
- }
- }
- while (utvm_alloc_idx > 0 && utvm_alloc_ends[utvm_alloc_idx - 1] == NULL) {
- utvm_alloc_idx--;
- }
- if (utvm_alloc_idx == 0) {
- utvm_workspace_curr = utvm_workspace_start;
- } else {
- // TODO(weberlo, areusch): could you possibly have utvm_alloc_idx pointing to a NULL entry in
- // this branch?
- utvm_workspace_curr = utvm_alloc_ends[utvm_alloc_idx - 1];
- }
- return 0;
- }
-}
-
-void TVMAPISetLastError(const char* msg) {}
-
-#ifdef __cplusplus
-} // TVM_EXTERN_C
-#endif
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file utvm_runtime.h
- * \brief uTVM runtime headers
- */
-#ifndef TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_H_
-#define TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_H_
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#include <stdint.h>
-#include <tvm/runtime/c_backend_api.h>
-#include <tvm/runtime/c_runtime_api.h>
-
-#include "utvm_runtime_enum.h"
-
-/*!
- * \brief Task structure for uTVM
- */
-typedef struct {
- /*! \brief Pointer to function to call for this task */
- int32_t (*func)(void*, void*, int32_t);
- /*! \brief Array of argument values */
- TVMValue* arg_values;
- /*! \brief Array of type codes for each argument value */
- int* arg_type_codes;
- /*! \brief Number of arguments */
- int32_t num_args;
-} UTVMTask;
-
-/*!
- * \brief microTVM processor startup.
- * Expected to reset the stack pointer, configure any hardware required to support the CRT
- * (i.e. FPU), and then jump to UTVMMain.
- */
-extern void UTVMInit();
-
-/*!
- * \brief Start the on-device timer.
- * \return UTVMReturnCode indicating the outcome of the operation.
- */
-extern int32_t UTVMTimerStart();
-
-/*!
- * \brief Stop the on-device timer.
- * TODO(areusch): Use an SI specification of timer units here.
- * \param err Receives a UTVMReturnCode indicating the outcome of the operation.
- * \return elapsed time since UTVMTimerStart returned, in device timer ticks.
- */
-extern uint32_t UTVMTimerStop(int32_t* err);
-
-/*!
- * \brief Main entry point for UTVM runtime.
- * Waits for "go" signal, then executes tasks and reports result. Should never return.
- */
-void UTVMMain();
-
-/*!
- * \brief Function entered when UTVMMain is complete.
- * Should never return. The host sets a breakpoint here to detect end of computation.
- */
-void UTVMDone();
-
-// GCC -O3 begins to inject memset and memmove calls, so we provide impls in
-// the runtime for this case and for general usage.
-
-void* memset(void* s, int c, size_t n);
-
-void* memmove(void* to, const void* from, size_t n);
-
-#ifdef __cplusplus
-} // TVM_EXTERN_C
-#endif
-
-#endif // TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file utvm_runtime_enum.h
- * \brief Defines constants used both on the host and on device.
- */
-#ifndef TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_
-#define TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/*!
- * \brief TODO
- */
-enum UTVMReturnCode {
- UTVM_ERR_OK = 0,
- UTVM_ERR_NOT_FINISHED = -1,
- UTVM_ERR_TIMER_NOT_IMPLEMENTED = -2,
- UTVM_ERR_TIMER_OVERFLOW = -3,
- UTVM_ERR_WS_DOUBLE_FREE = -4,
- UTVM_ERR_WS_OUT_OF_SPACE = -5,
- UTVM_ERR_WS_TOO_MANY_ALLOCS = -6,
- UTVM_ERR_WS_ZERO_SIZE_ALLOC = -7,
- UTVM_ERR_WS_UNALIGNED_START = -8,
- UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE = -9,
-};
-
-#ifdef __cplusplus
-} // TVM_EXTERN_C
-#endif
-
-#endif // TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file host_low_level_device.cc
- * \brief emulated low-level micro device implementation on host machine
- */
-
-#include <sys/mman.h>
-
-#include <cstring>
-#include <memory>
-
-#include "low_level_device.h"
-#include "micro_common.h"
-
-namespace tvm {
-namespace runtime {
-
-/*! \brief number of bytes in each page */
-constexpr int kPageSize = 4096;
-
-/*!
- * \brief emulated low-level device on host machine
- */
-class HostLowLevelDevice final : public LowLevelDevice {
- public:
- /*!
- * \brief constructor to initialize on-host memory region to act as device
- * \param num_bytes size of the emulated on-device memory region
- */
- explicit HostLowLevelDevice(size_t num_bytes, TargetPtr* base_addr) : size_(num_bytes) {
- size_t size_in_pages = (num_bytes + kPageSize - 1) / kPageSize;
- // TODO(weberlo): Set permissions per section (e.g., read-write perms for
- // the heap, execute perms for text, etc.).
- int mmap_prot = PROT_READ | PROT_WRITE | PROT_EXEC;
- int mmap_flags = MAP_ANONYMOUS | MAP_PRIVATE;
- base_addr_ = mmap(nullptr, size_in_pages * kPageSize, mmap_prot, mmap_flags, -1, 0);
- *base_addr =
- TargetPtr(TargetWordSize(sizeof(size_t) * 8), reinterpret_cast<uint64_t>(base_addr_));
- }
-
- /*!
- * \brief destructor to deallocate on-host device region
- */
- virtual ~HostLowLevelDevice() { munmap(base_addr_, size_); }
-
- void Read(TargetPtr addr, void* buf, size_t num_bytes) {
- std::memcpy(buf, addr.cast_to<void*>(), num_bytes);
- }
-
- void Write(TargetPtr addr, const void* buf, size_t num_bytes) {
- std::memcpy(addr.cast_to<void*>(), buf, num_bytes);
- }
-
- void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) {
- reinterpret_cast<void (*)(void)>(func_addr.value().uint64())();
- }
-
- const char* device_type() const final { return "host"; }
-
- private:
- /*! \brief base address of the micro device memory region */
- void* base_addr_;
- /*! \brief size of memory region */
- size_t size_;
-};
-
-const std::shared_ptr<LowLevelDevice> HostLowLevelDeviceCreate(size_t num_bytes,
- TargetPtr* base_addr) {
- std::shared_ptr<LowLevelDevice> lld = std::make_shared<HostLowLevelDevice>(num_bytes, base_addr);
- return lld;
-}
-
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file low_level_device.h
- * \brief Abstract low-level micro device management
- */
-#ifndef TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_
-#define TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_
-
-#include <memory>
-#include <string>
-
-#include "micro_common.h"
-
-namespace tvm {
-namespace runtime {
-/*!
- * \brief virtual interface for low-level micro device management
- */
-class LowLevelDevice {
- public:
- /*! \brief virtual destructor */
- virtual ~LowLevelDevice() {}
-
- /*!
- * \brief reads num_bytes from device memory at addr into buffer
- * \param addr on-device memory address to read from
- * \param buffer on-host buffer to be read into
- * \param num_bytes number of bytes to read
- */
- virtual void Read(TargetPtr addr, void* buffer, size_t num_bytes) = 0;
-
- /*!
- * \brief writes num_bytes from buffer to device memory at addr
- * \param addr on-device memory address to write into
- * \param buffer host buffer to write from
- * \param num_bytes number of bytes to write
- */
- virtual void Write(TargetPtr addr, const void* buffer, size_t num_bytes) = 0;
-
- /*!
- * \brief starts execution of device at func_addr
- * \param func_addr offset of the init stub function
- * \param breakpoint_addr address at which to stop function execution
- */
- virtual void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) = 0;
-
- /*!
- * \brief getter function for low-level device type
- * \return string containing device type
- */
- virtual const char* device_type() const = 0;
-};
-
-/*!
- * \brief create a host low-level device
- * \param num_bytes size of the memory region
- * \param base_addr pointer to write the host device's resulting base address into
- */
-const std::shared_ptr<LowLevelDevice> HostLowLevelDeviceCreate(size_t num_bytes,
- TargetPtr* base_addr);
-
-/*!
- * \brief connect to OpenOCD and create an OpenOCD low-level device
- * \param addr address of the OpenOCD server to connect to
- * \param port port of the OpenOCD server to connect to
- */
-const std::shared_ptr<LowLevelDevice> OpenOCDLowLevelDeviceCreate(const std::string& addr,
- int port);
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file micro_common.cc
- * \brief common utilties for uTVM
- */
-
-#include "micro_common.h"
-
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/registry.h>
-
-#include <cstdint>
-#include <cstdio>
-#include <sstream>
-#include <string>
-
-#include "low_level_device.h"
-#include "micro_session.h"
-
-namespace tvm {
-namespace runtime {
-
-const char* SectionToString(SectionKind section) {
- switch (section) {
- case SectionKind::kText:
- return "text";
- case SectionKind::kRodata:
- return "rodata";
- case SectionKind::kData:
- return "data";
- case SectionKind::kBss:
- return "bss";
- case SectionKind::kArgs:
- return "args";
- case SectionKind::kHeap:
- return "heap";
- case SectionKind::kWorkspace:
- return "workspace";
- case SectionKind::kStack:
- return "stack";
- default:
- return "";
- }
-}
-
-std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size,
- TargetPtr text_start, TargetPtr rodata_start,
- TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end,
- const std::string& toolchain_prefix) {
- const auto* f = Registry::Get("tvm_callback_relocate_binary");
- CHECK(f != nullptr) << "Require tvm_callback_relocate_binary to exist in registry";
- std::string relocated_bin =
- (*f)(binary_path, word_size.bytes(), text_start.cast_to<uint64_t>(),
- rodata_start.cast_to<uint64_t>(), data_start.cast_to<uint64_t>(),
- bss_start.cast_to<uint64_t>(), stack_end.cast_to<uint64_t>(), toolchain_prefix);
- return relocated_bin;
-}
-
-std::string ReadSection(const std::string& binary, SectionKind section,
- const std::string& toolchain_prefix) {
- CHECK(section == SectionKind::kText || section == SectionKind::kRodata ||
- section == SectionKind::kData || section == SectionKind::kBss)
- << "ReadSection requires section to be one of text, rodata, data, or bss.";
- const auto* f = Registry::Get("tvm_callback_read_binary_section");
- CHECK(f != nullptr) << "Require tvm_callback_read_binary_section to exist in registry";
- TVMByteArray arr;
- arr.data = &binary[0];
- arr.size = binary.length();
- std::string section_contents = (*f)(arr, SectionToString(section), toolchain_prefix);
- return section_contents;
-}
-
-size_t GetSectionSize(const std::string& binary_path, SectionKind section,
- const std::string& toolchain_prefix, TargetWordSize word_size) {
- CHECK(section == SectionKind::kText || section == SectionKind::kRodata ||
- section == SectionKind::kData || section == SectionKind::kBss)
- << "GetSectionSize requires section to be one of text, rodata, data, or bss.";
- const auto* f = Registry::Get("tvm_callback_get_section_size");
- CHECK(f != nullptr) << "Require tvm_callback_get_section_size to exist in registry";
- int size = (*f)(binary_path, SectionToString(section), toolchain_prefix);
- return UpperAlignValue(size, word_size.bytes());
-}
-
-std::ostream& operator<<(std::ostream& os, const TargetVal& v) {
- std::ios_base::fmtflags f(os.flags());
- os << std::dec << "0x";
- switch (v.width_bits()) {
- case 8:
- os << uint8_t(v.uint32());
- break;
- case 16:
- os << uint16_t(v.uint32());
- break;
- case 32:
- os << v.uint32();
- break;
- case 64:
- os << v.uint64();
- break;
- default:
- os << (v.uint64() & ((1 << v.width_bits()) - 1));
- }
- os.flags(f);
- return os;
-}
-
-std::ostream& operator<<(std::ostream& os, const TargetPtr& v) {
- os << "*" << v.value_;
- return os;
-}
-
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file micro_common.h
- */
-#ifndef TVM_RUNTIME_MICRO_MICRO_COMMON_H_
-#define TVM_RUNTIME_MICRO_MICRO_COMMON_H_
-
-#include <stdio.h>
-#include <tvm/runtime/registry.h>
-
-#include <sstream>
-#include <string>
-#include <unordered_map>
-#include <utility>
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief enum of device memory region sections
- *
- * The order in which the enum variants are defined also defines the order of
- * the sections in device memory.
- */
-enum class SectionKind : size_t {
- kText = 0,
- kRodata,
- kData,
- kBss,
- kArgs,
- kHeap,
- kWorkspace,
- kStack,
- kNumKinds,
-};
-
-/*! \brief data type for word sizes */
-class TargetWordSize {
- public:
- explicit TargetWordSize(size_t word_size_bits) : word_size_bits_{word_size_bits} {
- CHECK(word_size_bits == 32 || word_size_bits == 64)
- << "only 32-bit and 64-bit are supported now";
- }
-
- size_t bytes() const { return word_size_bits_ / 8; }
-
- size_t bits() const { return word_size_bits_; }
-
- private:
- size_t word_size_bits_;
-};
-
-/*! \brief class for storing values on varying target word sizes */
-class TargetVal {
- private:
- size_t width_bits_;
- uint64_t value_;
-
- public:
- /*! \brief construct a TargetVal matching the size of the given integral argument */
- template <typename T, typename U = typename std::enable_if<std::is_integral<T>::value, T>::type>
- explicit constexpr TargetVal(T value) : TargetVal(sizeof(T) * 8, value) {}
-
- /*! \brief construct an uninitialized value */
- TargetVal() : width_bits_{0}, value_{0} {}
-
- /*! \brief construct a TargetVal with explicit size and value */
- TargetVal(size_t width_bits, uint64_t value) : width_bits_{width_bits} {
- CHECK(width_bits >= 8 && width_bits <= 64 && (width_bits & (width_bits - 1)) == 0)
- << "width_bits must be a power of 2 in [8, 64], got " << width_bits;
- value_ = value & Bitmask();
- }
-
- bool IsInitialized() const { return width_bits_ != 0; }
-
- size_t width_bits() const {
- CHECK(IsInitialized()) << "TargetVal is not initialized";
- return width_bits_;
- }
-
- uint64_t Bitmask() const {
- CHECK(IsInitialized()) << "TargetVal is not initialized";
-
- if (width_bits_ == 64) {
- return ~0UL;
- } else {
- return (1UL << width_bits_) - 1;
- }
- }
-
- uint32_t uint32() const {
- CHECK(IsInitialized()) << "TargetVal is not initialized";
- CHECK(width_bits_ <= 32) << "TargetVal: requested 32-bit value, actual width is "
- << width_bits_;
- return uint32_t(value_ & Bitmask());
- }
-
- uint64_t uint64() const {
- CHECK(IsInitialized()) << "TargetVal is not initialized";
- return value_;
- }
-
- TargetVal& operator=(const TargetVal& other) {
- CHECK(other.IsInitialized()) << "Cannot assign an uninitialized TargetVal";
-
- if (!IsInitialized()) {
- width_bits_ = other.width_bits_;
- }
-
- CHECK(width_bits_ >= other.width_bits_)
- << "Cannot assign TargetVal with width " << other.width_bits_
- << "bits to TargetVal with width " << width_bits_ << "bits";
-
- value_ = other.value_ & Bitmask();
- return *this;
- }
-
- private:
- friend std::ostream& operator<<(std::ostream& os, const TargetVal& v);
-};
-
-// TODO(weberlo, areusch): just get rid of `TargetPtr`.
-/*! \brief absolute device address */
-class TargetPtr {
- public:
- /*! \brief construct a device address with variable-length value `value` */
- TargetPtr(TargetWordSize word_size, std::uint64_t value)
- : value_(TargetVal(word_size.bits(), value)) {}
-
- /*! \brief construct a null address */
- TargetPtr(TargetWordSize word_size, std::nullptr_t value)
- : value_{TargetVal(word_size.bits(), 0)} {}
-
- /*! \brief construct an uninitialized pointer whose word_size can be changed once */
- TargetPtr() = default;
-
- /*! \brief construct a device address using the given TargetVal */
- explicit TargetPtr(const TargetVal& value) : value_{value} {}
-
- /*! \brief destructor */
- ~TargetPtr() {}
-
- /*!
- * \brief get value of pointer
- * \return value of pointer
- */
- TargetVal value() const { return value_; }
-
- /*!
- * \brief cast location to type `T`
- * \return casted result
- */
- template <typename T>
- T cast_to() const {
- return reinterpret_cast<T>(value_.uint64());
- }
-
- /*! \brief check if location is null */
- bool operator==(std::nullptr_t) const { return value_.uint64() == 0; }
-
- /*! \brief check if location is not null */
- bool operator!=(std::nullptr_t) const { return value_.uint64() != 0; }
-
- /*! \brief add an integer to this absolute address to get a larger absolute address */
- TargetPtr operator+(size_t n) const {
- return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() + n);
- }
-
- /*! \brief mutably add an integer to this absolute address */
- TargetPtr& operator+=(size_t n) {
- value_ = TargetVal(value_.width_bits(), value_.uint64() + n);
- return *this;
- }
-
- /*! \brief subtract an integer from this absolute address to get a smaller absolute address */
- TargetPtr operator-(size_t n) const {
- return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() - n);
- }
-
- /*! \brief mutably subtract an integer from this absolute address */
- TargetPtr& operator-=(size_t n) {
- value_ = TargetVal(value_.width_bits(), value_.uint64() - n);
- return *this;
- }
-
- private:
- /*! \brief raw value storing the pointer */
- TargetVal value_;
-
- friend std::ostream& operator<<(std::ostream& os, const TargetPtr& v);
-};
-
-/*!
- * \brief map from symbols to their on-device offsets
- */
-class SymbolMap {
- public:
- /*!
- * \brief default constructor
- */
- SymbolMap() {}
-
- /*!
- * \brief constructor that builds the mapping
- * \param binary contents of binary object file
- * \param toolchain_prefix prefix of compiler toolchain to use
- */
- SymbolMap(const std::string& binary, const std::string& toolchain_prefix,
- TargetWordSize word_size) {
- const auto* f = Registry::Get("tvm_callback_get_symbol_map");
- CHECK(f != nullptr) << "require tvm_callback_get_symbol_map to exist in registry";
- TVMByteArray arr;
- arr.data = &binary[0];
- arr.size = binary.length();
- std::string map_str = (*f)(arr, toolchain_prefix);
- // Parse symbols and addresses from returned string.
- std::stringstream stream;
- stream << map_str;
- std::string name;
- std::uintptr_t addr;
- stream >> name;
- stream >> std::hex >> addr;
- while (stream) {
- map_.emplace(std::make_pair(name, TargetPtr(word_size, addr)));
- stream >> name;
- stream >> std::hex >> addr;
- }
- }
-
- /*!
- * \brief retrieve on-device offset for a symbol name
- * \param name name of the symbol
- * \return on-device offset of the symbol
- */
- TargetPtr operator[](const std::string& name) const {
- auto result = map_.find(name);
- CHECK(result != map_.end()) << "\"" << name << "\" not in symbol map";
- return result->second;
- }
-
- bool HasSymbol(const std::string& name) const { return map_.find(name) != map_.end(); }
-
- void Dump(std::ostream& stream) const {
- for (auto e : map_) {
- stream << "Entry:" << e.first << std::endl;
- }
- }
-
- private:
- /*! \brief backing map */
- std::unordered_map<std::string, TargetPtr> map_;
-};
-
-/*! \brief struct containing start and size of a device memory region */
-struct DevMemRegion {
- /*! \brief section start offset */
- TargetPtr start;
- /*! \brief size of section */
- size_t size;
-};
-
-/*! \brief struct containing section locations and symbol mappings */
-struct BinaryInfo {
- /*! \brief text section region */
- DevMemRegion text_section;
- /*! \brief rodata section region */
- DevMemRegion rodata_section;
- /*! \brief data section region */
- DevMemRegion data_section;
- /*! \brief bss section region */
- DevMemRegion bss_section;
- /*! \brief symbol map to offsets */
- SymbolMap symbol_map;
-};
-
-struct BinaryContents {
- BinaryInfo binary_info;
- std::string text_contents;
- std::string rodata_contents;
- std::string data_contents;
- std::string bss_contents;
-};
-
-/*!
- * \brief upper-aligns value according to specified alignment
- * \param value value to be aligned
- * \param align alignment
- * \return upper-aligned value
- */
-inline size_t UpperAlignValue(size_t value, size_t align) {
- return value + (align - (value % align)) % align;
-}
-
-/*!
- * \brief maps section enums to text
- * \param section section type
- * \return text form of the specified section
- */
-const char* SectionToString(SectionKind section);
-
-/*!
- * \brief links binary by repositioning section addresses
- * \param binary_name input binary filename
- * \param word_size word size on the target machine
- * \param text_start text section address
- * \param rodata_start rodata section address
- * \param data_start data section address
- * \param bss_start bss section address
- * \param stack_end stack section end address
- * \param toolchain_prefix prefix of compiler toolchain to use
- * \return relocated binary file contents
- */
-std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size,
- TargetPtr text_start, TargetPtr rodata_start,
- TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end,
- const std::string& toolchain_prefix);
-
-/*!
- * \brief reads section from binary
- * \param binary input binary contents
- * \param section section type to be read
- * \param toolchain_prefix prefix of compiler toolchain to use
- * \return contents of the section
- */
-std::string ReadSection(const std::string& binary, SectionKind section,
- const std::string& toolchain_prefix);
-
-/*!
- * \brief finds size of the section in the binary
- * \param binary input binary contents
- * \param section section type
- * \param toolchain_prefix prefix of compiler toolchain to use
- * \param word_size word size of the target, for alignment
- * \return size of the section if it exists, 0 otherwise
- */
-size_t GetSectionSize(const std::string& binary_name, SectionKind section,
- const std::string& toolchain_prefix, TargetWordSize word_size);
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MICRO_MICRO_COMMON_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file micro_device_api.cc
- */
-
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/registry.h>
-
-#include "../workspace_pool.h"
-#include "micro_session.h"
-
-namespace tvm {
-namespace runtime {
-/*!
- * \brief device API for uTVM micro devices
- */
-class MicroDeviceAPI final : public DeviceAPI {
- public:
- /*! \brief constructor */
- MicroDeviceAPI() {}
-
- void SetDevice(TVMContext ctx) final {}
-
- void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
- if (kind == kExist) {
- *rv = 1;
- }
- }
-
- void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
- DLDataType type_hint) final {
- ObjectPtr<MicroSession>& session = MicroSession::Current();
- TargetPtr data = session->AllocateInSection(SectionKind::kHeap, nbytes);
- CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap";
- return reinterpret_cast<void*>(new MicroDevSpace{data, session});
- }
-
- void FreeDataSpace(TVMContext ctx, void* ptr) final {
- MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(ptr);
- dev_space->session->FreeInSection(SectionKind::kHeap, dev_space->data);
- delete dev_space;
- }
-
- void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
- TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
- TVMStreamHandle stream) final {
- std::tuple<int, int> type_from_to(ctx_from.device_type, ctx_to.device_type);
- if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) {
- // Copying from the device to the device.
- MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
- MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
- CHECK(from_space->session == to_space->session)
- << "attempt to copy data between different micro sessions (" << from_space->session.get()
- << " != " << to_space->session.get() << ")";
- CHECK(ctx_from.device_id == ctx_to.device_id)
- << "can only copy between the same micro device";
- ObjectPtr<MicroSession>& session = from_space->session;
- // flush all pending tasks to ensure data is consistent
- session->FlushTaskQueue();
- const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
-
- TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset);
- TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset);
-
- std::vector<uint8_t> buffer(size);
- lld->Read(from_dev_addr, static_cast<void*>(buffer.data()), size);
- lld->Write(to_dev_addr, static_cast<void*>(buffer.data()), size);
-
- } else if (type_from_to == std::make_tuple(kDLMicroDev, kDLCPU)) {
- // Reading from the device.
- MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
- ObjectPtr<MicroSession>& session = from_space->session;
- // flush all pending tasks to ensure data is consistent
- session->FlushTaskQueue();
- const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
-
- TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset);
- void* to_host_ptr = GetHostLoc(to, to_offset);
- lld->Read(from_dev_addr, to_host_ptr, size);
-
- } else if (type_from_to == std::make_tuple(kDLCPU, kDLMicroDev)) {
- // Writing to the device.
- MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
- ObjectPtr<MicroSession>& session = to_space->session;
- // flush all pending tasks to ensure data is consistent
- session->FlushTaskQueue();
- const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
-
- void* from_host_ptr = GetHostLoc(from, from_offset);
- TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset);
- lld->Write(to_dev_addr, from_host_ptr, size);
-
- } else {
- LOG(FATAL) << "Expect copy from/to micro device or between micro device\n";
- }
- }
-
- void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
- MicroSession::Current()->FlushTaskQueue();
- }
-
- void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
- CHECK(false) << "the on-device workspace allocator isn't aware of this function";
- ObjectPtr<MicroSession>& session = MicroSession::Current();
-
- TargetPtr data = session->AllocateInSection(SectionKind::kWorkspace, size);
- CHECK(data.value().uint64() != 0)
- << "unable to allocate " << size << " bytes on device workspace";
- return static_cast<void*>(new MicroDevSpace{data, session});
- }
-
- void FreeWorkspace(TVMContext ctx, void* data) final {
- CHECK(false) << "the on-device workspace allocator isn't aware of this function";
- MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(data);
- ObjectPtr<MicroSession>& session = dev_space->session;
- session->FreeInSection(SectionKind::kWorkspace, dev_space->data);
- delete dev_space;
- }
-
- /*!
- * \brief obtain a global singleton of MicroDeviceAPI
- * \return global shared pointer to MicroDeviceAPI
- */
- static MicroDeviceAPI* Global() {
- static MicroDeviceAPI* inst = new MicroDeviceAPI();
- return inst;
- }
-
- private:
- TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { return dev_space->data + offset; }
-
- void* GetHostLoc(const void* ptr, size_t offset) {
- return reinterpret_cast<void*>(reinterpret_cast<std::uintptr_t>(ptr) + offset);
- }
-};
-
-// register device that can be obtained from Python frontend
-TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = MicroDeviceAPI::Global();
- *rv = static_cast<void*>(ptr);
-});
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file micro_module.cc
- */
-
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
-
-#include <string>
-#include <unordered_map>
-
-#include "../pack_args.h"
-#include "low_level_device.h"
-#include "micro_common.h"
-#include "micro_session.h"
-
-namespace tvm {
-namespace runtime {
-/*!
- * \brief module for uTVM micro devices
- */
-class MicroModuleNode final : public ModuleNode {
- public:
- MicroModuleNode() {}
-
- ~MicroModuleNode() {}
-
- const char* type_key() const final { return "micro"; }
-
- PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
-
- /*!
- * \brief initializes module by establishing device connection and loads binary
- * \param binary_path path of the binary to be loaded
- */
- void InitMicroModule(const std::string& binary_path) {
- // std::cout << "[MicroModuleNode::InitMicroModule]" << std::endl;
- // std::cout << " start" << std::endl;
- session_ = MicroSession::Current();
- symbol_map_ = session_->LoadBinary(binary_path, true).symbol_map;
- }
-
- private:
- SymbolMap symbol_map_;
- /*! \brief global session pointer */
- ObjectPtr<MicroSession> session_;
-};
-
-class MicroWrappedFunc {
- public:
- MicroWrappedFunc(ObjectPtr<MicroSession> session, TargetPtr func_ptr) {
- session_ = session;
- func_ptr_ = func_ptr;
- }
-
- void operator()(TVMArgs args, TVMRetValue* rv) const {
- session_->PushToTaskQueue(func_ptr_, args);
- }
-
- private:
- /*! \brief reference to the session for this function (to keep the session alive) */
- ObjectPtr<MicroSession> session_;
- /*! \brief offset of the function to be called */
- TargetPtr func_ptr_;
-};
-
-PackedFunc MicroModuleNode::GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
- TargetPtr func_ptr;
- if (name == tvm::runtime::symbol::tvm_module_main) {
- if (symbol_map_.HasSymbol(tvm::runtime::symbol::tvm_module_main)) {
- func_ptr = symbol_map_[tvm::runtime::symbol::tvm_module_main];
- } else {
- func_ptr = symbol_map_["default_function"];
- }
- } else {
- func_ptr = symbol_map_[name];
- }
- MicroWrappedFunc f(session_, func_ptr);
- return PackedFunc(f);
-}
-
-// register loadfile function to load module from Python frontend
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_micro_dev")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- auto n = make_object<MicroModuleNode>();
- n->InitMicroModule(args[0]);
- *rv = runtime::Module(n);
- });
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file micro_section_allocator.h
- */
-#ifndef TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_
-#define TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_
-
-#include <string>
-#include <unordered_map>
-
-#include "micro_common.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief allocator for an on-device memory section
- */
-class MicroSectionAllocator {
- public:
- /*!
- * \brief constructor that specifies section boundaries
- * \param region location and size of the section on the device
- */
- explicit MicroSectionAllocator(std::string section_name, DevMemRegion region,
- TargetWordSize word_size)
- : section_name_(section_name),
- start_addr_(region.start),
- size_(0),
- capacity_(region.size),
- word_size_(word_size) {
- CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0)
- << "micro section start not aligned to " << word_size.bytes() << " bytes";
- CHECK_EQ(capacity_ % word_size.bytes(), 0)
- << "micro section end not aligned to " << word_size.bytes() << " bytes";
- }
-
- /*!
- * \brief destructor
- */
- ~MicroSectionAllocator() {}
-
- /*!
- * \brief memory allocator
- * \param alloc_size size of allocated memory in bytes
- * \return pointer to allocated memory region in section, nullptr if out of space
- */
- TargetPtr Allocate(size_t size) {
- size_ = UpperAlignValue(size_, word_size_.bytes());
- CHECK(size_ + size < capacity_)
- << "cannot alloc " << size << " bytes in section \"" << section_name_
- << "\" (start_addr=" << start_addr_.cast_to<void*>() << ", used=" << size_
- << ", capacity=" << capacity_ << ")";
- TargetPtr alloc_addr = start_addr_ + size_;
- size_ += size;
- alloc_map_[alloc_addr.value().uint64()] = size;
- return alloc_addr;
- }
-
- /*!
- * \brief free prior allocation from section
- * \param offs offset to allocated memory
- * \note simple allocator scheme, more complex versions will be implemented later
- */
- void Free(TargetPtr addr) {
- CHECK(alloc_map_.find(addr.value().uint64()) != alloc_map_.end())
- << "freed pointer was never allocated";
- alloc_map_.erase(addr.value().uint64());
- if (alloc_map_.empty()) {
- size_ = 0;
- }
- }
-
- /*!
- * \brief start offset of the memory region managed by this allocator
- */
- TargetPtr start_addr() const { return start_addr_; }
-
- /*!
- * \brief current end addr of the space being used in this memory region
- */
- TargetPtr curr_end_addr() const { return start_addr_ + size_; }
-
- /*!
- * \brief end addr of the memory region managed by this allocator
- */
- TargetPtr max_addr() const { return start_addr_ + capacity_; }
-
- /*!
- * \brief size of the section
- */
- size_t size() const { return size_; }
-
- /*!
- * \brief capacity of the section
- */
- size_t capacity() const { return capacity_; }
-
- private:
- /*! \brief name of the section (for debugging) */
- std::string section_name_;
- /*! \brief start address of the section */
- TargetPtr start_addr_;
- /*! \brief current size of the section */
- size_t size_;
- /*! \brief total storage capacity of the section */
- size_t capacity_;
- /*! \brief number of bytes in a word on the target device */
- TargetWordSize word_size_;
- /*! \brief allocation map for allocation sizes */
- std::unordered_map<uint64_t, size_t> alloc_map_;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_
#include "micro_session.h"
-#include <dmlc/thread_local.h>
-#include <tvm/runtime/device_api.h>
+#include <dmlc/logging.h>
+#include <tvm/runtime/crt/rpc_common/framing.h>
+#include <tvm/runtime/crt/rpc_common/session.h>
#include <tvm/runtime/registry.h>
-#include <chrono>
-#include <locale>
+#include <cstdarg>
#include <memory>
-#include <stack>
-#include <tuple>
-#include <vector>
+#include <string>
+#include <utility>
-#include "low_level_device.h"
-#include "target_data_layout_encoder.h"
+#include "../../support/str_escape.h"
+#include "../crt/host/crt_config.h"
+#include "../rpc/rpc_channel.h"
+#include "../rpc/rpc_endpoint.h"
+#include "../rpc/rpc_session.h"
namespace tvm {
namespace runtime {
+namespace micro_rpc {
-struct TVMMicroSessionThreadLocalEntry {
- std::stack<ObjectPtr<MicroSession>> session_stack;
-};
-
-typedef dmlc::ThreadLocalStore<TVMMicroSessionThreadLocalEntry> TVMMicroSessionThreadLocalStore;
-
-ObjectPtr<MicroSession>& MicroSession::Current() {
- TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get();
- CHECK_GT(entry->session_stack.size(), 0) << "No current session";
- return entry->session_stack.top();
-}
-
-void MicroSession::EnterWithScope(ObjectPtr<MicroSession> session) {
- TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get();
- entry->session_stack.push(session);
-}
+class CallbackWriteStream : public WriteStream {
+ public:
+ explicit CallbackWriteStream(PackedFunc fsend) : fsend_{fsend} {}
-void MicroSession::ExitWithScope() {
- TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get();
- CHECK(!entry->session_stack.empty());
- entry->session_stack.pop();
-}
-
-MicroSession::MicroSession(const std::string& comms_method, const std::string& binary_path,
- const std::string& toolchain_prefix, uint64_t text_start,
- size_t text_size, uint64_t rodata_start, size_t rodata_size,
- uint64_t data_start, size_t data_size, uint64_t bss_start,
- size_t bss_size, uint64_t args_start, size_t args_size,
- uint64_t heap_start, size_t heap_size, uint64_t workspace_start,
- size_t workspace_size, uint64_t stack_start, size_t stack_size,
- TargetWordSize word_size, bool thumb_mode, bool use_device_timer,
- const std::string& server_addr, int port, PackedFunc debug_func)
- : toolchain_prefix_(toolchain_prefix),
- word_size_(word_size),
- thumb_mode_(thumb_mode),
- use_device_timer_(use_device_timer),
- batch_args_encoder_(args_size, word_size),
- debug_func_{debug_func} {
- if (comms_method == "host") {
- // TODO(weberlo): move checks to python
- CHECK(text_start == 0 && rodata_start == 0 && data_start == 0 && bss_start == 0 &&
- args_start == 0 && heap_start == 0 && workspace_start == 0 && stack_start == 0)
- << "unable to specify section addresses for host device";
- size_t memory_size = text_size + rodata_size + data_size + bss_size + args_size + heap_size +
- workspace_size + stack_size;
- TargetPtr base_addr;
- low_level_device_ = HostLowLevelDeviceCreate(memory_size, &base_addr);
- CHECK_EQ(base_addr.value().uint64() % word_size.bytes(), 0)
- << "base address not aligned to " << word_size.bytes() << " bytes";
- TargetPtr curr_addr = base_addr;
-
- section_allocators_[0] = std::make_shared<MicroSectionAllocator>("text",
- DevMemRegion{
- .start = curr_addr,
- .size = text_size,
- },
- word_size_);
- curr_addr += text_size;
- section_allocators_[1] = std::make_shared<MicroSectionAllocator>("rodata",
- DevMemRegion{
- .start = curr_addr,
- .size = rodata_size,
- },
- word_size_);
- curr_addr += rodata_size;
- section_allocators_[2] = std::make_shared<MicroSectionAllocator>("data",
- DevMemRegion{
- .start = curr_addr,
- .size = data_size,
- },
- word_size_);
- curr_addr += data_size;
- section_allocators_[3] = std::make_shared<MicroSectionAllocator>("bss",
- DevMemRegion{
- .start = curr_addr,
- .size = bss_size,
- },
- word_size_);
- curr_addr += bss_size;
- section_allocators_[4] = std::make_shared<MicroSectionAllocator>("args",
- DevMemRegion{
- .start = curr_addr,
- .size = args_size,
- },
- word_size_);
- curr_addr += args_size;
- section_allocators_[5] = std::make_shared<MicroSectionAllocator>("heap",
- DevMemRegion{
- .start = curr_addr,
- .size = heap_size,
- },
- word_size_);
- curr_addr += heap_size;
- section_allocators_[6] = std::make_shared<MicroSectionAllocator>("workspace",
- DevMemRegion{
- .start = curr_addr,
- .size = workspace_size,
- },
- word_size_);
- curr_addr += workspace_size;
- section_allocators_[7] = std::make_shared<MicroSectionAllocator>("stack",
- DevMemRegion{
- .start = curr_addr,
- .size = stack_size,
- },
- word_size_);
- curr_addr += stack_size;
- } else if (comms_method == "openocd") {
- low_level_device_ = OpenOCDLowLevelDeviceCreate(server_addr, port);
- section_allocators_[0] =
- std::make_shared<MicroSectionAllocator>("text",
- DevMemRegion{
- .start = TargetPtr(word_size_, text_start),
- .size = text_size,
- },
- word_size_);
- section_allocators_[1] =
- std::make_shared<MicroSectionAllocator>("rodata",
- DevMemRegion{
- .start = TargetPtr(word_size_, rodata_start),
- .size = rodata_size,
- },
- word_size_);
- section_allocators_[2] =
- std::make_shared<MicroSectionAllocator>("data",
- DevMemRegion{
- .start = TargetPtr(word_size_, data_start),
- .size = data_size,
- },
- word_size_);
- section_allocators_[3] =
- std::make_shared<MicroSectionAllocator>("bss",
- DevMemRegion{
- .start = TargetPtr(word_size_, bss_start),
- .size = bss_size,
- },
- word_size_);
- section_allocators_[4] =
- std::make_shared<MicroSectionAllocator>("args",
- DevMemRegion{
- .start = TargetPtr(word_size_, args_start),
- .size = args_size,
- },
- word_size_);
- section_allocators_[5] =
- std::make_shared<MicroSectionAllocator>("heap",
- DevMemRegion{
- .start = TargetPtr(word_size_, heap_start),
- .size = heap_size,
- },
- word_size_);
- section_allocators_[6] =
- std::make_shared<MicroSectionAllocator>("workspace",
- DevMemRegion{
- .start = TargetPtr(word_size_, workspace_start),
- .size = workspace_size,
- },
- word_size_);
- section_allocators_[7] =
- std::make_shared<MicroSectionAllocator>("stack",
- DevMemRegion{
- .start = TargetPtr(word_size_, stack_start),
- .size = stack_size,
- },
- word_size_);
- } else {
- LOG(FATAL) << "unsupported micro low-level device";
+ ssize_t Write(const uint8_t* data, size_t data_size_bytes) override {
+ TVMByteArray bytes;
+ bytes.data = (const char*)data;
+ bytes.size = data_size_bytes;
+ int64_t n = fsend_(bytes);
+ return n;
}
- TargetPtr args_start_addr = GetAllocator(SectionKind::kArgs)->start_addr();
- batch_args_encoder_.set_start_addr(args_start_addr);
-
- runtime_symbol_map_ = LoadBinary(binary_path, false).symbol_map;
-
- // Patch pointers to define the bounds of the workspace section and the word
- // size (for allocation alignment).
- std::shared_ptr<MicroSectionAllocator> ws_allocator = GetAllocator(SectionKind::kWorkspace);
- DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_allocator->start_addr());
- DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_allocator->max_addr());
- if (word_size.bytes() == 4) {
- DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint32_t(word_size.bytes()));
- } else if (word_size.bytes() == 8) {
- DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint64_t(word_size.bytes()));
- } else {
- CHECK(false) << "Unsupported word size unexpectedly here";
- }
-}
+ void PacketDone(bool is_valid) override {}
-MicroSession::~MicroSession() {
- for (size_t i = 0; i < static_cast<size_t>(SectionKind::kNumKinds); i++) {
- section_allocators_[i] = nullptr;
- }
- low_level_device_ = nullptr;
-}
-
-void MicroSession::PushToTaskQueue(TargetPtr func_ptr, const TVMArgs& args) {
- if (thumb_mode_) {
- // TODO(areusch): should be |=
- func_ptr += 1;
- }
- TargetVal func_dev_addr = func_ptr.value();
+ private:
+ PackedFunc fsend_;
+};
- std::tuple<TargetPtr, TargetPtr> arg_field_addrs = EncoderAppend(&batch_args_encoder_, args);
- TargetVal arg_values_dev_addr{std::get<0>(arg_field_addrs).value()};
- TargetVal arg_type_codes_dev_addr{std::get<1>(arg_field_addrs).value()};
+class MicroTransportChannel : public RPCChannel {
+ public:
+ MicroTransportChannel(PackedFunc fsend, PackedFunc frecv)
+ : write_stream_{fsend},
+ framer_{&write_stream_},
+ receive_buffer_{new uint8_t[TVM_CRT_MAX_PACKET_SIZE_BYTES], TVM_CRT_MAX_PACKET_SIZE_BYTES},
+ session_{0x5b, &framer_, &receive_buffer_, &HandleMessageReceivedCb, this},
+ unframer_{session_.Receiver()},
+ did_receive_message_{false},
+ frecv_{frecv},
+ message_buffer_{nullptr} {}
+
+ size_t ReceiveUntil(TypedPackedFunc<bool(void)> pf) {
+ size_t bytes_received = 0;
+ if (pf()) {
+ return 0;
+ }
- task_queue_.push_back(DevTask{.func = func_dev_addr,
- .arg_values = arg_values_dev_addr,
- .arg_type_codes = arg_type_codes_dev_addr,
- .num_args = args.num_args});
+ for (;;) {
+ while (pending_chunk_.size() > 0) {
+ size_t bytes_consumed = 0;
+ int unframer_error = unframer_.Write((const uint8_t*)pending_chunk_.data(),
+ pending_chunk_.size(), &bytes_consumed);
+
+ CHECK(bytes_consumed <= pending_chunk_.size());
+ pending_chunk_ = pending_chunk_.substr(bytes_consumed);
+ bytes_received += bytes_consumed;
+ if (unframer_error < 0) {
+ LOG(ERROR) << "unframer got error code: " << unframer_error;
+ } else {
+ if (pf()) {
+ return bytes_received;
+ }
+ }
+ }
- if (task_queue_.size() == MicroSession::kTaskQueueCapacity) {
- FlushTaskQueue();
+ std::string chunk = frecv_(128);
+ pending_chunk_ = chunk;
+ CHECK(pending_chunk_.size() != 0) << "zero-size chunk encountered";
+ CHECK_GT(pending_chunk_.size(), 0);
+ }
}
-}
-void MicroSession::FlushTaskQueue() {
- if (task_queue_.size() == 0) {
- // nothing to run
- return;
- }
- if (word_size_.bytes() == 4) {
- FlushTaskQueuePriv<StructUTVMTask32>();
- } else if (word_size_.bytes() == 8) {
- FlushTaskQueuePriv<StructUTVMTask64>();
+ void StartSession() {
+ CHECK_EQ(kTvmErrorNoError, session_.Initialize());
+ CHECK_EQ(kTvmErrorNoError, session_.StartSession());
+ ReceiveUntil([this]() -> bool { return session_.IsEstablished(); });
}
-}
-template <typename T>
-void MicroSession::FlushTaskQueuePriv() {
- std::vector<T> prepped_tasks;
- for (const auto& task : task_queue_) {
- prepped_tasks.push_back(T(task));
- }
+ size_t Send(const void* data, size_t size) override {
+ const uint8_t* data_bytes = static_cast<const uint8_t*>(data);
+ ssize_t ret = session_.SendMessage(MessageType::kNormal, data_bytes, size);
+ CHECK(ret == 0) << "SendMessage returned " << ret;
- // Flush `args` to device memory.
- low_level_device()->Write(batch_args_encoder_.start_addr(),
- reinterpret_cast<void*>(batch_args_encoder_.data()),
- batch_args_encoder_.buf_size());
-
- // Flush `tasks` to device memory.
- TargetPtr dev_tasks_addr = runtime_symbol_map_["utvm_tasks"];
- low_level_device()->Write(dev_tasks_addr, reinterpret_cast<void*>(prepped_tasks.data()),
- prepped_tasks.size() * sizeof(T));
- DevSymbolWrite<uint32_t>(runtime_symbol_map_, "utvm_num_tasks", prepped_tasks.size());
-
- TargetPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"];
- TargetPtr utvm_done_addr = runtime_symbol_map_["UTVMDone"];
- if (thumb_mode_) {
- // TODO(areusch): should be |=
- utvm_init_addr += 1;
+ return size;
}
- bool did_debug = false;
- if (debug_func_ != nullptr) {
- TVMRetValue rv = debug_func_();
- if (rv.type_code() == kTVMNullptr) {
- did_debug = true;
- } else {
- did_debug = static_cast<bool>(rv);
- }
+ size_t Recv(void* data, size_t size) override {
+ size_t num_bytes_recv = 0;
+ while (num_bytes_recv < size) {
+ if (message_buffer_ != nullptr) {
+ num_bytes_recv += message_buffer_->Read(static_cast<uint8_t*>(data), size);
+ if (message_buffer_->ReadAvailable() == 0) {
+ message_buffer_ = nullptr;
+ session_.ClearReceiveBuffer();
+ }
+ if (num_bytes_recv == size) {
+ CHECK(message_buffer_ == nullptr || message_buffer_->ReadAvailable() > 0);
+ return num_bytes_recv;
+ }
+ }
- if (did_debug && !use_device_timer_) {
- LOG(INFO) << "NOTE: when debugging and use_device_timer == false, reported execution time "
- << "will be inaccurate!";
+ did_receive_message_ = false;
+ ReceiveUntil([this]() -> bool { return did_receive_message_; });
}
- }
- if (!did_debug) {
- std::chrono::time_point<std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin,
- tend;
- tbegin = std::chrono::high_resolution_clock::now();
- low_level_device()->Execute(utvm_init_addr, utvm_done_addr);
- tend = std::chrono::high_resolution_clock::now();
- if (!use_device_timer_) {
- last_batch_time_ +=
- std::chrono::duration_cast<std::chrono::duration<double>>(tend - tbegin).count() * 1000;
- }
+ return num_bytes_recv;
}
- // Check if there was an error during execution. If so, log it.
- CheckDeviceError();
-
- if (use_device_timer_) {
- uint64_t sum = 0;
- std::vector<uint32_t> times;
- times.resize(task_queue_.size());
- low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(),
- task_queue_.size() * sizeof(uint32_t));
- int i = 0;
- for (uint32_t time : times) {
- LOG(INFO) << "Time " << i++ << ": " << time;
- sum += time;
- }
- last_batch_time_ += static_cast<double>(sum) / 1e3;
- } else {
- // TODO(weberlo): Reading internal data structure is hacky.
- uint64_t sum = 0;
- std::vector<uint32_t> times;
- times.resize(task_queue_.size());
- low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(),
- task_queue_.size() * sizeof(uint32_t));
- for (uint32_t time : times) {
- sum += time;
+ FrameBuffer* GetReceivedMessage() {
+ if (did_receive_message_) {
+ did_receive_message_ = false;
+ return message_buffer_;
}
- last_batch_cycles_ += static_cast<double>(sum);
- }
- batch_args_encoder_.Clear();
- task_queue_.clear();
-}
-
-BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_dylib_pointers) {
- DevMemRegion text_section;
- DevMemRegion rodata_section;
- DevMemRegion data_section;
- DevMemRegion bss_section;
-
- text_section.size =
- GetSectionSize(binary_path, SectionKind::kText, toolchain_prefix_, word_size_);
- rodata_section.size =
- GetSectionSize(binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_);
- data_section.size =
- GetSectionSize(binary_path, SectionKind::kData, toolchain_prefix_, word_size_);
- bss_section.size = GetSectionSize(binary_path, SectionKind::kBss, toolchain_prefix_, word_size_);
-
- text_section.start = AllocateInSection(SectionKind::kText, text_section.size);
- rodata_section.start = AllocateInSection(SectionKind::kRodata, rodata_section.size);
- data_section.start = AllocateInSection(SectionKind::kData, data_section.size);
- bss_section.start = AllocateInSection(SectionKind::kBss, bss_section.size);
-
- std::string relocated_bin = RelocateBinarySections(
- binary_path, word_size_, text_section.start, rodata_section.start, data_section.start,
- bss_section.start, GetAllocator(SectionKind::kStack)->max_addr(), toolchain_prefix_);
- std::string text_contents = ReadSection(relocated_bin, SectionKind::kText, toolchain_prefix_);
- std::string rodata_contents = ReadSection(relocated_bin, SectionKind::kRodata, toolchain_prefix_);
- std::string data_contents = ReadSection(relocated_bin, SectionKind::kData, toolchain_prefix_);
- std::string bss_contents = ReadSection(relocated_bin, SectionKind::kBss, toolchain_prefix_);
-
- low_level_device_->Write(text_section.start, &text_contents[0], text_section.size);
- low_level_device_->Write(rodata_section.start, &rodata_contents[0], rodata_section.size);
- low_level_device_->Write(data_section.start, &data_contents[0], data_section.size);
- low_level_device_->Write(bss_section.start, &bss_contents[0], bss_section.size);
- SymbolMap symbol_map{relocated_bin, toolchain_prefix_, word_size_};
-
- if (patch_dylib_pointers) {
- // Patch device lib pointers.
- PatchImplHole(symbol_map, "TVMBackendAllocWorkspace");
- PatchImplHole(symbol_map, "TVMBackendFreeWorkspace");
- PatchImplHole(symbol_map, "TVMAPISetLastError");
+ return nullptr;
}
- return BinaryInfo{
- .text_section = text_section,
- .rodata_section = rodata_section,
- .data_section = data_section,
- .bss_section = bss_section,
- .symbol_map = symbol_map,
- };
-}
-
-std::tuple<TargetPtr, TargetPtr> MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder,
- const TVMArgs& args) {
- const int* type_codes = args.type_codes;
- int num_args = args.num_args;
-
- auto tvm_vals_alloc = encoder->Alloc<TVMValue>(num_args);
- auto type_codes_alloc = encoder->Alloc<const int>(num_args);
-
- for (int i = 0; i < num_args; i++) {
- switch (type_codes[i]) {
- case kTVMNDArrayHandle:
- case kTVMDLTensorHandle: {
- DLTensor* base_arr_handle = args[i];
- // All uTVM arrays store a `MicroDevSpace` struct in their `data` field,
- // which wraps the actual data and stores a reference to the session, in
- // order to prevent premature session destruction.
- void* old_data = base_arr_handle->data;
- // Mutate the array to unwrap the `data` field.
- MicroDevSpace* dev_arr_ptr = reinterpret_cast<MicroDevSpace*>(old_data);
- base_arr_handle->data = reinterpret_cast<void*>(dev_arr_ptr->data.value().uint64());
- // Now, encode the unwrapped version.
- void* arr_ptr = nullptr;
- if (word_size_.bytes() == 4) {
- arr_ptr = EncoderAppend<TVMArray32>(encoder, *base_arr_handle).cast_to<void*>();
- } else if (word_size_.bytes() == 8) {
- arr_ptr = EncoderAppend<TVMArray64>(encoder, *base_arr_handle).cast_to<void*>();
- }
- // And restore the original wrapped version.
- base_arr_handle->data = old_data;
+ private:
+ static void HandleMessageReceivedCb(void* context, MessageType message_type, FrameBuffer* buf) {
+ static_cast<MicroTransportChannel*>(context)->HandleMessageReceived(message_type, buf);
+ }
- TVMValue val;
- val.v_handle = arr_ptr;
- tvm_vals_alloc->WriteValue(val);
+ void HandleMessageReceived(MessageType message_type, FrameBuffer* buf) {
+ size_t message_size_bytes;
+ switch (message_type) {
+ case MessageType::kStartSessionInit:
+ case MessageType::kStartSessionReply:
break;
- }
- // TODO(weberlo): Implement `double` and `int64` case.
- case kDLFloat:
- case kDLInt:
- case kDLUInt:
- default:
- LOG(FATAL) << "unsupported type code for writing args: " << type_codes[i];
- break;
- }
- }
- type_codes_alloc->WriteArray(type_codes, num_args);
- encoder->CheckUnfilledAllocs();
- return std::make_tuple(tvm_vals_alloc->start_addr(), type_codes_alloc->start_addr());
-}
-template <typename T>
-TargetPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr) {
- // `shape` and `strides` are stored on the host, so we need to write them to
- // the device first. The `data` field is already allocated on the device and
- // is a device pointer, so we don't need to write it.
- auto shape_alloc = encoder->Alloc<int64_t>(arr.ndim);
- shape_alloc->WriteArray(arr.shape, arr.ndim);
- TargetPtr shape_dev_addr = shape_alloc->start_addr();
- TargetPtr strides_dev_addr = TargetPtr(word_size_, nullptr);
- if (arr.strides != nullptr) {
- auto stride_alloc = encoder->Alloc<int64_t>(arr.ndim);
- stride_alloc->WriteArray(arr.strides, arr.ndim);
- strides_dev_addr = stride_alloc->start_addr();
- }
+ case MessageType::kTerminateSession:
+ LOG(FATAL) << "SessionTerminatedError: remote side has probably reset";
+ break;
- T dev_arr(TargetVal{word_size_.bits(), reinterpret_cast<uint64_t>(arr.data)}, arr.ctx, arr.ndim,
- arr.dtype, shape_dev_addr.value(), strides_dev_addr.value(),
- TargetVal{word_size_.bits(), arr.byte_offset});
- CHECK(dev_arr.ctx.device_type == static_cast<DLDeviceType>(kDLMicroDev))
- << "attempt to write DLTensor with non-micro device type";
- // Update the device type to CPU, because from the microcontroller's
- // perspective, it is.
- dev_arr.ctx.device_type = DLDeviceType::kDLCPU;
-
- auto tvm_arr_alloc = encoder->Alloc<T>();
- tvm_arr_alloc->WriteValue(dev_arr);
- return tvm_arr_alloc->start_addr();
-}
+ case MessageType::kLog:
+ uint8_t message[1024];
+ message_size_bytes = buf->ReadAvailable();
+ if (message_size_bytes == 0) {
+ return;
+ } else if (message_size_bytes > sizeof(message) - 1) {
+ LOG(ERROR) << "Remote log message is too long to display: " << message_size_bytes
+ << " bytes";
+ return;
+ }
-// TODO(weberlo): switch over entirely to error codes that expand to error
-// messages on the host side.
-void MicroSession::CheckDeviceError() {
- int32_t last_error = DevSymbolRead<int32_t>(runtime_symbol_map_, "utvm_last_error");
+ CHECK_EQ(buf->Read(message, sizeof(message) - 1), message_size_bytes);
+ message[message_size_bytes] = 0;
+ LOG(INFO) << "remote: " << message;
+ session_.ClearReceiveBuffer();
+ return;
- if (last_error) {
- if (!use_device_timer_ &&
- (last_error == UTVM_ERR_TIMER_OVERFLOW || last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) {
- // these errors don't matter if we're not using the on-device timer
- return;
- }
- std::string err_msg;
- switch (last_error) {
- case UTVM_ERR_NOT_FINISHED:
- err_msg = "execution timed out";
- break;
- case UTVM_ERR_TIMER_NOT_IMPLEMENTED:
- err_msg = "timer is not implemented for the target device";
- break;
- case UTVM_ERR_TIMER_OVERFLOW:
- // TODO(weberlo): this should be remedied by using interrupts to accumulate the
- // timer into a larger datatype (ARM timers are only 24 bits)
- err_msg = "timer overflowed during execution";
- break;
- case UTVM_ERR_WS_DOUBLE_FREE:
- err_msg = "free called with no active workspace allocations";
- break;
- case UTVM_ERR_WS_OUT_OF_SPACE:
- err_msg = "ran out of space in workspace section";
- break;
- case UTVM_ERR_WS_TOO_MANY_ALLOCS:
- err_msg = "exceeded number of allocs the runtime can keep track of";
- break;
- case UTVM_ERR_WS_ZERO_SIZE_ALLOC:
- err_msg = "attempt to allocate scratchpad of size zero";
- break;
- case UTVM_ERR_WS_UNALIGNED_START:
- err_msg = "start of workspace section is not word-aligned";
- break;
- case UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE:
- err_msg = "scratchpad allocation size is not a multiple of the word size";
- break;
- default:
- err_msg = "unknown error code";
+ case MessageType::kNormal:
+ did_receive_message_ = true;
+ message_buffer_ = buf;
break;
}
- LOG(FATAL) << "error during micro function execution:\n"
- << " error ID: " << std::dec << last_error << std::endl
- << " error message: " << err_msg;
- }
-}
-
-void MicroSession::PatchImplHole(const SymbolMap& symbol_map, const std::string& func_name) {
- TargetPtr runtime_impl_addr = runtime_symbol_map_[func_name];
- if (thumb_mode_) {
- runtime_impl_addr += 1;
}
- std::ostringstream func_name_underscore;
- func_name_underscore << func_name << "_";
- DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr);
-}
-std::string MicroSession::ReadString(TargetPtr str_addr) {
- std::ostringstream result;
- const size_t buf_size = 256;
- std::vector<char> buf(buf_size, 0);
- size_t i = buf_size;
- while (i == buf_size) {
- low_level_device()->Read(str_addr, buf.data(), buf_size);
- i = 0;
- while (i < buf_size) {
- if (buf[i] == 0) break;
- result << buf[i];
- i++;
- }
- str_addr = str_addr + i;
- }
- return result.str();
-}
-
-TargetPtr MicroSession::AllocateInSection(SectionKind type, size_t size) {
- return GetAllocator(type)->Allocate(size);
-}
+ CallbackWriteStream write_stream_;
+ Framer framer_;
+ FrameBuffer receive_buffer_;
+ Session session_;
+ Unframer unframer_;
+ bool did_receive_message_;
+ PackedFunc frecv_;
+ FrameBuffer* message_buffer_;
+ std::string pending_chunk_;
+};
-void MicroSession::FreeInSection(SectionKind type, TargetPtr addr) {
- return GetAllocator(type)->Free(addr);
-}
+TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue* rv) {
+ MicroTransportChannel* micro_channel = new MicroTransportChannel(args[1], args[2]);
+ micro_channel->StartSession();
+ std::unique_ptr<RPCChannel> channel(micro_channel);
+ auto ep = RPCEndpoint::Create(std::move(channel), args[0], "");
+ auto sess = CreateClientSession(ep);
+ *rv = CreateRPCSessionModule(sess);
+});
-template <typename T>
-T MicroSession::DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol) {
- TargetPtr sym_addr = symbol_map[symbol];
- T result;
- low_level_device()->Read(sym_addr, &result, sizeof(T));
- return result;
-}
+} // namespace micro_rpc
+} // namespace runtime
+} // namespace tvm
-void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol,
- const TargetPtr& ptr) {
- if (word_size_.bytes() == 4) {
- DevSymbolWrite(symbol_map, symbol, ptr.value().uint32());
- } else if (word_size_.bytes() == 8) {
- DevSymbolWrite(symbol_map, symbol, ptr.value().uint64());
- } else {
- CHECK(false) << "Unsupported word size unexpectedly here";
- }
-}
+extern "C" {
-template <typename T>
-void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol,
- const T& value) {
- TargetPtr sym_addr = symbol_map[symbol];
- low_level_device()->Write(sym_addr, &value, sizeof(T));
+void TVMLogf(const char* fmt, ...) {
+ va_list args;
+ char msg_buf[256];
+ va_start(args, fmt);
+ vsnprintf(msg_buf, sizeof(msg_buf), fmt, args);
+ va_end(args);
+ LOG(INFO) << msg_buf;
}
-PackedFunc MicroSession::GetFunction(const std::string& name,
- const ObjectPtr<Object>& sptr_to_self) {
- if (name == "enter") {
- return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- MicroSession::EnterWithScope(GetObjectPtr<MicroSession>(this));
- });
- } else if (name == "exit") {
- return PackedFunc(
- [sptr_to_self](TVMArgs args, TVMRetValue* rv) { MicroSession::ExitWithScope(); });
- // TODO(weberlo): add a `clear_batch_timer` func
- } else if (name == "get_last_batch_time") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchTime(); });
- // TODO(weberlo): remove this func
- } else if (name == "get_last_batch_cycles") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchCycles(); });
- } else {
- return PackedFunc();
- }
+void TVMPlatformAbort(int error_code) { CHECK(false) << "TVMPlatformAbort: " << error_code; }
}
-
-TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator").set_body([](TVMArgs args, TVMRetValue* rv) {
- PackedFunc pf = args[0];
- TVMContext ctx = args[1];
- uint64_t number = args[2];
- uint64_t repeat = args[3];
-
- auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue* rv) mutable {
- TVMRetValue temp;
- std::ostringstream os;
-
- for (unsigned int i = 0; i < repeat; ++i) {
- // start timing
- CHECK(number < MicroSession::kTaskQueueCapacity)
- << "`number` must be less than uTVM task queue capacity";
- for (unsigned int j = 0; j < number; ++j) {
- pf.CallPacked(args, &temp);
- }
- ObjectPtr<MicroSession> session = MicroSession::Current();
- DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
- double time_per_batch = session->GetLastBatchTime() / number;
- os.write(reinterpret_cast<char*>(&time_per_batch), sizeof(time_per_batch));
- }
- std::string blob = os.str();
- TVMByteArray arr;
- arr.size = blob.length();
- arr.data = blob.data();
- // return the time.
- *rv = arr;
- };
- *rv = PackedFunc(ftimer);
-});
-
-// create micro session and low-level device from Python frontend
-TVM_REGISTER_GLOBAL("micro._CreateSession").set_body([](TVMArgs args, TVMRetValue* rv) {
- const std::string& comms_method = args[0];
- const std::string& binary_path = args[1];
- const std::string& toolchain_prefix = args[2];
- uint64_t text_start = args[3];
- size_t text_size = uint64_t(args[4]);
- uint64_t rodata_start = args[5];
- size_t rodata_size = uint64_t(args[6]);
- uint64_t data_start = args[7];
- size_t data_size = uint64_t(args[8]);
- uint64_t bss_start = args[9];
- size_t bss_size = uint64_t(args[10]);
- uint64_t args_start = args[11];
- size_t args_size = uint64_t(args[12]);
- uint64_t heap_start = args[13];
- size_t heap_size = uint64_t(args[14]);
- uint64_t workspace_start = args[15];
- size_t workspace_size = uint64_t(args[16]);
- uint64_t stack_start = args[17];
- size_t stack_size = uint64_t(args[18]);
- TargetWordSize word_size{uint64_t(args[19])};
- bool thumb_mode = args[20];
- bool use_device_timer = args[21];
- const std::string& server_addr = args[22];
- int port = args[23];
- PackedFunc debug_func = args[24];
- ObjectPtr<MicroSession> session = make_object<MicroSession>(
- comms_method, binary_path, toolchain_prefix, text_start, text_size, rodata_start, rodata_size,
- data_start, data_size, bss_start, bss_size, args_start, args_size, heap_start, heap_size,
- workspace_start, workspace_size, stack_start, stack_size, word_size, thumb_mode,
- use_device_timer, server_addr, port, debug_func);
- *rv = Module(session);
-});
-
-} // namespace runtime
-} // namespace tvm
#ifndef TVM_RUNTIME_MICRO_MICRO_SESSION_H_
#define TVM_RUNTIME_MICRO_MICRO_SESSION_H_
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/registry.h>
-
-#include <memory>
-#include <string>
-#include <tuple>
-#include <unordered_map>
-#include <vector>
-
-#include "low_level_device.h"
-#include "micro_common.h"
-#include "micro_section_allocator.h"
-#include "target_data_layout_encoder.h"
-
-namespace tvm {
-namespace runtime {
-
-struct DevTask;
-
-/*!
- * \brief session for facilitating micro device interaction
- */
-class MicroSession : public ModuleNode {
- public:
- /*!
- * \brief Get member function to front-end
- * \param name The name of the function.
- * \param sptr_to_self The pointer to the module node.
- * \return The corresponding member function.
- */
- virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
-
- // todo having this decoupled from the value in utvm_runtime.c gives me stress dreams
- static const size_t kTaskQueueCapacity = 20;
-
- /*!
- * \return The type key of the executor.
- */
- const char* type_key() const final { return "MicroSession"; }
-
- /*!
- * \brief creates session by setting up a low-level device and initting allocators for it
- * \param comms_method method of communication with the device (e.g., "openocd")
- * \param binary_path file system path to the runtime binary
- * \param toolchain_prefix GCC toolchain prefix
- * \param text_start text section start address
- * \param text_size text section size
- * \param rodata_start text section start address
- * \param rodata_size rodata section size
- * \param data_start data section start address
- * \param data_size data section size
- * \param bss_start bss section start address
- * \param bss_size bss section size
- * \param args_start args section start address
- * \param args_size args section size
- * \param heap_start heap section start address
- * \param heap_size heap section size
- * \param workspace_start workspace section start address
- * \param workspace_size workspace section size
- * \param stack_start stack section start address
- * \param stack_size stack section size
- * \param word_size_bytes number of bytes in a word on the target device
- * \param thumb_mode whether the target device requires a thumb-mode bit on function addresses
- * \param server_addr address of the OpenOCD server to connect to (if `comms_method == "openocd"`)
- * \param port port of the OpenOCD server to connect to (if `comms_method == "openocd"`)
- */
- MicroSession(const std::string& comms_method, const std::string& binary_path,
- const std::string& toolchain_prefix, uint64_t text_start, size_t text_size,
- uint64_t rodata_start, size_t rodata_size, uint64_t data_start, size_t data_size,
- uint64_t bss_start, size_t bss_size, uint64_t args_start, size_t args_size,
- uint64_t heap_start, size_t heap_size, uint64_t workspace_start,
- size_t workspace_size, uint64_t stack_start, size_t stack_size,
- TargetWordSize word_size, bool thumb_mode, bool use_device_timer,
- const std::string& server_addr, int port, PackedFunc debug_func);
-
- /*!
- * \brief destructor
- */
- ~MicroSession();
-
- static ObjectPtr<MicroSession>& Current();
-
- /*!
- * \brief sets up runtime metadata for `func` and copies arguments for on-device execution
- * \param func address of the function to be executed
- * \param args args to the packed function
- * \return elapsed time during function execution on the device
- */
- void PushToTaskQueue(TargetPtr func, const TVMArgs& args);
-
- /*!
- * \brief serialize runtime metadata to the device for enqueued tasks and execute
- * \return elapsed time during function execution on the device
- */
- void FlushTaskQueue();
-
- /*!
- * \brief TODO
- */
- template <typename T>
- void FlushTaskQueuePriv();
-
- /*!
- * \brief loads binary onto device
- * \param binary_path path to binary object file
- * \param patch_dylib_pointers whether to patch runtime API function pointers
- * \return info about loaded binary
- */
- BinaryInfo LoadBinary(const std::string& binary_path, bool patch_dylib_pointers);
-
- /*!
- * \brief allocate memory in section
- * \param type type of section to allocate in
- * \param size size of allocated memory in bytes
- * \return pointer to allocated memory region in section, nullptr if out of space
- */
- TargetPtr AllocateInSection(SectionKind type, size_t size);
-
- /*!
- * \brief free prior allocation from section
- * \param type type of section to allocate in
- * \param addr device address of allocated memory
- */
- void FreeInSection(SectionKind type, TargetPtr addr);
-
- /*!
- * \brief read string from device to host
- * \param str_addr device address of first character of string
- * \return host copy of device string that was read
- */
- std::string ReadString(TargetPtr str_addr);
-
- /*!
- * \brief read value of symbol from device memory
- * \param symbol_map symbol map to read location of symbol from
- * \param symbol name of symbol being read from
- * \return value at symbol in memory
- */
- template <typename T>
- T DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol);
-
- /*!
- * \brief write pointer value into device memory corresponding to symbol
- * \param symbol_map symbol map to read location of symbol from
- * \param symbol name of symbol being written to
- * \param ptr pointer value to write into symbol
- */
- void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const TargetPtr& ptr);
-
- /*!
- * \brief write value into device memory corresponding to symbol
- * \param symbol_map symbol map to read location of symbol from
- * \param symbol name of symbol being written to
- * \param value value being written into symbol
- */
- template <typename T>
- void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value);
-
- /*!
- * \brief returns low-level device pointer
- * \note assumes low-level device has been initialized
- */
- const std::shared_ptr<LowLevelDevice>& low_level_device() const {
- CHECK(low_level_device_ != nullptr) << "attempt to get uninitialized low-level device";
- return low_level_device_;
- }
-
- const double GetLastBatchTime() {
- double result = last_batch_time_;
- last_batch_time_ = 0.0;
- return result;
- }
-
- const double GetLastBatchCycles() {
- double result = last_batch_cycles_;
- last_batch_cycles_ = 0.0;
- return result;
- }
-
- private:
- /*! \brief low-level device pointer */
- std::shared_ptr<LowLevelDevice> low_level_device_;
- /*! \brief prefix for binary names in target compiler toolchain */
- std::string toolchain_prefix_;
- /*! \brief array of memory allocators for each on-device section */
- std::shared_ptr<MicroSectionAllocator>
- section_allocators_[static_cast<size_t>(SectionKind::kNumKinds)];
- /*! \brief number of bytes in a word on the target device */
- TargetWordSize word_size_;
- /*! \brief whether the target device requires a thumb-mode bit on function addresses
- *
- * ARM and other manufacturers use the lowest bit of a function address to determine
- * whether it's a "thumb mode" function. The Thumb ISA is more restricted, but
- * results in more compact binaries.
- */
- bool thumb_mode_;
- /*! \brief TODO */
- bool use_device_timer_;
- /*! \brief symbol map for the device runtime */
- SymbolMap runtime_symbol_map_;
- /*! \brief TODO */
- std::vector<DevTask> task_queue_;
- // TODO(weberlo): we don't even need an allocator mechanism for the args
- // section. there's only ever one allocation.
- /*! \brief TODO hack */
- TargetDataLayoutEncoder batch_args_encoder_;
- /*! \brief TODO hack */
- double last_batch_time_;
- /*! \brief TODO hack */
- double last_batch_cycles_;
- /*! \brief the debug function invoked to launch gdb */
- PackedFunc debug_func_;
-
- /*!
- * \brief patches a function pointer in this module to an implementation
- * \param func_name name of the function pointer being patched
- */
- void PatchImplHole(const SymbolMap& symbol_map, const std::string& func_name);
-
- /*!
- * \brief appends arguments to the host-side buffer of `encoder`
- * \param encoder encoder being used to append `args`
- * \param args args to be appended
- * \return device address of the allocated args
- */
- std::tuple<TargetPtr, TargetPtr> EncoderAppend(TargetDataLayoutEncoder* encoder,
- const TVMArgs& args);
-
- /*!
- * \brief appends a `DLTensor` to the host-side buffer of `encoder`
- * \param encoder encoder being used to append `arr`
- * \param arr DLTensor to be appended
- * \return device address of the allocated `DLTensor`
- */
- template <typename T>
- TargetPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr);
-
- /*!
- * \brief checks and logs if there was an error during the device's most recent execution
- */
- void CheckDeviceError();
-
- /*!
- * \brief returns section allocator corresponding to the given section kind
- * \param kind kind of target section
- * \return shared pointer to section allocator
- */
- std::shared_ptr<MicroSectionAllocator> GetAllocator(SectionKind kind) {
- return section_allocators_[static_cast<size_t>(kind)];
- }
-
- /*!
- * \brief Push a new session context onto the thread-local stack.
- * The session on top of the stack is used as the current global session.
- */
- static void EnterWithScope(ObjectPtr<MicroSession> session);
-
- /*!
- * \brief Pop a session off the thread-local context stack,
- * restoring the previous session as the current context.
- */
- static void ExitWithScope();
-};
-
-/*!
- * \brief a device memory region associated with the session that allocated it
- *
- * We use this to store a reference to the session in each allocated object and
- * only deallocate the session once there are no more references to it.
- */
-struct MicroDevSpace {
- /*! \brief data being wrapped */
- TargetPtr data;
- /*! \brief shared ptr to session where this data is valid */
- ObjectPtr<MicroSession> session;
-};
-
-// TODO(weberlo): maybe templatize serialization to reduce redundancy
-
-/*! \brief TVM array for serialization to 32-bit devices */
-struct TVMArray32 {
- TVMArray32(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape,
- TargetVal strides, TargetVal byte_offset)
- : data{data.uint32()},
- ctx{ctx},
- ndim{ndim},
- dtype{dtype},
- shape{shape.uint32()},
- strides{strides.uint32()},
- byte_offset{byte_offset.uint32()} {}
-
- /*!
- * \brief The opaque data pointer points to the allocated data.
- * This will be CUDA device pointer or cl_mem handle in OpenCL.
- * This pointer is always aligns to 256 bytes as in CUDA.
- */
- uint32_t data;
- /*! \brief The device context of the tensor */
- DLContext ctx;
- /*! \brief Number of dimensions */
- int32_t ndim;
- /*! \brief The data type of the pointer */
- DLDataType dtype;
- /*! \brief The shape of the tensor */
- uint32_t shape;
- /*!
- * \brief strides of the tensor,
- * can be NULL, indicating tensor is compact.
- */
- uint32_t strides;
- /*! \brief The offset in bytes to the beginning pointer to data */
- uint32_t byte_offset;
-};
-
-/*! \brief TVM array for serialization to 64-bit devices */
-struct TVMArray64 {
- TVMArray64(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape,
- TargetVal strides, TargetVal byte_offset)
- : data(data.uint64()),
- ctx(ctx),
- ndim(ndim),
- dtype(dtype),
- shape(shape.uint64()),
- strides(strides.uint64()),
- byte_offset(byte_offset.uint64()) {}
- /*!
- * \brief The opaque data pointer points to the allocated data.
- * This will be CUDA device pointer or cl_mem handle in OpenCL.
- * This pointer is always aligns to 256 bytes as in CUDA.
- */
- uint64_t data;
- /*! \brief The device context of the tensor */
- DLContext ctx;
- /*! \brief Number of dimensions */
- int32_t ndim;
- /*! \brief The data type of the pointer */
- DLDataType dtype;
- /*! \brief The shape of the tensor */
- uint64_t shape;
- /*!
- * \brief strides of the tensor,
- * can be NULL, indicating tensor is compact.
- */
- uint64_t strides;
- /*! \brief The offset in bytes to the beginning pointer to data */
- uint64_t byte_offset;
-};
-
-/*! \brief MicroTVM task to store in task queue before specializing to word size */
-struct DevTask {
- /*! \brief Pointer to function to call for this task */
- TargetVal func;
- /*! \brief Array of argument values */
- TargetVal arg_values;
- /*! \brief Array of type codes for each argument value */
- TargetVal arg_type_codes;
- /*! \brief Number of arguments */
- int32_t num_args;
-};
-
-/*! \brief MicroTVM task for serialization to 32-bit devices */
-typedef struct StructUTVMTask32 {
- StructUTVMTask32(DevTask task)
- : func(task.func.uint32()),
- arg_values(task.arg_values.uint32()),
- arg_type_codes(task.arg_type_codes.uint32()),
- num_args(task.num_args) {}
-
- /*! \brief Pointer to function to call for this task */
- uint32_t func;
- /*! \brief Array of argument values */
- uint32_t arg_values;
- /*! \brief Array of type codes for each argument value */
- uint32_t arg_type_codes;
- /*! \brief Number of arguments */
- int32_t num_args;
-} StructUTVMTask32;
-
-/*! \brief MicroTVM task for serialization to 64-bit devices */
-typedef struct StructUTVMTask64 {
- StructUTVMTask64(DevTask task)
- : func(task.func.uint64()),
- arg_values(task.arg_values.uint64()),
- arg_type_codes(task.arg_type_codes.uint64()),
- num_args(task.num_args) {}
-
- /*! \brief Pointer to function to call for this task */
- uint64_t func;
- /*! \brief Array of argument values */
- uint64_t arg_values;
- /*! \brief Array of type codes for each argument value */
- uint64_t arg_type_codes;
- /*! \brief Number of arguments */
- int32_t num_args;
-} StructUTVMTask64;
-
-} // namespace runtime
-} // namespace tvm
#endif // TVM_RUNTIME_MICRO_MICRO_SESSION_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file openocd_low_level_device.cc
- */
-#include <iomanip>
-#include <sstream>
-
-#include "low_level_device.h"
-#include "micro_common.h"
-#include "tcl_socket.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief OpenOCD low-level device for uTVM micro devices connected over JTAG
- */
-class OpenOCDLowLevelDevice final : public LowLevelDevice {
- public:
- /*!
- * \brief constructor to initialize connection to openocd device
- * \param server_addr address of the OpenOCD server to connect to
- * \param port port of the OpenOCD server to connect to
- */
- explicit OpenOCDLowLevelDevice(const std::string& server_addr, int port) : socket_() {
- server_addr_ = server_addr;
- port_ = port;
-
- socket_.Connect(tvm::support::SockAddr(server_addr_.c_str(), port_));
- socket_.cmd_builder() << "reset run";
- socket_.SendCommand();
-
- socket_.cmd_builder() << "halt 500";
- socket_.SendCommand();
- }
-
- void Read(TargetPtr addr, void* buf, size_t num_bytes) override {
- if (num_bytes == 0) {
- return;
- }
-
- // TODO(weberlo): Refactor between read and write.
- // Check if we need to chunk this write request.
- if (num_bytes > kMemTransferLimit) {
- char* curr_buf_ptr = reinterpret_cast<char*>(buf);
- while (num_bytes != 0) {
- size_t amount_to_read;
- if (num_bytes > kMemTransferLimit) {
- amount_to_read = kMemTransferLimit;
- } else {
- amount_to_read = num_bytes;
- }
- Read(addr, reinterpret_cast<void*>(curr_buf_ptr), amount_to_read);
- addr += amount_to_read;
- curr_buf_ptr += amount_to_read;
- num_bytes -= amount_to_read;
- }
- return;
- }
- {
- socket_.cmd_builder() << "array unset output";
- socket_.SendCommand();
-
- socket_.cmd_builder() << "mem2array output"
- << " " << std::dec << kWordSize << " "
- << addr.cast_to<void*>()
- // Round up any request sizes under a byte, since OpenOCD doesn't
- // support sub-byte-sized transfers.
- << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes);
- socket_.SendCommand();
- }
-
- {
- socket_.cmd_builder() << "return $output";
- socket_.SendCommand();
- const std::string& reply = socket_.last_reply();
-
- std::istringstream values(reply);
- char* char_buf = reinterpret_cast<char*>(buf);
- ssize_t req_bytes_remaining = num_bytes;
- uint32_t index;
- uint32_t val;
- while (req_bytes_remaining > 0) {
- // The response from this command pairs indices with the contents of the
- // memory at that index.
- values >> index;
- CHECK(index < num_bytes) << "index " << index << " out of bounds (length " << num_bytes
- << ")";
- // Read the value into `curr_val`, instead of reading directly into
- // `buf_iter`, because otherwise it's interpreted as the ASCII value and
- // not the integral value.
- values >> val;
- char_buf[index] = static_cast<uint8_t>(val);
- req_bytes_remaining--;
- }
- if (num_bytes >= 8) {
- uint32_t check_index;
- values >> check_index;
- CHECK(check_index != index) << "more data in response than requested";
- }
- }
- }
-
- void Write(TargetPtr addr, const void* buf, size_t num_bytes) override {
- if (num_bytes == 0) {
- return;
- }
-
- // Check if we need to chunk this write request.
- if (num_bytes > kMemTransferLimit) {
- const char* curr_buf_ptr = reinterpret_cast<const char*>(buf);
- while (num_bytes != 0) {
- size_t amount_to_write;
- if (num_bytes > kMemTransferLimit) {
- amount_to_write = kMemTransferLimit;
- } else {
- amount_to_write = num_bytes;
- }
- Write(addr, reinterpret_cast<const void*>(curr_buf_ptr), amount_to_write);
- addr += amount_to_write;
- curr_buf_ptr += amount_to_write;
- num_bytes -= amount_to_write;
- }
- return;
- }
-
- // Clear `input` array.
- socket_.cmd_builder() << "array unset input";
- socket_.SendCommand();
- // Build a command to set the value of `input`.
- {
- std::ostringstream& cmd_builder = socket_.cmd_builder();
- cmd_builder << "array set input {";
- const char* char_buf = reinterpret_cast<const char*>(buf);
- for (size_t i = 0; i < num_bytes; i++) {
- // In a Tcl `array set` commmand, we need to pair the array indices with
- // their values.
- cmd_builder << i << " ";
- // Need to cast to uint, so the number representation of `buf[i]` is
- // printed, and not the ASCII representation.
- cmd_builder << static_cast<uint32_t>(char_buf[i]) << " ";
- }
- cmd_builder << "}";
- socket_.SendCommand();
- }
- {
- socket_.cmd_builder() << "array2mem input"
- << " " << std::dec << kWordSize << " " << addr.cast_to<void*>() << " "
- << std::dec << num_bytes;
- socket_.SendCommand();
- }
- }
-
- void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) override {
- socket_.cmd_builder() << "halt 0";
- socket_.SendCommand();
-
- // Set a breakpoint at the beginning of `UTVMDone`.
- socket_.cmd_builder() << "bp " << breakpoint_addr.cast_to<void*>() << " 2";
- socket_.SendCommand();
-
- socket_.cmd_builder() << "resume " << func_addr.cast_to<void*>();
- socket_.SendCommand();
-
- socket_.cmd_builder() << "wait_halt " << kWaitTime;
- socket_.SendCommand();
-
- socket_.cmd_builder() << "halt 0";
- socket_.SendCommand();
-
- // Remove the breakpoint.
- socket_.cmd_builder() << "rbp " << breakpoint_addr.cast_to<void*>();
- socket_.SendCommand();
- }
-
- const char* device_type() const final { return "openocd"; }
-
- private:
- /*! \brief socket used to communicate with the device through Tcl */
- TclSocket socket_;
- /*! \brief address of OpenOCD server */
- std::string server_addr_;
- /*! \brief port of OpenOCD server */
- int port_;
-
- /*! \brief number of bytes in a word on the target device (64-bit) */
- static const constexpr ssize_t kWordSize = 8;
- // NOTE: The OS pipe buffer must be able to handle a line long enough to
- // print this transfer request.
- /*! \brief maximum number of bytes allowed in a single memory transfer */
- static const constexpr ssize_t kMemTransferLimit = 8000;
- /*! \brief number of milliseconds to wait for function execution to halt */
- static const constexpr int kWaitTime = 30000;
-};
-
-const std::shared_ptr<LowLevelDevice> OpenOCDLowLevelDeviceCreate(const std::string& server_addr,
- int port) {
- std::shared_ptr<LowLevelDevice> lld = std::make_shared<OpenOCDLowLevelDevice>(server_addr, port);
- return lld;
-}
-
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include "target_data_layout_encoder.h"
-
-namespace tvm {
-namespace runtime {
-
-TargetDataLayoutEncoder::Alloc::Alloc(TargetDataLayoutEncoder* parent, size_t start_offset,
- size_t size, TargetPtr start_addr)
- : parent_(parent),
- start_offset_(start_offset),
- curr_offset_(0),
- size_(size),
- start_addr_(start_addr) {
- parent_->live_unchecked_allocs_.insert(this);
-}
-
-TargetDataLayoutEncoder::Alloc::~Alloc() {
- auto it = parent_->live_unchecked_allocs_.find(this);
- if (it != parent_->live_unchecked_allocs_.end()) {
- // alloc was not already checked
- parent_->live_unchecked_allocs_.erase(it);
- if (curr_offset_ != size_) {
- parent_->unchecked_alloc_start_offsets_.push_back(start_addr_.value().uint64());
- }
- }
-}
-
-void TargetDataLayoutEncoder::Alloc::CheckUnfilled() {
- CHECK(curr_offset_ == size_) << "unwritten space in alloc 0x" << std::hex
- << start_addr_.value().uint64() << "; curr_offset=0x" << curr_offset_
- << ", size=0x" << size_;
-}
-
-TargetPtr TargetDataLayoutEncoder::Alloc::start_addr() { return start_addr_; }
-
-size_t TargetDataLayoutEncoder::Alloc::size() { return size_; }
-
-void TargetDataLayoutEncoder::CheckUnfilledAllocs() {
- CHECK(live_unchecked_allocs_.size() > 0) << "No allocs to check";
- if (unchecked_alloc_start_offsets_.size() > 0) {
- LOG(ERROR) << "Unchecked allocs were found:";
- for (size_t alloc_start_addr : unchecked_alloc_start_offsets_) {
- LOG(ERROR) << " * 0x" << std::hex << alloc_start_addr;
- }
- CHECK(false) << "Unchecked allocs found during CheckUnfilledAllocs";
- }
-
- for (class Alloc* s : live_unchecked_allocs_) {
- s->CheckUnfilled();
- }
- live_unchecked_allocs_.clear();
-}
-
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file target_data_layout_encoder.h
- * \brief uTVM data layout encoder
- */
-#ifndef TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_
-#define TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_
-
-#include <memory>
-#include <set>
-#include <vector>
-
-#include "host_driven/utvm_runtime_enum.h"
-#include "micro_common.h"
-
-namespace tvm {
-namespace runtime {
-
-// TODO(weberlo, areusch): Handle endianness.
-
-/*!
- * \brief data encoder for uTVM that builds a host-side buffer
- */
-class TargetDataLayoutEncoder {
- public:
- /*!
- * \brief helper class for writing into `TargetDataLayoutEncoder`
- */
- class Alloc {
- public:
- /*!
- * \brief constructor
- * \param parent pointer to parent encoder
- * \param start_offset start byte offset of the alloc in the backing buffer
- * \param size size (in bytes) of the memory region allocated for this alloc
- * \param start_addr start address of the alloc in the device's memory
- */
- Alloc(TargetDataLayoutEncoder* parent, size_t start_offset, size_t size, TargetPtr start_addr);
-
- ~Alloc();
-
- /*!
- * \brief writes `sizeof(T) * num_elems` bytes of data from `arr`
- * \param arr array to be read from
- * \param num_elems number of elements in array
- */
- template <typename T>
- void WriteArray(const T* arr, size_t num_elems);
-
- /*!
- * \brief writes `val`
- * \param val value to be written
- */
- template <typename T>
- void WriteValue(const T& val);
-
- /*!
- * \brief returns start address of the alloc in device memory
- * \return device start address
- */
- TargetPtr start_addr();
-
- /*!
- * \brief returns number of bytes allocated for this alloc
- * \return size of this alloc
- */
- size_t size();
-
- size_t curr_offset() const { return curr_offset_; }
-
- void CheckUnfilled();
-
- private:
- /*! \brief pointer to parent encoder */
- TargetDataLayoutEncoder* parent_;
- /*! \brief start offset of the alloc in the parent's backing parent_buffer */
- size_t start_offset_;
- /*! \brief current offset relative to the start offset of this alloc */
- size_t curr_offset_;
- /*! \brief size (in bytes) of the memory region allocated for this alloc */
- size_t size_;
- /*! \brief start address of the alloc in the device's memory */
- TargetPtr start_addr_;
- };
-
- /*!
- * \brief constructor
- * \param start_addr start address of the encoder in device memory
- */
- explicit TargetDataLayoutEncoder(size_t capacity, TargetWordSize word_size)
- : buf_(std::vector<uint8_t>()),
- curr_offset_(0),
- start_addr_(word_size, nullptr),
- capacity_(capacity),
- word_size_(word_size) {}
-
- /*!
- * \brief allocates a alloc for `sizeof(T) * num_elems` bytes of data
- * \param num_elems number of elements of type `T` being allocated (defaults to 1)
- * \return alloc of size `sizeof(T) * num_elems` bytes
- */
- template <typename T>
- std::unique_ptr<class Alloc> Alloc(size_t num_elems = 1) {
- curr_offset_ = UpperAlignValue(curr_offset_, word_size_.bytes());
- size_t size = sizeof(T) * num_elems;
- if (curr_offset_ + size > buf_.size()) {
- buf_.resize(curr_offset_ + size);
- }
- CHECK(buf_.size() < capacity_) << "out of space in data encoder";
- size_t alloc_start_offset = curr_offset_;
- curr_offset_ += size;
- class Alloc* alloc =
- new class Alloc(this, alloc_start_offset, size, start_addr() + alloc_start_offset);
- return std::unique_ptr<class Alloc>(alloc);
- }
-
- void Clear() {
- buf_.clear();
- curr_offset_ = 0;
- }
-
- /*!
- * \brief returns the array backing the encoder's buffer
- * \return array backing the encoder's buffer
- */
- uint8_t* data() { return buf_.data(); }
-
- /*!
- * \brief returns current size of the encoder's buffer
- * \return buffer size
- */
- size_t buf_size() const { return buf_.size(); }
-
- TargetPtr start_addr() const {
- CHECK_NE(start_addr_.value().uint64(), 0) << "start addr uninitialized";
- return start_addr_;
- }
-
- void set_start_addr(TargetPtr start_addr) {
- CHECK_EQ(buf_.size(), 0) << "cannot change encoder start addr unless empty";
- start_addr_ =
- TargetPtr(word_size_, UpperAlignValue(start_addr.value().uint64(), word_size_.bytes()));
- }
-
- void CheckUnfilledAllocs();
-
- private:
- /*! \brief in-memory backing buffer */
- std::vector<uint8_t> buf_;
- /*! \brief current offset */
- size_t curr_offset_;
- /*! \brief start address of the encoder in device memory */
- TargetPtr start_addr_;
- /*! \brief number of bytes available in device memory */
- size_t capacity_;
- /*! \brief number of bytes in a word on the target device */
- TargetWordSize word_size_;
- /*! \brief Alloc instances allocated now but not yet checked by CheckUnfilledAllocs */
- std::set<class Alloc*> live_unchecked_allocs_;
- /*! \brief start offsets Alloc instances that were dealloated before CheckUnfilledAllocs ran */
- std::vector<size_t> unchecked_alloc_start_offsets_;
- friend Alloc::~Alloc();
-};
-
-template <typename T>
-void TargetDataLayoutEncoder::Alloc::WriteArray(const T* arr, size_t num_elems) {
- if (num_elems == 0) return;
- size_t size = sizeof(T) * num_elems;
- CHECK(curr_offset_ + size <= size_) << "not enough space in alloc";
- uint8_t* curr_ptr = &(parent_->data())[start_offset_ + curr_offset_];
- std::memcpy(curr_ptr, arr, size);
- curr_offset_ += size;
-}
-
-template <typename T>
-void TargetDataLayoutEncoder::Alloc::WriteValue(const T& val) {
- WriteArray(&val, 1);
-}
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tcl_socket.cc
- */
-#include "tcl_socket.h"
-
-#include <string>
-
-namespace tvm {
-namespace runtime {
-
-TclSocket::TclSocket() {
- tcp_socket_.Create();
- tcp_socket_.SetKeepAlive(true);
- reply_buf_.reserve(kReplyBufSize);
-}
-
-TclSocket::~TclSocket() { tcp_socket_.Close(); }
-
-void TclSocket::Connect(tvm::support::SockAddr addr) {
- CHECK(tcp_socket_.Connect(addr)) << "failed to connect";
-}
-
-void TclSocket::SendCommand() {
- const char terminate_token = kCommandTerminateToken;
- cmd_builder_ << terminate_token;
- std::string full_cmd = cmd_builder_.str();
-
- CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) << "failed to send command";
- cmd_builder_.str(std::string());
-
- reply_builder_.str(std::string());
- char last_read = '\0';
- // Receive from the socket until we reach a command terminator.
- do {
- ssize_t bytes_read;
- // Recieve from the socket until it's drained.
- do {
- // Leave room at the end of `reply_buf` to tack on a null terminator.
- bytes_read = tcp_socket_.Recv(reply_buf_.data(), kReplyBufSize - 1);
- reply_buf_[bytes_read] = '\0';
- reply_builder_ << reply_buf_.data();
- // Update last read character.
- last_read = reply_buf_[bytes_read - 1];
- } while (bytes_read == kReplyBufSize - 1);
- CHECK(bytes_read != -1) << "failed to read command reply";
- } while (last_read != terminate_token);
- last_reply_ = reply_builder_.str();
- CHECK_EQ(last_reply_[last_reply_.length() - 1], terminate_token) << "missing command terminator";
-}
-
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tcl_socket.h
- * \brief TCP socket wrapper for communicating using Tcl commands
- */
-#ifndef TVM_RUNTIME_MICRO_TCL_SOCKET_H_
-#define TVM_RUNTIME_MICRO_TCL_SOCKET_H_
-
-#include <string>
-#include <vector>
-
-#include "../../support/socket.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief TCP socket wrapper for communicating using Tcl commands
- *
- * Usage generally involves building a command using the `cmd_builder` stream
- * interface, then sending the command with `SendCommand`, and if necessary,
- * reading the reply.
- */
-class TclSocket {
- public:
- /*!
- * \brief constructor to create the socket
- */
- TclSocket();
-
- /*!
- * \brief destructor to close the socket connection
- */
- ~TclSocket();
-
- /*!
- * \brief open connection with server
- * \param addr server address
- */
- void Connect(tvm::support::SockAddr addr);
-
- /*
- * \brief send the built command to the server and await a reply
- *
- * \return the reply
- */
- void SendCommand();
-
- /*
- * \return string stream for current command being built
- */
- std::ostringstream& cmd_builder() { return cmd_builder_; }
-
- /*
- * \return reply from most recently sent command
- */
- const std::string& last_reply() { return last_reply_; }
-
- private:
- /*! \brief underlying TCP socket being wrapped */
- tvm::support::TCPSocket tcp_socket_;
- /*! \brief buffer used to receive messages from the socket */
- std::vector<uint8_t> reply_buf_;
- /*! \brief string stream used to build current command */
- std::ostringstream cmd_builder_;
- /*! \brief string stream used to receive replies from sent commands */
- std::ostringstream reply_builder_;
- /*! \brief reply from most recently sent command */
- std::string last_reply_;
-
- /*! \brief character denoting the end of a Tcl command */
- static const constexpr char kCommandTerminateToken = '\x1a';
- /*! \brief size of the buffer used to receive messages (in bytes) */
- static const constexpr size_t kReplyBufSize = 4096;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MICRO_TCL_SOCKET_H_
* \note This file do not depend on c++ std or c std,
* and only depends on TVM's C runtime API.
*/
-#ifndef TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_
-#define TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_
+#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
+#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
-#include <dmlc/endian.h>
+#define DMLC_LITTLE_ENDIAN true
+#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
-#include "../../../support/arena.h"
-#include "../rpc_protocol.h"
+#include "../../support/generic_arena.h"
+#include "rpc_reference.h"
/*! \brief Whether or not to enable glog style DLOG */
#ifndef TVM_MINRPC_ENABLE_LOGGING
* \tparam TIOHandler IO provider to provide io handling.
* An IOHandler needs to provide the following functions:
* - PosixWrite, PosixRead, Close: posix style, read, write, close API.
+ * - MessageStart(num_bytes), MessageDone(): framing APIs.
* - Exit: exit with status code.
*/
template <typename TIOHandler>
* \brief Constructor.
* \param io The IO handler.
*/
- explicit MinRPCServer(TIOHandler io) : io_(io), arena_(PageAllocator(io)) {}
+ explicit MinRPCServer(TIOHandler* io) : io_(io), arena_(PageAllocator(io)) {}
- /*! \brief Run the server loop until shutdown signal is received. */
- void ServerLoop() {
+ /*! \brief Process a single request.
+ *
+ * \return true when the server should continue processing requests. false when it should be
+ * shutdown.
+ */
+ bool ProcessOnePacket() {
RPCCode code;
uint64_t packet_len;
- while (true) {
- arena_.RecycleAll();
- allow_clean_shutdown_ = true;
+ arena_.RecycleAll();
+ allow_clean_shutdown_ = true;
- this->Read(&packet_len);
- if (packet_len == 0) continue;
- this->Read(&code);
+ this->Read(&packet_len);
+ if (packet_len == 0) return true;
+ this->Read(&code);
- allow_clean_shutdown_ = false;
+ allow_clean_shutdown_ = false;
- if (code >= RPCCode::kSyscallCodeStart) {
- this->HandleSyscallFunc(code);
- } else {
- switch (code) {
- case RPCCode::kCallFunc: {
- HandleNormalCallFunc();
- break;
- }
- case RPCCode::kInitServer: {
- HandleInitServer();
- break;
- }
- case RPCCode::kCopyFromRemote: {
- HandleCopyFromRemote();
- break;
- }
- case RPCCode::kCopyToRemote: {
- HandleCopyToRemote();
- break;
- }
- case RPCCode::kShutdown: {
- this->Shutdown();
- return;
- }
- default: {
- this->ThrowError(RPCServerStatus::kUnknownRPCCode);
- break;
- }
+ if (code >= RPCCode::kSyscallCodeStart) {
+ this->HandleSyscallFunc(code);
+ } else {
+ switch (code) {
+ case RPCCode::kCallFunc: {
+ HandleNormalCallFunc();
+ break;
+ }
+ case RPCCode::kInitServer: {
+ HandleInitServer();
+ break;
+ }
+ case RPCCode::kCopyFromRemote: {
+ HandleCopyFromRemote();
+ break;
+ }
+ case RPCCode::kCopyToRemote: {
+ HandleCopyToRemote();
+ break;
+ }
+ case RPCCode::kShutdown: {
+ this->Shutdown();
+ return false;
+ }
+ default: {
+ this->ThrowError(RPCServerStatus::kUnknownRPCCode);
+ break;
}
}
}
+
+ return true;
}
void Shutdown() {
arena_.FreeAll();
- io_.Close();
+ io_->Close();
}
void HandleNormalCallFunc() {
ret_value[2].v_handle = ret_value[1].v_handle;
ret_tcode[2] = kTVMOpaqueHandle;
this->ReturnPackedSeq(ret_value, ret_tcode, 3);
+ } else if (rv_tcode == kTVMBytes) {
+ ret_tcode[1] = kTVMBytes;
+ this->ReturnPackedSeq(ret_value, ret_tcode, 2);
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
ret_tcode[1] = kTVMOpaqueHandle;
this->ReturnPackedSeq(ret_value, ret_tcode, 2);
RPCCode code = RPCCode::kCopyAck;
uint64_t packet_nbytes = sizeof(code) + num_bytes;
+ io_->MessageStart(packet_nbytes);
this->Write(packet_nbytes);
this->Write(code);
this->WriteArray(data_ptr, num_bytes);
+ io_->MessageDone();
} else {
this->ReturnLastTVMError();
}
}
void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- io_.Exit(static_cast<int>(code));
+ io_->Exit(static_cast<int>(code));
}
template <typename T>
return this->WriteRawBytes(data, sizeof(T) * count);
}
+ void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); }
+
+ void MessageDone() { io_->MessageDone(); }
+
private:
// Internal allocator that redirects alloc to TVM's C API.
class PageAllocator {
public:
using ArenaPageHeader = tvm::support::ArenaPageHeader;
- explicit PageAllocator(TIOHandler io) : io_(io) {}
+ explicit PageAllocator(TIOHandler* io) : io_(io) {}
ArenaPageHeader* allocate(size_t min_size) {
size_t npages = ((min_size + kPageSize - 1) / kPageSize);
if (TVMDeviceAllocDataSpace(DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign,
DLDataType{kDLInt, 1, 1}, &data) != 0) {
- io_.Exit(static_cast<int>(RPCServerStatus::kAllocError));
+ io_->Exit(static_cast<int>(RPCServerStatus::kAllocError));
}
ArenaPageHeader* header = static_cast<ArenaPageHeader*>(data);
void deallocate(ArenaPageHeader* page) {
if (TVMDeviceFreeDataSpace(DLContext{kDLCPU, 0}, page) != 0) {
- io_.Exit(static_cast<int>(RPCServerStatus::kAllocError));
+ io_->Exit(static_cast<int>(RPCServerStatus::kAllocError));
}
}
static const constexpr int kPageAlign = 8;
private:
- TIOHandler io_;
+ TIOHandler* io_;
};
void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) {
uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
+ io_->MessageStart(packet_nbytes);
this->Write(packet_nbytes);
this->Write(code);
this->Write(num_args);
this->Write(tcode);
+ io_->MessageDone();
}
void ReturnHandle(void* handle) {
int32_t tcode = kTVMOpaqueHandle;
RPCCode code = RPCCode::kReturn;
uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
-
uint64_t packet_nbytes =
sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle);
+ io_->MessageStart(packet_nbytes);
this->Write(packet_nbytes);
this->Write(code);
this->Write(num_args);
this->Write(tcode);
this->Write(encode_handle);
+ io_->MessageDone();
}
void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); }
uint8_t* buf = reinterpret_cast<uint8_t*>(data);
size_t ndone = 0;
while (ndone < size) {
- ssize_t ret = io_.PosixRead(buf, size - ndone);
+ ssize_t ret = io_->PosixRead(buf, size - ndone);
if (ret == 0) {
if (allow_clean_shutdown_) {
this->Shutdown();
- io_.Exit(0);
+ io_->Exit(0);
} else {
this->ThrowError(RPCServerStatus::kReadError);
}
const uint8_t* buf = reinterpret_cast<const uint8_t*>(data);
size_t ndone = 0;
while (ndone < size) {
- ssize_t ret = io_.PosixWrite(buf, size - ndone);
+ ssize_t ret = io_->PosixWrite(buf, size - ndone);
if (ret == 0 || ret == -1) {
this->ThrowError(RPCServerStatus::kWriteError);
}
}
/*! \brief IO handler. */
- TIOHandler io_;
+ TIOHandler* io_;
/*! \brief internal arena. */
support::GenericArena<PageAllocator> arena_;
/*! \brief Whether we are in a state that allows clean shutdown. */
} // namespace runtime
} // namespace tvm
-#endif // TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_
+#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
explicit PosixIOHandler(int read_fd = 0, int write_fd = 1)
: read_fd_(read_fd), write_fd_(write_fd) {}
+ void MessageStart(uint64_t packet_nbytes) {}
+
+ void MessageDone() {}
+
ssize_t PosixRead(void* data, size_t size) { return read(read_fd_, data, size); }
ssize_t PosixWrite(const void* data, size_t size) { return write(write_fd_, data, size); }
if (argc != 3) return -1;
// pass the descriptor via arguments.
tvm::runtime::PosixIOHandler handler(atoi(argv[1]), atoi(argv[2]));
- tvm::runtime::PosixMinRPCServer server(handler);
- server.ServerLoop();
+ tvm::runtime::PosixMinRPCServer server(&handler);
+ bool is_running = true;
+ while (is_running) {
+ is_running = server.ProcessOnePacket();
+ }
+
return 0;
}
*/
/*!
- * \file rpc_procotol.h
+ * \file rpc_reference.h
* \brief Common header defining the communication code used in the RPC protocol.
*/
-#ifndef TVM_RUNTIME_RPC_RPC_PROTOCOL_H_
-#define TVM_RUNTIME_RPC_RPC_PROTOCOL_H_
+#ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
+#define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
namespace tvm {
namespace runtime {
kAllocError
};
+inline const char* RPCCodeToString(RPCCode code) {
+ switch (code) {
+ case RPCCode::kShutdown:
+ return "kShutdown";
+ case RPCCode::kInitServer:
+ return "kInitServer";
+ case RPCCode::kCallFunc:
+ return "kCallFunc";
+ case RPCCode::kReturn:
+ return "kReturn";
+ case RPCCode::kException:
+ return "kException";
+ case RPCCode::kCopyFromRemote:
+ return "kCopyFromRemote";
+ case RPCCode::kCopyToRemote:
+ return "kCopyToRemote";
+ case RPCCode::kCopyAck:
+ return "kCopyAck";
+ // The following are syscall code that can send over CallRemote
+ case RPCCode::kGetGlobalFunc:
+ return "kGetGlobalFunc";
+ case RPCCode::kFreeHandle:
+ return "kFreeHandle";
+ case RPCCode::kDevSetDevice:
+ return "kDevSetDevice";
+ case RPCCode::kDevGetAttr:
+ return "kDevGetAttr";
+ case RPCCode::kDevAllocData:
+ return "kDevAllocData";
+ case RPCCode::kDevFreeData:
+ return "kDevFreeData";
+ case RPCCode::kDevStreamSync:
+ return "kDevStreamSync";
+ case RPCCode::kCopyAmongRemote:
+ return "kCopyAmongRemote";
+ default:
+ return "";
+ }
+}
+
/*!
* \brief Convert RPC server status to string.
* \param status The status.
uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len;
+ channel->MessageStart(packet_nbytes);
channel->Write(packet_nbytes);
channel->Write(code);
channel->Write(num_args);
channel->Write(tcode);
channel->Write(len);
channel->WriteArray(msg, len);
+ channel->MessageDone();
}
/*!
uint64_t packet_nbytes =
sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel);
+ channel->MessageStart(packet_nbytes);
channel->Write(packet_nbytes);
channel->Write(code);
SendPackedSeq(arg_values, type_codes, num_args, false, channel);
+ channel->MessageDone();
}
/*!
uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
+ channel->MessageStart(packet_nbytes);
channel->Write(packet_nbytes);
channel->Write(code);
channel->Write(num_args);
channel->Write(tcode);
+ channel->MessageDone();
}
};
} // namespace runtime
} // namespace tvm
-#endif // TVM_RUNTIME_RPC_RPC_PROTOCOL_H_
+
+#endif // TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_
using Stream::Write;
using Stream::WriteArray;
+ void MessageStart(uint64_t packet_nbytes) {
+ // Unused here, implemented for uTVM framing layer.
+ }
+
bool Read(RPCCode* code) {
int32_t cdata;
if (!this->Read(&cdata)) return false;
this->Write(cdata);
}
+ void MessageDone() {
+ // Unused here, implemented for uTVM framing layer.
+ }
+
template <typename T>
T* ArenaAlloc(int count) {
static_assert(std::is_pod<T>::value, "need to be trival");
#include <utility>
#include "../../support/ring_buffer.h"
+#include "../minrpc/rpc_reference.h"
#include "rpc_channel.h"
-#include "rpc_protocol.h"
#include "rpc_session.h"
namespace tvm {
#include <memory>
#include <string>
-#include "rpc_protocol.h"
+#include "../minrpc/rpc_reference.h"
namespace tvm {
namespace runtime {
#ifndef TVM_SUPPORT_ARENA_H_
#define TVM_SUPPORT_ARENA_H_
-#ifndef TVM_ARENA_HAS_DESTRUCTOR
-#define TVM_ARENA_HAS_DESTRUCTOR 1
-#endif
-
#include <cstddef>
#include <type_traits>
#include <utility>
+#include "generic_arena.h"
+
namespace tvm {
namespace support {
/*!
- * \brief An arena page header.
- */
-struct ArenaPageHeader {
- /*! \brief points to the next page. */
- ArenaPageHeader* next;
- /*!
- * \brief Total size of the page.
- */
- size_t size;
- /*! \brief memory allocator offset inside page. */
- size_t offset;
-};
-
-/*!
* \brief Simple page allocator that uses new and delete.
*/
class SimplePageAllocator {
using Page = std::aligned_storage<kPageSize, kPageAlign>::type;
};
-/*!
- * \brief Arena allocator that allocates memory from continuous
- * chunk and frees them all only during destruction.
- */
-template <typename PageAllocator>
-class GenericArena {
- public:
- explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) {
- // eagerly allocate the first page.
- head_ = tail_ = alloc_.allocate(1);
- head_->next = nullptr;
- }
-
-#if TVM_ARENA_HAS_DESTRUCTOR
- ~GenericArena() { this->FreeAll(); }
-#endif
-
- /*! \brief Free all pages. */
- void FreeAll() {
- FreePageList(&head_);
- FreePageList(&free_list_);
- }
- /*! \brief Recycle all the pages in the arena */
- void RecycleAll() {
- // put all the current list to the free list.
- tail_->next = free_list_;
- // allocate the first in the free list to head
- free_list_ = head_->next;
- head_->next = nullptr;
- // Reset the head.
- head_->offset = sizeof(ArenaPageHeader);
- tail_ = head_;
- }
- /*!
- * \brief Allocate a space from Arena for type T
- * \param T the data type to be allocated
- * \param count Numberof elements
- * \note The space of T is not initialized.
- */
- template <typename T>
- T* allocate_(int count = 1) {
- static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment");
- return static_cast<T*>(Alloc(sizeof(T) * count, alignof(T)));
- }
- /*!
- * \brief Create a new instance of type T.
- * \param args The constructor argument.
- * \tparam T the type to be created.
- * \tparam Args Arguments to the constructor.
- *
- * \return The allocated object.
- * \note The type T must be simple type, or only contain
- * memory allocated from the same arena.
- * Otherwise the destructor needs to be called explicitly.
- */
- template <typename T, typename... Args>
- T* make(Args&&... args) {
- T* ptr = allocate_<T>();
- new (ptr) T(std::forward<Args>(args)...);
- return ptr;
- }
-
- private:
- /*! \brief internal page allocator. */
- PageAllocator alloc_;
- /* \brief The the head of the allocated list. */
- ArenaPageHeader* head_{nullptr};
- /*! \brief The tail of the allocated list. */
- ArenaPageHeader* tail_{nullptr};
- /* \brief List of free pages. */
- ArenaPageHeader* free_list_{nullptr};
- /*!
- * \brief Align ptr by upper bound.
- * \param offset The offset value.
- * \param align The alignment requirement.
- */
- size_t UpperAlign(size_t offset, size_t align) {
- return offset + (align - (offset % align)) % align;
- }
- /*!
- * \brief Internal aligned alloc function.
- * \param size The size of the memory.
- * \param align The alignment requirement.
- */
- void* Alloc(size_t size, size_t align) {
- size_t offset = UpperAlign(head_->offset, align);
- if (offset + size <= head_->size) {
- head_->offset = offset + size;
- return reinterpret_cast<char*>(head_) + offset;
- } else {
- ArenaPageHeader* new_head;
- offset = UpperAlign(sizeof(ArenaPageHeader), align);
- if (free_list_ != nullptr && offset + size <= free_list_->size) {
- new_head = free_list_;
- free_list_ = free_list_->next;
- } else {
- new_head = alloc_.allocate(offset + size);
- }
- new_head->next = head_;
- new_head->offset = offset + size;
- head_ = new_head;
- return reinterpret_cast<char*>(head_) + offset;
- }
- }
- /*!
- * \brief Free all the pages in the list.
- * \param ptr The head ptr.
- */
- void FreePageList(ArenaPageHeader** ptr) {
- // delete all the allocated pages.
- while (ptr[0] != nullptr) {
- ArenaPageHeader* temp = ptr[0];
- ptr[0] = ptr[0]->next;
- alloc_.deallocate(temp);
- }
- }
-};
-
using Arena = GenericArena<SimplePageAllocator>;
/*!
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file arena.h
+ * \brief Arena allocator that allocates memory chunks and frees them all during destruction time.
+ *
+ * NOTE: This file is portable to bare-metal embedded devices. Don't use operator new (without
+ * placement parameters) or malloc.
+ */
+#ifndef TVM_SUPPORT_GENERIC_ARENA_H_
+#define TVM_SUPPORT_GENERIC_ARENA_H_
+
+#ifndef TVM_ARENA_HAS_DESTRUCTOR
+#define TVM_ARENA_HAS_DESTRUCTOR 1
+#endif
+
+#include <stddef.h>
+
+#include <utility>
+
+namespace tvm {
+namespace support {
+
+namespace {
+template <typename T> // For lvalues (T is T&),
+T&& forward(T&& param) { // take/return lvalue refs.
+ return static_cast<T&&>(param); // For rvalues (T is T),
+} // take/return rvalue refs.
+} // namespace
+
+/*!
+ * \brief An arena page header.
+ */
+struct ArenaPageHeader {
+ /*! \brief points to the next page. */
+ ArenaPageHeader* next;
+ /*!
+ * \brief Total size of the page.
+ */
+ size_t size;
+ /*! \brief memory allocator offset inside page. */
+ size_t offset;
+};
+
+/*!
+ * \brief Arena allocator that allocates memory from continuous
+ * chunk and frees them all only during destruction.
+ */
+template <typename PageAllocator>
+class GenericArena {
+ public:
+ explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) {
+ // eagerly allocate the first page.
+ head_ = tail_ = alloc_.allocate(1);
+ head_->next = nullptr;
+ }
+
+#if TVM_ARENA_HAS_DESTRUCTOR
+ ~GenericArena() { this->FreeAll(); }
+#endif
+
+ /*! \brief Free all pages. */
+ void FreeAll() {
+ FreePageList(&head_);
+ FreePageList(&free_list_);
+ }
+ /*! \brief Recycle all the pages in the arena */
+ void RecycleAll() {
+ // put all the current list to the free list.
+ tail_->next = free_list_;
+ // allocate the first in the free list to head
+ free_list_ = head_->next;
+ head_->next = nullptr;
+ // Reset the head.
+ head_->offset = sizeof(ArenaPageHeader);
+ tail_ = head_;
+ }
+ /*!
+ * \brief Allocate a space from Arena for type T
+ * \param T the data type to be allocated
+ * \param count Numberof elements
+ * \note The space of T is not initialized.
+ */
+ template <typename T>
+ T* allocate_(int count = 1) {
+ static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment");
+ return static_cast<T*>(Alloc(sizeof(T) * count, alignof(T)));
+ }
+ /*!
+ * \brief Create a new instance of type T.
+ * \param args The constructor argument.
+ * \tparam T the type to be created.
+ * \tparam Args Arguments to the constructor.
+ *
+ * \return The allocated object.
+ * \note The type T must be simple type, or only contain
+ * memory allocated from the same arena.
+ * Otherwise the destructor needs to be called explicitly.
+ */
+ template <typename T, typename... Args>
+ T* make(Args&&... args) {
+ T* ptr = allocate_<T>();
+ new (ptr) T(forward<Args>(args)...);
+ return ptr;
+ }
+
+ private:
+ /*! \brief internal page allocator. */
+ PageAllocator alloc_;
+ /* \brief The the head of the allocated list. */
+ ArenaPageHeader* head_{nullptr};
+ /*! \brief The tail of the allocated list. */
+ ArenaPageHeader* tail_{nullptr};
+ /* \brief List of free pages. */
+ ArenaPageHeader* free_list_{nullptr};
+ /*!
+ * \brief Align ptr by upper bound.
+ * \param offset The offset value.
+ * \param align The alignment requirement.
+ */
+ size_t UpperAlign(size_t offset, size_t align) {
+ return offset + (align - (offset % align)) % align;
+ }
+ /*!
+ * \brief Internal aligned alloc function.
+ * \param size The size of the memory.
+ * \param align The alignment requirement.
+ */
+ void* Alloc(size_t size, size_t align) {
+ size_t offset = UpperAlign(head_->offset, align);
+ if (offset + size <= head_->size) {
+ head_->offset = offset + size;
+ return reinterpret_cast<char*>(head_) + offset;
+ } else {
+ ArenaPageHeader* new_head;
+ offset = UpperAlign(sizeof(ArenaPageHeader), align);
+ if (free_list_ != nullptr && offset + size <= free_list_->size) {
+ new_head = free_list_;
+ free_list_ = free_list_->next;
+ } else {
+ new_head = alloc_.allocate(offset + size);
+ }
+ new_head->next = head_;
+ new_head->offset = offset + size;
+ head_ = new_head;
+ return reinterpret_cast<char*>(head_) + offset;
+ }
+ }
+ /*!
+ * \brief Free all the pages in the list.
+ * \param ptr The head ptr.
+ */
+ void FreePageList(ArenaPageHeader** ptr) {
+ // delete all the allocated pages.
+ while (ptr[0] != nullptr) {
+ ArenaPageHeader* temp = ptr[0];
+ ptr[0] = ptr[0]->next;
+ alloc_.deallocate(temp);
+ }
+ }
+};
+
+} // namespace support
+} // namespace tvm
+#endif // TVM_SUPPORT_GENERIC_ARENA_H_
TVM_REGISTER_TARGET_KIND("c", kDLCPU)
.add_attr_option<Bool>("system-lib")
.add_attr_option<String>("runtime")
+ .add_attr_option<String>("mcpu")
.set_default_keys({"cpu"});
TVM_REGISTER_TARGET_KIND("cuda", kDLGPU)
#include <tvm/topi/generic/injective.h>
TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
- *rv = topi::generic::schedule_injective(args[0], args[1]);
+ *rv = ::tvm::topi::generic::schedule_injective(args[0], args[1]);
});
TEST(MicroStandaloneRuntime, BuildModule) {
using namespace tvm;
- auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32));
+ auto tensor_type = relay::TensorType({2, 3}, ::tvm::runtime::DataType::Float(32));
auto a = relay::Var("a", tensor_type);
auto b = relay::Var("b", tensor_type);
auto add_op = relay::Op::Get("add");
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TESTS_CRT_BUFFER_WRITE_STREAM_H_
+#define TESTS_CRT_BUFFER_WRITE_STREAM_H_
+
+#include <inttypes.h>
+#include <tvm/runtime/crt/rpc_common/frame_buffer.h>
+#include <tvm/runtime/crt/rpc_common/write_stream.h>
+
+using ::tvm::runtime::micro_rpc::FrameBuffer;
+using ::tvm::runtime::micro_rpc::WriteStream;
+
+template <unsigned int N>
+class BufferWriteStream : public WriteStream {
+ public:
+ ssize_t Write(const uint8_t* data, size_t data_size_bytes) override {
+ return buffer_.Write(data, data_size_bytes);
+ }
+
+ void Reset() {
+ buffer_.Clear();
+ packet_done_ = false;
+ }
+
+ inline bool packet_done() { return packet_done_; }
+
+ inline bool is_valid() { return is_valid_; }
+
+ void PacketDone(bool is_valid) override {
+ EXPECT_FALSE(packet_done_);
+ packet_done_ = true;
+ is_valid_ = is_valid;
+ }
+
+ std::string BufferContents() { return std::string((const char*)buffer_data_, buffer_.Size()); }
+
+ static constexpr unsigned int capacity() { return N; };
+
+ private:
+ bool packet_done_{false};
+ bool is_valid_{false};
+ uint8_t buffer_data_[N];
+ FrameBuffer buffer_{buffer_data_, N};
+};
+
+#endif // TESTS_CRT_BUFFER_WRITE_STREAM_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/runtime/crt/memory.h>
+#include <tvm/runtime/crt/rpc_common/frame_buffer.h>
+#include <tvm/runtime/crt/rpc_common/framing.h>
+
+#include <string>
+#include <vector>
+
+#include "buffer_write_stream.h"
+#include "crt_config.h"
+#include "platform.cc"
+
+using ::tvm::runtime::micro_rpc::Escape;
+using ::tvm::runtime::micro_rpc::FrameBuffer;
+using ::tvm::runtime::micro_rpc::Framer;
+using ::tvm::runtime::micro_rpc::Unframer;
+
+class FramerTest : public ::testing::Test {
+ protected:
+ BufferWriteStream<300> write_stream_;
+ Framer framer_{&write_stream_};
+};
+
+class TestPacket {
+ public:
+ static std::vector<const TestPacket*> instances;
+
+ // NOTE: take payload and wire as arrays to avoid clipping at \0
+ template <int N, int M>
+ TestPacket(const std::string name, const char (&payload)[N], const char (&wire)[M])
+ : name{name}, payload{payload, N - 1}, wire{wire, M - 1} { // omit trailing \0
+ instances.emplace_back(this);
+ }
+
+ inline const uint8_t* payload_data() const {
+ return reinterpret_cast<const uint8_t*>(payload.data());
+ }
+
+ inline const uint8_t* wire_data() const { return reinterpret_cast<const uint8_t*>(wire.data()); }
+
+ std::string name;
+ std::string payload;
+ std::string wire;
+};
+
+void PrintTo(const TestPacket* p, std::ostream* os) {
+ *os << "TestPacket(\"" << p->name << "\", ...)";
+}
+
+void PrintTo(tvm_crt_error_t p, std::ostream* os) {
+ std::ios_base::fmtflags f(os->flags());
+ *os << "tvm_crt_error_t(0x" << std::hex << std::setw(8) << std::setfill('0') << p << ")";
+ os->flags(f);
+}
+
+std::vector<const TestPacket*> TestPacket::instances;
+
+#define TEST_PACKET(name, payload, wire) \
+ static const TestPacket k##name { #name, payload, wire }
+
+// NOTE: golden packet CRCs are generated with this python:
+// import binascii
+// import struct
+// struct.pack('<H', binascii.crc_hqx('\xff\xfd\x05\0\0\0three', 0xffff))
+
+TEST_PACKET(Packet1, "one", "\xff\xfd\3\0\0\0one\x58\xf4");
+TEST_PACKET(Packet2, "two2", "\xff\xfd\4\0\0\0two2\x13\x11");
+TEST_PACKET(Packet3, "three", "\xff\xfd\5\0\0\0threec\x9f");
+TEST_PACKET(EscapeCodeInSizePacket,
+ "this payload is exactly 255 characters long. chunk is 64 bytes. "
+ "this payload is exactly 255 characters long. chunk is 64 bytes. "
+ "this payload is exactly 255 characters long. chunk is 64 bytes. "
+ "this payload is exactly 255 characters long. chunk is 64 bytes.",
+ "\xff\xfd\xff\xff\0\0\0"
+ "this payload is exactly 255 characters long. chunk is 64 bytes. "
+ "this payload is exactly 255 characters long. chunk is 64 bytes. "
+ "this payload is exactly 255 characters long. chunk is 64 bytes. "
+ "this payload is exactly 255 characters long. chunk is 64 bytes."
+ "6~");
+TEST_PACKET(ZeroLengthPacket, "", "\xff\xfd\0\0\0\0\203D");
+
+// Generated with:
+// import binascii
+// import random
+// import string
+// import struct
+// escaped_prefix = b'es_\xff\xff_cape'
+// crc = b''
+// while b'\xff' not in crc:
+// suffix = bytes(''.join(random.choices(string.printable, k=10)), 'utf-8')
+// packet = b'\xff\xfd' + struct.pack('<I', len(escaped_prefix + suffix)) + escaped_prefix +
+// suffix crc = struct.pack('<H', binascii.crc_hqx(packet, 0xffff))
+// print(suffix)
+// print(packet + crc.replace(b'\xff', b'\xff\xff'))
+TEST_PACKET(EscapePacket, "es_\xff_capeir/^>t@\"hr",
+ "\xff\xfd\x13\0\0\0es_\xff\xff_capeir/^>t@\"hr\xb4\xff\xff");
+
+TEST_F(FramerTest, ValidPacketTrain) {
+ EXPECT_EQ(kTvmErrorNoError, framer_.Write(kPacket1.payload_data(), kPacket1.payload.size()));
+ EXPECT_EQ(kTvmErrorNoError, framer_.Write(kPacket2.payload_data(), kPacket2.payload.size()));
+ framer_.Reset();
+ EXPECT_EQ(kTvmErrorNoError, framer_.Write(kPacket3.payload_data(), kPacket3.payload.size()));
+
+ EXPECT_EQ("\xfe" + kPacket1.wire + // packet1 plus nop prefix.
+ kPacket2.wire + // packet2, no prefix.
+ "\xfe" + kPacket3.wire, // packet3 plus nop prefix.
+ write_stream_.BufferContents());
+}
+
+TEST_F(FramerTest, ZeroLengthPacket) {
+ EXPECT_EQ(kTvmErrorNoError,
+ framer_.Write(kZeroLengthPacket.payload_data(), kZeroLengthPacket.payload.size()));
+ EXPECT_EQ("\xfe" + kZeroLengthPacket.wire, write_stream_.BufferContents());
+}
+
+TEST_F(FramerTest, Escapes) {
+ EXPECT_EQ(kTvmErrorNoError,
+ framer_.Write(kEscapePacket.payload_data(), kEscapePacket.payload.size()));
+ EXPECT_EQ("\xfe" + kEscapePacket.wire, write_stream_.BufferContents());
+}
+
+class UnframerTest : public ::testing::Test {
+ protected:
+ BufferWriteStream<300> write_stream_;
+ Unframer unframer_{&write_stream_};
+};
+
+TEST_F(UnframerTest, PacketTooLong) {
+ const uint8_t escape[2] = {uint8_t(Escape::kEscapeStart), uint8_t(Escape::kPacketStart)};
+ uint16_t crc = crc16_compute(escape, sizeof(escape), nullptr);
+ size_t bytes_consumed;
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write(escape, sizeof(escape), &bytes_consumed));
+ EXPECT_EQ(sizeof(escape), bytes_consumed);
+
+ uint32_t packet_length = write_stream_.capacity() + 1;
+ uint8_t* packet_length_bytes = reinterpret_cast<uint8_t*>(&packet_length);
+ for (size_t i = 0; i < sizeof(packet_length); i++) {
+ ASSERT_NE('\xff', packet_length_bytes[i]);
+ }
+ crc = crc16_compute(packet_length_bytes, sizeof(packet_length), &crc);
+ EXPECT_EQ(kTvmErrorNoError,
+ unframer_.Write(packet_length_bytes, sizeof(packet_length), &bytes_consumed));
+ EXPECT_EQ(sizeof(packet_length), bytes_consumed);
+
+ uint8_t long_payload[decltype(write_stream_)::capacity() + 1];
+ for (size_t i = 0; i < sizeof(long_payload); i++) {
+ long_payload[i] = i & 0xff;
+ if (long_payload[i] == uint8_t(Escape::kEscapeStart)) {
+ long_payload[i] = 0;
+ }
+ }
+ crc = crc16_compute(long_payload, sizeof(long_payload), &crc);
+ EXPECT_EQ(kTvmErrorWriteStreamShortWrite,
+ unframer_.Write(long_payload, sizeof(long_payload), &bytes_consumed));
+ EXPECT_EQ(write_stream_.capacity(), bytes_consumed);
+
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write((uint8_t*)&crc, sizeof(crc), &bytes_consumed));
+ EXPECT_EQ(2, bytes_consumed); // 2, because framer is now in kFindPacketStart.
+ EXPECT_FALSE(write_stream_.packet_done());
+ EXPECT_FALSE(write_stream_.is_valid());
+ EXPECT_EQ(std::string((char*)long_payload, write_stream_.capacity()),
+ write_stream_.BufferContents());
+
+ // Writing a smaller packet directly afterward should work.
+ write_stream_.Reset();
+ EXPECT_EQ(kTvmErrorNoError,
+ unframer_.Write(kPacket1.wire_data(), kPacket1.wire.size(), &bytes_consumed));
+ EXPECT_EQ(kPacket1.wire.size(), bytes_consumed);
+ EXPECT_TRUE(write_stream_.packet_done());
+ EXPECT_TRUE(write_stream_.is_valid());
+ EXPECT_EQ(kPacket1.payload, write_stream_.BufferContents());
+};
+
+class UnframerTestParameterized : public UnframerTest,
+ public ::testing::WithParamInterface<const TestPacket*> {};
+
+TEST_P(UnframerTestParameterized, TestFullPacket) {
+ size_t bytes_consumed;
+ EXPECT_EQ(kTvmErrorNoError,
+ unframer_.Write(GetParam()->wire_data(), GetParam()->wire.size(), &bytes_consumed));
+ EXPECT_EQ(GetParam()->wire.size(), bytes_consumed);
+ EXPECT_TRUE(write_stream_.packet_done());
+ EXPECT_TRUE(write_stream_.is_valid());
+ EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents());
+}
+
+TEST_P(UnframerTestParameterized, TestByteAtATime) {
+ size_t bytes_consumed;
+ size_t wire_size = GetParam()->wire.size();
+ for (size_t i = 0; i < wire_size; i++) {
+ EXPECT_EQ(kTvmErrorNoError,
+ unframer_.Write(reinterpret_cast<const uint8_t*>(&GetParam()->wire[i]), 1,
+ &bytes_consumed));
+ EXPECT_EQ(1, bytes_consumed);
+ EXPECT_EQ(i == wire_size - 1, write_stream_.packet_done());
+ }
+ EXPECT_TRUE(write_stream_.is_valid());
+ EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents());
+}
+
+TEST_P(UnframerTestParameterized, TestArbitraryBoundary) {
+ size_t bytes_consumed;
+ size_t wire_size = GetParam()->wire.size();
+ for (size_t i = 1; i < wire_size; i++) {
+ unframer_.Reset();
+ write_stream_.Reset();
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), i, &bytes_consumed));
+ EXPECT_EQ(i, bytes_consumed);
+ EXPECT_FALSE(write_stream_.packet_done());
+ EXPECT_EQ(kTvmErrorNoError,
+ unframer_.Write(&GetParam()->wire_data()[i], wire_size - i, &bytes_consumed));
+ EXPECT_EQ(wire_size - i, bytes_consumed);
+ EXPECT_TRUE(write_stream_.packet_done());
+ EXPECT_TRUE(write_stream_.is_valid());
+ EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents());
+ }
+}
+
+TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) {
+ size_t bytes_consumed;
+ size_t wire_size = GetParam()->wire.size();
+
+ // This test interrupts packet transmission at an arbitrary point in the packet and restarts from
+ // the beginning. It simulates handling a device reset in the protocol. The behavior of the framer
+ // depends on how much of the packet had been transmitted, so the test is split into parts:
+
+ // Part 1. Restarting during the initial escape sequence.
+ unframer_.Reset();
+ write_stream_.Reset();
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed));
+ EXPECT_EQ(1, bytes_consumed);
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed));
+ EXPECT_EQ(wire_size, bytes_consumed);
+ EXPECT_TRUE(write_stream_.packet_done());
+ EXPECT_TRUE(write_stream_.is_valid());
+ EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents());
+
+ // Part 2. Restarting after the initial escape sequence.
+ for (size_t i = 2; i < wire_size; i++) {
+ unframer_.Reset();
+ write_stream_.Reset();
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), i, &bytes_consumed));
+ EXPECT_EQ(i, bytes_consumed);
+
+ // First test byte-by-byte interruption.
+ // Interrupt the packet transmission. The first byte will return no error as it is the escape
+ // byte.
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed));
+ EXPECT_EQ(1, bytes_consumed);
+ EXPECT_FALSE(write_stream_.packet_done());
+
+ // Secondt byte will return a short packet error.
+ EXPECT_EQ(kTvmErrorFramingShortPacket,
+ unframer_.Write(&GetParam()->wire_data()[1], 1, &bytes_consumed));
+ EXPECT_EQ(0, bytes_consumed);
+ EXPECT_FALSE(write_stream_.packet_done());
+
+ EXPECT_EQ(kTvmErrorNoError,
+ unframer_.Write(&GetParam()->wire_data()[1], wire_size - 1, &bytes_consumed));
+ EXPECT_EQ(wire_size - 1, bytes_consumed);
+ EXPECT_TRUE(write_stream_.packet_done());
+ EXPECT_TRUE(write_stream_.is_valid());
+ EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents());
+
+ // Next, test interruption just by sending the whole payload at once.
+ unframer_.Reset();
+ write_stream_.Reset();
+ EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), i, &bytes_consumed));
+ EXPECT_EQ(i, bytes_consumed);
+
+ // Interrupt the packet transmission. The first Write() call will just consume 1 byte to reset
+ // the internal state.
+ EXPECT_EQ(kTvmErrorFramingShortPacket,
+ unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed));
+ EXPECT_EQ(1, bytes_consumed);
+ EXPECT_FALSE(write_stream_.packet_done());
+ EXPECT_EQ(kTvmErrorNoError,
+ unframer_.Write(&GetParam()->wire_data()[1], wire_size - 1, &bytes_consumed));
+ EXPECT_EQ(wire_size - 1, bytes_consumed);
+ EXPECT_TRUE(write_stream_.packet_done());
+ EXPECT_TRUE(write_stream_.is_valid());
+ EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents());
+
+ break;
+ }
+}
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+INSTANTIATE_TEST_CASE_P(UnframerTests, UnframerTestParameterized,
+ ::testing::ValuesIn(TestPacket::instances));
+#pragma GCC diagnostic pop
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ testing::FLAGS_gtest_death_test_style = "threadsafe";
+ return RUN_ALL_TESTS();
+}
#include <tvm/runtime/crt/func_registry.h>
#include <tvm/runtime/crt/internal/common/func_registry.h>
+#include "platform.cc"
+
typedef struct {
const char* a;
const char* b;
#include <tvm/runtime/crt/memory.h>
#include "crt_config.h"
+#include "platform.cc"
#define ROUND_UP(qty, modulo) (((qty) + ((modulo)-1)) / (modulo) * (modulo))
EXPECT_EQ(vleak_size, 0);
}
-extern "C" {
-void TVMPlatformAbort(int error_code) { FAIL() << "TVMPlatformAbort(" << error_code << ")"; }
-}
-
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
* under the License.
*/
-/*!
- * \file utvm_init.c
- * \brief uTVM init definition for the host emulated device
- */
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <stdarg.h>
+#include <tvm/runtime/crt/platform.h>
-#ifdef __cplusplus
extern "C" {
-#endif
-
-#include "utvm_runtime.h"
-
-void UTVMInit() {
- // no init required for the host
- UTVMMain();
+void InternalTVMPlatformAbort(tvm_crt_error_t error_code) {
+ FAIL() << "TVMPlatformAbort(" << error_code << ")";
}
+void TVMPlatformAbort(tvm_crt_error_t error_code) {
+ InternalTVMPlatformAbort(error_code);
+ exit(2); // for __attribute__((noreturn))
+}
+void* TVMSystemLibEntryPoint() { return NULL; }
+void TVMLogf(const char* fmt, ...) {
+ va_list args;
+ char log_buf[1024];
+ va_start(args, fmt);
+ int ret = vsnprintf(log_buf, sizeof(log_buf), fmt, args);
+ va_end(args);
-#ifdef __cplusplus
-} // TVM_EXTERN_C
-#endif
+ if (ret < 0) {
+ LOG(ERROR) << "TVMLogf: error formatting: " << fmt;
+ } else {
+ LOG(INFO) << "TVMLogf: " << std::string(log_buf, ret);
+ }
+}
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/runtime/crt/memory.h>
+#include <tvm/runtime/crt/rpc_common/frame_buffer.h>
+#include <tvm/runtime/crt/rpc_common/session.h>
+
+#include <string>
+#include <vector>
+
+#include "buffer_write_stream.h"
+#include "crt_config.h"
+#include "platform.cc"
+
+using ::tvm::runtime::micro_rpc::Framer;
+using ::tvm::runtime::micro_rpc::MessageType;
+using ::tvm::runtime::micro_rpc::Session;
+using ::tvm::runtime::micro_rpc::Unframer;
+
+extern "C" {
+void TestSessionMessageReceivedThunk(void* context, MessageType message_type, FrameBuffer* buf);
+}
+
+class ReceivedMessage {
+ public:
+ ReceivedMessage(MessageType type, std::string message) : type{type}, message{message} {}
+
+ bool operator==(const ReceivedMessage& other) const {
+ return other.type == type && other.message == message;
+ }
+
+ MessageType type;
+ std::string message;
+};
+
+class TestSession {
+ public:
+ TestSession(uint8_t initial_nonce)
+ : framer{&framer_write_stream},
+ receive_buffer{receive_buffer_array, sizeof(receive_buffer_array)},
+ sess{initial_nonce, &framer, &receive_buffer, TestSessionMessageReceivedThunk, this},
+ unframer{sess.Receiver()} {}
+
+ void WriteTo(TestSession* other) {
+ auto framer_buffer = framer_write_stream.BufferContents();
+ size_t bytes_to_write = framer_buffer.size();
+ const uint8_t* write_cursor = reinterpret_cast<const uint8_t*>(framer_buffer.data());
+ while (bytes_to_write > 0) {
+ size_t bytes_consumed;
+ auto to_return = other->unframer.Write(write_cursor, bytes_to_write, &bytes_consumed);
+ EXPECT_EQ(to_return, kTvmErrorNoError);
+ bytes_to_write -= bytes_consumed;
+ write_cursor += bytes_consumed;
+ }
+ }
+
+ void ClearBuffers() {
+ framer_write_stream.Reset();
+ messages_received.clear();
+ sess.ClearReceiveBuffer();
+ }
+
+ std::vector<ReceivedMessage> messages_received;
+ BufferWriteStream<300> framer_write_stream;
+ Framer framer;
+ uint8_t receive_buffer_array[300];
+ FrameBuffer receive_buffer;
+ Session sess;
+ Unframer unframer;
+};
+
+#define EXPECT_FRAMED_PACKET(session, expected) \
+ EXPECT_EQ(std::string(expected, sizeof(expected) - 1), \
+ (session).framer_write_stream.BufferContents());
+
+extern "C" {
+void TestSessionMessageReceivedThunk(void* context, MessageType message_type, FrameBuffer* buf) {
+ std::string message;
+ if (message_type != MessageType::kStartSessionReply) {
+ uint8_t message_buf[300];
+ EXPECT_LE(buf->ReadAvailable(), sizeof(message_buf));
+ size_t message_size_bytes = buf->Read(message_buf, sizeof(message_buf));
+ message = std::string(reinterpret_cast<char*>(message_buf), message_size_bytes);
+ }
+
+ static_cast<TestSession*>(context)->messages_received.emplace_back(
+ ReceivedMessage(message_type, message));
+}
+}
+
+void PrintTo(tvm_crt_error_t p, std::ostream* os) {
+ std::ios_base::fmtflags f(os->flags());
+ *os << "tvm_crt_error_t(0x" << std::hex << std::setw(8) << std::setfill('0') << p << ")";
+ os->flags(f);
+}
+
+void PrintTo(ReceivedMessage msg, std::ostream* os) {
+ *os << "ReceivedMessage(" << int(msg.type) << ", \"" << msg.message << "\")";
+}
+
+class SessionTest : public ::testing::Test {
+ public:
+ static constexpr const uint8_t kAliceNonce = 0x3c;
+ static constexpr const uint8_t kBobNonce = 0xab;
+
+ TestSession alice_{kAliceNonce};
+ TestSession bob_{kBobNonce};
+};
+
+TEST_F(SessionTest, NormalExchange) {
+ tvm_crt_error_t err;
+ err = alice_.sess.Initialize();
+ EXPECT_EQ(kTvmErrorNoError, err);
+ EXPECT_FRAMED_PACKET(alice_,
+ "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
+ "fw");
+ alice_.WriteTo(&bob_);
+
+ err = bob_.sess.Initialize();
+ EXPECT_EQ(kTvmErrorNoError, err);
+ EXPECT_FRAMED_PACKET(bob_,
+ "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
+ "fw");
+ alice_.WriteTo(&alice_);
+
+ bob_.ClearBuffers();
+ alice_.ClearBuffers();
+
+ err = alice_.sess.StartSession();
+ EXPECT_EQ(err, kTvmErrorNoError);
+ EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\x04\0\0\0\x82\0\0\x01{\xE9");
+
+ bob_.ClearBuffers();
+ alice_.WriteTo(&bob_);
+ EXPECT_FRAMED_PACKET(bob_,
+ "\xff\xfd\x4\0\0\0\x82"
+ "f\x01\x01\x81\xf3");
+ EXPECT_TRUE(bob_.sess.IsEstablished());
+
+ bob_.WriteTo(&alice_);
+ EXPECT_TRUE(alice_.sess.IsEstablished());
+ ASSERT_EQ(alice_.messages_received.size(), 1);
+ EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, ""));
+
+ alice_.ClearBuffers();
+ alice_.sess.SendMessage(MessageType::kNormal, reinterpret_cast<const uint8_t*>("hello"), 5);
+ EXPECT_FRAMED_PACKET(alice_,
+ "\xFF\xFD\b\0\0\0\x82"
+ "f\x10hello\x90(");
+ alice_.WriteTo(&bob_);
+ ASSERT_EQ(bob_.messages_received.size(), 2);
+ EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, ""));
+ EXPECT_EQ(bob_.messages_received[1], ReceivedMessage(MessageType::kNormal, "hello"));
+
+ bob_.ClearBuffers();
+ bob_.sess.SendMessage(MessageType::kNormal, reinterpret_cast<const uint8_t*>("olleh"), 5);
+ EXPECT_FRAMED_PACKET(bob_,
+ "\xff\xfd\b\0\0\0\x82"
+ "f\x10ollehLv");
+ bob_.WriteTo(&alice_);
+ ASSERT_EQ(alice_.messages_received.size(), 1);
+ EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kNormal, "olleh"));
+
+ alice_.ClearBuffers();
+ bob_.ClearBuffers();
+
+ alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("log1"), 4);
+ EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4");
+ alice_.WriteTo(&bob_);
+ ASSERT_EQ(bob_.messages_received.size(), 1);
+ EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1"));
+
+ bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("zero"), 4);
+ EXPECT_FRAMED_PACKET(bob_, "\xff\xfd\a\0\0\0\0\0\x03zero\xb2h");
+ bob_.WriteTo(&alice_);
+ ASSERT_EQ(alice_.messages_received.size(), 1);
+ EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero"));
+}
+
+TEST_F(SessionTest, LogBeforeSessionStart) {
+ alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("log1"), 4);
+ EXPECT_FRAMED_PACKET(alice_, "\xfe\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4");
+ alice_.WriteTo(&bob_);
+ ASSERT_EQ(bob_.messages_received.size(), 1);
+ EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1"));
+
+ bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast<const uint8_t*>("zero"), 4);
+ EXPECT_FRAMED_PACKET(bob_, "\xfe\xff\xfd\a\0\0\0\0\0\x03zero\xb2h");
+ bob_.WriteTo(&alice_);
+ ASSERT_EQ(alice_.messages_received.size(), 1);
+ EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero"));
+}
+
+static constexpr const char kBobStartPacket[] = "\xff\xfd\x04\0\0\0f\0\0\x01`\xa7";
+
+TEST_F(SessionTest, DoubleStart) {
+ tvm_crt_error_t err;
+ err = alice_.sess.Initialize();
+ EXPECT_EQ(kTvmErrorNoError, err);
+ EXPECT_FRAMED_PACKET(alice_,
+ "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
+ "fw");
+ alice_.WriteTo(&bob_);
+
+ err = bob_.sess.Initialize();
+ EXPECT_EQ(kTvmErrorNoError, err);
+ EXPECT_FRAMED_PACKET(bob_,
+ "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
+ "fw");
+ alice_.WriteTo(&alice_);
+
+ bob_.ClearBuffers();
+ alice_.ClearBuffers();
+
+ EXPECT_EQ(kTvmErrorNoError, alice_.sess.StartSession());
+ EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\x04\0\0\0\x82\0\0\x01{\xe9");
+ EXPECT_FALSE(alice_.sess.IsEstablished());
+
+ EXPECT_EQ(kTvmErrorNoError, bob_.sess.StartSession());
+ EXPECT_FRAMED_PACKET(bob_, kBobStartPacket);
+ EXPECT_FALSE(bob_.sess.IsEstablished());
+
+ // Sending Alice -> Bob should have no effect (regenerated Bob nonce > regenerated Alice nonce).
+ bob_.framer_write_stream.Reset();
+ alice_.WriteTo(&bob_);
+ EXPECT_FRAMED_PACKET(bob_, "");
+ EXPECT_FALSE(bob_.sess.IsEstablished());
+
+ // Sending Bob -> Alice should start the session.
+ alice_.ClearBuffers();
+ size_t bytes_consumed;
+ EXPECT_EQ(kTvmErrorNoError,
+ alice_.unframer.Write(reinterpret_cast<const uint8_t*>(kBobStartPacket),
+ sizeof(kBobStartPacket), &bytes_consumed));
+ EXPECT_EQ(bytes_consumed, sizeof(kBobStartPacket));
+ EXPECT_FRAMED_PACKET(alice_, "\xFF\xFD\x4\0\0\0fE\x01\x01\fb");
+ EXPECT_TRUE(alice_.sess.IsEstablished());
+
+ bob_.ClearBuffers();
+ alice_.WriteTo(&bob_);
+ EXPECT_TRUE(bob_.sess.IsEstablished());
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ testing::FLAGS_gtest_death_test_style = "threadsafe";
+ return RUN_ALL_TESTS();
+}
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import contextlib
+import copy
+import glob
+import os
+import pty
+import sys
+import subprocess
+import textwrap
+
+import numpy as np
+
+import tvm
+import tvm.relay
+import tvm.micro
+from tvm.micro import transport
+
+from tvm.topi.util import get_const_tuple
+from tvm.topi.testing import conv2d_nchw_python
+
+BUILD = True
+DEBUG = False
+
+TARGET = tvm.target.target.micro('host')
+
+def _make_sess_from_op(workspace, op_name, sched, arg_bufs):
+ with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}):
+ mod = tvm.build(sched, arg_bufs, TARGET, target_host=TARGET, name=op_name)
+
+ return _make_session(workspace, mod)
+
+
+def _make_session(workspace, mod):
+ compiler = tvm.micro.DefaultCompiler(target=TARGET)
+ opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, 'host'))
+
+ micro_binary = tvm.micro.build_static_runtime(
+ # the x86 compiler *expects* you to give the exact same dictionary for both
+ # lib_opts and bin_opts. so the library compiler is mutating lib_opts and
+ # the binary compiler is expecting those mutations to be in bin_opts.
+ # TODO(weberlo) fix this very bizarre behavior
+ workspace, compiler, mod, lib_opts=opts['bin_opts'], bin_opts=opts['bin_opts'])
+
+ flasher_kw = {
+ 'debug': DEBUG,
+ }
+ flasher = compiler.flasher(**flasher_kw)
+ return tvm.micro.Session(binary=micro_binary, flasher=flasher)
+
+
+def _make_add_sess(workspace):
+ A = tvm.te.placeholder((2,), dtype='int8')
+ B = tvm.te.placeholder((1,), dtype='int8')
+ C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name='C')
+ sched = tvm.te.create_schedule(C.op)
+ return _make_sess_from_op(workspace, 'add', sched, [A, B, C])
+
+
+def _make_ident_sess(workspace):
+ A = tvm.te.placeholder((2,), dtype='int8')
+ B = tvm.te.compute(A.shape, lambda i: A[i], name='B')
+ sched = tvm.te.create_schedule(B.op)
+ return _make_sess_from_op(workspace, 'ident', sched, [A, B])
+
+
+def test_compile_runtime():
+ """Test compiling the on-device runtime."""
+ workspace = tvm.micro.Workspace()
+
+ with _make_add_sess(workspace) as sess:
+ A_data = tvm.nd.array(np.array([2, 3], dtype='int8'), ctx=sess.context)
+ assert (A_data.asnumpy() == np.array([2, 3])).all()
+ B_data = tvm.nd.array(np.array([4], dtype='int8'), ctx=sess.context)
+ assert (B_data.asnumpy() == np.array([4])).all()
+ C_data = tvm.nd.array(np.array([0, 0], dtype='int8'), ctx=sess.context)
+ assert (C_data.asnumpy() == np.array([0, 0])).all()
+
+ system_lib = sess.get_system_lib()
+ system_lib.get_function('add')(A_data, B_data, C_data)
+ assert (C_data.asnumpy() == np.array([6, 7])).all()
+
+
+def test_reset():
+ """Test when the remote end resets during a session."""
+ workspace = tvm.micro.Workspace()
+
+ with _make_add_sess(workspace) as sess:
+ try:
+ sess._rpc.get_function('tvm.testing.reset_server')()
+ assert False, 'expected to raise SessionTerminatedError; did not raise'
+ except transport.SessionTerminatedError:
+ pass
+
+
+def test_graph_runtime():
+ """Test use of the graph runtime with microTVM."""
+ workspace = tvm.micro.Workspace()
+ relay_mod = tvm.parser.fromtext(
+ """
+ #[version = "0.0.5"]
+ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) {
+ %0 = %a + %b;
+ %0
+ }""")
+
+ with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}):
+ factory = tvm.relay.build(relay_mod, target=TARGET)
+
+ with _make_session(workspace, factory.get_lib()) as sess:
+ graph_mod = tvm.micro.create_local_graph_runtime(factory.get_json(), sess.get_system_lib(), sess.context)
+ A_data = tvm.nd.array(np.array([2, 3], dtype='uint8'), ctx=sess.context)
+ assert (A_data.asnumpy() == np.array([2, 3])).all()
+ B_data = tvm.nd.array(np.array([4, 7], dtype='uint8'), ctx=sess.context)
+ assert (B_data.asnumpy() == np.array([4, 7])).all()
+
+ graph_mod.run(a=A_data, b=B_data)
+
+ out = graph_mod.get_output(0)
+ assert (out.asnumpy() == np.array([6, 10])).all()
+
+
+if __name__ == '__main__':
+ test_compile_runtime()
+ test_reset()
+ test_graph_runtime()
+++ /dev/null
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import os
-
-import numpy as np
-import tvm
-from tvm import te
-from tvm.contrib import graph_runtime, util
-from tvm import relay
-import tvm.micro as micro
-from tvm.micro import create_micro_mod
-
-# # Use the host emulated micro device.
-DEV_CONFIG_A = micro.device.host.generate_config()
-DEV_CONFIG_B = micro.device.host.generate_config()
-TARGET = "c --runtime=c"
-
-
-def relay_micro_build(func, dev_config, params=None):
- """Create a graph runtime module with a micro device context from a Relay function.
-
- Parameters
- ----------
- func : relay.Function
- function to compile
-
- dev_config : Dict[str, Any]
- MicroTVM config dict for the target device
-
- params : dict
- input parameters that do not change during inference
-
- Return
- ------
- mod : tvm.runtime.Module
- graph runtime module for the target device
- """
- with tvm.transform.PassContext(
- disabled_pass={"FuseOps"}, config={"tir.disable_vectorize": True}
- ):
- graph, c_mod, params = relay.build(func, target=TARGET, params=params)
- micro_mod = micro.create_micro_mod(c_mod, dev_config)
- ctx = tvm.micro_dev(0)
- mod = graph_runtime.create(graph, micro_mod, ctx)
- mod.set_input(**params)
- return mod
-
-
-GDB_INIT_TEMPLATE = """
-layout asm
-target remote localhost:{gdb_port}
-set $pc = UTVMInit
-break UTVMDone
-"""
-
-
-def reset_gdbinit():
- if "server_port" not in DEV_CONFIG_A:
- return
- gdb_init_dir = os.environ["MICRO_GDB_INIT_DIR"]
- with open(f"{gdb_init_dir}/.gdbinit", "w") as f:
- gdb_port = DEV_CONFIG_A["server_port"] - 3333
- f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port))
-
-
-def test_alloc():
- """Test tensor allocation on the device."""
- if not tvm.runtime.enabled("micro_dev"):
- return
- shape = (1024,)
- dtype = "float32"
- with micro.Session(DEV_CONFIG_A):
- ctx = tvm.micro_dev(0)
- np_tensor = np.random.uniform(size=shape).astype(dtype)
- micro_tensor = tvm.nd.array(np_tensor, ctx)
- tvm.testing.assert_allclose(np_tensor, micro_tensor.asnumpy())
-
-
-def test_add():
- """Test a module which performs addition."""
- if not tvm.runtime.enabled("micro_dev"):
- return
- shape = (1024,)
- dtype = "float32"
-
- reset_gdbinit()
-
- # Construct TVM expression.
- tvm_shape = tvm.runtime.convert(shape)
- A = te.placeholder(tvm_shape, name="A", dtype=dtype)
- B = te.placeholder(tvm_shape, name="B", dtype=dtype)
- C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C")
- s = te.create_schedule(C.op)
-
- func_name = "fadd"
- c_mod = tvm.build(s, [A, B, C], target="c", name=func_name)
-
- with micro.Session(DEV_CONFIG_A) as sess:
- micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A)
- micro_func = micro_mod[func_name]
- ctx = tvm.micro_dev(0)
-
- a_np = np.random.uniform(size=shape).astype(dtype)
- a = tvm.nd.array(a_np, ctx)
- b_np = np.random.uniform(size=shape).astype(dtype)
- b = tvm.nd.array(b_np, ctx)
- c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
- micro_func(a, b, c)
-
- # ensure inputs weren't corrupted
- tvm.testing.assert_allclose(a.asnumpy(), a_np)
- tvm.testing.assert_allclose(b.asnumpy(), b_np)
- # ensure output is correct
- tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
-
-
-def test_workspace_add():
- """Test a module which uses a workspace to compute an intermediate value."""
- if not tvm.runtime.enabled("micro_dev"):
- return
- shape = (1024,)
- dtype = "float32"
-
- reset_gdbinit()
-
- # Construct TVM expression.
- tvm_shape = tvm.runtime.convert(shape)
- A = te.placeholder(tvm_shape, name="A", dtype=dtype)
- B = te.placeholder(tvm_shape, name="B", dtype=dtype)
- B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B")
- C = te.compute(A.shape, lambda *i: B(*i) + 1, name="C")
- s = te.create_schedule(C.op)
-
- func_name = "fadd_two_workspace"
- c_mod = tvm.build(s, [A, C], target="c", name=func_name)
-
- with micro.Session(DEV_CONFIG_A) as sess:
- micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A)
- micro_func = micro_mod[func_name]
- ctx = tvm.micro_dev(0)
- a_np = np.random.uniform(size=shape).astype(dtype)
- a = tvm.nd.array(a_np, ctx)
- c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
- micro_func(a, c)
-
- # ensure input wasn't corrupted
- tvm.testing.assert_allclose(a.asnumpy(), a_np)
- # ensure output is correct
- tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 2.0)
-
-
-def test_graph_runtime():
- """Test a program which uses the graph runtime."""
- if not tvm.runtime.enabled("micro_dev"):
- return
- shape = (1024,)
- dtype = "float32"
-
- # Construct Relay program.
- x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype))
- xx = relay.multiply(x, x)
- z = relay.add(xx, relay.const(1.0))
- func = relay.Function([x], z)
-
- with micro.Session(DEV_CONFIG_A):
- mod = relay_micro_build(func, DEV_CONFIG_A)
-
- x_in = np.random.uniform(size=shape[0]).astype(dtype)
- mod.run(x=x_in)
- result = mod.get_output(0).asnumpy()
-
- tvm.testing.assert_allclose(mod.get_input(0).asnumpy(), x_in)
- tvm.testing.assert_allclose(result, x_in * x_in + 1.0)
-
-
-def test_conv2d():
- if not tvm.runtime.enabled("micro_dev"):
- return
-
- from tvm.relay import create_executor
- from tvm.relay import transform
-
- dshape = (1, 4, 16, 16)
- dtype = "int8"
- func_name = "fused_nn_conv2d"
-
- reset_gdbinit()
-
- # Construct Relay program.
- x = relay.var("x", shape=dshape, dtype=dtype)
- conv_expr = relay.nn.conv2d(x, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=4)
- func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr)
- mod = tvm.IRModule.from_expr(func)
- mod = transform.InferType()(mod)
-
- x_shape = list(map(lambda x: x.value, mod["main"].params[0].checked_type.shape))
- w_shape = list(map(lambda x: x.value, mod["main"].params[1].checked_type.shape))
- out_shape = list(map(lambda x: x.value, mod["main"].ret_type.shape))
-
- with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
- graph, c_mod, params = relay.build(mod, target="c")
-
- with micro.Session(DEV_CONFIG_A):
- micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A)
- candidate_func_name = func_name
- for i in range(100):
- try:
- micro_func = micro_mod[candidate_func_name]
- break
- except tvm.TVMError as e:
- candidate_func_name = f"{func_name}_{i}"
- else:
- assert False
- ctx = tvm.micro_dev(0)
-
- x_data = tvm.nd.array(np.random.uniform(size=x_shape).astype(dtype), ctx)
- w_data = tvm.nd.array(np.random.uniform(size=w_shape).astype(dtype), ctx)
- result = tvm.nd.array(np.zeros(shape=out_shape, dtype=dtype), ctx)
- micro_func(x_data, w_data, result)
-
- out_data = np.zeros(out_shape, dtype=dtype)
- params = {"x": x_data.asnumpy(), "w": w_data.asnumpy()}
- intrp = create_executor("debug")
- expected_result = intrp.evaluate(mod["main"])(x_data, w_data)
-
- tvm.testing.assert_allclose(result.asnumpy(), expected_result.asnumpy())
-
-
-def test_interleave_sessions():
- """Test closing and reopening sessions."""
- if not tvm.runtime.enabled("micro_dev"):
- return
- shape = (1024,)
- dtype = "float32"
-
- # Construct Relay add program.
- x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype))
- ret = relay.add(x, relay.const(1.0))
- add_const_func = relay.Function([x], ret)
-
- sess_a = micro.Session(DEV_CONFIG_A)
- sess_b = micro.Session(DEV_CONFIG_B)
- with sess_a:
- np_tensor_a = np.random.uniform(size=shape).astype(dtype)
- micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0))
- with sess_b:
- np_tensor_b = np.random.uniform(size=shape).astype(dtype)
- micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0))
- with sess_a:
- add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A)
- add_const_mod.run(x=micro_tensor_a)
- add_result = add_const_mod.get_output(0).asnumpy()
- tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
- with sess_b:
- add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B)
- add_const_mod.run(x=micro_tensor_b)
- add_result = add_const_mod.get_output(0).asnumpy()
- tvm.testing.assert_allclose(add_result, np_tensor_b + 1.0)
-
-
-def test_nested_sessions():
- """Test entering and exiting nested session contexts."""
- if not tvm.runtime.enabled("micro_dev"):
- return
- shape = (1024,)
- dtype = "float32"
-
- # Construct Relay add program.
- x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype))
- ret = relay.add(x, relay.const(1.0))
- add_const_func = relay.Function([x], ret)
-
- sess_a = micro.Session(DEV_CONFIG_A)
- sess_b = micro.Session(DEV_CONFIG_B)
- with sess_a:
- np_tensor_a = np.random.uniform(size=shape).astype(dtype)
- micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0))
- with sess_b:
- np_tensor_b = np.random.uniform(size=shape).astype(dtype)
- micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0))
- add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A)
- add_const_mod.run(x=micro_tensor_a)
- add_result = add_const_mod.get_output(0).asnumpy()
- tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
-
-
-def test_inactive_session_use():
- """Test the use of objects allocated in a session that is no longer active."""
- if not tvm.runtime.enabled("micro_dev"):
- return
- shape = (1024,)
- dtype = "float32"
-
- # Construct Relay add program.
- x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype))
- ret = relay.add(x, relay.const(1.0))
- add_const_func = relay.Function([x], ret)
-
- sess_a = micro.Session(DEV_CONFIG_A)
- sess_b = micro.Session(DEV_CONFIG_B)
- with sess_a:
- np_tensor_a = np.random.uniform(size=shape).astype(dtype)
- micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0))
- add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A)
-
- with sess_b:
- # These objects belong to `sess_a`.
- add_const_mod.run(x=micro_tensor_a)
- add_result = add_const_mod.get_output(0).asnumpy()
- tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0)
-
-
-# TODO add workspace alloc/free stress test
-
-if __name__ == "__main__":
- test_alloc()
- print()
- print("finished alloc test")
- input("[press enter to continue]")
- test_add()
- print()
- print("finished add test")
- input("[press enter to continue]")
- test_workspace_add()
- print()
- print("finished workspace add test")
- input("[press enter to continue]")
- test_graph_runtime()
- print()
- print("finished graph runtime test")
- input("[press enter to continue]")
- test_conv2d()
- print()
- print("finished conv2d test")
- input("[press enter to continue]")
- test_interleave_sessions()
- print()
- print("finished interleaved sessions test")
- input("[press enter to continue]")
- test_nested_sessions()
- print()
- print("finished nested sessions test")
- input("[press enter to continue]")
- test_inactive_session_use()
- print()
- print("finished use inactive session test")
- input("[press enter to continue]")
echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake
echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake
echo set\(USE_ETHOSN /opt/arm/ethosn-driver\) >> config.cmake
-echo set\(USE_ETHOSN_HW OFF\) >> config.cmake
\ No newline at end of file
+echo set\(USE_ETHOSN_HW OFF\) >> config.cmake
echo set\(USE_SORT ON\) >> config.cmake
echo set\(USE_RPC ON\) >> config.cmake
echo set\(USE_GRAPH_RUNTIME_DEBUG ON\) >> config.cmake
+echo set\(USE_MICRO ON\) >> config.cmake
echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake
echo set\(USE_STANDALONE_CRT ON\) >> config.cmake
echo set\(USE_VM_PROFILER ON\) >> config.cmake
# Remove existing testcases
rm -f build/*_test
-make cpptest -j8
-make crttest -j8
+make cpptest -j2
+make crttest # NOTE: don't parallelize, due to issue with build deps.
for test in build/*_test; do
./$test
done
+
+# Test MISRA-C runtime
+cd apps/bundle_deploy
+rm -rf build
+make test_dynamic test_static
+cd ../..
# Test TVM
make cython3
-# Test MISRA-C runtime
-cd apps/bundle_deploy
-rm -rf build
-make test_dynamic test_static
-cd ../..
-
# Test extern package
cd apps/extension
rm -rf lib
============================
**Author**: `Tom Gall <https://github.com/tom-gall>`_
-This tutorial is an introduction to working with MicroTVM and a TFLite
+This tutorial is an introduction to working with MicroTVM and a TFLite
model with Relay.
"""
tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}
)
-# %%
-# Running on device
-# ----------------------------------------------
-#
-# Setup the device config which is what will be used to communicate
-# with the microcontroller (a STM32F746 Discovery board)
-TARGET = "c --system-lib --runtime=c"
-dev_config = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666)
-
-######################################################################
-# Next with the dev_config, we establish a micro session and create
-# a context
-#
-# .. code-block:: python
-#
-# with micro.Session(dev_config) as sess:
-# ctx = tvm.micro_dev(0)
-
######################################################################
# Now we create a build config for relay. turning off two options
# and then calling relay.build which will result in a C source
#
# .. code-block:: python
#
-# with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True},disabled_pass=['FuseOps']):
-# graph, c_mod, params = relay.build(mod, target=TARGET, params=params)
+TARGET = tvm.target.target.micro("host")
-######################################################################
-# With the c_mod that is the handle to our C source code, we create a
-# micro module, followed by a compiled object which behind the scenes
-# is linked to the microTVM runtime for running on the target board
-#
-# .. code-block:: python
-#
-# micro_mod = micro.create_micro_mod(c_mod, dev_config)
-# mod = graph_runtime.create(graph, micro_mod, ctx)
+with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True},disabled_pass=["FuseOps"]):
+ graph, c_mod, c_params = relay.build(mod, target=TARGET, params=params)
-######################################################################
-# Pass the weights to get ready to perform inference
-#
-# .. code-block:: python
-#
-# mod.set_input(**params)
-######################################################################
-# The model consumes a single float32 value and returns a predicted
-# sine value.
-# To pass the input value we construct a tvm.nd.array object
-# with a single contrived number as input. For this model values of
-# 0 to 2Pi are acceptable.
-#
-# .. code-block:: python
+# %%
+# Running on simulated device
+# ----------------------------------------------
#
-# mod.set_input(input_tensor, tvm.nd.array(np.array([0.5], dtype="float32")))
+# First, compile a static microTVM runtime for the targeted device. In this case, the host simulated
+# device is used.
+workspace = tvm.micro.Workspace()
+
+compiler = tvm.micro.DefaultCompiler(target=TARGET)
+opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, "host"))
+
+micro_binary = tvm.micro.build_static_runtime(
+ # the x86 compiler *expects* you to give the exact same dictionary for both
+ # lib_opts and bin_opts. so the library compiler is mutating lib_opts and
+ # the binary compiler is expecting those mutations to be in bin_opts.
+ # TODO(weberlo) fix this very bizarre behavior
+ workspace, compiler, c_mod, lib_opts=opts["bin_opts"], bin_opts=opts["bin_opts"])
-######################################################################
-# Run the model on device
-#
-# .. code-block:: python
-#
-# mod.run()
######################################################################
-# Get output from the run and print
+# Next, establish a session with the simulated device and run the
+# computation. The `with session` line would typically flash an attached
+# microcontroller, but in this tutorial, it simply launches a subprocess
+# to stand in for an attached microcontroller.
#
# .. code-block:: python
#
-# tvm_output = mod.get_output(0).asnumpy()
-# print("result is: "+str(tvm_output))
+flasher = compiler.flasher()
+with tvm.micro.Session(binary=micro_binary, flasher=flasher) as session:
+ graph_mod = tvm.micro.create_local_graph_runtime(
+ graph, session.get_system_lib(), session.context)
+
+ # Set the model parameters using the lowered parameters produced by `relay.build`.
+ graph_mod.set_input(**c_params)
+
+ # The model consumes a single float32 value and returns a predicted sine value. To pass the
+ # input value we construct a tvm.nd.array object with a single contrived number as input. For
+ # this model values of 0 to 2Pi are acceptable.
+ graph_mod.set_input(input_tensor, tvm.nd.array(np.array([0.5], dtype="float32")))
+ graph_mod.run()
+
+ tvm_output = graph_mod.get_output(0).asnumpy()
+ print("result is: "+str(tvm_output))