[WIP][µTVM] Add OpenOCD Low-Level Device (RISC-V Support) (#3756)
authorLogan Weber <36520469+weberlo@users.noreply.github.com>
Mon, 2 Sep 2019 07:32:52 +0000 (00:32 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 2 Sep 2019 07:32:52 +0000 (15:32 +0800)
12 files changed:
python/tvm/contrib/binutil.py
python/tvm/micro/base.py
src/runtime/micro/low_level_device.h
src/runtime/micro/micro_common.cc
src/runtime/micro/micro_session.cc
src/runtime/micro/micro_session.h
src/runtime/micro/openocd_low_level_device.cc [new file with mode: 0644]
src/runtime/micro/tcl_socket.cc [new file with mode: 0644]
src/runtime/micro/tcl_socket.h [new file with mode: 0644]
tests/python/unittest/test_codegen_c_host.py
tests/python/unittest/test_codegen_c_host_fadd.py [deleted file]
tests/python/unittest/test_runtime_micro.py

index a444cdc..1b8140c 100644 (file)
@@ -75,8 +75,23 @@ def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix):
         entry_size = int(tokens[1])
         if entry_name in sections_to_sum:
             section_size += entry_size
-    return section_size
 
+    # NOTE: For some reason, the size of the BSS section on the RISC-V
+    # GCC is sometimes reported to be smaller than it is, so we need to adjust
+    # for this.
+    if "riscv" in toolchain_prefix and section_name == 'bss':
+        # TODO(weberlo): Figure out why 32 is the minimum constant that works.
+        #
+        # The current hypothesis is that the last symbols in the ".bss" and
+        # ".sbss" sections may have size zero, since the symbols in these
+        # sections are uninitialized and there's no address that follows that
+        # would enforce a particular size.
+        #
+        # If this is the case, then 32 just happens to be a safe amount of
+        # padding for most cases, but symbols can be arbitrarily large, so this
+        # isn't bulletproof.
+        return section_size + 32
+    return section_size
 
 @register_func("tvm_callback_relocate_binary")
 def tvm_callback_relocate_binary(
@@ -169,6 +184,7 @@ SECTIONS
         msg = "linking error using ld:\n"
         msg += py_str(out)
         raise RuntimeError(msg)
+
     with open(rel_obj_path, "rb") as f:
         rel_bin = bytearray(f.read())
     return rel_bin
index 7cb13c4..cab6f78 100644 (file)
@@ -29,7 +29,7 @@ from tvm.contrib import cc as _cc
 from .._ffi.function import _init_api
 from .._ffi.libinfo import find_include_path
 
-SUPPORTED_DEVICE_TYPES = ["host"]
+SUPPORTED_DEVICE_TYPES = ["host", "openocd"]
 
 class Session:
     """MicroTVM Device Session
@@ -50,15 +50,22 @@ class Session:
     .. code-block:: python
 
       c_mod = ...  # some module generated with "c" as the target
-      device_type = "host"
-      with tvm.micro.Session(device_type) as sess:
-          sess.create_micro_mod(c_mod)
+      device_type = "openocd"
+      toolchain_prefix = "riscv64-unknown-elf-"
+      with tvm.micro.Session(device_type,
+                             toolchain_prefix,
+                             base_addr=0x10010000,
+                             server_addr="127.0.0.1",
+                             port=6666):
+          c_mod.export_library(lib_obj_path, fcompile=tvm.micro.cross_compiler(toolchain_prefix))
+          micro_mod = tvm.module.load(lib_obj_path, "micro_dev")
     """
 
-    def __init__(self, device_type, toolchain_prefix):
+    def __init__(self, device_type, toolchain_prefix, **kwargs):
         if device_type not in SUPPORTED_DEVICE_TYPES:
             raise RuntimeError("unknown micro device type \"{}\"".format(device_type))
         self._check_system()
+        self._check_args(device_type, kwargs)
 
         # First, find and compile runtime library.
         runtime_src_path = os.path.join(_get_micro_device_dir(), "utvm_runtime.c")
@@ -67,7 +74,11 @@ class Session:
         create_micro_lib(
             runtime_obj_path, runtime_src_path, toolchain_prefix, include_dev_lib_header=False)
 
-        self.module = _CreateSession(device_type, runtime_obj_path, toolchain_prefix)
+        base_addr = kwargs.get("base_addr", 0)
+        server_addr = kwargs.get("server_addr", "")
+        port = kwargs.get("port", 0)
+        self.module = _CreateSession(
+            device_type, runtime_obj_path, toolchain_prefix, base_addr, server_addr, port)
         self._enter = self.module["enter"]
         self._exit = self.module["exit"]
 
@@ -83,6 +94,15 @@ class Session:
         if sys.maxsize <= 2**32:
             raise RuntimeError("microTVM is currently only supported on 64-bit platforms")
 
+    def _check_args(self, device_type, args):
+        """Check if the given configuration is valid."""
+        if device_type == "host":
+            pass
+        elif device_type == "openocd":
+            assert "base_addr" in args
+            assert "server_addr" in args
+            assert "port" in args
+
     def __enter__(self):
         self._enter()
 
@@ -181,7 +201,9 @@ def create_micro_lib(
     options = ["-I" + path for path in find_include_path()]
     options += ["-I{}".format(_get_micro_device_dir())]
     options += ["-fno-stack-protector"]
-    if sys.maxsize > 2**32 and sys.platform.startswith("linux"):
+    # TODO(weberlo): Don't rely on the toolchain prefix to identify if this is the host
+    # device.
+    if toolchain_prefix == "" and sys.maxsize > 2**32 and sys.platform.startswith("linux"):
         # Only add this option if the host is a 64-bit Linux.
         options += ["-mcmodel=large"]
     compile_cmd = "{}gcc".format(toolchain_prefix)
index a3b2e35..1285405 100644 (file)
@@ -26,6 +26,7 @@
 #define TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_
 
 #include <memory>
+#include <string>
 
 #include "micro_common.h"
 
@@ -66,9 +67,21 @@ class LowLevelDevice {
    */
   virtual void Execute(DevBaseOffset func_offset, DevBaseOffset breakpoint) = 0;
 
+  // TODO(weberlo): Should we just give the device the *entire* memory layout
+  // decided by the session?
+
+  /*!
+   * \brief sets the offset of the top of the stack section
+   * \param stack_top offset of the stack top
+   */
+  virtual void SetStackTop(DevBaseOffset stack_top) {
+    LOG(FATAL) << "unimplemented";
+  }
+
   /*!
    * \brief convert from base offset to absolute address
    * \param offset base offset
+   * \return absolute address
    */
   DevPtr ToDevPtr(DevBaseOffset offset) {
     return DevPtr(base_addr() + offset.value());
@@ -77,6 +90,7 @@ class LowLevelDevice {
   /*!
    * \brief convert from absolute address to base offset
    * \param ptr absolute address
+   * \return base offset
    */
   DevBaseOffset ToDevOffset(DevPtr ptr) {
     return DevBaseOffset(ptr.value() - base_addr());
@@ -102,6 +116,14 @@ class LowLevelDevice {
  */
 const std::shared_ptr<LowLevelDevice> HostLowLevelDeviceCreate(size_t num_bytes);
 
+/*!
+ * \brief connect to OpenOCD and create an OpenOCD low-level device
+ * \param port port of the OpenOCD server to connect to
+ */
+const std::shared_ptr<LowLevelDevice> OpenOCDLowLevelDeviceCreate(std::uintptr_t base_addr,
+                                                                  const std::string& addr,
+                                                                  int port);
+
 }  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_
index 459d00d..e50faeb 100644 (file)
@@ -39,7 +39,7 @@ namespace runtime {
 size_t GetDefaultSectionSize(SectionKind kind) {
   switch (kind) {
     case SectionKind::kText:
-      return 0xF0000;
+      return 0xF000;
     case SectionKind::kRodata:
       return 0xF000;
     case SectionKind::kData:
@@ -47,13 +47,13 @@ size_t GetDefaultSectionSize(SectionKind kind) {
     case SectionKind::kBss:
       return 0xF00;
     case SectionKind::kArgs:
-      return 0xF00000;
+      return 0xF0000;
     case SectionKind::kStack:
       return 0xF000;
     case SectionKind::kHeap:
-      return 0xF000000;
+      return 0xF00000;
     case SectionKind::kWorkspace:
-      return 0xF000000;
+      return 0xF0000;
     default:
       LOG(FATAL) << "invalid section " << static_cast<size_t>(kind);
       return 0;
index ca6f446..9790154 100644 (file)
 /*!
  *  Copyright (c) 2019 by Contributors
  * \file micro_session.cc
- * \brief session to manage multiple micro modules
- *
- * Each session consists of an interaction with a *single* logical device.
- * Within that interaction, multiple TVM modules can be loaded on the logical
- * device.
- *
- * Multiple sessions can exist simultaneously, but there is only ever one
- * *active* session. The idea of an active session mainly has implications for
- * the frontend, in that one must make a session active in order to allocate
- * new TVM objects on it. Aside from that, previously allocated objects can be
- * used even if the session which they belong to is not currently active.
  */
 
 #include <dmlc/thread_local.h>
@@ -86,25 +75,38 @@ MicroSession::~MicroSession() {
   for (size_t i = 0; i < static_cast<size_t>(SectionKind::kNumKinds); i++) {
     section_allocators_[i] = nullptr;
   }
-
   low_level_device_ = nullptr;
 }
 
 void MicroSession::CreateSession(const std::string& device_type,
                                  const std::string& binary_path,
-                                 const std::string& toolchain_prefix) {
+                                 const std::string& toolchain_prefix,
+                                 std::uintptr_t base_addr,
+                                 const std::string& server_addr,
+                                 int port) {
   // TODO(weberlo): make device type enum
+  toolchain_prefix_ = toolchain_prefix;
   if (device_type == "host") {
     low_level_device_ = HostLowLevelDeviceCreate(memory_size_);
+  } else if (device_type == "openocd") {
+    // TODO(weberlo): We need a better way of configuring devices.
+    low_level_device_ = OpenOCDLowLevelDeviceCreate(base_addr, server_addr, port);
   } else {
     LOG(FATAL) << "unsupported micro low-level device";
   }
+
   SetRuntimeBinaryPath(binary_path);
   CHECK(!runtime_binary_path_.empty()) << "uTVM runtime not initialized";
   runtime_bin_info_ = LoadBinary(runtime_binary_path_, /* patch_dylib_pointers */ false);
   utvm_main_symbol_ = low_level_device()->ToDevOffset(runtime_symbol_map()["UTVMMain"]);
   utvm_done_symbol_ = low_level_device()->ToDevOffset(runtime_symbol_map()["UTVMDone"]);
 
+  if (device_type == "openocd") {
+    // Set OpenOCD device's stack pointer.
+    auto stack_section = GetAllocator(SectionKind::kStack);
+    low_level_device_->SetStackTop(stack_section->max_end_offset());
+  }
+
   // Patch workspace pointers to the start of the workspace section.
   DevBaseOffset workspace_start_offset = GetAllocator(SectionKind::kWorkspace)->start_offset();
   DevBaseOffset workspace_end_offset = GetAllocator(SectionKind::kWorkspace)->max_end_offset();
@@ -143,6 +145,7 @@ void MicroSession::PushToExecQueue(DevBaseOffset func, const TVMArgs& args) {
   };
   // Write the task.
   DevSymbolWrite(runtime_symbol_map(), "task", task);
+
   low_level_device()->Execute(utvm_main_symbol_, utvm_done_symbol_);
   // Check if there was an error during execution.  If so, log it.
   CheckDeviceError();
@@ -299,7 +302,7 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d
 
 void MicroSession::PatchImplHole(const SymbolMap& symbol_map, const std::string& func_name) {
   void* runtime_impl_addr = runtime_symbol_map()[func_name].cast_to<void*>();
-  std::stringstream func_name_underscore;
+  std::ostringstream func_name_underscore;
   func_name_underscore << func_name << "_";
   DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr);
 }
@@ -309,7 +312,7 @@ void MicroSession::SetRuntimeBinaryPath(std::string path) {
 }
 
 std::string MicroSession::ReadString(DevBaseOffset str_offset) {
-  std::stringstream result;
+  std::ostringstream result;
   const size_t buf_size = 256;
   std::vector<char> buf(buf_size, 0);
   size_t i = buf_size;
@@ -372,8 +375,12 @@ TVM_REGISTER_GLOBAL("micro._CreateSession")
     const std::string& device_type = args[0];
     const std::string& binary_path = args[1];
     const std::string& toolchain_prefix = args[2];
+    uint64_t base_addr = args[3];
+    const std::string& server_addr = args[4];
+    int port = args[5];
     std::shared_ptr<MicroSession> session = std::make_shared<MicroSession>();
-    session->CreateSession(device_type, binary_path, toolchain_prefix);
+    session->CreateSession(
+        device_type, binary_path, toolchain_prefix, base_addr, server_addr, port);
     *rv = Module(session);
     });
 
index e163549..1400f74 100644 (file)
 /*!
  *  Copyright (c) 2019 by Contributors
  * \file micro_session.h
+ * \brief session to manage multiple micro modules
+ *
+ * Each session consists of an interaction with a *single* logical device.
+ * Within that interaction, multiple TVM modules can be loaded on the logical
+ * device.
+ *
+ * Multiple sessions can exist simultaneously, but there is only ever one
+ * *active* session. The idea of an active session mainly has implications for
+ * the frontend, in that one must make a session active in order to allocate
+ * new TVM objects on it. Aside from that, previously allocated objects can be
+ * used even if the session which they belong to is not currently active.
  */
 #ifndef TVM_RUNTIME_MICRO_MICRO_SESSION_H_
 #define TVM_RUNTIME_MICRO_MICRO_SESSION_H_
@@ -82,7 +93,10 @@ class MicroSession : public ModuleNode {
    */
   void CreateSession(const std::string& device_type,
                      const std::string& binary_path,
-                     const std::string& toolchain_prefix);
+                     const std::string& toolchain_prefix,
+                     std::uintptr_t base_addr,
+                     const std::string& server_addr,
+                     int port);
 
   /*!
    * \brief ends the session by destructing the low-level device and its allocators
diff --git a/src/runtime/micro/openocd_low_level_device.cc b/src/runtime/micro/openocd_low_level_device.cc
new file mode 100644 (file)
index 0000000..e0623dd
--- /dev/null
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file openocd_low_level_device.cc
+ */
+#include <sstream>
+#include <iomanip>
+
+#include "micro_common.h"
+#include "low_level_device.h"
+#include "tcl_socket.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief OpenOCD low-level device for uTVM micro devices connected over JTAG
+ */
+class OpenOCDLowLevelDevice final : public LowLevelDevice {
+ public:
+  /*!
+   * \brief constructor to initialize connection to openocd device
+   * \param base_addr base address of the device
+   * \param server_addr address of the OpenOCD server to connect to
+   * \param port port of the OpenOCD server to connect to
+   */
+  explicit OpenOCDLowLevelDevice(std::uintptr_t base_addr,
+                                 const std::string& server_addr,
+                                 int port) : socket_() {
+      socket_.Connect(tvm::common::SockAddr(server_addr.c_str(), port));
+      socket_.cmd_builder() << "reset halt";
+      socket_.SendCommand();
+      base_addr_ = base_addr;
+      CHECK(base_addr_ % 8 == 0) << "base address not aligned to 8 bytes";
+  }
+
+  void Read(DevBaseOffset offset, void* buf, size_t num_bytes) {
+    if (num_bytes == 0) {
+      return;
+    }
+
+    // TODO(weberlo): Refactor between read and write.
+    // Check if we need to chunk this write request.
+    if (num_bytes > kMemTransferLimit) {
+      DevBaseOffset curr_offset = offset;
+      char* curr_buf_ptr = reinterpret_cast<char*>(buf);
+      while (num_bytes != 0) {
+        size_t amount_to_read;
+        if (num_bytes > kMemTransferLimit) {
+          amount_to_read = kMemTransferLimit;
+        } else {
+          amount_to_read = num_bytes;
+        }
+        Read(offset, reinterpret_cast<void*>(curr_buf_ptr), amount_to_read);
+        offset += amount_to_read;
+        curr_buf_ptr += amount_to_read;
+        num_bytes -= amount_to_read;
+      }
+      return;
+    }
+    {
+      socket_.cmd_builder() << "array unset output";
+      socket_.SendCommand();
+
+      DevPtr addr = DevPtr(base_addr_ + offset.value());
+      socket_.cmd_builder()
+        << "mem2array output"
+        << " " << std::dec << kWordSize
+        << " " << addr.cast_to<void*>()
+        // Round up any request sizes under a byte, since OpenOCD doesn't support
+        // sub-byte-sized transfers.
+        << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes);
+      socket_.SendCommand();
+    }
+
+    {
+      socket_.cmd_builder() << "ocd_echo $output";
+      socket_.SendCommand();
+      const std::string& reply = socket_.last_reply();
+
+      std::istringstream values(reply);
+      char* char_buf = reinterpret_cast<char*>(buf);
+      ssize_t req_bytes_remaining = num_bytes;
+      uint32_t index;
+      uint32_t val;
+      while (req_bytes_remaining > 0) {
+        // The response from this command pairs indices with the contents of the
+        // memory at that index.
+        values >> index;
+        CHECK(index < num_bytes)
+          << "index " << index <<
+          " out of bounds (length " << num_bytes << ")";
+        // Read the value into `curr_val`, instead of reading directly into
+        // `buf_iter`, because otherwise it's interpreted as the ASCII value and
+        // not the integral value.
+        values >> val;
+        char_buf[index] = static_cast<uint8_t>(val);
+        req_bytes_remaining--;
+      }
+      if (num_bytes >= 8) {
+        uint32_t check_index;
+        values >> check_index;
+        CHECK(check_index != index) << "more data in response than requested";
+      }
+    }
+  }
+
+  void Write(DevBaseOffset offset, const void* buf, size_t num_bytes) {
+    if (num_bytes == 0) {
+      return;
+    }
+
+    // Check if we need to chunk this write request.
+    if (num_bytes > kMemTransferLimit) {
+      DevBaseOffset curr_offset = offset;
+      const char* curr_buf_ptr = reinterpret_cast<const char*>(buf);
+      while (num_bytes != 0) {
+        size_t amount_to_write;
+        if (num_bytes > kMemTransferLimit) {
+          amount_to_write = kMemTransferLimit;
+        } else {
+          amount_to_write = num_bytes;
+        }
+        Write(offset, reinterpret_cast<const void*>(curr_buf_ptr), amount_to_write);
+        offset += amount_to_write;
+        curr_buf_ptr += amount_to_write;
+        num_bytes -= amount_to_write;
+      }
+      return;
+    }
+
+    // Clear `input` array.
+    socket_.cmd_builder() << "array unset input";
+    socket_.SendCommand();
+    // Build a command to set the value of `input`.
+    {
+      std::ostringstream& cmd_builder = socket_.cmd_builder();
+      cmd_builder << "array set input {";
+      const char* char_buf = reinterpret_cast<const char*>(buf);
+      for (size_t i = 0; i < num_bytes; i++) {
+        // In a Tcl `array set` commmand, we need to pair the array indices with
+        // their values.
+        cmd_builder << i << " ";
+        // Need to cast to uint, so the number representation of `buf[i]` is
+        // printed, and not the ASCII representation.
+        cmd_builder << static_cast<uint32_t>(char_buf[i]) << " ";
+      }
+      cmd_builder << "}";
+      socket_.SendCommand();
+    }
+    {
+      DevPtr addr = DevPtr(base_addr_ + offset.value());
+      socket_.cmd_builder()
+        << "array2mem input"
+        << " " << std::dec << kWordSize
+        << " " << addr.cast_to<void*>()
+        << " " << std::dec << num_bytes;
+      socket_.SendCommand();
+    }
+  }
+
+  void Execute(DevBaseOffset func_offset, DevBaseOffset breakpoint) {
+    socket_.cmd_builder() << "halt 0";
+    socket_.SendCommand();
+
+    // Set up the stack pointer.
+    DevPtr stack_end = stack_top() - 8;
+    socket_.cmd_builder() << "reg sp " << stack_end.cast_to<void*>();
+    socket_.SendCommand();
+
+    // Set a breakpoint at the beginning of `UTVMDone`.
+    socket_.cmd_builder() << "bp " << ToDevPtr(breakpoint).cast_to<void*>() << " 2";
+    socket_.SendCommand();
+
+    DevPtr func_addr = DevPtr(base_addr_ + func_offset.value());
+    socket_.cmd_builder() << "resume " << func_addr.cast_to<void*>();
+    socket_.SendCommand();
+
+    socket_.cmd_builder() << "wait_halt " << kWaitTime;
+    socket_.SendCommand();
+
+    socket_.cmd_builder() << "halt 0";
+    socket_.SendCommand();
+
+    // Remove the breakpoint.
+    socket_.cmd_builder() << "rbp " << ToDevPtr(breakpoint).cast_to<void*>();
+    socket_.SendCommand();
+  }
+
+  void SetStackTop(DevBaseOffset stack_top) {
+    stack_top_ = DevPtr(base_addr_ + stack_top.value());
+  }
+
+  std::uintptr_t base_addr() const final {
+    return base_addr_;
+  }
+
+  DevPtr stack_top() const {
+    CHECK(stack_top_ != nullptr) << "stack top was never initialized";
+    return stack_top_;
+  }
+
+  const char* device_type() const final {
+    return "openocd";
+  }
+
+ private:
+  /*! \brief base address of the micro device memory region */
+  std::uintptr_t base_addr_;
+  /*! \brief top of the stack section */
+  DevPtr stack_top_;
+  /*! \brief socket used to communicate with the device through Tcl */
+  TclSocket socket_;
+
+  /*! \brief number of bytes in a word on the target device (64-bit) */
+  static const constexpr ssize_t kWordSize = 8;
+  // NOTE: OpenOCD will call any request larger than this constant an "absurd
+  // request".
+  /*! \brief maximum number of bytes allowed in a single memory transfer */
+  static const constexpr ssize_t kMemTransferLimit = 64000;
+  /*! \brief number of milliseconds to wait for function execution to halt */
+  static const constexpr int kWaitTime = 10000;
+};
+
+const std::shared_ptr<LowLevelDevice> OpenOCDLowLevelDeviceCreate(std::uintptr_t base_addr,
+                                                                  const std::string& server_addr,
+                                                                  int port) {
+  std::shared_ptr<LowLevelDevice> lld =
+      std::make_shared<OpenOCDLowLevelDevice>(base_addr, server_addr, port);
+  return lld;
+}
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/micro/tcl_socket.cc b/src/runtime/micro/tcl_socket.cc
new file mode 100644 (file)
index 0000000..5422599
--- /dev/null
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tcl_socket.cc
+ */
+#include <string>
+
+#include "tcl_socket.h"
+
+namespace tvm {
+namespace runtime {
+
+TclSocket::TclSocket() {
+  tcp_socket_.Create();
+  tcp_socket_.SetKeepAlive(true);
+  reply_buf_.reserve(kReplyBufSize);
+}
+
+TclSocket::~TclSocket() {
+  tcp_socket_.Close();
+}
+
+void TclSocket::Connect(tvm::common::SockAddr addr) {
+  CHECK(tcp_socket_.Connect(addr)) << "failed to connect";
+}
+
+void TclSocket::SendCommand() {
+  cmd_builder_ << kCommandTerminateToken;
+  std::string full_cmd = cmd_builder_.str();
+  CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1)
+    << "failed to send command";
+  cmd_builder_.str(std::string());
+
+  reply_builder_.str(std::string());
+  char last_read = '\0';
+  // Receive from the socket until we reach a command terminator.
+  do {
+    ssize_t bytes_read;
+    // Recieve from the socket until it's drained.
+    do {
+      // Leave room at the end of `reply_buf` to tack on a null terminator.
+      bytes_read = tcp_socket_.Recv(reply_buf_.data(), kReplyBufSize - 1);
+      reply_buf_[bytes_read] = '\0';
+      reply_builder_ << reply_buf_.data();
+      // Update last read character.
+      last_read = reply_buf_[bytes_read - 1];
+    } while (bytes_read == kReplyBufSize - 1);
+    CHECK(bytes_read != -1) << "failed to read command reply";
+  } while (last_read != kCommandTerminateToken);
+  last_reply_ = reply_builder_.str();
+  CHECK_EQ(last_reply_[last_reply_.length()-1], kCommandTerminateToken)
+    << "missing command terminator";
+}
+
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/micro/tcl_socket.h b/src/runtime/micro/tcl_socket.h
new file mode 100644 (file)
index 0000000..80ce185
--- /dev/null
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file tcl_socket.h
+ * \brief TCP socket wrapper for communicating using Tcl commands
+ */
+#ifndef TVM_RUNTIME_MICRO_TCL_SOCKET_H_
+#define TVM_RUNTIME_MICRO_TCL_SOCKET_H_
+
+#include <string>
+#include <vector>
+
+#include "../../common/socket.h"
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief TCP socket wrapper for communicating using Tcl commands
+ *
+ * Usage generally involves building a command using the `cmd_builder` stream
+ * interface, then sending the command with `SendCommand`, and if necessary,
+ * reading the reply.
+ */
+class TclSocket {
+ public:
+  /*!
+   * \brief constructor to create the socket
+   */
+  TclSocket();
+
+  /*!
+   * \brief destructor to close the socket connection
+   */
+  ~TclSocket();
+
+  /*!
+   * \brief open connection with server
+   * \param addr server address
+   */
+  void Connect(tvm::common::SockAddr addr);
+
+  /*
+   * \brief send the built command to the server and await a reply
+   *
+   * \return the reply
+   */
+  void SendCommand();
+
+  /*
+   * \return string stream for current command being built
+  */
+  std::ostringstream& cmd_builder() { return cmd_builder_; }
+
+  /*
+   * \return reply from most recently sent command
+  */
+  const std::string& last_reply() { return last_reply_; }
+
+ private:
+  /*! \brief underlying TCP socket being wrapped */
+  tvm::common::TCPSocket tcp_socket_;
+  /*! \brief buffer used to receive messages from the socket */
+  std::vector<uint8_t> reply_buf_;
+  /*! \brief string stream used to build current command */
+  std::ostringstream cmd_builder_;
+  /*! \brief string stream used to receive replies from sent commands */
+  std::ostringstream reply_builder_;
+  /*! \brief reply from most recently sent command */
+  std::string last_reply_;
+
+  /*! \brief character denoting the end of a Tcl command */
+  static const constexpr char kCommandTerminateToken = '\x1a';
+  /*! \brief size of the buffer used to receive messages (in bytes) */
+  static const constexpr size_t kReplyBufSize = 4096;
+};
+
+}  // namespace runtime
+}  // namespace tvm
+#endif  // TVM_RUNTIME_MICRO_TCL_SOCKET_H_
index 5161c68..1c09a7b 100644 (file)
@@ -44,6 +44,7 @@ def test_add():
             c.asnumpy(), a.asnumpy() + b.asnumpy())
     check_c()
 
+
 def test_add_pipeline():
     nn = 1024
     n = tvm.convert(nn)
@@ -95,6 +96,32 @@ def test_add_pipeline():
     with tvm.build_config(offset_factor=4):
         check_c()
 
+
+def test_reinterpret():
+    nn = 1024
+    n = tvm.convert(nn)
+    A = tvm.placeholder((n,), name='A', dtype="int32")
+    B = tvm.compute(A.shape, lambda *i: tvm.call_pure_intrin("float32", "reinterpret", A(*i)), name='B')
+    s = tvm.create_schedule(B.op)
+
+    def check_c():
+        mhost = tvm.build(s, [A, B], "c", name="reinterpret")
+        temp = util.tempdir()
+        path_dso = temp.relpath("temp.so")
+        mhost.export_library(path_dso)
+        m = tvm.module.load(path_dso)
+        fadd = m['reinterpret']
+        ctx = tvm.cpu(0)
+        n = nn
+        a = tvm.nd.array(np.random.randint(-2 ** 30, 2 ** 30, size=n).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
+        fadd(a, b)
+        tvm.testing.assert_allclose(
+            b.asnumpy(), a.asnumpy().view('float32'))
+    check_c()
+
+
 if __name__ == "__main__":
     test_add()
     test_add_pipeline()
+    test_reinterpret()
diff --git a/tests/python/unittest/test_codegen_c_host_fadd.py b/tests/python/unittest/test_codegen_c_host_fadd.py
deleted file mode 100644 (file)
index f5cde82..0000000
+++ /dev/null
@@ -1,140 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import tvm
-import numpy as np
-from tvm import relay
-from tvm.contrib import util
-
-def test_add():
-    nn = 1024
-    n = tvm.convert(nn)
-    A = tvm.placeholder((n,), name='A')
-    B = tvm.placeholder((n,), name='B')
-    C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
-    s = tvm.create_schedule(C.op)
-
-    def check_c():
-        mhost = tvm.build(s, [A, B, C], "c", name="fadd")
-        temp = util.tempdir()
-        path_dso = temp.relpath("temp.so")
-        mhost.export_library(path_dso)
-        print(mhost.get_source())
-        m = tvm.module.load(path_dso)
-        fadd = m['fadd']
-        ctx = tvm.cpu(0)
-        # launch the kernel.
-        n = nn
-        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
-        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
-        fadd(a, b, c)
-        tvm.testing.assert_allclose(
-           c.asnumpy(), a.asnumpy() + b.asnumpy())
-    check_c()
-
-def test_relay_id():
-    # x = relay.var("x")
-    # f = relay.Function([x], x)
-    x = relay.var('x', shape=[])
-    func = relay.Function([x], x)
-    ttype = relay.TensorType([], dtype='float32')
-    relay.FuncType([ttype], ttype)
-    mod = relay.module.Module()
-    func_gvar = relay.GlobalVar("f")
-    mod[func_gvar] = func
-    print(mod)
-
-
-def test_add_pipeline():
-    nn = 1024
-    n = tvm.convert(nn)
-    A = tvm.placeholder((n,), name='A')
-    B = tvm.placeholder((n,), name='B')
-    AA = tvm.compute((n,), lambda *i: A(*i), name='A')
-    BB = tvm.compute((n,), lambda *i: B(*i), name='B')
-    T = tvm.compute(A.shape, lambda *i: AA(*i) + BB(*i), name='T')
-    C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
-    s = tvm.create_schedule(C.op)
-    xo, xi = s[C].split(C.op.axis[0], factor=4)
-    xo1, xo2 = s[C].split(xo, factor=13)
-    s[C].parallel(xo2)
-    s[C].pragma(xo1, "parallel_launch_point")
-    s[C].pragma(xo2, "parallel_stride_pattern")
-    s[C].pragma(xo2, "parallel_barrier_when_finish")
-    s[C].vectorize(xi)
-
-    def check_c():
-        if not tvm.module.enabled("llvm"):
-            return
-        # Specifically allow offset to test codepath when offset is available
-        Ab = tvm.decl_buffer(
-            A.shape, A.dtype,
-            elem_offset=tvm.var('Aoffset'),
-            offset_factor=8,
-            name='A')
-        binds = {A : Ab}
-        # BUILD and invoke the kernel.
-        f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline")
-        fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)]
-        fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
-        mhost = tvm.codegen.build_module(fsplits[0], "c")
-        temp = util.tempdir()
-        path_dso = temp.relpath("temp.so")
-        mhost.export_library(path_dso)
-        m = tvm.module.load(path_dso)
-        fadd = m["fadd_pipeline"]
-        ctx = tvm.cpu(0)
-        # launch the kernel.
-        n = nn
-        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
-        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
-        fadd(a, b, c)
-        tvm.testing.assert_allclose(
-            c.asnumpy(), a.asnumpy() + b.asnumpy())
-
-    with tvm.build_config(offset_factor=4):
-        check_c()
-
-
-def test_reinterpret():
-    nn = 1024
-    n = tvm.convert(nn)
-    A = tvm.placeholder((n,), name='A', dtype="int32")
-    B = tvm.compute(A.shape, lambda *i: tvm.call_pure_intrin("float32", "reinterpret", A(*i)), name='B')
-    s = tvm.create_schedule(B.op)
-
-    def check_c():
-        mhost = tvm.build(s, [A, B], "c", name="reinterpret")
-        temp = util.tempdir()
-        path_dso = temp.relpath("temp.so")
-        mhost.export_library(path_dso)
-        m = tvm.module.load(path_dso)
-        fadd = m['reinterpret']
-        ctx = tvm.cpu(0)
-        n = nn
-        a = tvm.nd.array(np.random.randint(-2 ** 30, 2 ** 30, size=n).astype(A.dtype), ctx)
-        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
-        fadd(a, b)
-        tvm.testing.assert_allclose(
-            b.asnumpy(), a.asnumpy().view('float32'))
-    check_c()
-
-if __name__ == "__main__":
-    test_add()
-    test_add_pipeline()
-    test_reinterpret()
index 06461bd..f857ce7 100644 (file)
@@ -47,7 +47,9 @@ def create_micro_mod(c_mod, toolchain_prefix):
     """
     temp_dir = util.tempdir()
     lib_obj_path = temp_dir.relpath("dev_lib.obj")
-    c_mod.export_library(lib_obj_path, fcompile=tvm.micro.cross_compiler(toolchain_prefix=""))
+    c_mod.export_library(
+            lib_obj_path,
+            fcompile=tvm.micro.cross_compiler(toolchain_prefix=toolchain_prefix))
     micro_mod = tvm.module.load(lib_obj_path, "micro_dev")
     return micro_mod
 
@@ -78,6 +80,8 @@ def relay_micro_build(func, toolchain_prefix, params=None):
 
 
 # TODO(weberlo): Add example program to test scalar double/int TVMValue serialization.
+# TODO(weberlo): How can we test the OpenOCD device?  The CI would need to have OpenOCD
+# and Spike installed.
 
 def test_alloc():
     """Test tensor allocation on the device."""
@@ -207,6 +211,7 @@ def test_multiple_modules():
         tvm.testing.assert_allclose(
                 sub_result, x_in - 1.0)
 
+
 def test_interleave_sessions():
     """Test closing and reopening sessions."""
     if not tvm.module.enabled("micro_dev"):