--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+cmake_minimum_required(VERSION 3.2)
+project(tsim C CXX)
+
+set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
+set(VTA_DIR ${TVM_DIR}/vta)
+
+include_directories("${TVM_DIR}/include")
+include_directories("${TVM_DIR}/3rdparty/dlpack/include")
+include_directories("${TVM_DIR}/3rdparty/dmlc-core/include")
+include_directories("${TVM_DIR}/vta/src/dpi")
+
+set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden")
+set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11")
+
+if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND
+ CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
+ set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}")
+endif()
+
+file(GLOB TSIM_SW_SRC src/driver.cc)
+list(APPEND TSIM_SW_SRC ${VTA_DIR}/src/vmem/virtual_memory.cc)
+list(APPEND TSIM_SW_SRC ${VTA_DIR}/src/dpi/module.cc)
+
+add_library(sw SHARED ${TSIM_SW_SRC})
+target_include_directories(sw PRIVATE ${VTA_DIR}/include ${VTA_DIR}/src)
+
+if(APPLE)
+ set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
+endif(APPLE)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH)
+
+BUILD_NAME = build
+build_dir = $(abspath .)/$(BUILD_NAME)
+
+default: chisel driver
+ python3 tests/python/chisel_accel.py serial
+
+serial:
+ python3 tests/python/chisel_accel.py serial
+
+parallel:
+ python3 tests/python/chisel_accel.py parallel
+
+driver: | $(build_dir)
+ cd $(build_dir) && cmake .. && make
+
+$(build_dir):
+ mkdir -p $@
+
+chisel:
+ make -C hardware/chisel
+
+clean:
+ -rm -rf $(build_dir)
+ make -C hardware/chisel clean
--- /dev/null
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+VTA TSIM Application
+======================
+Prior to this application, please take a look at `<tvm-root>/vta/apps/tsim_example` for installation
+This is an application that performs Bit Serial Multiplication for GEMM utilizing TSIM.
+
+**Bit Serial Multiplication for GEMM:**
+
+General Matrix Multiplications (GEMM), are mostly calculated by repeatly calculating the dot product for each pair of vectors.
+The dot product is calculated by summing every product of the vector pair.
+We approach this operation with slicing and shifting, like how basic multiplication works, each vector elements before we accumulate them.
+We can sufficiently reduce the cycles required to perform a gemm given that the data bit width is small. This GEMM application uses TSIM for future accerlerator prototypes.
+
+* Test Chisel3 backend with bit serial GEMM
+ * Go to `<tvm-root>/vta/apps/gemm`
+ * Run `make`
+
+* If you have already compiled chisel backend (i.e. ran `make`)
+ * Bit Serial test with another input set, run `make serial`
+ * Bit parallel test with another input set, run `make parallel`
+
+* Some steps for creating your own custom TSIM application
+ * Go to `<tvm-root>/vta/apps/gemm`
+ * Create custom circuit within `./hardware/chisel/src/scala.main/accel/Compute.scala`
+ * Map the according Registers in `./hardware/chisel/src/scala.main/accel/RegFile.scala`
+ * Create your test script
+ * Map the registers in `./src/driver.cc` and link it with both `RegFile.scala` and the test script
+ * Understanding of `<tvm-root>/vta/apps/tsim_example`, which performs add by one to a vector, is highly encouraged to create a more complex application
+
+* Some pointers
+ * Chisel3 tests in `<tvm-root>/vta/apps/gemm/tests/python`
+ * Chisel3 accelerator backend `<tvm-root>/vta/apps/gemm/hardware/chisel`
+ * Software C++ driver (backend) that handles the accelerator `<tvm-root>/vta/apps/gemm/src/driver.cc`
+ * Software Python driver (frontend) that handles the accelerator `<tvm-root>/vta/apps/gemm/python/accel`
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+ifeq (, $(shell which verilator))
+ $(error "No Verilator in $(PATH), consider doing apt-get install verilator")
+endif
+
+# Change VERILATOR_INC_DIR if Verilator is installed on a different location
+ifeq (, $(VERILATOR_INC_DIR))
+ ifeq (, $(wildcard /usr/local/share/verilator/include/*))
+ ifeq (, $(wildcard /usr/share/verilator/include/*))
+ $(error "Verilator include directory is not set properly")
+ else
+ VERILATOR_INC_DIR := /usr/share/verilator/include
+ endif
+ else
+ VERILATOR_INC_DIR := /usr/local/share/verilator/include
+ endif
+endif
+
+TOP = TestAccel
+BUILD_NAME = build
+USE_TRACE = 1
+LIBNAME = libhw
+
+vta_dir = $(abspath ../../../../)
+tvm_dir = $(abspath ../../../../../)
+build_dir = $(abspath .)/$(BUILD_NAME)
+verilator_build_dir = $(build_dir)/verilator
+chisel_build_dir = $(build_dir)/chisel
+
+verilator_opt = --cc
+verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
+verilator_opt += +define+RANDOMIZE_REG_INIT
+verilator_opt += +define+RANDOMIZE_MEM_INIT
+verilator_opt += --x-assign unique
+verilator_opt += --output-split 20000
+verilator_opt += --output-split-cfuncs 20000
+verilator_opt += --top-module ${TOP}
+verilator_opt += -Mdir ${verilator_build_dir}
+verilator_opt += -I$(chisel_build_dir)
+
+cxx_flags = -O2 -Wall -fPIC -shared
+cxx_flags += -fvisibility=hidden -std=c++11
+cxx_flags += -DVL_TSIM_NAME=V$(TOP)
+cxx_flags += -DVL_PRINTF=printf
+cxx_flags += -DVL_USER_FINISH
+cxx_flags += -DVM_COVERAGE=0
+cxx_flags += -DVM_SC=0
+cxx_flags += -Wno-sign-compare
+cxx_flags += -include V$(TOP).h
+cxx_flags += -I$(verilator_build_dir)
+cxx_flags += -I$(VERILATOR_INC_DIR)
+cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
+cxx_flags += -I$(vta_dir)/include
+cxx_flags += -I$(tvm_dir)/include
+cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
+
+cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp
+cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp
+cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
+cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
+
+ifneq ($(USE_TRACE), 0)
+ verilator_opt += --trace
+ cxx_flags += -DVM_TRACE=1
+ cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP).vcd
+ cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp
+else
+ cxx_flags += -DVM_TRACE=0
+endif
+
+# The following is to be consistent with cmake
+ifeq ($(shell uname), Darwin)
+ lib_path = $(build_dir)/$(LIBNAME).dylib
+else
+ lib_path = $(build_dir)/$(LIBNAME).so
+endif
+
+default: lib
+
+lib: $(lib_path)
+$(lib_path): $(verilator_build_dir)/V$(TOP).cpp
+ g++ $(cxx_flags) $(cxx_files) -o $@
+
+verilator: $(verilator_build_dir)/V$(TOP).cpp
+$(verilator_build_dir)/V$(TOP).cpp: $(chisel_build_dir)/$(TOP).v
+ verilator $(verilator_opt) $<
+
+verilog: $(chisel_build_dir)/$(TOP).v
+$(chisel_build_dir)/$(TOP).v: install_vta_package
+ sbt 'test:runMain test.Elaborate --target-dir $(chisel_build_dir) --top-name $(TOP)'
+
+install_vta_package:
+ cd $(vta_dir)/hardware/chisel && sbt publishLocal
+
+clean:
+ -rm -rf $(build_dir) target project/target project/project
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+name := "accel"
+version := "0.1.0-SNAPSHOT"
+organization := "edu.washington.cs"
+
+def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
+ Seq() ++ {
+ // If we're building with Scala > 2.11, enable the compile option
+ // switch to support our anonymous Bundle definitions:
+ // https://github.com/scala/bug/issues/10047
+ CrossVersion.partialVersion(scalaVersion) match {
+ case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
+ case _ => Seq(
+ "-Xsource:2.11",
+ "-language:reflectiveCalls",
+ "-language:implicitConversions",
+ "-deprecation",
+ "-Xlint",
+ "-Ywarn-unused",
+ )
+ }
+ }
+}
+
+def javacOptionsVersion(scalaVersion: String): Seq[String] = {
+ Seq() ++ {
+ // Scala 2.12 requires Java 8. We continue to generate
+ // Java 7 compatible code for Scala 2.11
+ // for compatibility with old clients.
+ CrossVersion.partialVersion(scalaVersion) match {
+ case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
+ Seq("-source", "1.7", "-target", "1.7")
+ case _ =>
+ Seq("-source", "1.8", "-target", "1.8")
+ }
+ }
+}
+
+scalaVersion := "2.11.12"
+
+resolvers ++= Seq(
+ Resolver.sonatypeRepo("snapshots"),
+ Resolver.sonatypeRepo("releases"))
+
+libraryDependencies ++= Seq(
+ "edu.berkeley.cs" %% "chisel3" % "3.1.7",
+ "edu.washington.cs" %% "vta" % "0.1.0-SNAPSHOT",
+)
+
+scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
+javacOptions ++= javacOptionsVersion(scalaVersion.value)
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+sbt.version = 1.1.1
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+logLevel := Level.Warn
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package accel
+
+import chisel3._
+import vta.dpi._
+
+/** Add-by-one accelerator.
+ *
+ * ___________ ___________
+ * | | | |
+ * | HostDPI | <--> | RegFile | <->|
+ * |_________| |_________| |
+ * |
+ * ___________ ___________ |
+ * | | | | |
+ * | MemDPI | <--> | Compute | <->|
+ * |_________| |_________|
+ *
+ */
+case class AccelConfig() {
+ val nCtrl = 1
+ val nECnt = 1
+ val nVals = 4
+ val nPtrs = 3
+ val regBits = 32
+ val ptrBits = 2*regBits
+}
+
+class Accel extends Module {
+ val io = IO(new Bundle {
+ val host = new VTAHostDPIClient
+ val mem = new VTAMemDPIMaster
+ })
+ implicit val config = AccelConfig()
+ val rf = Module(new RegFile)
+ val ce = Module(new Compute)
+ rf.io.host <> io.host
+ io.mem <> ce.io.mem
+ ce.io.launch := rf.io.launch
+ rf.io.finish := ce.io.finish
+ rf.io.ecnt <> ce.io.ecnt
+ ce.io.vals <> rf.io.vals
+ ce.io.ptrs <> rf.io.ptrs
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package accel
+
+import chisel3._
+import chisel3.util._
+import vta.dpi._
+
+/** Compute
+ *
+ * Bit Slice GEMM:
+ *
+ * 1. Wait for launch to be asserted
+ * 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
+ * 3. Wait for the value
+ * 4. Increment read-address for next value
+ * 5. Wait for sliced accumulator
+ * 6. Check if counter (cnt) is equal to length process,
+ otherwise goto step 2
+ * 7. Check if reset slice accumulator
+ * 8. Wait for overall accumulator
+ * 8. Issue a write request for 8-byte value at out_baddr address
+ */
+class Compute(implicit config: AccelConfig) extends Module {
+ val io = IO(new Bundle {
+ val launch = Input(Bool())
+ val finish = Output(Bool())
+ val ecnt = Vec(config.nECnt, ValidIO(UInt(config.regBits.W)))
+ val vals = Input(Vec(config.nVals, UInt(config.regBits.W)))
+ val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
+ val mem = new VTAMemDPIMaster
+ })
+ val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
+ val state = RegInit(sIdle)
+ val shift = io.vals(0)
+ val length = io.vals(1)
+ val rstAccum = io.vals(2)
+ val startDot = io.vals(3)
+ val cycles = RegInit(0.U(config.regBits.W))
+ val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
+ val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
+ val cnt = Reg(UInt(config.regBits.W))
+ val raddr1 = Reg(UInt(config.ptrBits.W))
+ val raddr2 = Reg(UInt(config.ptrBits.W))
+ val waddr = Reg(UInt(config.ptrBits.W))
+
+ switch (state) {
+ is (sIdle) {
+ when (io.launch) {
+ state := sReadAReq
+ }
+ }
+ // Read
+ is (sReadAReq) {
+ state := sReadAData
+ }
+ is (sReadAData) {
+ when (io.mem.rd.valid) {
+ state := sReadBReq
+ }
+ }
+ is (sReadBReq) {
+ state := sReadBData
+ }
+ is (sReadBData) {
+ when (io.mem.rd.valid) {
+ state := sWriteReq
+ }
+ }
+ // Write
+ is (sWriteReq) {
+ state := sWriteData
+ }
+ is (sWriteData) {
+ when (cnt === (length - 1.U)) {
+ state := sIdle
+ } .otherwise {
+ state := sReadAReq
+ }
+ }
+ }
+
+ val last = state === sWriteData && cnt === (length - 1.U)
+
+ // cycle counter
+ when (state === sIdle) {
+ cycles := 0.U
+ } .otherwise {
+ cycles := cycles + 1.U
+ }
+
+ io.ecnt(0).valid := last
+ io.ecnt(0).bits := cycles
+
+ // calculate next address
+ when (state === sIdle) {
+ raddr1 := io.ptrs(0)
+ raddr2 := io.ptrs(1)
+ waddr := io.ptrs(2)
+ } .elsewhen (state === sWriteData) { // increment input array by 1-byte
+ raddr1 := raddr1 + 1.U
+ raddr2 := raddr2 + 1.U
+ waddr := waddr
+ }
+
+ // create request
+ io.mem.req.valid := state === sReadAReq | state === sReadBReq | state === sWriteReq
+ io.mem.req.opcode := state === sWriteReq
+ io.mem.req.len := 0.U // one-word-per-request
+ io.mem.req.addr := Mux(state === sReadAReq | state === sReadBReq, Mux(state === sReadAReq, raddr1, raddr2), waddr)
+
+ // read
+ when (state === sReadAData && io.mem.rd.valid) {
+ reg1 := io.mem.rd.bits(7, 0)
+ }
+
+ when (state === sReadBData && io.mem.rd.valid) {
+ reg2 := io.mem.rd.bits(7, 0)
+ }
+
+ io.mem.rd.ready := state === sReadAData | state === sReadBData
+
+
+ val sliceAccum = Module(new Accumulator(63))
+ val overallAccum = Module(new Accumulator(64))
+
+ sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
+ sliceAccum.io.in := reg1 * reg2
+ sliceAccum.io.clear := startDot
+ overallAccum.io.clear := rstAccum
+ overallAccum.io.valid := last // last element has been processed
+ overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
+
+ // write
+ io.mem.wr.valid := overallAccum.io.ready
+ io.mem.wr.bits := overallAccum.io.sum
+
+
+ // count read/write
+ when (state === sIdle) {
+ cnt := 0.U
+ } .elsewhen (state === sWriteData) {
+ cnt := cnt + 1.U
+ }
+
+ io.finish := overallAccum.io.ready // data has been added
+}
+
+
+class Accumulator(dataBits: Int = 8) extends Module {
+ val io = IO(new Bundle {
+ val clear = Input(Bool())
+ val valid = Input(Bool())
+ val ready = Output(Bool())
+ val in = Input(UInt(dataBits.W))
+ val sum = Output(UInt((dataBits).W))
+ })
+
+ val reg = RegInit(0.U((dataBits).W))
+ val ready = RegNext(io.valid)
+ when (io.clear) {
+ reg := 0.U
+ } .elsewhen (io.valid) {
+ reg := reg + io.in
+ }
+ io.ready := ready
+ io.sum := reg
+}
+
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package accel
+
+import chisel3._
+import chisel3.util._
+import vta.dpi._
+
+/** Register File.
+ *
+ * Six 32-bit register file.
+ *
+ * -------------------------------
+ * Register description | addr
+ * -------------------------|-----
+ * Control status register | 0x00
+ * Cycle counter | 0x04
+ * Shift value | 0x08
+ * Vector length | 0x0c
+ * Reset Accumulator | 0x10
+ * Reset Dot Module | 0x14
+ * Input1 pointer lsb | 0x18
+ * Input1 pointer msb | 0x1c
+ * Input2 pointer lsb | 0x20
+ * Input2 pointer msb | 0x24
+ * Output pointer lsb | 0x28
+ * Output pointer msb | 0x2c
+ * -------------------------------
+
+ * ------------------------------
+ * Control status register | bit
+ * ------------------------------
+ * Launch | 0
+ * Finish | 1
+ * ------------------------------
+ */
+class RegFile(implicit config: AccelConfig) extends Module {
+ val io = IO(new Bundle {
+ val launch = Output(Bool())
+ val finish = Input(Bool())
+ val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W))))
+ val vals = Output(Vec(config.nVals, UInt(config.regBits.W)))
+ val ptrs = Output(Vec(config.nPtrs, UInt(config.ptrBits.W)))
+ val host = new VTAHostDPIClient
+ })
+ val sIdle :: sRead :: Nil = Enum(2)
+ val state = RegInit(sIdle)
+
+ switch (state) {
+ is (sIdle) {
+ when (io.host.req.valid && !io.host.req.opcode) {
+ state := sRead
+ }
+ }
+ is (sRead) {
+ state := sIdle
+ }
+ }
+
+ io.host.req.deq := state === sIdle & io.host.req.valid
+
+ val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs)
+ val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
+ val addr = Seq.tabulate(nTotal)(_ * 4)
+ val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
+ val eo = config.nCtrl
+ val vo = eo + config.nECnt
+ val po = vo + config.nVals
+
+ when (io.finish) {
+ reg(0) := "b_10".U
+ } .elsewhen (state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && addr(0).U === io.host.req.addr) {
+ reg(0) := io.host.req.value
+ }
+
+ for (i <- 0 until config.nECnt) {
+ when (io.ecnt(i).valid) {
+ reg(eo + i) := io.ecnt(i).bits
+ } .elsewhen (state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
+ reg(eo + i) := io.host.req.value
+ }
+ }
+
+ for (i <- 0 until (config.nVals + (2*config.nPtrs))) {
+ when (state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
+ reg(vo + i) := io.host.req.value
+ }
+ }
+
+ val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
+ when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
+ rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
+ }
+
+ io.host.resp.valid := state === sRead
+ io.host.resp.bits := rdata
+
+ io.launch := reg(0)(0)
+
+ for (i <- 0 until config.nVals) {
+ io.vals(i) := reg(vo + i)
+ }
+
+ for (i <- 0 until config.nPtrs) {
+ io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package test
+
+import chisel3._
+import chisel3.experimental.MultiIOModule
+import vta.dpi._
+import accel._
+
+/** VTA simulation shell.
+ *
+ * Instantiate Host and Memory DPI modules.
+ *
+ */
+class VTASimShell extends MultiIOModule {
+ val host = IO(new VTAHostDPIMaster)
+ val mem = IO(new VTAMemDPIClient)
+ val sim_clock = IO(Input(Clock()))
+ val sim_wait = IO(Output(Bool()))
+ val mod_sim = Module(new VTASimDPI)
+ val mod_host = Module(new VTAHostDPI)
+ val mod_mem = Module(new VTAMemDPI)
+ mod_mem.io.clock := clock
+ mod_mem.io.reset := reset
+ mod_mem.io.dpi <> mem
+ mod_host.io.clock := clock
+ mod_host.io.reset := reset
+ host <> mod_host.io.dpi
+ mod_sim.io.clock := sim_clock
+ mod_sim.io.reset := reset
+ sim_wait := mod_sim.io.dpi_wait
+}
+
+/** Test accelerator.
+ *
+ * Instantiate and connect the simulation-shell and the accelerator.
+ *
+ */
+class TestAccel extends MultiIOModule {
+ val sim_clock = IO(Input(Clock()))
+ val sim_wait = IO(Output(Bool()))
+ val sim_shell = Module(new VTASimShell)
+ val vta_accel = Module(new Accel)
+ sim_shell.sim_clock := sim_clock
+ sim_wait := sim_shell.sim_wait
+ sim_shell.mem <> vta_accel.io.mem
+ vta_accel.io.host <> sim_shell.host
+}
+
+/** Generate TestAccel as top module */
+object Elaborate extends App {
+ chisel3.Driver.execute(args, () => new TestAccel)
+}
--- /dev/null
+from . import tsim
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import ctypes
+import os.path as osp
+from sys import platform
+
+def get_ext():
+ """Return shared library extension"""
+ return ".dylib" if platform == "darwin" else ".so"
+
+def load_dll(dll):
+ """Load shared library
+
+ Parameters
+ ------------
+ dll : str
+ Path for shared library
+
+ Returns
+ ------------
+ The shared library
+ """
+ try:
+ return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
+ except OSError:
+ return []
+
+def load_sw():
+ """Load all software shared libraries"""
+ cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+ sw_libname = "libsw" + get_ext()
+ sw_lib = osp.join(cur_path, "..", "build", sw_libname)
+ load_dll(sw_lib)
+
+def init(hw_backend):
+ """Init hardware and software shared library for accelerator
+
+ Parameters
+ ------------
+ hw_backend : str
+ Hardware backend can be verilog or chisel
+
+ """
+ cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
+ hw_libname = "libhw" + get_ext()
+ if hw_backend in ("verilog", "chisel"):
+ hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
+ load_sw()
+ m = tvm.module.load(hw_lib, "vta-tsim")
+ f = tvm.get_global_func("tvm.vta.tsim.init")
+ f(m)
+
+def load_module():
+ """Return driver function"""
+ load_sw()
+ return tvm.get_global_func("tvm.vta.driver")
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <vta/dpi/module.h>
+
+#include "vmem/virtual_memory.h"
+
+namespace vta {
+namespace driver {
+
+using vta::dpi::DPIModuleNode;
+using tvm::runtime::Module;
+
+class DPILoader {
+ public:
+ ~DPILoader() {
+ dpi_->SimResume();
+ dpi_->SimFinish();
+ }
+
+ void Init(Module module) {
+ mod_ = module;
+ dpi_ = this->Get();
+ dpi_->SimLaunch();
+ dpi_->SimWait();
+ }
+
+ DPIModuleNode* Get() {
+ return static_cast<DPIModuleNode*>(mod_.operator->());
+ }
+
+ static DPILoader* Global() {
+ static DPILoader inst;
+ return &inst;
+ }
+
+ // TVM module
+ Module mod_;
+ // DPI Module
+ DPIModuleNode* dpi_{nullptr};
+};
+
+class Device {
+ public:
+ Device() {
+ loader_ = DPILoader::Global();
+ }
+
+ uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
+ uint32_t cycles;
+ uint32_t length = inp1->shape[0];
+ size_t size1 = (inp1->dtype.bits >> 3) * length;
+ size_t size2 = (inp2->dtype.bits >> 3) * length;
+ size_t size3 = (64 >> 3);
+ inp1_ = this->MemAlloc(size1);
+ inp2_ = this->MemAlloc(size2);
+ out_ = this->MemAlloc(size3);
+ this->MemCopyFromHost(inp1_, inp1->data, size1);
+ this->MemCopyFromHost(inp2_, inp2->data, size2);
+ this->Init();
+ this->Launch(length, shiftVal, reset);
+ cycles = this->WaitForCompletion();
+ this->MemCopyToHost(out->data, out_, size3);
+ this->MemFree(inp1_);
+ this->MemFree(inp2_);
+ this->MemFree(out_);
+ return cycles;
+ }
+
+ private:
+ void Init() {
+ dpi_ = loader_->Get();
+ dpi_->SimResume();
+ }
+
+ void* MemAlloc(size_t size) {
+ void * addr = vta::vmem::VirtualMemoryManager::Global()->Alloc(size);
+ return reinterpret_cast<void*>(vta::vmem::VirtualMemoryManager::Global()->GetPhyAddr(addr));
+ }
+
+ void MemFree(void* buf) {
+ void * addr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(reinterpret_cast<uint64_t>(buf));
+ vta::vmem::VirtualMemoryManager::Global()->Free(addr);
+ }
+
+ vta_phy_addr_t MemGetPhyAddr(void* buf) {
+ return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
+ }
+
+ void MemCopyFromHost(void* dst, const void* src, size_t size) {
+ vta::vmem::VirtualMemoryManager::Global()->MemCopyFromHost(dst, src, size);
+ }
+
+ void MemCopyToHost(void* dst, const void* src, size_t size) {
+ vta::vmem::VirtualMemoryManager::Global()->MemCopyToHost(dst, src, size);
+ }
+
+ void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
+ dpi_->WriteReg(0x08, shiftVal);
+ dpi_->WriteReg(0x0c, length); // vector length
+ dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
+ dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
+ dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
+ dpi_->WriteReg(0x00, 0x1); // launch
+ dpi_->WriteReg(0x00, 0x0); // launch
+
+ if (reset == 1) {
+ dpi_->WriteReg(0x10, 0x1); // reset accum
+ dpi_->WriteReg(0x10, 0x0); // stop reset accum
+ }
+ dpi_->WriteReg(0x14, 0x1); // reset dot
+ dpi_->WriteReg(0x14, 0x0); // stop reset dot
+ }
+
+ uint32_t WaitForCompletion() {
+ uint32_t i, val;
+ for (i = 0; i < wait_cycles_; i++) {
+ val = dpi_->ReadReg(0x00);
+ if (val == 2) break; // finish
+ }
+ val = dpi_->ReadReg(0x04);
+ dpi_->SimWait();
+ return val;
+ }
+
+ // wait cycles
+ uint32_t wait_cycles_{100000000};
+ // DPI loader
+ DPILoader* loader_{nullptr};
+ // DPI Module
+ DPIModuleNode* dpi_{nullptr};
+ // input vm ptr
+ void* inp1_{nullptr};
+ void* inp2_{nullptr};
+ // output vm ptr
+ void* out_{nullptr};
+};
+
+using tvm::runtime::TVMRetValue;
+using tvm::runtime::TVMArgs;
+
+TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ Module m = args[0];
+ DPILoader::Global()->Init(m);
+ });
+
+TVM_REGISTER_GLOBAL("tvm.vta.driver")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ DLTensor* A = args[0];
+ DLTensor* B = args[1];
+ DLTensor* C = args[3];
+ Device dev_;
+ uint32_t cycles = dev_.Run(A, B, static_cast<int>(args[2]), C, static_cast<int>(args[4]));
+ *rv = static_cast<int>(cycles);
+ });
+
+} // namespace driver
+} // namespace vta
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+import tsim
+import sys
+
+""" Vector Bit Slice and Pack Function
+Parameters
+----------
+A : Vector to be sliced and packed
+slice_width : slice width
+
+Returnsi
+---------
+C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A
+"""
+def slice(A, slice_width):
+ assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
+ dtype = type(A[0])
+ row = 0
+ # currently only supports uint
+ if dtype is np.uint8: row = 8 // slice_width
+ elif dtype is np.uint16: row = 16 // slice_width
+ elif dtype is np.uint32: row = 32 // slice_width
+ elif dtype is np.uint64: row = 64 // slice_width
+ else: raise ValueError("datatype " + str(dtype) + "currently not supported")
+ if (row >= 8):
+ dtype = 'uint' + str(row)
+ else:
+ dtype = 'uint8'
+
+ C = np.zeros((row, len(A))).astype(dtype) # sliced and transform
+
+ # create mask
+ slice_mask = 2**(slice_width)-1
+ # slice and pack
+ for x in range(len(A)):
+ for y in range(row):
+ C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask)
+ return C
+
+""" Matrix Multiplication Function
+Parameters
+----------
+A : Matrix A
+B: Matrix B
+w_width : weight slice width
+a_width : activation slice width
+
+Returns
+---------
+C: result of A * B
+"""
+# A is a n*m matrix, B is a m*p matrix(not transposed yet)
+def matrix_multiply(A, B, w_width, a_width):
+ assert A.shape[1] == B.shape[0], "can't perform multiplication"
+ BT = B.transpose()
+ cycles = 0
+ C = np.zeros((A.shape[0], B.shape[1])).astype('uint64')
+ for i in range(A.shape[0]):
+ for j in range(B.shape[1]):
+ # C[i, j] = A[i].dot(BT[j])
+ A_sliced = slice(A[i], w_width)
+ B_sliced = slice(BT[j], a_width)
+
+ C[i, j] = compute(A_sliced, B_sliced, w_width, a_width)
+ test = test_accel(A_sliced, B_sliced, w_width, a_width)
+ cycles += test[1]
+ np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j]))
+ print("PASS SW serial & parallel")
+
+ np.testing.assert_equal(test[0], C[i, j])
+ print("PASS SW & HW bit serial")
+
+ np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j]))
+ print("PASS SW bit parallel & HW bit parallel")
+
+ print("result: ")
+ print(C)
+ print("ALL TESTS PASSED, cycles: " + str(cycles))
+ return C
+
+""" Software Verification Function"""
+# takes 2 matrix input (sliced and packed)
+def compute(A, B, w_width, a_width):
+ assert A.shape[1] == B.shape[1], "sliced shape not match"
+ # reset hardware accumulator
+ accum = 0
+ for x in range(A.shape[0]):
+ for y in range(B.shape[0]):
+ # hardware implementation
+ accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width)
+ # get value from accumulator
+ return accum
+
+"""Testing Function for Dot Product"""
+def test_accel(A, B, w_width, a_width):
+ assert A.shape[1] == B.shape[1], "sliced shape not match"
+
+ dtype = A.dtype
+ ctx = tvm.cpu(0)
+ f = tsim.load_module()
+
+ a_arr = []
+ b_arr = []
+ for i in range(A.shape[0]):
+ list_a = np.zeros(A.shape[1]).astype(dtype)
+ for j in range(A.shape[1]):
+ list_a[j] = A[i][j]
+ a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx))
+
+ for i in range(B.shape[0]):
+ list_b = np.zeros(B.shape[1]).astype(dtype)
+ for j in range(B.shape[1]):
+ list_b[j] = B[i][j]
+ b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx))
+
+ cycles = 0
+
+ accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx)
+ for i in range(len(a_arr)):
+ for j in range(len(b_arr)):
+ shift = np.uint8(i*w_width + j*a_width)
+ if i == 0 and j == 0:
+ cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator
+ else:
+ cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset
+
+ return (accum.asnumpy()[0], cycles)
+
+""" Matrix Generator
+Parameters
+----------
+dtype : String, datatype generated (supports only uint)
+w_width : weight bit slices(needs to be less than actual bit width)
+a_width : activation bit slices(needs to be less than actual bit width)
+"""
+def top_test(dtype, w_width, a_width):
+
+ rmax = np.random.randint(256)
+ # random matrix generation (dimension up to 8)
+ rrow = np.random.randint(7) + 1
+ rclmn = np.random.randint(7) + 1
+ rrow2 = np.random.randint(7) + 1
+ A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype)
+ B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype)
+
+ print("A: ")
+ print(A)
+ print("\n")
+ print("B: ")
+ print(B)
+ print("\n")
+ matrix_multiply(A, B, w_width, a_width)
+
+
+if __name__ == "__main__":
+ tsim.init("chisel")
+ for i in range(1):
+ # reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits
+ if sys.argv[1] == 'serial':
+ # generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation
+ top_test("uint8",4, 2)
+ elif sys.argv[1] == 'parallel':
+ # generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel)
+ top_test('uint8', 1, 1)