# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)
+
+# Build TSIM for VTA
+set(USE_VTA_TSIM OFF)
find_library(__cma_lib NAMES cma PATH /usr/lib)
target_link_libraries(vta ${__cma_lib})
endif()
+
+ if(NOT USE_VTA_TSIM STREQUAL "OFF")
+ include_directories("vta/include")
+ file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
+ list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
+ endif()
+
else()
message(STATUS "Cannot found python in env, VTA build is skipped..")
endif()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+cmake_minimum_required(VERSION 3.2)
+project(tsim C CXX)
+
+set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
+set(VTA_DIR ${TVM_DIR}/vta)
+
+include_directories("${TVM_DIR}/include")
+include_directories("${TVM_DIR}/3rdparty/dlpack/include")
+include_directories("${TVM_DIR}/3rdparty/dmlc-core/include")
+include_directories("${TVM_DIR}/vta/src/dpi")
+
+set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden")
+set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11")
+
+if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND
+ CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
+ set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}")
+endif()
+
+# Module rules
+include(cmake/modules/tsim.cmake)
+include(cmake/modules/driver.cmake)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH)
+
+BUILD_DIR = $(shell python python/tsim/config.py --get-build-name)
+
+TVM_DIR = $(abspath ../../../)
+
+TSIM_TARGET = verilog
+TSIM_TOP_NAME = TestAccel
+TSIM_BUILD_NAME = build
+
+# optional
+TSIM_TRACE_NAME = trace.vcd
+
+default: cmake run
+
+.PHONY: cmake
+
+cmake: | $(BUILD_DIR)
+ cd $(BUILD_DIR) && cmake .. && make
+
+$(BUILD_DIR):
+ mkdir -p $@
+
+run:
+ python3 tests/python/test_tsim.py | grep PASS
+
+clean:
+ -rm -rf $(BUILD_DIR)
--- /dev/null
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+VTA TSIM Installation
+======================
+
+*TSIM* is a cycle-accurate hardware simulation environment that can be invoked and managed directly from TVM. It aims to enable cycle accurate simulation of deep learning accelerators including VTA.
+This simulation environment can be used in both OSX and Linux.
+There are two dependencies required to make *TSIM* works: [Verilator](https://www.veripool.org/wiki/verilator) and [sbt](https://www.scala-sbt.org/) for accelerators designed in [Chisel3](https://github.com/freechipsproject/chisel3).
+
+## OSX Dependencies
+
+Install `sbt` and `verilator` using [Homebrew](https://brew.sh/).
+
+```bash
+brew install verilator sbt
+```
+
+## Linux Dependencies
+
+Add `sbt` to package manager (Ubuntu).
+
+```bash
+echo "deb https://dl.bintray.com/sbt/debian /" | sudo tee -a /etc/apt/sources.list.d/sbt.list
+sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 2EE0EA64E40A89B84B2DF73499E82A75642AC823
+sudo apt-get update
+```
+
+Install `sbt` and `verilator`.
+
+```bash
+sudo apt install verilator sbt
+```
+
+## Setup in TVM
+
+1. Install `verilator` and `sbt` as described above
+2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake
+3. Build tvm
+
+## How to run VTA TSIM examples
+
+There are two sample VTA accelerators (add-by-one) designed in Chisel3 and Verilog to show how *TSIM* works.
+These examples are located at `<tvm-root>/vta/apps/tsim_example`.
+
+* Instructions
+ * Open `<tvm-root>/vta/apps/tsim_example/python/tsim/config.json`
+ * Change `TARGET` from `verilog` to `chisel`, depending on what language backend you would like to test
+ * Go to `tvm/vta/apps/tsim`
+ * Run `make`
+
+* Some pointers
+ * Build cmake script for driver `<tvm-root>/vta/apps/tsim_example/cmake/modules/driver.cmake`
+ * Build cmake script for tsim `<tvm-root>/vta/apps/tsim_example/cmake/modules/tsim.cmake`
+ * Software driver that handles the VTA accelerator `<tvm-root>/vta/apps/tsim_example/src/driver.cc`
+ * VTA add-by-one accelerator (Verilog) `<tvm-root>/vta/apps/tsim_example/hardware/verilog`
+ * VTA add-by-one accelerator (Chisel) `<tvm-root>/vta/apps/tsim_example/hardware/chisel`
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+file(GLOB TSIM_SW_SRC src/driver.cc)
+add_library(driver SHARED ${TSIM_SW_SRC})
+target_include_directories(driver PRIVATE ${VTA_DIR}/include)
+
+if(APPLE)
+ set_target_properties(driver PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
+endif(APPLE)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(MSVC)
+ message(STATUS "TSIM build is skipped in Windows..")
+else()
+ find_program(PYTHON NAMES python python3 python3.6)
+ find_program(VERILATOR NAMES verilator)
+
+ if (VERILATOR AND PYTHON)
+
+ if (TSIM_TOP_NAME STREQUAL "")
+ message(FATAL_ERROR "TSIM_TOP_NAME should be defined")
+ endif()
+
+ if (TSIM_BUILD_NAME STREQUAL "")
+ message(FATAL_ERROR "TSIM_BUILD_NAME should be defined")
+ endif()
+
+ set(TSIM_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/python/tsim/config.py)
+
+ execute_process(COMMAND ${TSIM_CONFIG} --get-target OUTPUT_VARIABLE __TSIM_TARGET)
+ execute_process(COMMAND ${TSIM_CONFIG} --get-top-name OUTPUT_VARIABLE __TSIM_TOP_NAME)
+ execute_process(COMMAND ${TSIM_CONFIG} --get-build-name OUTPUT_VARIABLE __TSIM_BUILD_NAME)
+ execute_process(COMMAND ${TSIM_CONFIG} --get-use-trace OUTPUT_VARIABLE __TSIM_USE_TRACE)
+ execute_process(COMMAND ${TSIM_CONFIG} --get-trace-name OUTPUT_VARIABLE __TSIM_TRACE_NAME)
+
+ string(STRIP ${__TSIM_TARGET} TSIM_TARGET)
+ string(STRIP ${__TSIM_TOP_NAME} TSIM_TOP_NAME)
+ string(STRIP ${__TSIM_BUILD_NAME} TSIM_BUILD_NAME)
+ string(STRIP ${__TSIM_USE_TRACE} TSIM_USE_TRACE)
+ string(STRIP ${__TSIM_TRACE_NAME} TSIM_TRACE_NAME)
+
+ set(TSIM_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/${TSIM_BUILD_NAME})
+
+ if (TSIM_TARGET STREQUAL "chisel")
+
+ find_program(SBT NAMES sbt)
+
+ if (SBT)
+
+ # Install Chisel VTA package for DPI modules
+ set(VTA_CHISEL_DIR ${VTA_DIR}/hardware/chisel)
+
+ execute_process(WORKING_DIRECTORY ${VTA_CHISEL_DIR}
+ COMMAND ${SBT} publishLocal RESULT_VARIABLE RETCODE)
+
+ if (NOT RETCODE STREQUAL "0")
+ message(FATAL_ERROR "[TSIM] sbt failed to install VTA scala package")
+ endif()
+
+ # Chisel - Scala to Verilog compilation
+ set(TSIM_CHISEL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/chisel)
+ set(CHISEL_TARGET_DIR ${TSIM_BUILD_DIR}/chisel)
+ set(CHISEL_OPT "test:runMain test.Elaborate --target-dir ${CHISEL_TARGET_DIR} --top-name ${TSIM_TOP_NAME}")
+
+ execute_process(WORKING_DIRECTORY ${TSIM_CHISEL_DIR} COMMAND ${SBT} ${CHISEL_OPT} RESULT_VARIABLE RETCODE)
+
+ if (NOT RETCODE STREQUAL "0")
+ message(FATAL_ERROR "[TSIM] sbt failed to compile from Chisel to Verilog.")
+ endif()
+
+ file(GLOB VERILATOR_RTL_SRC ${CHISEL_TARGET_DIR}/*.v)
+
+ else()
+ message(FATAL_ERROR "[TSIM] sbt should be installed for Chisel")
+ endif() # sbt
+
+ elseif (TSIM_TARGET STREQUAL "verilog")
+
+ set(VTA_VERILOG_DIR ${VTA_DIR}/hardware/chisel/src/main/resources/verilog)
+ set(TSIM_VERILOG_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/verilog)
+ file(GLOB VERILATOR_RTL_SRC ${VTA_VERILOG_DIR}/*.v ${TSIM_VERILOG_DIR}/*.v)
+
+ else()
+ message(STATUS "[TSIM] target language can be only verilog or chisel...")
+ endif() # TSIM_TARGET
+
+ if (TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog")
+
+ # Check if tracing can be enabled
+ if (NOT TSIM_USE_TRACE STREQUAL "OFF")
+ message(STATUS "[TSIM] Verilog enable tracing")
+ else()
+ message(STATUS "[TSIM] Verilator disable tracing")
+ endif()
+
+ # Verilator - Verilog to C++ compilation
+ set(VERILATOR_TARGET_DIR ${TSIM_BUILD_DIR}/verilator)
+ set(VERILATOR_OPT +define+RANDOMIZE_GARBAGE_ASSIGN +define+RANDOMIZE_REG_INIT)
+ list(APPEND VERILATOR_OPT +define+RANDOMIZE_MEM_INIT --x-assign unique)
+ list(APPEND VERILATOR_OPT --output-split 20000 --output-split-cfuncs 20000)
+ list(APPEND VERILATOR_OPT --top-module ${TSIM_TOP_NAME} -Mdir ${VERILATOR_TARGET_DIR})
+ list(APPEND VERILATOR_OPT --cc ${VERILATOR_RTL_SRC})
+
+ if (NOT TSIM_USE_TRACE STREQUAL "OFF")
+ list(APPEND VERILATOR_OPT --trace)
+ endif()
+
+ execute_process(COMMAND ${VERILATOR} ${VERILATOR_OPT} RESULT_VARIABLE RETCODE)
+
+ if (NOT RETCODE STREQUAL "0")
+ message(FATAL_ERROR "[TSIM] Verilator failed to compile Verilog to C++...")
+ endif()
+
+ # Build shared library (.so)
+ set(VTA_HW_DPI_DIR ${VTA_DIR}/hardware/dpi)
+ set(VERILATOR_INC_DIR /usr/local/share/verilator/include)
+ set(VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated.cpp ${VERILATOR_INC_DIR}/verilated_dpi.cpp)
+
+ if (NOT TSIM_USE_TRACE STREQUAL "OFF")
+ list(APPEND VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated_vcd_c.cpp)
+ endif()
+
+ file(GLOB VERILATOR_GEN_SRC ${VERILATOR_TARGET_DIR}/*.cpp)
+ file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
+ add_library(tsim SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})
+
+ set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
+ if (NOT TSIM_USE_TRACE STREQUAL "OFF")
+ list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
+ else()
+ list(APPEND VERILATOR_DEF VM_TRACE=0)
+ endif()
+ target_compile_definitions(tsim PRIVATE ${VERILATOR_DEF})
+ target_compile_options(tsim PRIVATE -Wno-sign-compare -include V${TSIM_TOP_NAME}.h)
+ target_include_directories(tsim PRIVATE ${VERILATOR_TARGET_DIR} ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd ${VTA_DIR}/include)
+
+ if(APPLE)
+ set_target_properties(tsim PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
+ endif(APPLE)
+
+ endif() # TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog"
+
+ else()
+ message(STATUS "[TSIM] could not find Python or Verilator, build is skipped...")
+ endif() # VERILATOR
+endif() # MSVC
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+clean:
+ -rm -rf target project/target project/project
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+name := "accel"
+version := "0.1.0-SNAPSHOT"
+organization := "edu.washington.cs"
+
+def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
+ Seq() ++ {
+ // If we're building with Scala > 2.11, enable the compile option
+ // switch to support our anonymous Bundle definitions:
+ // https://github.com/scala/bug/issues/10047
+ CrossVersion.partialVersion(scalaVersion) match {
+ case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
+ case _ => Seq(
+ "-Xsource:2.11",
+ "-language:reflectiveCalls",
+ "-language:implicitConversions",
+ "-deprecation",
+ "-Xlint",
+ "-Ywarn-unused",
+ )
+ }
+ }
+}
+
+def javacOptionsVersion(scalaVersion: String): Seq[String] = {
+ Seq() ++ {
+ // Scala 2.12 requires Java 8. We continue to generate
+ // Java 7 compatible code for Scala 2.11
+ // for compatibility with old clients.
+ CrossVersion.partialVersion(scalaVersion) match {
+ case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
+ Seq("-source", "1.7", "-target", "1.7")
+ case _ =>
+ Seq("-source", "1.8", "-target", "1.8")
+ }
+ }
+}
+
+scalaVersion := "2.11.12"
+
+resolvers ++= Seq(
+ Resolver.sonatypeRepo("snapshots"),
+ Resolver.sonatypeRepo("releases"))
+
+libraryDependencies ++= Seq(
+ "edu.berkeley.cs" %% "chisel3" % "3.1.7",
+ "edu.washington.cs" %% "vta" % "0.1.0-SNAPSHOT",
+)
+
+scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
+javacOptions ++= javacOptionsVersion(scalaVersion.value)
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+sbt.version = 1.1.1
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+logLevel := Level.Warn
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package accel
+
+import chisel3._
+import vta.dpi._
+
+/** Add-by-one accelerator.
+ *
+ * ___________ ___________
+ * | | | |
+ * | HostDPI | <--> | RegFile | <->|
+ * |_________| |_________| |
+ * |
+ * ___________ ___________ |
+ * | | | | |
+ * | MemDPI | <--> | Compute | <->|
+ * |_________| |_________|
+ *
+ */
+class Accel extends Module {
+ val io = IO(new Bundle {
+ val host = new VTAHostDPIClient
+ val mem = new VTAMemDPIMaster
+ })
+ val rf = Module(new RegFile)
+ val ce = Module(new Compute)
+ rf.io.host <> io.host
+ io.mem <> ce.io.mem
+ ce.io.launch := rf.io.launch
+ rf.io.finish := ce.io.finish
+ ce.io.length := rf.io.length
+ ce.io.inp_baddr := rf.io.inp_baddr
+ ce.io.out_baddr := rf.io.out_baddr
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package accel
+
+import chisel3._
+import chisel3.util._
+import vta.dpi._
+
+/** Compute
+ *
+ * Add-by-one procedure:
+ *
+ * 1. Wait for launch to be asserted
+ * 2. Issue a read request for 8-byte value at inp_baddr address
+ * 3. Wait for the value
+ * 4. Issue a write request for 8-byte value at out_baddr address
+ * 5. Increment read-address and write-address for next value
+ * 6. Check if counter (cnt) is equal to length to assert finish,
+ * otherwise go to step 2.
+ */
+class Compute extends Module {
+ val io = IO(new Bundle {
+ val launch = Input(Bool())
+ val finish = Output(Bool())
+ val length = Input(UInt(32.W))
+ val inp_baddr = Input(UInt(64.W))
+ val out_baddr = Input(UInt(64.W))
+ val mem = new VTAMemDPIMaster
+ })
+ val sIdle :: sReadReq :: sReadData :: sWriteReq :: sWriteData :: Nil = Enum(5)
+ val state = RegInit(sIdle)
+ val reg = Reg(chiselTypeOf(io.mem.rd.bits))
+ val cnt = Reg(chiselTypeOf(io.length))
+ val raddr = Reg(chiselTypeOf(io.inp_baddr))
+ val waddr = Reg(chiselTypeOf(io.out_baddr))
+
+ switch (state) {
+ is (sIdle) {
+ when (io.launch) {
+ state := sReadReq
+ }
+ }
+ is (sReadReq) {
+ state := sReadData
+ }
+ is (sReadData) {
+ when (io.mem.rd.valid) {
+ state := sWriteReq
+ }
+ }
+ is (sWriteReq) {
+ state := sWriteData
+ }
+ is (sWriteData) {
+ when (cnt === (io.length - 1.U)) {
+ state := sIdle
+ } .otherwise {
+ state := sReadReq
+ }
+ }
+ }
+
+ // calculate next address
+ when (state === sIdle) {
+ raddr := io.inp_baddr
+ waddr := io.out_baddr
+ } .elsewhen (state === sWriteData) { // increment by 8-bytes
+ raddr := raddr + 8.U
+ waddr := waddr + 8.U
+ }
+
+ // create request
+ io.mem.req.valid := state === sReadReq | state === sWriteReq
+ io.mem.req.opcode := state === sWriteReq
+ io.mem.req.len := 0.U // one-word-per-request
+ io.mem.req.addr := Mux(state === sReadReq, raddr, waddr)
+
+ // read
+ when (state === sReadData && io.mem.rd.valid) {
+ reg := io.mem.rd.bits + 1.U
+ }
+ io.mem.rd.ready := state === sReadData
+
+ // write
+ io.mem.wr.valid := state === sWriteData
+ io.mem.wr.bits := reg
+
+ // count read/write
+ when (state === sIdle) {
+ cnt := 0.U
+ } .elsewhen (state === sWriteData) {
+ cnt := cnt + 1.U
+ }
+
+ // done when read/write are equal to length
+ io.finish := state === sWriteData && cnt === (io.length - 1.U)
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package accel
+
+import chisel3._
+import chisel3.util._
+import vta.dpi._
+
+/** Register File.
+ *
+ * Six 32-bit register file.
+ *
+ * -------------------------------
+ * Register description | addr
+ * -------------------------|-----
+ * Control status register | 0x00
+ * Length value register | 0x04
+ * Input pointer lsb | 0x08
+ * Input pointer msb | 0x0c
+ * Output pointer lsb | 0x10
+ * Output pointer msb | 0x14
+ * -------------------------------
+
+ * ------------------------------
+ * Control status register | bit
+ * ------------------------------
+ * Launch | 0
+ * Finish | 1
+ * ------------------------------
+ */
+class RegFile extends Module {
+ val io = IO(new Bundle {
+ val launch = Output(Bool())
+ val finish = Input(Bool())
+ val length = Output(UInt(32.W))
+ val inp_baddr = Output(UInt(64.W))
+ val out_baddr = Output(UInt(64.W))
+ val host = new VTAHostDPIClient
+ })
+ val sIdle :: sRead :: Nil = Enum(2)
+ val state = RegInit(sIdle)
+
+ switch (state) {
+ is (sIdle) {
+ when (io.host.req.valid && !io.host.req.opcode) {
+ state := sRead
+ }
+ }
+ is (sRead) {
+ state := sIdle
+ }
+ }
+
+ io.host.req.deq := state === sIdle & io.host.req.valid
+
+ val reg = Seq.fill(6)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
+ val addr = Seq.tabulate(6)(_ * 4)
+ val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
+
+ (reg zip addr).foreach { case(r, a) =>
+ if (a == 0) { // control status register
+ when (io.finish) {
+ r := "b_10".U
+ } .elsewhen (state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && a.U === io.host.req.addr) {
+ r := io.host.req.value
+ }
+ } else {
+ when (state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && a.U === io.host.req.addr) {
+ r := io.host.req.value
+ }
+ }
+ }
+
+ val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
+ when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
+ rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
+ }
+
+ io.host.resp.valid := state === sRead
+ io.host.resp.bits := rdata
+
+ io.launch := reg(0)(0)
+ io.length := reg(1)
+ io.inp_baddr := Cat(reg(3), reg(2))
+ io.out_baddr := Cat(reg(5), reg(4))
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package test
+
+import chisel3._
+import chisel3.experimental.{RawModule, withClockAndReset}
+import vta.dpi._
+import accel._
+
+/** VTA simulation shell.
+ *
+ * Instantiate Host and Memory DPI modules.
+ *
+ */
+class VTASimShell extends RawModule {
+ val io = IO(new Bundle {
+ val clock = Input(Clock())
+ val reset = Input(Bool())
+ val host = new VTAHostDPIMaster
+ val mem = new VTAMemDPIClient
+ })
+ val host = Module(new VTAHostDPI)
+ val mem = Module(new VTAMemDPI)
+ mem.io.reset := io.reset
+ mem.io.clock := io.clock
+ host.io.reset := io.reset
+ host.io.clock := io.clock
+ io.mem <> mem.io.dpi
+ io.host <> host.io.dpi
+}
+
+/** Test accelerator.
+ *
+ * Instantiate and connect the simulation-shell and the accelerator.
+ *
+ */
+class TestAccel extends RawModule {
+ val clock = IO(Input(Clock()))
+ val reset = IO(Input(Bool()))
+
+ val sim_shell = Module(new VTASimShell)
+ val vta_accel = withClockAndReset(clock, reset) { Module(new Accel) }
+
+ sim_shell.io.clock := clock
+ sim_shell.io.reset := reset
+ vta_accel.io.host <> sim_shell.io.host
+ sim_shell.io.mem <> vta_accel.io.mem
+}
+
+/** Generate TestAccel as top module */
+object Elaborate extends App {
+ chisel3.Driver.execute(args, () => new TestAccel)
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/** Add-by-one accelerator.
+ *
+ * ___________ ___________
+ * | | | |
+ * | HostDPI | <--> | RegFile | <->|
+ * |_________| |_________| |
+ * |
+ * ___________ ___________ |
+ * | | | | |
+ * | MemDPI | <--> | Compute | <->|
+ * |_________| |_________|
+ *
+ */
+module Accel #
+( parameter HOST_ADDR_BITS = 8,
+ parameter HOST_DATA_BITS = 32,
+ parameter MEM_LEN_BITS = 8,
+ parameter MEM_ADDR_BITS = 64,
+ parameter MEM_DATA_BITS = 64
+)
+(
+ input clock,
+ input reset,
+
+ input host_req_valid,
+ input host_req_opcode,
+ input [HOST_ADDR_BITS-1:0] host_req_addr,
+ input [HOST_DATA_BITS-1:0] host_req_value,
+ output host_req_deq,
+ output host_resp_valid,
+ output [HOST_DATA_BITS-1:0] host_resp_bits,
+
+ output mem_req_valid,
+ output mem_req_opcode,
+ output [MEM_LEN_BITS-1:0] mem_req_len,
+ output [MEM_ADDR_BITS-1:0] mem_req_addr,
+ output mem_wr_valid,
+ output [MEM_DATA_BITS-1:0] mem_wr_bits,
+ input mem_rd_valid,
+ input [MEM_DATA_BITS-1:0] mem_rd_bits,
+ output mem_rd_ready
+);
+
+ logic launch;
+ logic finish;
+ logic [HOST_DATA_BITS-1:0] length;
+ logic [MEM_ADDR_BITS-1:0] inp_baddr;
+ logic [MEM_ADDR_BITS-1:0] out_baddr;
+
+ RegFile #
+ (
+ .MEM_ADDR_BITS(MEM_ADDR_BITS),
+ .HOST_ADDR_BITS(HOST_ADDR_BITS),
+ .HOST_DATA_BITS(HOST_DATA_BITS)
+ )
+ rf
+ (
+ .clock (clock),
+ .reset (reset),
+
+ .host_req_valid (host_req_valid),
+ .host_req_opcode (host_req_opcode),
+ .host_req_addr (host_req_addr),
+ .host_req_value (host_req_value),
+ .host_req_deq (host_req_deq),
+ .host_resp_valid (host_resp_valid),
+ .host_resp_bits (host_resp_bits),
+
+ .launch (launch),
+ .finish (finish),
+ .length (length),
+ .inp_baddr (inp_baddr),
+ .out_baddr (out_baddr)
+ );
+
+ Compute #
+ (
+ .MEM_LEN_BITS(MEM_LEN_BITS),
+ .MEM_ADDR_BITS(MEM_ADDR_BITS),
+ .MEM_DATA_BITS(MEM_DATA_BITS),
+ .HOST_DATA_BITS(HOST_DATA_BITS)
+ )
+ comp
+ (
+ .clock (clock),
+ .reset (reset),
+
+ .mem_req_valid (mem_req_valid),
+ .mem_req_opcode (mem_req_opcode),
+ .mem_req_len (mem_req_len),
+ .mem_req_addr (mem_req_addr),
+ .mem_wr_valid (mem_wr_valid),
+ .mem_wr_bits (mem_wr_bits),
+ .mem_rd_valid (mem_rd_valid),
+ .mem_rd_bits (mem_rd_bits),
+ .mem_rd_ready (mem_rd_ready),
+
+ .launch (launch),
+ .finish (finish),
+ .length (length),
+ .inp_baddr (inp_baddr),
+ .out_baddr (out_baddr)
+ );
+
+endmodule
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/** Compute
+ *
+ * Add-by-one procedure:
+ *
+ * 1. Wait for launch to be asserted
+ * 2. Issue a read request for 8-byte value at inp_baddr address
+ * 3. Wait for the value
+ * 4. Issue a write request for 8-byte value at out_baddr address
+ * 5. Increment read-address and write-address for next value
+ * 6. Check if counter (cnt) is equal to length to assert finish,
+ * otherwise go to step 2.
+ */
+module Compute #
+(
+ parameter MEM_LEN_BITS = 8,
+ parameter MEM_ADDR_BITS = 64,
+ parameter MEM_DATA_BITS = 64,
+ parameter HOST_DATA_BITS = 32
+)
+(
+ input clock,
+ input reset,
+
+ output mem_req_valid,
+ output mem_req_opcode,
+ output [MEM_LEN_BITS-1:0] mem_req_len,
+ output [MEM_ADDR_BITS-1:0] mem_req_addr,
+ output mem_wr_valid,
+ output [MEM_DATA_BITS-1:0] mem_wr_bits,
+ input mem_rd_valid,
+ input [MEM_DATA_BITS-1:0] mem_rd_bits,
+ output mem_rd_ready,
+
+ input launch,
+ output finish,
+ input [HOST_DATA_BITS-1:0] length,
+ input [MEM_ADDR_BITS-1:0] inp_baddr,
+ input [MEM_ADDR_BITS-1:0] out_baddr
+);
+
+ typedef enum logic [2:0] {IDLE,
+ READ_REQ,
+ READ_DATA,
+ WRITE_REQ,
+ WRITE_DATA} state_t;
+
+ state_t state_n, state_r;
+
+ logic [31:0] cnt;
+ logic [MEM_DATA_BITS-1:0] data;
+ logic [MEM_ADDR_BITS-1:0] raddr;
+ logic [MEM_ADDR_BITS-1:0] waddr;
+
+ always_ff @(posedge clock) begin
+ if (reset) begin
+ state_r <= IDLE;
+ end else begin
+ state_r <= state_n;
+ end
+ end
+
+ always_comb begin
+ state_n = IDLE;
+ case (state_r)
+ IDLE: begin
+ if (launch) begin
+ state_n = READ_REQ;
+ end
+ end
+
+ READ_REQ: begin
+ state_n = READ_DATA;
+ end
+
+ READ_DATA: begin
+ if (mem_rd_valid) begin
+ state_n = WRITE_REQ;
+ end else begin
+ state_n = READ_DATA;
+ end
+ end
+
+ WRITE_REQ: begin
+ state_n = WRITE_DATA;
+ end
+
+ WRITE_DATA: begin
+ if (cnt == (length - 1'b1)) begin
+ state_n = IDLE;
+ end else begin
+ state_n = READ_REQ;
+ end
+ end
+
+ default: begin
+ end
+ endcase
+ end
+
+ // calculate next address
+ always_ff @(posedge clock) begin
+ if (reset | state_r == IDLE) begin
+ raddr <= inp_baddr;
+ waddr <= out_baddr;
+ end else if (state_r == WRITE_DATA) begin
+ raddr <= raddr + 'd8;
+ waddr <= waddr + 'd8;
+ end
+ end
+
+ // create request
+ assign mem_req_valid = (state_r == READ_REQ) | (state_r == WRITE_REQ);
+ assign mem_req_opcode = state_r == WRITE_REQ;
+ assign mem_req_len = 'd0; // one-word-per-request
+ assign mem_req_addr = (state_r == READ_REQ)? raddr : waddr;
+
+ // read
+ always_ff @(posedge clock) begin
+ if ((state_r == READ_DATA) & mem_rd_valid) begin
+ data <= mem_rd_bits + 1'b1;
+ end
+ end
+ assign mem_rd_ready = state_r == READ_DATA;
+
+ // write
+ assign mem_wr_valid = state_r == WRITE_DATA;
+ assign mem_wr_bits = data;
+
+ // count read/write
+ always_ff @(posedge clock) begin
+ if (reset | state_r == IDLE) begin
+ cnt <= 'd0;
+ end else if (state_r == WRITE_DATA) begin
+ cnt <= cnt + 1'b1;
+ end
+ end
+
+ // done when read/write are equal to length
+ assign finish = (state_r == WRITE_DATA) & (cnt == (length - 1'b1));
+endmodule
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/** Register File.
+ *
+ * Six 32-bit register file.
+ *
+ * -------------------------------
+ * Register description | addr
+ * -------------------------|-----
+ * Control status register | 0x00
+ * Length value register | 0x04
+ * Input pointer lsb | 0x08
+ * Input pointer msb | 0x0c
+ * Output pointer lsb | 0x10
+ * Output pointer msb | 0x14
+ * -------------------------------
+
+ * ------------------------------
+ * Control status register | bit
+ * ------------------------------
+ * Launch | 0
+ * Finish | 1
+ * ------------------------------
+ */
+module RegFile #
+ (parameter MEM_ADDR_BITS = 64,
+ parameter HOST_ADDR_BITS = 8,
+ parameter HOST_DATA_BITS = 32
+)
+(
+ input clock,
+ input reset,
+
+ input host_req_valid,
+ input host_req_opcode,
+ input [HOST_ADDR_BITS-1:0] host_req_addr,
+ input [HOST_DATA_BITS-1:0] host_req_value,
+ output host_req_deq,
+ output host_resp_valid,
+ output [HOST_DATA_BITS-1:0] host_resp_bits,
+
+ output launch,
+ input finish,
+ output [HOST_DATA_BITS-1:0] length,
+ output [MEM_ADDR_BITS-1:0] inp_baddr,
+ output [MEM_ADDR_BITS-1:0] out_baddr
+);
+
+ typedef enum logic {IDLE, READ} state_t;
+ state_t state_n, state_r;
+
+ always_ff @(posedge clock) begin
+ if (reset) begin
+ state_r <= IDLE;
+ end else begin
+ state_r <= state_n;
+ end
+ end
+
+ always_comb begin
+ state_n = IDLE;
+ case (state_r)
+ IDLE: begin
+ if (host_req_valid & ~host_req_opcode) begin
+ state_n = READ;
+ end
+ end
+
+ READ: begin
+ state_n = IDLE;
+ end
+ endcase
+ end
+
+ assign host_req_deq = (state_r == IDLE) ? host_req_valid : 1'b0;
+
+ logic [HOST_DATA_BITS-1:0] rf [5:0];
+
+ genvar i;
+ for (i = 0; i < 6; i++) begin
+ logic wen = (state_r == IDLE)? host_req_valid & host_req_opcode & i*4 == host_req_addr : 1'b0;
+ if (i == 0) begin
+ always_ff @(posedge clock) begin
+ if (reset) begin
+ end else if (finish) begin
+ rf[i] <= 'd2;
+ end else if (wen) begin
+ rf[i] <= host_req_value;
+ end
+ end
+ end else begin
+ always_ff @(posedge clock) begin
+ if (reset) begin
+ end else if (wen) begin
+ rf[i] <= host_req_value;
+ end
+ end
+ end
+ end
+
+ logic [HOST_DATA_BITS-1:0] rdata;
+ always_ff @(posedge clock) begin
+ if (reset) begin
+ rdata <= 'd0;
+ end else if ((state_r == IDLE) & host_req_valid & ~host_req_opcode) begin
+ if (host_req_addr == 'h00) begin
+ rdata <= rf[0];
+ end else if (host_req_addr == 'h04) begin
+ rdata <= rf[1];
+ end else if (host_req_addr == 'h08) begin
+ rdata <= rf[2];
+ end else if (host_req_addr == 'h0c) begin
+ rdata <= rf[3];
+ end else if (host_req_addr == 'h10) begin
+ rdata <= rf[4];
+ end else if (host_req_addr == 'h14) begin
+ rdata <= rf[5];
+ end else begin
+ rdata <= 'd0;
+ end
+ end
+ end
+
+ assign host_resp_valid = (state_r == READ);
+ assign host_resp_bits = rdata;
+
+ assign launch = rf[0][0];
+ assign length = rf[1];
+ assign inp_baddr = {rf[3], rf[2]};
+ assign out_baddr = {rf[5], rf[4]};
+
+endmodule
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/** Test accelerator.
+ *
+ * Instantiate host/memory DPI modules and connect them to the accelerator.
+ *
+ */
+module TestAccel
+(
+ input clock,
+ input reset
+);
+
+ localparam HOST_ADDR_BITS = 8;
+ localparam HOST_DATA_BITS = 32;
+
+ logic host_req_valid;
+ logic host_req_opcode;
+ logic [HOST_ADDR_BITS-1:0] host_req_addr;
+ logic [HOST_DATA_BITS-1:0] host_req_value;
+ logic host_req_deq;
+ logic host_resp_valid;
+ logic [HOST_DATA_BITS-1:0] host_resp_bits;
+
+ localparam MEM_LEN_BITS = 8;
+ localparam MEM_ADDR_BITS = 64;
+ localparam MEM_DATA_BITS = 64;
+
+ logic mem_req_valid;
+ logic mem_req_opcode;
+ logic [MEM_LEN_BITS-1:0] mem_req_len;
+ logic [MEM_ADDR_BITS-1:0] mem_req_addr;
+ logic mem_wr_valid;
+ logic [MEM_DATA_BITS-1:0] mem_wr_bits;
+ logic mem_rd_valid;
+ logic [MEM_DATA_BITS-1:0] mem_rd_bits;
+ logic mem_rd_ready;
+
+ VTAHostDPI host
+ (
+ .clock (clock),
+ .reset (reset),
+
+ .dpi_req_valid (host_req_valid),
+ .dpi_req_opcode (host_req_opcode),
+ .dpi_req_addr (host_req_addr),
+ .dpi_req_value (host_req_value),
+ .dpi_req_deq (host_req_deq),
+ .dpi_resp_valid (host_resp_valid),
+ .dpi_resp_bits (host_resp_bits)
+ );
+
+ VTAMemDPI mem
+ (
+ .clock (clock),
+ .reset (reset),
+
+ .dpi_req_valid (mem_req_valid),
+ .dpi_req_opcode (mem_req_opcode),
+ .dpi_req_len (mem_req_len),
+ .dpi_req_addr (mem_req_addr),
+ .dpi_wr_valid (mem_wr_valid),
+ .dpi_wr_bits (mem_wr_bits),
+ .dpi_rd_valid (mem_rd_valid),
+ .dpi_rd_bits (mem_rd_bits),
+ .dpi_rd_ready (mem_rd_ready)
+ );
+
+ Accel #
+ (
+ .HOST_ADDR_BITS(HOST_ADDR_BITS),
+ .HOST_DATA_BITS(HOST_DATA_BITS),
+ .MEM_LEN_BITS(MEM_LEN_BITS),
+ .MEM_ADDR_BITS(MEM_ADDR_BITS),
+ .MEM_DATA_BITS(MEM_DATA_BITS)
+ )
+ accel
+ (
+ .clock (clock),
+ .reset (reset),
+
+ .host_req_valid (host_req_valid),
+ .host_req_opcode (host_req_opcode),
+ .host_req_addr (host_req_addr),
+ .host_req_value (host_req_value),
+ .host_req_deq (host_req_deq),
+ .host_resp_valid (host_resp_valid),
+ .host_resp_bits (host_resp_bits),
+
+ .mem_req_valid (mem_req_valid),
+ .mem_req_opcode (mem_req_opcode),
+ .mem_req_len (mem_req_len),
+ .mem_req_addr (mem_req_addr),
+ .mem_wr_valid (mem_wr_valid),
+ .mem_wr_bits (mem_wr_bits),
+ .mem_rd_valid (mem_rd_valid),
+ .mem_rd_bits (mem_rd_bits),
+ .mem_rd_ready (mem_rd_ready)
+ );
+endmodule
--- /dev/null
+{
+ "TARGET" : "verilog",
+ "TOP_NAME" : "TestAccel",
+ "BUILD_NAME" : "build",
+ "USE_TRACE" : "off",
+ "TRACE_NAME" : "trace"
+}
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os.path as osp
+import sys
+import json
+import argparse
+
+cur = osp.abspath(osp.dirname(__file__))
+cfg = json.load(open(osp.join(cur, 'config.json')))
+
+def main():
+ """Main function"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--get-target", action="store_true",
+ help="Get target language, i.e. verilog or chisel")
+ parser.add_argument("--get-top-name", action="store_true",
+ help="Get hardware design top name")
+ parser.add_argument("--get-build-name", action="store_true",
+ help="Get build folder name")
+ parser.add_argument("--get-use-trace", action="store_true",
+ help="Get use trace")
+ parser.add_argument("--get-trace-name", action="store_true",
+ help="Get trace filename")
+ args = parser.parse_args()
+
+ if len(sys.argv) == 1:
+ parser.print_help()
+ return
+
+ if args.get_target:
+ print(cfg['TARGET'])
+
+ if args.get_top_name:
+ print(cfg['TOP_NAME'])
+
+ if args.get_build_name:
+ print(cfg['BUILD_NAME'])
+
+ if args.get_use_trace:
+ print(cfg['USE_TRACE'])
+
+ if args.get_trace_name:
+ print(cfg['TRACE_NAME'])
+
+if __name__ == "__main__":
+ main()
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import ctypes
+import json
+import os.path as osp
+from sys import platform
+
+def get_build_path():
+ curr_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+ cfg = json.load(open(osp.join(curr_path, 'config.json')))
+ return osp.join(curr_path, "..", "..", cfg['BUILD_NAME'])
+
+def get_lib_ext():
+ if platform == "darwin":
+ ext = ".dylib"
+ else:
+ ext = ".so"
+ return ext
+
+def get_lib_path(name):
+ build_path = get_build_path()
+ ext = get_lib_ext()
+ libname = name + ext
+ return osp.join(build_path, libname)
+
+def _load_driver_lib():
+ lib = get_lib_path("libdriver")
+ try:
+ return [ctypes.CDLL(lib, ctypes.RTLD_GLOBAL)]
+ except OSError:
+ return []
+
+def load_driver():
+ return tvm.get_global_func("tvm.vta.driver")
+
+def load_tsim():
+ lib = get_lib_path("libtsim")
+ return tvm.module.load(lib, "vta-tsim")
+
+LIBS = _load_driver_lib()
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <vta/dpi/module.h>
+
+namespace vta {
+namespace driver {
+
+uint32_t get_half_addr(void *p, bool upper) {
+ if (upper) {
+ return ((uint64_t) ((uint64_t*) p)) >> 32;
+ } else {
+ return ((uint64_t) ((uint64_t*) p));
+ }
+}
+
+using vta::dpi::DPIModuleNode;
+using tvm::runtime::Module;
+
+class TestDriver {
+ public:
+ TestDriver(Module module)
+ : module_(module) {
+ dpi_ = static_cast<DPIModuleNode*>(
+ module.operator->());
+ }
+
+ int Run(uint32_t length, void* inp, void* out) {
+ uint32_t wait_cycles = 100000000;
+ this->Launch(wait_cycles, length, inp, out);
+ this->WaitForCompletion(wait_cycles);
+ dpi_->Finish();
+ return 0;
+ }
+
+ private:
+ void Launch(uint32_t wait_cycles, uint32_t length, void* inp, void* out) {
+ dpi_->Launch(wait_cycles);
+ // write registers
+ dpi_->WriteReg(0x04, length);
+ dpi_->WriteReg(0x08, get_half_addr(inp, false));
+ dpi_->WriteReg(0x0c, get_half_addr(inp, true));
+ dpi_->WriteReg(0x10, get_half_addr(out, false));
+ dpi_->WriteReg(0x14, get_half_addr(out, true));
+ dpi_->WriteReg(0x00, 0x1); // launch
+ }
+
+ void WaitForCompletion(uint32_t wait_cycles) {
+ uint32_t i, val;
+ for (i = 0; i < wait_cycles; i++) {
+ val = dpi_->ReadReg(0x00);
+ if (val == 2) break; // finish
+ }
+ }
+
+ private:
+ DPIModuleNode* dpi_;
+ Module module_;
+};
+
+using tvm::runtime::TVMRetValue;
+using tvm::runtime::TVMArgs;
+
+TVM_REGISTER_GLOBAL("tvm.vta.driver")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ Module dev_mod = args[0];
+ DLTensor* A = args[1];
+ DLTensor* B = args[2];
+ TestDriver dev_(dev_mod);
+ dev_.Run(A->shape[0], A->data, B->data);
+ });
+
+} // namespace driver
+} // namespace vta
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tsim.load import load_driver, load_tsim
+
+def test_tsim(i):
+ rmin = 1 # min vector size of 1
+ rmax = 64
+ n = np.random.randint(rmin, rmax)
+ ctx = tvm.cpu(0)
+ a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
+ b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
+ tsim = load_tsim()
+ f = load_driver()
+ f(tsim, a, b)
+ emsg = "[FAIL] test number:{} n:{}".format(i, n)
+ np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1, err_msg=emsg)
+ print("[PASS] test number:{} n:{}".format(i, n))
+
+if __name__ == "__main__":
+ for i in range(10):
+ test_tsim(i)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+clean:
+ -rm -rf target project/target project/project
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+name := "vta"
+version := "0.1.0-SNAPSHOT"
+organization := "edu.washington.cs"
+
+def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
+ Seq() ++ {
+ // If we're building with Scala > 2.11, enable the compile option
+ // switch to support our anonymous Bundle definitions:
+ // https://github.com/scala/bug/issues/10047
+ CrossVersion.partialVersion(scalaVersion) match {
+ case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
+ case _ => Seq(
+ "-Xsource:2.11",
+ "-language:reflectiveCalls",
+ "-language:implicitConversions",
+ "-deprecation",
+ "-Xlint",
+ "-Ywarn-unused",
+ )
+ }
+ }
+}
+
+def javacOptionsVersion(scalaVersion: String): Seq[String] = {
+ Seq() ++ {
+ // Scala 2.12 requires Java 8. We continue to generate
+ // Java 7 compatible code for Scala 2.11
+ // for compatibility with old clients.
+ CrossVersion.partialVersion(scalaVersion) match {
+ case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
+ Seq("-source", "1.7", "-target", "1.7")
+ case _ =>
+ Seq("-source", "1.8", "-target", "1.8")
+ }
+ }
+}
+
+scalaVersion := "2.11.12"
+
+resolvers ++= Seq(
+ Resolver.sonatypeRepo("snapshots"),
+ Resolver.sonatypeRepo("releases"))
+
+libraryDependencies ++= Seq(
+ "edu.berkeley.cs" %% "chisel3" % "3.1.7",
+)
+
+scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
+javacOptions ++= javacOptionsVersion(scalaVersion.value)
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+sbt.version = 1.1.1
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+logLevel := Level.Warn
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+module VTAHostDPI #
+( parameter ADDR_BITS = 8,
+ parameter DATA_BITS = 32
+)
+(
+ input clock,
+ input reset,
+ output logic dpi_req_valid,
+ output logic dpi_req_opcode,
+ output logic [ADDR_BITS-1:0] dpi_req_addr,
+ output logic [DATA_BITS-1:0] dpi_req_value,
+ input dpi_req_deq,
+ input dpi_resp_valid,
+ input [DATA_BITS-1:0] dpi_resp_bits
+);
+
+ import "DPI-C" function void VTAHostDPI
+ (
+ output byte unsigned exit,
+ output byte unsigned req_valid,
+ output byte unsigned req_opcode,
+ output byte unsigned req_addr,
+ output int unsigned req_value,
+ input byte unsigned req_deq,
+ input byte unsigned resp_valid,
+ input int unsigned resp_value
+ );
+
+ typedef logic dpi1_t;
+ typedef logic [7:0] dpi8_t;
+ typedef logic [31:0] dpi32_t;
+
+ dpi1_t __reset;
+ dpi8_t __exit;
+ dpi8_t __req_valid;
+ dpi8_t __req_opcode;
+ dpi8_t __req_addr;
+ dpi32_t __req_value;
+ dpi8_t __req_deq;
+ dpi8_t __resp_valid;
+ dpi32_t __resp_bits;
+
+ // reset
+ always_ff @(posedge clock) begin
+ __reset <= reset;
+ end
+
+ // delaying outputs by one-cycle
+ // since verilator does not support delays
+ always_ff @(posedge clock) begin
+ dpi_req_valid <= dpi1_t ' (__req_valid);
+ dpi_req_opcode <= dpi1_t ' (__req_opcode);
+ dpi_req_addr <= __req_addr;
+ dpi_req_value <= __req_value;
+ end
+
+ assign __req_deq = dpi8_t ' (dpi_req_deq);
+ assign __resp_valid = dpi8_t ' (dpi_resp_valid);
+ assign __resp_bits = dpi_resp_bits;
+
+ // evaluate DPI function
+ always_ff @(posedge clock) begin
+ if (reset | __reset) begin
+ __exit = 0;
+ __req_valid = 0;
+ __req_opcode = 0;
+ __req_addr = 0;
+ __req_value = 0;
+ end
+ else begin
+ VTAHostDPI(
+ __exit,
+ __req_valid,
+ __req_opcode,
+ __req_addr,
+ __req_value,
+ __req_deq,
+ __resp_valid,
+ __resp_bits);
+ end
+ end
+
+ logic [63:0] cycles;
+
+ always_ff @(posedge clock) begin
+ if (reset | __reset) begin
+ cycles <= 'd0;
+ end
+ else begin
+ cycles <= cycles + 1'b1;
+ end
+ end
+
+ always_ff @(posedge clock) begin
+ if (__exit == 'd1) begin
+ $display("[DONE] at cycle:%016d", cycles);
+ $finish;
+ end
+ end
+
+endmodule
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+module VTAMemDPI #
+( parameter LEN_BITS = 8,
+ parameter ADDR_BITS = 64,
+ parameter DATA_BITS = 64
+)
+(
+ input clock,
+ input reset,
+ input dpi_req_valid,
+ input dpi_req_opcode,
+ input [LEN_BITS-1:0] dpi_req_len,
+ input [ADDR_BITS-1:0] dpi_req_addr,
+ input dpi_wr_valid,
+ input [DATA_BITS-1:0] dpi_wr_bits,
+ output logic dpi_rd_valid,
+ output logic [DATA_BITS-1:0] dpi_rd_bits,
+ input dpi_rd_ready
+);
+
+ import "DPI-C" function void VTAMemDPI
+ (
+ input byte unsigned req_valid,
+ input byte unsigned req_opcode,
+ input byte unsigned req_len,
+ input longint unsigned req_addr,
+ input byte unsigned wr_valid,
+ input longint unsigned wr_value,
+ output byte unsigned rd_valid,
+ output longint unsigned rd_value,
+ input byte unsigned rd_ready
+ );
+
+ typedef logic dpi1_t;
+ typedef logic [7:0] dpi8_t;
+ typedef logic [31:0] dpi32_t;
+ typedef logic [63:0] dpi64_t;
+
+ dpi1_t __reset;
+ dpi8_t __req_valid;
+ dpi8_t __req_opcode;
+ dpi8_t __req_len;
+ dpi64_t __req_addr;
+ dpi8_t __wr_valid;
+ dpi64_t __wr_value;
+ dpi8_t __rd_valid;
+ dpi64_t __rd_value;
+ dpi8_t __rd_ready;
+
+ always_ff @(posedge clock) begin
+ __reset <= reset;
+ end
+
+ // delaying outputs by one-cycle
+ // since verilator does not support delays
+ always_ff @(posedge clock) begin
+ dpi_rd_valid <= dpi1_t ' (__rd_valid);
+ dpi_rd_bits <= __rd_value;
+ end
+
+ assign __req_valid = dpi8_t ' (dpi_req_valid);
+ assign __req_opcode = dpi8_t ' (dpi_req_opcode);
+ assign __req_len = dpi_req_len;
+ assign __req_addr = dpi_req_addr;
+ assign __wr_valid = dpi8_t ' (dpi_wr_valid);
+ assign __wr_value = dpi_wr_bits;
+ assign __rd_ready = dpi8_t ' (dpi_rd_ready);
+
+ // evaluate DPI function
+ always_ff @(posedge clock) begin
+ if (reset | __reset) begin
+ __rd_valid = 0;
+ __rd_value = 0;
+ end
+ else begin
+ VTAMemDPI(
+ __req_valid,
+ __req_opcode,
+ __req_len,
+ __req_addr,
+ __wr_valid,
+ __wr_value,
+ __rd_valid,
+ __rd_value,
+ __rd_ready);
+ end
+ end
+endmodule
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.dpi
+
+import chisel3._
+import chisel3.util._
+
+/** Host DPI parameters */
+trait VTAHostDPIParams {
+ val dpiAddrBits = 8
+ val dpiDataBits = 32
+}
+
+/** Host master interface.
+ *
+ * This interface is tipically used by the Host
+ */
+class VTAHostDPIMaster extends Bundle with VTAHostDPIParams {
+ val req = new Bundle {
+ val valid = Output(Bool())
+ val opcode = Output(Bool())
+ val addr = Output(UInt(dpiAddrBits.W))
+ val value = Output(UInt(dpiDataBits.W))
+ val deq = Input(Bool())
+ }
+ val resp = Flipped(ValidIO(UInt(dpiDataBits.W)))
+}
+
+/** Host client interface.
+ *
+ * This interface is tipically used by the Accelerator
+ */
+class VTAHostDPIClient extends Bundle with VTAHostDPIParams {
+ val req = new Bundle {
+ val valid = Input(Bool())
+ val opcode = Input(Bool())
+ val addr = Input(UInt(dpiAddrBits.W))
+ val value = Input(UInt(dpiDataBits.W))
+ val deq = Output(Bool())
+ }
+ val resp = ValidIO(UInt(dpiDataBits.W))
+}
+
+/** Host DPI module.
+ *
+ * Wrapper for Host Verilog DPI module.
+ */
+class VTAHostDPI extends BlackBox with HasBlackBoxResource {
+ val io = IO(new Bundle {
+ val clock = Input(Clock())
+ val reset = Input(Bool())
+ val dpi = new VTAHostDPIMaster
+ })
+ setResource("/verilog/VTAHostDPI.v")
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.dpi
+
+import chisel3._
+import chisel3.util._
+
+/** Memory DPI parameters */
+trait VTAMemDPIParams {
+ val dpiLenBits = 8
+ val dpiAddrBits = 64
+ val dpiDataBits = 64
+}
+
+/** Memory master interface.
+ *
+ * This interface is tipically used by the Accelerator
+ */
+class VTAMemDPIMaster extends Bundle with VTAMemDPIParams {
+ val req = new Bundle {
+ val valid = Output(Bool())
+ val opcode = Output(Bool())
+ val len = Output(UInt(dpiLenBits.W))
+ val addr = Output(UInt(dpiAddrBits.W))
+ }
+ val wr = ValidIO(UInt(dpiDataBits.W))
+ val rd = Flipped(Decoupled(UInt(dpiDataBits.W)))
+}
+
+/** Memory client interface.
+ *
+ * This interface is tipically used by the Host
+ */
+class VTAMemDPIClient extends Bundle with VTAMemDPIParams {
+ val req = new Bundle {
+ val valid = Input(Bool())
+ val opcode = Input(Bool())
+ val len = Input(UInt(dpiLenBits.W))
+ val addr = Input(UInt(dpiAddrBits.W))
+ }
+ val wr = Flipped(ValidIO(UInt(dpiDataBits.W)))
+ val rd = Decoupled(UInt(dpiDataBits.W))
+}
+
+/** Memory DPI module.
+ *
+ * Wrapper for Memory Verilog DPI module.
+ */
+class VTAMemDPI extends BlackBox with HasBlackBoxResource {
+ val io = IO(new Bundle {
+ val clock = Input(Clock())
+ val reset = Input(Bool())
+ val dpi = new VTAMemDPIClient
+ })
+ setResource("/verilog/VTAMemDPI.v")
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <vta/dpi/tsim.h>
+
+#if VM_TRACE
+#include <verilated_vcd_c.h>
+#endif
+
+#if VM_TRACE
+#define STRINGIZE(x) #x
+#define STRINGIZE_VALUE_OF(x) STRINGIZE(x)
+#endif
+
+static VTAContextHandle _ctx = nullptr;
+static VTAMemDPIFunc _mem_dpi = nullptr;
+static VTAHostDPIFunc _host_dpi = nullptr;
+
+void VTAHostDPI(dpi8_t* exit,
+ dpi8_t* req_valid,
+ dpi8_t* req_opcode,
+ dpi8_t* req_addr,
+ dpi32_t* req_value,
+ dpi8_t req_deq,
+ dpi8_t resp_valid,
+ dpi32_t resp_value) {
+ assert(_host_dpi != nullptr);
+ (*_host_dpi)(_ctx, exit, req_valid, req_opcode,
+ req_addr, req_value, req_deq,
+ resp_valid, resp_value);
+}
+
+void VTAMemDPI(dpi8_t req_valid,
+ dpi8_t req_opcode,
+ dpi8_t req_len,
+ dpi64_t req_addr,
+ dpi8_t wr_valid,
+ dpi64_t wr_value,
+ dpi8_t* rd_valid,
+ dpi64_t* rd_value,
+ dpi8_t rd_ready) {
+ assert(_mem_dpi != nullptr);
+ (*_mem_dpi)(_ctx, req_valid, req_opcode, req_len,
+ req_addr, wr_valid, wr_value,
+ rd_valid, rd_value, rd_ready);
+
+}
+
+void VTADPIInit(VTAContextHandle handle,
+ VTAHostDPIFunc host_dpi,
+ VTAMemDPIFunc mem_dpi) {
+ _ctx = handle;
+ _host_dpi = host_dpi;
+ _mem_dpi = mem_dpi;
+}
+
+int VTADPISim(uint64_t max_cycles) {
+ uint64_t trace_count = 0;
+
+#if VM_TRACE
+ uint64_t start = 0;
+#endif
+
+ VL_TSIM_NAME* top = new VL_TSIM_NAME;
+
+#if VM_TRACE
+ Verilated::traceEverOn(true);
+ VerilatedVcdC* tfp = new VerilatedVcdC;
+ top->trace(tfp, 99);
+ tfp->open(STRINGIZE_VALUE_OF(TSIM_TRACE_FILE));
+#endif
+
+ // reset
+ for (int i = 0; i < 10; i++) {
+ top->reset = 1;
+ top->clock = 0;
+ top->eval();
+#if VM_TRACE
+ if (trace_count >= start)
+ tfp->dump(static_cast<vluint64_t>(trace_count * 2));
+#endif
+ top->clock = 1;
+ top->eval();
+#if VM_TRACE
+ if (trace_count >= start)
+ tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1));
+#endif
+ trace_count++;
+ }
+ top->reset = 0;
+
+ // start simulation
+ while (!Verilated::gotFinish() && trace_count < max_cycles) {
+ top->clock = 0;
+ top->eval();
+#if VM_TRACE
+ if (trace_count >= start)
+ tfp->dump(static_cast<vluint64_t>(trace_count * 2));
+#endif
+ top->clock = 1;
+ top->eval();
+#if VM_TRACE
+ if (trace_count >= start)
+ tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1));
+#endif
+ trace_count++;
+ }
+
+#if VM_TRACE
+ tfp->close();
+#endif
+
+ delete top;
+
+ return 0;
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef VTA_DPI_MODULE_H_
+#define VTA_DPI_MODULE_H_
+
+#include <tvm/runtime/module.h>
+#include <mutex>
+#include <queue>
+#include <condition_variable>
+#include <string>
+
+namespace vta {
+namespace dpi {
+
+/*!
+ * \brief DPI driver module for managing the accelerator
+ */
+class DPIModuleNode : public tvm::runtime::ModuleNode {
+ public:
+/*!
+ * \brief Launch accelerator until it finishes or reach max_cycles
+ * \param max_cycles The maximum of cycles to wait
+ */
+ virtual void Launch(uint64_t max_cycles) = 0;
+
+/*!
+ * \brief Write an accelerator register
+ * \param addr The register address
+ * \param value The register value
+ */
+ virtual void WriteReg(int addr, uint32_t value) = 0;
+
+/*!
+ * \brief Read an accelerator register
+ * \param addr The register address
+ */
+ virtual uint32_t ReadReg(int addr) = 0;
+
+/*! \brief Kill or Exit() the accelerator */
+ virtual void Finish() = 0;
+
+ static tvm::runtime::Module Load(std::string dll_name);
+};
+
+} // namespace dpi
+} // namespace vta
+#endif // VTA_DPI_MODULE_H_
+
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef VTA_DPI_TSIM_H_
+#define VTA_DPI_TSIM_H_
+
+#include <tvm/runtime/c_runtime_api.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef unsigned char dpi8_t;
+
+typedef unsigned int dpi32_t;
+
+typedef unsigned long long dpi64_t; // NOLINT(*)
+
+/*! \brief the context handle */
+typedef void* VTAContextHandle;
+
+/*!
+ * \brief Host DPI callback function that is invoked in VTAHostDPI.v every clock cycle
+ * \param exit Host kill simulation
+ * \param req_valid Host has a valid request for read or write a register in Accel
+ * \param req_opcode Host request type, opcode=0 for read and opcode=1 for write
+ * \param req_addr Host request register address
+ * \param req_value Host request value to be written to a register
+ * \param req_deq Accel is ready to dequeue Host request
+ * \param resp_valid Accel has a valid response for Host
+ * \param resp_value Accel response value for Host
+ * \return 0 if success,
+ */
+typedef void (*VTAHostDPIFunc)(
+ VTAContextHandle self,
+ dpi8_t* exit,
+ dpi8_t* req_valid,
+ dpi8_t* req_opcode,
+ dpi8_t* req_addr,
+ dpi32_t* req_value,
+ dpi8_t req_deq,
+ dpi8_t resp_valid,
+ dpi32_t resp_value);
+
+/*!
+ * \brief Memory DPI callback function that is invoked in VTAMemDPI.v every clock cycle
+ * \param req_valid Accel has a valid request for Host
+ * \param req_opcode Accel request type, opcode=0 (read) and opcode=1 (write)
+ * \param req_len Accel request length of size 8-byte and starts at 0
+ * \param req_addr Accel request base address
+ * \param wr_valid Accel has a valid value for Host
+ * \param wr_value Accel has a value to be written Host
+ * \param rd_valid Host has a valid value for Accel
+ * \param rd_value Host has a value to be read by Accel
+ */
+typedef void (*VTAMemDPIFunc)(
+ VTAContextHandle self,
+ dpi8_t req_valid,
+ dpi8_t req_opcode,
+ dpi8_t req_len,
+ dpi64_t req_addr,
+ dpi8_t wr_valid,
+ dpi64_t wr_value,
+ dpi8_t* rd_valid,
+ dpi64_t* rd_value,
+ dpi8_t rd_ready);
+
+/*! \brief The type of VTADPIInit function pointer */
+typedef void (*VTADPIInitFunc)(VTAContextHandle handle,
+ VTAHostDPIFunc host_dpi,
+ VTAMemDPIFunc mem_dpi);
+
+
+/*! \brief The type of VTADPISim function pointer */
+typedef int (*VTADPISimFunc)(uint64_t max_cycles);
+
+/*!
+ * \brief Set Host and Memory DPI functions
+ * \param handle DPI Context handle
+ * \param host_dpi Host DPI function
+ * \param mem_dpi Memory DPI function
+ */
+TVM_DLL void VTADPIInit(VTAContextHandle handle,
+ VTAHostDPIFunc host_dpi,
+ VTAMemDPIFunc mem_dpi);
+
+/*!
+ * \brief Instantiate VTA design and generate clock/reset
+ * \param max_cycles The maximum number of simulation cycles
+ */
+TVM_DLL int VTADPISim(uint64_t max_cycles);
+
+#ifdef __cplusplus
+}
+#endif
+#endif // VTA_DPI_TSIM_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <vta/dpi/module.h>
+#include <vta/dpi/tsim.h>
+#if defined(_WIN32)
+#include <windows.h>
+#else
+#include <dlfcn.h>
+#endif
+
+#include <mutex>
+#include <queue>
+#include <thread>
+#include <condition_variable>
+
+namespace vta {
+namespace dpi {
+
+using namespace tvm::runtime;
+
+typedef void* DeviceHandle;
+
+struct HostRequest {
+ uint8_t opcode;
+ uint8_t addr;
+ uint32_t value;
+};
+
+struct HostResponse {
+ uint32_t value;
+};
+
+struct MemResponse {
+ uint8_t valid;
+ uint64_t value;
+};
+
+template <typename T>
+class ThreadSafeQueue {
+ public:
+ void Push(const T item) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ queue_.push(std::move(item));
+ cond_.notify_one();
+ }
+
+ void WaitPop(T* item) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ cond_.wait(lock, [this]{return !queue_.empty();});
+ *item = std::move(queue_.front());
+ queue_.pop();
+ }
+
+ bool TryPop(T* item, bool pop) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (queue_.empty()) return false;
+ *item = std::move(queue_.front());
+ if (pop) queue_.pop();
+ return true;
+ }
+
+ private:
+ mutable std::mutex mutex_;
+ std::queue<T> queue_;
+ std::condition_variable cond_;
+};
+
+class HostDevice {
+ public:
+ void PushRequest(uint8_t opcode, uint8_t addr, uint32_t value);
+ bool TryPopRequest(HostRequest* r, bool pop);
+ void PushResponse(uint32_t value);
+ void WaitPopResponse(HostResponse* r);
+ void Exit();
+ uint8_t GetExitStatus();
+
+ private:
+ uint8_t exit_{0};
+ mutable std::mutex mutex_;
+ ThreadSafeQueue<HostRequest> req_;
+ ThreadSafeQueue<HostResponse> resp_;
+};
+
+class MemDevice {
+ public:
+ void SetRequest(uint8_t opcode, uint64_t addr, uint32_t len);
+ MemResponse ReadData(uint8_t ready);
+ void WriteData(uint64_t value);
+
+ private:
+ uint64_t* raddr_{0};
+ uint64_t* waddr_{0};
+ uint32_t rlen_{0};
+ uint32_t wlen_{0};
+ std::mutex mutex_;
+};
+
+void HostDevice::PushRequest(uint8_t opcode, uint8_t addr, uint32_t value) {
+ HostRequest r;
+ r.opcode = opcode;
+ r.addr = addr;
+ r.value = value;
+ req_.Push(r);
+}
+
+bool HostDevice::TryPopRequest(HostRequest* r, bool pop) {
+ r->opcode = 0xad;
+ r->addr = 0xad;
+ r->value = 0xbad;
+ return req_.TryPop(r, pop);
+}
+
+void HostDevice::PushResponse(uint32_t value) {
+ HostResponse r;
+ r.value = value;
+ resp_.Push(r);
+}
+
+void HostDevice::WaitPopResponse(HostResponse* r) {
+ resp_.WaitPop(r);
+}
+
+void HostDevice::Exit() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ exit_ = 1;
+}
+
+uint8_t HostDevice::GetExitStatus() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ return exit_;
+}
+
+void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (opcode == 1) {
+ wlen_ = len + 1;
+ waddr_ = reinterpret_cast<uint64_t*>(addr);
+ } else {
+ rlen_ = len + 1;
+ raddr_ = reinterpret_cast<uint64_t*>(addr);
+ }
+}
+
+MemResponse MemDevice::ReadData(uint8_t ready) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ MemResponse r;
+ r.valid = rlen_ > 0;
+ r.value = rlen_ > 0 ? *raddr_ : 0xdeadbeefdeadbeef;
+ if (ready == 1 && rlen_ > 0) {
+ raddr_++;
+ rlen_ -= 1;
+ }
+ return r;
+}
+
+void MemDevice::WriteData(uint64_t value) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (wlen_ > 0) {
+ *waddr_ = value;
+ waddr_++;
+ wlen_ -= 1;
+ }
+}
+
+class DPIModule final : public DPIModuleNode {
+ public:
+ ~DPIModule() {
+ if (lib_handle_) Unload();
+ }
+
+ const char* type_key() const final {
+ return "vta-tsim";
+ }
+
+ PackedFunc GetFunction(
+ const std::string& name,
+ const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+ if (name == "WriteReg") {
+ return TypedPackedFunc<void(int, int)>(
+ [this](int addr, int value){
+ this->WriteReg(addr, value);
+ });
+ } else {
+ LOG(FATAL) << "Member " << name << "does not exists";
+ return nullptr;
+ }
+ }
+
+ void Init(const std::string& name) {
+ Load(name);
+ VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>(
+ GetSymbol("VTADPIInit"));
+ CHECK(finit != nullptr);
+ finit(this, VTAHostDPI, VTAMemDPI);
+ fvsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
+ CHECK(fvsim_ != nullptr);
+ }
+
+ void Launch(uint64_t max_cycles) {
+ auto frun = [this, max_cycles]() {
+ (*fvsim_)(max_cycles);
+ };
+ vsim_thread_ = std::thread(frun);
+ }
+
+ void WriteReg(int addr, uint32_t value) {
+ host_device_.PushRequest(1, addr, value);
+ }
+
+ uint32_t ReadReg(int addr) {
+ uint32_t value;
+ HostResponse* r = new HostResponse;
+ host_device_.PushRequest(0, addr, 0);
+ host_device_.WaitPopResponse(r);
+ value = r->value;
+ delete r;
+ return value;
+ }
+
+ void Finish() {
+ host_device_.Exit();
+ vsim_thread_.join();
+ }
+
+ protected:
+ VTADPISimFunc fvsim_;
+ HostDevice host_device_;
+ MemDevice mem_device_;
+ std::thread vsim_thread_;
+
+ void HostDPI(dpi8_t* exit,
+ dpi8_t* req_valid,
+ dpi8_t* req_opcode,
+ dpi8_t* req_addr,
+ dpi32_t* req_value,
+ dpi8_t req_deq,
+ dpi8_t resp_valid,
+ dpi32_t resp_value) {
+ HostRequest* r = new HostRequest;
+ *exit = host_device_.GetExitStatus();
+ *req_valid = host_device_.TryPopRequest(r, req_deq);
+ *req_opcode = r->opcode;
+ *req_addr = r->addr;
+ *req_value = r->value;
+ if (resp_valid) {
+ host_device_.PushResponse(resp_value);
+ }
+ delete r;
+ }
+
+ void MemDPI(
+ dpi8_t req_valid,
+ dpi8_t req_opcode,
+ dpi8_t req_len,
+ dpi64_t req_addr,
+ dpi8_t wr_valid,
+ dpi64_t wr_value,
+ dpi8_t* rd_valid,
+ dpi64_t* rd_value,
+ dpi8_t rd_ready) {
+ MemResponse r = mem_device_.ReadData(rd_ready);
+ *rd_valid = r.valid;
+ *rd_value = r.value;
+ if (wr_valid) {
+ mem_device_.WriteData(wr_value);
+ }
+ if (req_valid) {
+ mem_device_.SetRequest(req_opcode, req_addr, req_len);
+ }
+ }
+
+ static void VTAHostDPI(
+ VTAContextHandle self,
+ dpi8_t* exit,
+ dpi8_t* req_valid,
+ dpi8_t* req_opcode,
+ dpi8_t* req_addr,
+ dpi32_t* req_value,
+ dpi8_t req_deq,
+ dpi8_t resp_valid,
+ dpi32_t resp_value) {
+ static_cast<DPIModule*>(self)->HostDPI(
+ exit, req_valid, req_opcode, req_addr,
+ req_value, req_deq, resp_valid, resp_value);
+ }
+
+ static void VTAMemDPI(
+ VTAContextHandle self,
+ dpi8_t req_valid,
+ dpi8_t req_opcode,
+ dpi8_t req_len,
+ dpi64_t req_addr,
+ dpi8_t wr_valid,
+ dpi64_t wr_value,
+ dpi8_t* rd_valid,
+ dpi64_t* rd_value,
+ dpi8_t rd_ready) {
+ static_cast<DPIModule*>(self)->MemDPI(
+ req_valid, req_opcode, req_len,
+ req_addr, wr_valid, wr_value,
+ rd_valid, rd_value, rd_ready);
+ }
+
+ private:
+ // Platform dependent handling.
+#if defined(_WIN32)
+ // library handle
+ HMODULE lib_handle_{nullptr};
+ // Load the library
+ void Load(const std::string& name) {
+ // use wstring version that is needed by LLVM.
+ std::wstring wname(name.begin(), name.end());
+ lib_handle_ = LoadLibraryW(wname.c_str());
+ CHECK(lib_handle_ != nullptr)
+ << "Failed to load dynamic shared library " << name;
+ }
+ void* GetSymbol(const char* name) {
+ return reinterpret_cast<void*>(
+ GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
+ }
+ void Unload() {
+ FreeLibrary(lib_handle_);
+ }
+#else
+ // Library handle
+ void* lib_handle_{nullptr};
+ // load the library
+ void Load(const std::string& name) {
+ lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
+ CHECK(lib_handle_ != nullptr)
+ << "Failed to load dynamic shared library " << name
+ << " " << dlerror();
+ }
+ void* GetSymbol(const char* name) {
+ return dlsym(lib_handle_, name);
+ }
+ void Unload() {
+ dlclose(lib_handle_);
+ }
+#endif
+};
+
+Module DPIModuleNode::Load(std::string dll_name) {
+ std::shared_ptr<DPIModule> n =
+ std::make_shared<DPIModule>();
+ n->Init(dll_name);
+ return Module(n);
+}
+
+TVM_REGISTER_GLOBAL("module.loadfile_vta-tsim")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = DPIModuleNode::Load(args[0]);
+ });
+} // namespace dpi
+} // namespace vta