# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)
-# Build TSIM for VTA
-set(USE_VTA_TSIM OFF)
-
# Whether use Relay debug mode
set(USE_RELAY_DEBUG OFF)
--use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json)
endif()
- execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target)
- string(STRIP ${__vta_target} VTA_TARGET)
+ execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "Build VTA runtime with target: " ${VTA_TARGET})
add_library(vta SHARED ${VTA_RUNTIME_SRCS})
+ if(${VTA_TARGET} STREQUAL "tsim")
+ target_compile_definitions(vta PUBLIC USE_TSIM)
+ include_directories("vta/include")
+ file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
+ list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
+ endif()
+
target_include_directories(vta PUBLIC vta/include)
foreach(__def ${VTA_DEFINITIONS})
target_link_libraries(vta ${__cma_lib})
endif()
- if(NOT USE_VTA_TSIM STREQUAL "OFF")
- include_directories("vta/include")
- file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
- list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
- endif()
-
else()
message(STATUS "Cannot found python in env, VTA build is skipped..")
endif()
## Setup in TVM
1. Install `verilator` and `sbt` as described above
-2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake
+2. Set the VTA TARGET to `tsim` on `<tvm-root>/vta/config/vta_config.json`
3. Build tvm
## How to run VTA TSIM examples
file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})
- set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
+ set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
else()
# specific language governing permissions and limitations
# under the License.
+CONFIG = DefaultF1Config
+TOP = VTA
+TOP_TEST = Test
+BUILD_NAME = build
+USE_TRACE = 0
+VTA_LIBNAME = libvta_hw
+
+config_test = $(TOP_TEST)$(CONFIG)
+vta_dir = $(abspath ../../)
+tvm_dir = $(abspath ../../../)
+verilator_inc_dir = /usr/local/share/verilator/include
+verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator
+chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel
+
+verilator_opt = --cc
+verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
+verilator_opt += +define+RANDOMIZE_REG_INIT
+verilator_opt += +define+RANDOMIZE_MEM_INIT
+verilator_opt += --x-assign unique
+verilator_opt += --output-split 20000
+verilator_opt += --output-split-cfuncs 20000
+verilator_opt += --top-module ${TOP_TEST}
+verilator_opt += -Mdir ${verilator_build_dir}
+verilator_opt += -I$(chisel_build_dir)
+
+cxx_flags = -O2 -Wall -fPIC -shared
+cxx_flags += -fvisibility=hidden -std=c++11
+cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST)
+cxx_flags += -DVL_PRINTF=printf
+cxx_flags += -DVL_USER_FINISH
+cxx_flags += -DVM_COVERAGE=0
+cxx_flags += -DVM_SC=0
+cxx_flags += -Wno-sign-compare
+cxx_flags += -include V$(TOP_TEST).h
+cxx_flags += -I$(verilator_build_dir)
+cxx_flags += -I$(verilator_inc_dir)
+cxx_flags += -I$(verilator_inc_dir)/vltstd
+cxx_flags += -I$(vta_dir)/include
+cxx_flags += -I$(tvm_dir)/include
+cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
+
+cxx_files = $(verilator_inc_dir)/verilated.cpp
+cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp
+cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
+cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
+
+ifneq ($(USE_TRACE), 0)
+ verilator_opt += --trace
+ cxx_flags += -DVM_TRACE=1
+ cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd
+ cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp
+else
+ cxx_flags += -DVM_TRACE=0
+endif
+
+default: lib
+
+lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
+$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp
+ g++ $(cxx_flags) $(cxx_files) -o $@
+
+verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp
+$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
+ verilator $(verilator_opt) $<
+
+verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v
+$(chisel_build_dir)/$(TOP).$(CONFIG).v:
+ sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)'
+
+verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
+$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v:
+ sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)'
+
clean:
-rm -rf target project/target project/project
+
+cleanall:
+ -rm -rf $(vta_dir)/$(BUILD_NAME)
always_ff @(posedge clock) begin
if (__exit == 'd1) begin
- $display("[DONE] at cycle:%016d", cycles);
+ $display("[TSIM] Verilog $finish called at cycle:%016d", cycles);
$finish;
end
end
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Compute.
+ *
+ * The compute unit is in charge of the following:
+ * - Loading micro-ops from memory (loadUop module)
+ * - Loading biases (acc) from memory (tensorAcc module)
+ * - Compute ALU instructions (tensorAlu module)
+ * - Compute GEMM instructions (tensorGemm module)
+ */
+class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val i_post = Vec(2, Input(Bool()))
+ val o_post = Vec(2, Output(Bool()))
+ val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
+ val uop_baddr = Input(UInt(mp.addrBits.W))
+ val acc_baddr = Input(UInt(mp.addrBits.W))
+ val vme_rd = Vec(2, new VMEReadMaster)
+ val inp = new TensorMaster(tensorType = "inp")
+ val wgt = new TensorMaster(tensorType = "wgt")
+ val out = new TensorMaster(tensorType = "out")
+ val finish = Output(Bool())
+ })
+ val sIdle :: sSync :: sExe :: Nil = Enum(3)
+ val state = RegInit(sIdle)
+
+ val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
+
+ val loadUop = Module(new LoadUop)
+ val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
+ val tensorGemm = Module(new TensorGemm)
+ val tensorAlu = Module(new TensorAlu)
+
+ val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
+
+ // decode
+ val dec = Module(new ComputeDecode)
+ dec.io.inst := inst_q.io.deq.bits
+
+ val inst_type = Cat(dec.io.isFinish,
+ dec.io.isAlu,
+ dec.io.isGemm,
+ dec.io.isLoadAcc,
+ dec.io.isLoadUop).asUInt
+
+ val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
+ val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
+ val start = snext & sprev
+ val done =
+ MuxLookup(inst_type,
+ false.B, // default
+ Array(
+ "h_01".U -> loadUop.io.done,
+ "h_02".U -> tensorAcc.io.done,
+ "h_04".U -> tensorGemm.io.done,
+ "h_08".U -> tensorAlu.io.done,
+ "h_10".U -> true.B // Finish
+ )
+ )
+
+ // control
+ switch (state) {
+ is (sIdle) {
+ when (start) {
+ when (dec.io.isSync) {
+ state := sSync
+ } .elsewhen (inst_type.orR) {
+ state := sExe
+ }
+ }
+ }
+ is (sSync) {
+ state := sIdle
+ }
+ is (sExe) {
+ when (done) {
+ state := sIdle
+ }
+ }
+ }
+
+ // instructions
+ inst_q.io.enq <> io.inst
+ inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
+
+ // uop
+ loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
+ loadUop.io.inst := inst_q.io.deq.bits
+ loadUop.io.baddr := io.uop_baddr
+ io.vme_rd(0) <> loadUop.io.vme_rd
+ loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
+
+ // acc
+ tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
+ tensorAcc.io.inst := inst_q.io.deq.bits
+ tensorAcc.io.baddr := io.acc_baddr
+ tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
+ tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
+ io.vme_rd(1) <> tensorAcc.io.vme_rd
+
+ // gemm
+ tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
+ tensorGemm.io.inst := inst_q.io.deq.bits
+ tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
+ tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
+ tensorGemm.io.inp <> io.inp
+ tensorGemm.io.wgt <> io.wgt
+ tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
+ tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
+ tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
+ tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
+
+ // alu
+ tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
+ tensorAlu.io.inst := inst_q.io.deq.bits
+ tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
+ tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
+ tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
+ tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
+ tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
+ tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
+
+ // out
+ io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
+ io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
+
+ // semaphore
+ s(0).io.spost := io.i_post(0)
+ s(1).io.spost := io.i_post(1)
+ s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
+ s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
+ io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
+ io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))
+
+ // finish
+ io.finish := state === sExe & done & dec.io.isFinish
+
+ // debug
+ if (debug) {
+ // start
+ when (state === sIdle && start) {
+ when (dec.io.isSync) {
+ printf("[Compute] start sync\n")
+ } .elsewhen (dec.io.isLoadUop) {
+ printf("[Compute] start load uop\n")
+ } .elsewhen (dec.io.isLoadAcc) {
+ printf("[Compute] start load acc\n")
+ } .elsewhen (dec.io.isGemm) {
+ printf("[Compute] start gemm\n")
+ } .elsewhen (dec.io.isAlu) {
+ printf("[Compute] start alu\n")
+ } .elsewhen (dec.io.isFinish) {
+ printf("[Compute] start finish\n")
+ }
+ }
+ // done
+ when (state === sSync) {
+ printf("[Compute] done sync\n")
+ }
+ when (state === sExe) {
+ when (done) {
+ when (dec.io.isLoadUop) {
+ printf("[Compute] done load uop\n")
+ } .elsewhen (dec.io.isLoadAcc) {
+ printf("[Compute] done load acc\n")
+ } .elsewhen (dec.io.isGemm) {
+ printf("[Compute] done gemm\n")
+ } .elsewhen (dec.io.isAlu) {
+ printf("[Compute] done alu\n")
+ } .elsewhen (dec.io.isFinish) {
+ printf("[Compute] done finish\n")
+ }
+ }
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import vta.util.config._
+
+/** CoreConfig.
+ *
+ * This is one supported configuration for VTA. This file will
+ * be eventually filled out with class configurations that can be
+ * mixed/matched with Shell configurations for different backends.
+ */
+class CoreConfig extends Config((site, here, up) => {
+ case CoreKey => CoreParams(
+ batch = 1,
+ blockOut = 16,
+ blockIn = 16,
+ inpBits = 8,
+ wgtBits = 8,
+ uopBits = 32,
+ accBits = 32,
+ outBits = 8,
+ uopMemDepth = 2048,
+ inpMemDepth = 2048,
+ wgtMemDepth = 1024,
+ accMemDepth = 2048,
+ outMemDepth = 2048,
+ instQueueEntries = 512)
+})
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import vta.util.config._
+import vta.shell._
+
+/** Core parameters */
+case class CoreParams (
+ batch: Int = 1,
+ blockOut: Int = 16,
+ blockIn: Int = 16,
+ inpBits: Int = 8,
+ wgtBits: Int = 8,
+ uopBits: Int = 32,
+ accBits: Int = 32,
+ outBits: Int = 8,
+ uopMemDepth: Int = 512,
+ inpMemDepth: Int = 512,
+ wgtMemDepth: Int = 512,
+ accMemDepth: Int = 512,
+ outMemDepth: Int = 512,
+ instQueueEntries: Int = 32
+)
+
+case object CoreKey extends Field[CoreParams]
+
+/** Core.
+ *
+ * The core defines the current VTA architecture by connecting memory and
+ * compute modules together such as load/store and compute. Most of the
+ * connections in the core are bulk (<>), and we should try to keep it this
+ * way, because it is easier to understand what is going on.
+ *
+ * Also, the core must be instantiated by a shell using the
+ * VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces.
+ * More info about these interfaces and modules can be found in the shell
+ * directory.
+ */
+class Core(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val vcr = new VCRClient
+ val vme = new VMEMaster
+ })
+ val fetch = Module(new Fetch)
+ val load = Module(new Load)
+ val compute = Module(new Compute)
+ val store = Module(new Store)
+
+ // Read(rd) and write(wr) from/to memory (i.e. DRAM)
+ io.vme.rd(0) <> fetch.io.vme_rd
+ io.vme.rd(1) <> compute.io.vme_rd(0)
+ io.vme.rd(2) <> load.io.vme_rd(0)
+ io.vme.rd(3) <> load.io.vme_rd(1)
+ io.vme.rd(4) <> compute.io.vme_rd(1)
+ io.vme.wr(0) <> store.io.vme_wr
+
+ // Fetch instructions (tasks) from memory (DRAM) into queues (SRAMs)
+ fetch.io.launch := io.vcr.launch
+ fetch.io.ins_baddr := io.vcr.ptrs(0)
+ fetch.io.ins_count := io.vcr.vals(0)
+
+ // Load inputs and weights from memory (DRAM) into scratchpads (SRAMs)
+ load.io.i_post := compute.io.o_post(0)
+ load.io.inst <> fetch.io.inst.ld
+ load.io.inp_baddr := io.vcr.ptrs(2)
+ load.io.wgt_baddr := io.vcr.ptrs(3)
+
+ // The compute module performs the following:
+ // - Load micro-ops (uops) and accumulations (acc)
+ // - Compute dense and ALU instructions (tasks)
+ compute.io.i_post(0) := load.io.o_post
+ compute.io.i_post(1) := store.io.o_post
+ compute.io.inst <> fetch.io.inst.co
+ compute.io.uop_baddr := io.vcr.ptrs(1)
+ compute.io.acc_baddr := io.vcr.ptrs(4)
+ compute.io.inp <> load.io.inp
+ compute.io.wgt <> load.io.wgt
+
+ // The store module performs the following:
+ // - Writes results from compute into scratchpads (SRAMs)
+ // - Store results from scratchpads (SRAMs) to memory (DRAM)
+ store.io.i_post := compute.io.o_post(1)
+ store.io.inst <> fetch.io.inst.st
+ store.io.out_baddr := io.vcr.ptrs(5)
+ store.io.out <> compute.io.out
+
+ // Finish instruction is executed and asserts the VCR finish flag
+ val finish = RegNext(compute.io.finish)
+ io.vcr.finish := finish
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+
+import ISA._
+
+/** MemDecode.
+ *
+ * Decode memory instructions with a Bundle. This is similar to an union,
+ * therefore order matters when declaring fields. These are the instructions
+ * decoded with this bundle:
+ * - LUOP
+ * - LWGT
+ * - LINP
+ * - LACC
+ * - SOUT
+ */
+class MemDecode extends Bundle {
+ val xpad_1 = UInt(M_PAD_BITS.W)
+ val xpad_0 = UInt(M_PAD_BITS.W)
+ val ypad_1 = UInt(M_PAD_BITS.W)
+ val ypad_0 = UInt(M_PAD_BITS.W)
+ val xstride = UInt(M_STRIDE_BITS.W)
+ val xsize = UInt(M_SIZE_BITS.W)
+ val ysize = UInt(M_SIZE_BITS.W)
+ val empty_0 = UInt(7.W) // derive this
+ val dram_offset = UInt(M_DRAM_OFFSET_BITS.W)
+ val sram_offset = UInt(M_SRAM_OFFSET_BITS.W)
+ val id = UInt(M_ID_BITS.W)
+ val push_next = Bool()
+ val push_prev = Bool()
+ val pop_next = Bool()
+ val pop_prev = Bool()
+ val op = UInt(OP_BITS.W)
+}
+
+/** GemmDecode.
+ *
+ * Decode GEMM instruction with a Bundle. This is similar to an union,
+ * therefore order matters when declaring fields.
+ */
+class GemmDecode extends Bundle {
+ val wgt_1 = UInt(C_WIDX_BITS.W)
+ val wgt_0 = UInt(C_WIDX_BITS.W)
+ val inp_1 = UInt(C_IIDX_BITS.W)
+ val inp_0 = UInt(C_IIDX_BITS.W)
+ val acc_1 = UInt(C_AIDX_BITS.W)
+ val acc_0 = UInt(C_AIDX_BITS.W)
+ val empty_0 = Bool()
+ val lp_1 = UInt(C_ITER_BITS.W)
+ val lp_0 = UInt(C_ITER_BITS.W)
+ val uop_end = UInt(C_UOP_END_BITS.W)
+ val uop_begin = UInt(C_UOP_BGN_BITS.W)
+ val reset = Bool()
+ val push_next = Bool()
+ val push_prev = Bool()
+ val pop_next = Bool()
+ val pop_prev = Bool()
+ val op = UInt(OP_BITS.W)
+}
+
+/** AluDecode.
+ *
+ * Decode ALU instructions with a Bundle. This is similar to an union,
+ * therefore order matters when declaring fields. These are the instructions
+ * decoded with this bundle:
+ * - VMIN
+ * - VMAX
+ * - VADD
+ * - VSHX
+ */
+class AluDecode extends Bundle {
+ val empty_1 = Bool()
+ val alu_imm = UInt(C_ALU_IMM_BITS.W)
+ val alu_use_imm = Bool()
+ val alu_op = UInt(C_ALU_DEC_BITS.W)
+ val src_1 = UInt(C_IIDX_BITS.W)
+ val src_0 = UInt(C_IIDX_BITS.W)
+ val dst_1 = UInt(C_AIDX_BITS.W)
+ val dst_0 = UInt(C_AIDX_BITS.W)
+ val empty_0 = Bool()
+ val lp_1 = UInt(C_ITER_BITS.W)
+ val lp_0 = UInt(C_ITER_BITS.W)
+ val uop_end = UInt(C_UOP_END_BITS.W)
+ val uop_begin = UInt(C_UOP_BGN_BITS.W)
+ val reset = Bool()
+ val push_next = Bool()
+ val push_prev = Bool()
+ val pop_next = Bool()
+ val pop_prev = Bool()
+ val op = UInt(OP_BITS.W)
+}
+
+/** UopDecode.
+ *
+ * Decode micro-ops (uops).
+ */
+class UopDecode extends Bundle {
+ val u2 = UInt(10.W)
+ val u1 = UInt(11.W)
+ val u0 = UInt(11.W)
+}
+
+/** FetchDecode.
+ *
+ * Partial decoding for dispatching instructions to Load, Compute, and Store.
+ */
+class FetchDecode extends Module {
+ val io = IO(new Bundle {
+ val inst = Input(UInt(INST_BITS.W))
+ val isLoad = Output(Bool())
+ val isCompute = Output(Bool())
+ val isStore = Output(Bool())
+ })
+ val csignals =
+ ListLookup(io.inst,
+ List(N, OP_X),
+ Array(
+ LUOP -> List(Y, OP_G),
+ LWGT -> List(Y, OP_L),
+ LINP -> List(Y, OP_L),
+ LACC -> List(Y, OP_G),
+ SOUT -> List(Y, OP_S),
+ GEMM -> List(Y, OP_G),
+ FNSH -> List(Y, OP_G),
+ VMIN -> List(Y, OP_G),
+ VMAX -> List(Y, OP_G),
+ VADD -> List(Y, OP_G),
+ VSHX -> List(Y, OP_G)
+ )
+ )
+
+ val (cs_val_inst: Bool) :: cs_op_type :: Nil = csignals
+
+ io.isLoad := cs_val_inst & cs_op_type === OP_L
+ io.isCompute := cs_val_inst & cs_op_type === OP_G
+ io.isStore := cs_val_inst & cs_op_type === OP_S
+}
+
+/** LoadDecode.
+ *
+ * Decode dependencies, type and sync for Load module.
+ */
+class LoadDecode extends Module {
+ val io = IO(new Bundle {
+ val inst = Input(UInt(INST_BITS.W))
+ val push_next = Output(Bool())
+ val pop_next = Output(Bool())
+ val isInput = Output(Bool())
+ val isWeight = Output(Bool())
+ val isSync = Output(Bool())
+ })
+ val dec = io.inst.asTypeOf(new MemDecode)
+ io.push_next := dec.push_next
+ io.pop_next := dec.pop_next
+ io.isInput := io.inst === LINP & dec.xsize =/= 0.U
+ io.isWeight := io.inst === LWGT & dec.xsize =/= 0.U
+ io.isSync := (io.inst === LINP | io.inst === LWGT) & dec.xsize === 0.U
+}
+
+/** ComputeDecode.
+ *
+ * Decode dependencies, type and sync for Compute module.
+ */
+class ComputeDecode extends Module {
+ val io = IO(new Bundle {
+ val inst = Input(UInt(INST_BITS.W))
+ val push_next = Output(Bool())
+ val push_prev = Output(Bool())
+ val pop_next = Output(Bool())
+ val pop_prev = Output(Bool())
+ val isLoadAcc = Output(Bool())
+ val isLoadUop = Output(Bool())
+ val isSync = Output(Bool())
+ val isAlu = Output(Bool())
+ val isGemm = Output(Bool())
+ val isFinish = Output(Bool())
+ })
+ val dec = io.inst.asTypeOf(new MemDecode)
+ io.push_next := dec.push_next
+ io.push_prev := dec.push_prev
+ io.pop_next := dec.pop_next
+ io.pop_prev := dec.pop_prev
+ io.isLoadAcc := io.inst === LACC & dec.xsize =/= 0.U
+ io.isLoadUop := io.inst === LUOP & dec.xsize =/= 0.U
+ io.isSync := (io.inst === LACC | io.inst === LUOP) & dec.xsize === 0.U
+ io.isAlu := io.inst === VMIN | io.inst === VMAX | io.inst === VADD | io.inst === VSHX
+ io.isGemm := io.inst === GEMM
+ io.isFinish := io.inst === FNSH
+}
+
+/** StoreDecode.
+ *
+ * Decode dependencies, type and sync for Store module.
+ */
+class StoreDecode extends Module {
+ val io = IO(new Bundle {
+ val inst = Input(UInt(INST_BITS.W))
+ val push_prev = Output(Bool())
+ val pop_prev = Output(Bool())
+ val isStore = Output(Bool())
+ val isSync = Output(Bool())
+ })
+ val dec = io.inst.asTypeOf(new MemDecode)
+ io.push_prev := dec.push_prev
+ io.pop_prev := dec.pop_prev
+ io.isStore := io.inst === SOUT & dec.xsize =/= 0.U
+ io.isSync := io.inst === SOUT & dec.xsize === 0.U
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Fetch.
+ *
+ * The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
+ * VTA Memory Engine (VME), and push them into an instruction queue called
+ * inst_q. Once the instruction queue is full, instructions are dispatched to
+ * the Load, Compute and Store module queues based on the instruction opcode.
+ * After draining the queue, the fetch unit checks if there are more instructions
+ * via the ins_count register which is written by the host.
+ *
+ * Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
+ * because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
+ * This should be configurable for larger payloads, i.e. 64-bytes, which can load
+ * more than one instruction at the time. Finally, the instruction queue is
+ * sized (entries_q), depending on the maximum burst allowed in the memory.
+ */
+class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val vp = p(ShellKey).vcrParams
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val launch = Input(Bool())
+ val ins_baddr = Input(UInt(mp.addrBits.W))
+ val ins_count = Input(UInt(vp.regBits.W))
+ val vme_rd = new VMEReadMaster
+ val inst = new Bundle {
+ val ld = Decoupled(UInt(INST_BITS.W))
+ val co = Decoupled(UInt(INST_BITS.W))
+ val st = Decoupled(UInt(INST_BITS.W))
+ }
+ })
+ val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word
+ val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q))
+ val dec = Module(new FetchDecode)
+
+ val s1_launch = RegNext(io.launch)
+ val pulse = io.launch & ~s1_launch
+
+ val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
+ val rlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+ val ilen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+
+ val xrem = Reg(chiselTypeOf(io.ins_count))
+ val xsize = (io.ins_count << 1.U) - 1.U
+ val xmax = (1 << mp.lenBits).U
+ val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+
+ val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
+ val state = RegInit(sIdle)
+
+ // control
+ switch (state) {
+ is (sIdle) {
+ when (pulse) {
+ state := sReadCmd
+ when (xsize < xmax) {
+ rlen := xsize
+ ilen := xsize >> 1.U
+ xrem := 0.U
+ } .otherwise {
+ rlen := xmax - 1.U
+ ilen := (xmax >> 1.U) - 1.U
+ xrem := xsize - xmax
+ }
+ }
+ }
+ is (sReadCmd) {
+ when (io.vme_rd.cmd.ready) {
+ state := sReadLSB
+ }
+ }
+ is (sReadLSB) {
+ when (io.vme_rd.data.valid) {
+ state := sReadMSB
+ }
+ }
+ is (sReadMSB) {
+ when (io.vme_rd.data.valid) {
+ when (inst_q.io.count === ilen) {
+ state := sDrain
+ } .otherwise {
+ state := sReadLSB
+ }
+ }
+ }
+ is (sDrain) {
+ when (inst_q.io.count === 0.U) {
+ when (xrem === 0.U) {
+ state := sIdle
+ } .elsewhen (xrem < xmax) {
+ state := sReadCmd
+ rlen := xrem
+ ilen := xrem >> 1.U
+ xrem := 0.U
+ } .otherwise {
+ state := sReadCmd
+ rlen := xmax - 1.U
+ ilen := (xmax >> 1.U) - 1.U
+ xrem := xrem - xmax
+ }
+ }
+ }
+ }
+
+ // read instructions from dram
+ when (state === sIdle) {
+ raddr := io.ins_baddr
+ } .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
+ raddr := raddr + xmax_bytes
+ }
+
+ io.vme_rd.cmd.valid := state === sReadCmd
+ io.vme_rd.cmd.bits.addr := raddr
+ io.vme_rd.cmd.bits.len := rlen
+
+ io.vme_rd.data.ready := inst_q.io.enq.ready
+
+ val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits))
+ val msb = io.vme_rd.data.bits
+ val inst = Cat(msb, lsb)
+
+ when (state === sReadLSB) { lsb := io.vme_rd.data.bits }
+
+ inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
+ inst_q.io.enq.bits := inst
+
+ // decode
+ dec.io.inst := inst_q.io.deq.bits
+
+ // instruction queues
+ io.inst.ld.valid := dec.io.isLoad & inst_q.io.deq.valid & state === sDrain
+ io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain
+ io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain
+
+ io.inst.ld.bits := inst_q.io.deq.bits
+ io.inst.co.bits := inst_q.io.deq.bits
+ io.inst.st.bits := inst_q.io.deq.bits
+
+ // check if selected queue is ready
+ val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
+ val deq_ready =
+ MuxLookup(deq_sel,
+ false.B, // default
+ Array(
+ "h_01".U -> io.inst.ld.ready,
+ "h_02".U -> io.inst.st.ready,
+ "h_04".U -> io.inst.co.ready
+ )
+ )
+
+ // dequeue instruction
+ inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
+
+
+ // debug
+ if (debug) {
+ when (state === sIdle && pulse) {
+ printf("[Fetch] Launch\n")
+ }
+ // instruction
+ when (inst_q.io.deq.fire()) {
+ when (dec.io.isLoad) {
+ printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
+ }
+ when (dec.io.isCompute) {
+ printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
+ }
+ when (dec.io.isStore) {
+ printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
+ }
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+
+/** ISAConstants.
+ *
+ * These constants are used for decoding (parsing) fields on instructions.
+ */
+trait ISAConstants
+{
+ val INST_BITS = 128
+
+ val OP_BITS = 3
+
+ val M_DEP_BITS = 4
+ val M_ID_BITS = 2
+ val M_SRAM_OFFSET_BITS = 16
+ val M_DRAM_OFFSET_BITS = 32
+ val M_SIZE_BITS = 16
+ val M_STRIDE_BITS = 16
+ val M_PAD_BITS = 4
+
+ val C_UOP_BGN_BITS = 13
+ val C_UOP_END_BITS = 14
+ val C_ITER_BITS = 14
+ val C_AIDX_BITS = 11
+ val C_IIDX_BITS = 11
+ val C_WIDX_BITS = 10
+ val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
+ val C_ALU_OP_BITS = 3
+ val C_ALU_IMM_BITS = 16
+
+ val Y = true.B
+ val N = false.B
+
+ val OP_L = 0.asUInt(OP_BITS.W)
+ val OP_S = 1.asUInt(OP_BITS.W)
+ val OP_G = 2.asUInt(OP_BITS.W)
+ val OP_F = 3.asUInt(OP_BITS.W)
+ val OP_A = 4.asUInt(OP_BITS.W)
+ val OP_X = 5.asUInt(OP_BITS.W)
+
+ val ALU_OP_NUM = 5
+ val ALU_OP = Enum(ALU_OP_NUM)
+
+ val M_ID_U = 0.asUInt(M_ID_BITS.W)
+ val M_ID_W = 1.asUInt(M_ID_BITS.W)
+ val M_ID_I = 2.asUInt(M_ID_BITS.W)
+ val M_ID_A = 3.asUInt(M_ID_BITS.W)
+}
+
+/** ISA.
+ *
+ * This is the VTA ISA, here we specify the cares and dont-cares that makes
+ * decoding easier. Since instructions are quite long 128-bit, we could generate
+ * these based on ISAConstants.
+ *
+ * FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler
+ * TODO: Add VXOR to clear accumulator
+ */
+object ISA {
+ def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
+ def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
+ def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
+ def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
+ def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
+ def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
+ def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Load.
+ *
+ * Load inputs and weights from memory (DRAM) into scratchpads (SRAMs).
+ * This module instantiate the TensorLoad unit which is in charge of
+ * loading 1D and 2D tensors to scratchpads, so it can be used by
+ * other modules such as Compute.
+ */
+class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val i_post = Input(Bool())
+ val o_post = Output(Bool())
+ val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
+ val inp_baddr = Input(UInt(mp.addrBits.W))
+ val wgt_baddr = Input(UInt(mp.addrBits.W))
+ val vme_rd = Vec(2, new VMEReadMaster)
+ val inp = new TensorClient(tensorType = "inp")
+ val wgt = new TensorClient(tensorType = "wgt")
+ })
+ val sIdle :: sSync :: sExe :: Nil = Enum(3)
+ val state = RegInit(sIdle)
+
+ val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
+ val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
+
+ val dec = Module(new LoadDecode)
+ dec.io.inst := inst_q.io.deq.bits
+
+ val tensorType = Seq("inp", "wgt")
+ val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
+ val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
+
+ val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
+ val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
+
+ // control
+ switch (state) {
+ is (sIdle) {
+ when (start) {
+ when (dec.io.isSync) {
+ state := sSync
+ } .elsewhen (dec.io.isInput || dec.io.isWeight) {
+ state := sExe
+ }
+ }
+ }
+ is (sSync) {
+ state := sIdle
+ }
+ is (sExe) {
+ when (done) {
+ state := sIdle
+ }
+ }
+ }
+
+ // instructions
+ inst_q.io.enq <> io.inst
+ inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
+
+ // load tensor
+ // [0] input (inp)
+ // [1] weight (wgt)
+ val ptr = Seq(io.inp_baddr, io.wgt_baddr)
+ val tsor = Seq(io.inp, io.wgt)
+ for (i <- 0 until 2) {
+ tensorLoad(i).io.start := state === sIdle & start & tensorDec(i)
+ tensorLoad(i).io.inst := inst_q.io.deq.bits
+ tensorLoad(i).io.baddr := ptr(i)
+ tensorLoad(i).io.tensor <> tsor(i)
+ io.vme_rd(i) <> tensorLoad(i).io.vme_rd
+ }
+
+ // semaphore
+ s.io.spost := io.i_post
+ s.io.swait := dec.io.pop_next & (state === sIdle & start)
+ io.o_post := dec.io.push_next & ((state === sExe & done) | (state === sSync))
+
+ // debug
+ if (debug) {
+ // start
+ when (state === sIdle && start) {
+ when (dec.io.isSync) {
+ printf("[Load] start sync\n")
+ } .elsewhen (dec.io.isInput) {
+ printf("[Load] start input\n")
+ } .elsewhen (dec.io.isWeight) {
+ printf("[Load] start weight\n")
+ }
+ }
+ // done
+ when (state === sSync) {
+ printf("[Load] done sync\n")
+ }
+ when (state === sExe) {
+ when (done) {
+ when (dec.io.isInput) {
+ printf("[Load] done input\n")
+ } .elsewhen (dec.io.isWeight) {
+ printf("[Load] done weight\n")
+ }
+ }
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** UopMaster.
+ *
+ * Uop interface used by a master module, i.e. TensorAlu or TensorGemm,
+ * to request a micro-op (uop) from the uop-scratchpad. The index (idx) is
+ * used as an address to find the uop in the uop-scratchpad.
+ */
+class UopMaster(implicit p: Parameters) extends Bundle {
+ val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
+ val idx = ValidIO(UInt(addrBits.W))
+ val data = Flipped(ValidIO(new UopDecode))
+ override def cloneType = new UopMaster().asInstanceOf[this.type]
+}
+
+/** UopClient.
+ *
+ * Uop interface used by a client module, i.e. LoadUop, to receive
+ * a request from a master module, i.e. TensorAlu or TensorGemm.
+ * The index (idx) is used as an address to find the uop in the uop-scratchpad.
+ */
+class UopClient(implicit p: Parameters) extends Bundle {
+ val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
+ val idx = Flipped(ValidIO(UInt(addrBits.W)))
+ val data = ValidIO(new UopDecode)
+ override def cloneType = new UopClient().asInstanceOf[this.type]
+}
+
+/** LoadUop.
+ *
+ * Load micro-ops (uops) from memory, i.e. DRAM, and store them in the
+ * uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in
+ * group of 2 given the fact that the DRAM payload is 8-bytes. This module
+ * should be modified later on to support different DRAM sizes efficiently.
+ */
+class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vme_rd = new VMEReadMaster
+ val uop = new UopClient
+ })
+ val numUop = 2 // store two uops per sram word
+ val uopBits = p(CoreKey).uopBits
+ val uopDepth = p(CoreKey).uopMemDepth / numUop
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+ val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
+ val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+ val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+ val xrem = Reg(chiselTypeOf(dec.xsize))
+ val xsize = dec.xsize(0) + (dec.xsize >> log2Ceil(numUop)) - 1.U
+ val xmax = (1 << mp.lenBits).U
+ val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+
+ val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
+ val sizeIsEven = (dec.xsize % 2.U) === 0.U
+
+ val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3)
+ val state = RegInit(sIdle)
+
+ // control
+ switch (state) {
+ is (sIdle) {
+ when (io.start) {
+ state := sReadCmd
+ when (xsize < xmax) {
+ xlen := xsize
+ xrem := 0.U
+ } .otherwise {
+ xlen := xmax - 1.U
+ xrem := xsize - xmax
+ }
+ }
+ }
+ is (sReadCmd) {
+ when (io.vme_rd.cmd.ready) {
+ state := sReadData
+ }
+ }
+ is (sReadData) {
+ when (io.vme_rd.data.valid) {
+ when(xcnt === xlen) {
+ when (xrem === 0.U) {
+ state := sIdle
+ } .elsewhen (xrem < xmax) {
+ state := sReadCmd
+ xlen := xrem
+ xrem := 0.U
+ } .otherwise {
+ state := sReadCmd
+ xlen := xmax - 1.U
+ xrem := xrem - xmax
+ }
+ }
+ }
+ }
+ }
+
+ // read-from-dram
+ when (state === sIdle) {
+ when (offsetIsEven) {
+ raddr := io.baddr + dec.dram_offset
+ } .otherwise {
+ raddr := io.baddr + dec.dram_offset - 4.U
+ }
+ } .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
+ raddr := raddr + xmax_bytes
+ }
+
+ io.vme_rd.cmd.valid := state === sReadCmd
+ io.vme_rd.cmd.bits.addr := raddr
+ io.vme_rd.cmd.bits.len := xlen
+
+ io.vme_rd.data.ready := state === sReadData
+
+ when (state =/= sReadData) {
+ xcnt := 0.U
+ } .elsewhen (io.vme_rd.data.fire()) {
+ xcnt := xcnt + 1.U
+ }
+
+ val waddr = Reg(UInt(log2Ceil(uopDepth).W))
+ when (state === sIdle) {
+ waddr := dec.sram_offset >> log2Ceil(numUop)
+ } .elsewhen (io.vme_rd.data.fire()) {
+ waddr := waddr + 1.U
+ }
+
+ val wdata = Wire(Vec(numUop, UInt(uopBits.W)))
+ val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
+ val wmask = Reg(Vec(numUop, Bool()))
+
+ when (offsetIsEven) {
+ when (sizeIsEven) {
+ wmask := "b_11".U.asTypeOf(wmask)
+ } .elsewhen (io.vme_rd.cmd.fire()) {
+ when (dec.xsize === 1.U) {
+ wmask := "b_01".U.asTypeOf(wmask)
+ } .otherwise {
+ wmask := "b_11".U.asTypeOf(wmask)
+ }
+ } .elsewhen (io.vme_rd.data.fire()) {
+ when (xcnt === xlen - 1.U) {
+ wmask := "b_01".U.asTypeOf(wmask)
+ } .otherwise {
+ wmask := "b_11".U.asTypeOf(wmask)
+ }
+ }
+ } .otherwise {
+ when (io.vme_rd.cmd.fire()) {
+ wmask := "b_10".U.asTypeOf(wmask)
+ } .elsewhen (io.vme_rd.data.fire()) {
+ when (sizeIsEven && xcnt === xlen - 1.U) {
+ wmask := "b_01".U.asTypeOf(wmask)
+ } .otherwise {
+ wmask := "b_11".U.asTypeOf(wmask)
+ }
+ }
+ }
+
+ wdata := io.vme_rd.data.bits.asTypeOf(wdata)
+ when (io.vme_rd.data.fire()) {
+ mem.write(waddr, wdata, wmask)
+ }
+
+ // read-from-sram
+ io.uop.data.valid := RegNext(io.uop.idx.valid)
+
+ val sIdx = io.uop.idx.bits % numUop.U
+ val rIdx = io.uop.idx.bits >> log2Ceil(numUop)
+ val memRead = mem.read(rIdx, io.uop.idx.valid)
+ val sWord = memRead.asUInt.asTypeOf(wdata)
+ val sUop = sWord(sIdx).asTypeOf(io.uop.data.bits)
+
+ io.uop.data.bits <> sUop
+
+ // done
+ io.done := state === sReadData & io.vme_rd.data.valid & xcnt === xlen & xrem === 0.U
+
+ // debug
+ if (debug) {
+ when (io.vme_rd.cmd.fire()) {
+ printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+
+/** Semaphore.
+ *
+ * This semaphore is used instead of push/pop fifo, used in the initial
+ * version of VTA. This semaphore is incremented (spost) or decremented (swait)
+ * depending on the push and pop fields on instructions to prevent RAW and WAR
+ * hazards.
+ */
+class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
+ val io = IO(new Bundle {
+ val spost = Input(Bool())
+ val swait = Input(Bool())
+ val sready = Output(Bool())
+ })
+ val cnt = RegInit(counterInitValue.U(counterBits.W))
+ when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U }
+ when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
+ io.sready := cnt =/= 0.U
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Store.
+ *
+ * Store results back to memory (DRAM) from scratchpads (SRAMs).
+ * This module instantiate the TensorStore unit which is in charge
+ * of storing 1D and 2D tensors to main memory.
+ */
+class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val i_post = Input(Bool())
+ val o_post = Output(Bool())
+ val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
+ val out_baddr = Input(UInt(mp.addrBits.W))
+ val vme_wr = new VMEWriteMaster
+ val out = new TensorClient(tensorType = "out")
+ })
+ val sIdle :: sSync :: sExe :: Nil = Enum(3)
+ val state = RegInit(sIdle)
+
+ val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
+ val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
+
+ val dec = Module(new StoreDecode)
+ dec.io.inst := inst_q.io.deq.bits
+
+ val tensorStore = Module(new TensorStore(tensorType = "out"))
+
+ val start = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s.io.sready, true.B)
+ val done = tensorStore.io.done
+
+ // control
+ switch (state) {
+ is (sIdle) {
+ when (start) {
+ when (dec.io.isSync) {
+ state := sSync
+ } .elsewhen (dec.io.isStore) {
+ state := sExe
+ }
+ }
+ }
+ is (sSync) {
+ state := sIdle
+ }
+ is (sExe) {
+ when (done) {
+ state := sIdle
+ }
+ }
+ }
+
+ // instructions
+ inst_q.io.enq <> io.inst
+ inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
+
+ // store
+ tensorStore.io.start := state === sIdle & start & dec.io.isStore
+ tensorStore.io.inst := inst_q.io.deq.bits
+ tensorStore.io.baddr := io.out_baddr
+ io.vme_wr <> tensorStore.io.vme_wr
+ tensorStore.io.tensor <> io.out
+
+ // semaphore
+ s.io.spost := io.i_post
+ s.io.swait := dec.io.pop_prev & (state === sIdle & start)
+ io.o_post := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
+
+ // debug
+ if (debug) {
+ // start
+ when (state === sIdle && start) {
+ when (dec.io.isSync) {
+ printf("[Store] start sync\n")
+ } .elsewhen (dec.io.isStore) {
+ printf("[Store] start\n")
+ }
+ }
+ // done
+ when (state === sSync) {
+ printf("[Store] done sync\n")
+ }
+ when (state === sExe) {
+ when (done) {
+ printf("[Store] done\n")
+ }
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+
+/** ALU datapath */
+class Alu(implicit p: Parameters) extends Module {
+ val aluBits = p(CoreKey).accBits
+ val io = IO(new Bundle {
+ val opcode = Input(UInt(C_ALU_OP_BITS.W))
+ val a = Input(SInt(aluBits.W))
+ val b = Input(SInt(aluBits.W))
+ val y = Output(SInt(aluBits.W))
+ })
+
+ // FIXME: the following three will change once we support properly SHR and SHL
+ val ub = io.b.asUInt
+ val width = log2Ceil(aluBits)
+ val m = ~ub(width - 1, 0) + 1.U
+
+ val n = ub(width - 1, 0)
+ val fop = Seq(Mux(io.a < io.b, io.a, io.b),
+ Mux(io.a < io.b, io.b, io.a),
+ io.a + io.b,
+ io.a >> n,
+ io.a << m)
+
+ val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i))
+ io.y := MuxLookup(io.opcode, io.a, opmux)
+}
+
+/** Pipelined ALU */
+class AluReg(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val opcode = Input(UInt(C_ALU_OP_BITS.W))
+ val a = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
+ val b = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
+ val y = ValidIO(UInt(p(CoreKey).accBits.W))
+ })
+ val alu = Module(new Alu)
+ val rA = RegEnable(io.a.bits, io.a.valid)
+ val rB = RegEnable(io.b.bits, io.b.valid)
+ val valid = RegNext(io.b.valid)
+
+ alu.io.opcode := io.opcode
+
+ // register input
+ alu.io.a := rA.asSInt
+ alu.io.b := rB.asSInt
+
+ // output
+ io.y.valid := valid
+ io.y.bits := alu.io.y.asUInt
+}
+
+/** Vector of pipeline ALUs */
+class AluVector(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val opcode = Input(UInt(C_ALU_OP_BITS.W))
+ val acc_a = new TensorMasterData(tensorType = "acc")
+ val acc_b = new TensorMasterData(tensorType = "acc")
+ val acc_y = new TensorClientData(tensorType = "acc")
+ val out = new TensorClientData(tensorType = "out")
+ })
+ val blockOut = p(CoreKey).blockOut
+ val f = Seq.fill(blockOut)(Module(new AluReg))
+ val valid = Wire(Vec(blockOut, Bool()))
+ for (i <- 0 until blockOut) {
+ f(i).io.opcode := io.opcode
+ f(i).io.a.valid := io.acc_a.data.valid
+ f(i).io.a.bits := io.acc_a.data.bits(0)(i)
+ f(i).io.b.valid := io.acc_b.data.valid
+ f(i).io.b.bits := io.acc_b.data.bits(0)(i)
+ valid(i) := f(i).io.y.valid
+ io.acc_y.data.bits(0)(i) := f(i).io.y.bits
+ io.out.data.bits(0)(i) := f(i).io.y.bits
+ }
+ io.acc_y.data.valid := valid.asUInt.andR
+ io.out.data.valid := valid.asUInt.andR
+}
+
+/** TensorAlu.
+ *
+ * This unit instantiate the ALU vector unit (AluVector) and go over the
+ * micro-ops (uops) which are used to read the source operands (vectors)
+ * from the acc-scratchpad and then they are written back the same
+ * acc-scratchpad.
+ */
+class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val uop = new UopMaster
+ val acc = new TensorMaster(tensorType = "acc")
+ val out = new TensorMaster(tensorType = "out")
+ })
+ val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6)
+ val state = RegInit(sIdle)
+ val alu = Module(new AluVector)
+ val dec = io.inst.asTypeOf(new AluDecode)
+ val uop_idx = Reg(chiselTypeOf(dec.uop_end))
+ val uop_end = dec.uop_end
+ val uop_dst = Reg(chiselTypeOf(dec.uop_end))
+ val uop_src = Reg(chiselTypeOf(dec.uop_end))
+ val cnt_o = Reg(chiselTypeOf(dec.lp_0))
+ val dst_o = Reg(chiselTypeOf(dec.uop_end))
+ val src_o = Reg(chiselTypeOf(dec.uop_end))
+ val cnt_i = Reg(chiselTypeOf(dec.lp_1))
+ val dst_i = Reg(chiselTypeOf(dec.uop_end))
+ val src_i = Reg(chiselTypeOf(dec.uop_end))
+ val done =
+ state === sExe &
+ alu.io.out.data.valid &
+ (cnt_o === dec.lp_0 - 1.U) &
+ (cnt_i === dec.lp_1 - 1.U) &
+ (uop_idx === uop_end - 1.U)
+
+ switch (state) {
+ is (sIdle) {
+ when (io.start) {
+ state := sReadUop
+ }
+ }
+ is (sReadUop) {
+ state := sComputeIdx
+ }
+ is (sComputeIdx) {
+ state := sReadTensorA
+ }
+ is (sReadTensorA) {
+ state := sReadTensorB
+ }
+ is (sReadTensorB) {
+ state := sExe
+ }
+ is (sExe) {
+ when (alu.io.out.data.valid) {
+ when ((cnt_o === dec.lp_0 - 1.U) &&
+ (cnt_i === dec.lp_1 - 1.U) &&
+ (uop_idx === uop_end - 1.U)) {
+ state := sIdle
+ } .otherwise {
+ state := sReadUop
+ }
+ }
+ }
+ }
+
+ when (state === sIdle ||
+ (state === sExe &&
+ alu.io.out.data.valid &&
+ uop_idx === uop_end - 1.U)) {
+ uop_idx := dec.uop_begin
+ } .elsewhen (state === sExe && alu.io.out.data.valid) {
+ uop_idx := uop_idx + 1.U
+ }
+
+ when (state === sIdle) {
+ cnt_o := 0.U
+ dst_o := 0.U
+ src_o := 0.U
+ } .elsewhen (state === sExe &&
+ alu.io.out.data.valid &&
+ uop_idx === uop_end - 1.U &&
+ cnt_i === dec.lp_1 - 1.U) {
+ cnt_o := cnt_o + 1.U
+ dst_o := dst_o + dec.dst_0
+ src_o := src_o + dec.src_0
+ }
+
+ when (state === sIdle) {
+ cnt_i := 0.U
+ dst_i := 0.U
+ src_i := 0.U
+ } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
+ cnt_i := 0.U
+ dst_i := dst_o
+ src_i := src_o
+ } .elsewhen (state === sExe &&
+ alu.io.out.data.valid &&
+ uop_idx === uop_end - 1.U) {
+ cnt_i := cnt_i + 1.U
+ dst_i := dst_i + dec.dst_1
+ src_i := src_i + dec.src_1
+ }
+
+ when (state === sComputeIdx && io.uop.data.valid) {
+ uop_dst := io.uop.data.bits.u0 + dst_i
+ uop_src := io.uop.data.bits.u1 + src_i
+ }
+
+ // uop
+ io.uop.idx.valid := state === sReadUop
+ io.uop.idx.bits := uop_idx
+
+ // acc_i
+ io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
+ io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
+
+ // imm
+ val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
+ tensorImm.data.valid := state === sReadTensorB
+ tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } }
+
+ // alu
+ val isSHR = dec.alu_op === ALU_OP(3)
+ val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1)
+ val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
+ alu.io.opcode := fixme_alu_op
+ alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
+ alu.io.acc_a.data.bits <> io.acc.rd.data.bits
+ alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe)
+ alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits)
+
+ // acc_o
+ io.acc.wr.valid := alu.io.acc_y.data.valid
+ io.acc.wr.bits.idx := uop_dst
+ io.acc.wr.bits.data <> alu.io.acc_y.data.bits
+
+ // out
+ io.out.wr.valid := alu.io.out.data.valid
+ io.out.wr.bits.idx := uop_dst
+ io.out.wr.bits.data <> alu.io.out.data.bits
+ io.out.tieoffRead() // write-only
+
+ io.done := done
+
+ if (debug) {
+
+ when (state === sReadUop) {
+ printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
+ }
+
+ when (state === sReadTensorA) {
+ printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
+ }
+
+ when (state === sIdle && io.start) {
+ printf(p"[TensorAlu] decode:$dec\n")
+ }
+
+ alu.io.acc_a.data.bits.foreach { tensor =>
+ tensor.zipWithIndex.foreach { case(elem, i) =>
+ when (alu.io.acc_a.data.valid) {
+ printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
+ }
+ }
+ }
+
+ alu.io.acc_b.data.bits.foreach { tensor =>
+ tensor.zipWithIndex.foreach { case(elem, i) =>
+ when (alu.io.acc_b.data.valid) {
+ printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
+ }
+ }
+ }
+
+ alu.io.acc_y.data.bits.foreach { tensor =>
+ tensor.zipWithIndex.foreach { case(elem, i) =>
+ when (alu.io.acc_y.data.valid) {
+ printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
+ }
+ }
+ }
+
+ alu.io.out.data.bits.foreach { tensor =>
+ tensor.zipWithIndex.foreach { case(elem, i) =>
+ when (alu.io.out.data.valid) {
+ printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
+ }
+ }
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import chisel3.experimental._
+import vta.util.config._
+import scala.math.pow
+
+/** Pipelined multiply and accumulate */
+class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module {
+ require (cBits >= dataBits * 2)
+ require (outBits >= dataBits * 2)
+ val io = IO(new Bundle {
+ val a = Input(SInt(dataBits.W))
+ val b = Input(SInt(dataBits.W))
+ val c = Input(SInt(cBits.W))
+ val y = Output(SInt(outBits.W))
+ })
+ val mult = Wire(SInt(cBits.W))
+ val add = Wire(SInt(outBits.W))
+ val rA = RegNext(io.a)
+ val rB = RegNext(io.b)
+ val rC = RegNext(io.c)
+ mult := rA * rB
+ add := rC + mult
+ io.y := add
+}
+
+/** Pipelined adder */
+class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module {
+ require (outBits >= dataBits)
+ val io = IO(new Bundle {
+ val a = Input(SInt(dataBits.W))
+ val b = Input(SInt(dataBits.W))
+ val y = Output(SInt(outBits.W))
+ })
+ val add = Wire(SInt(outBits.W))
+ val rA = RegNext(io.a)
+ val rB = RegNext(io.b)
+ add := rA + rB
+ io.y := add
+}
+
+/** Pipelined DotProduct based on MAC and Adder */
+class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module {
+ val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
+ require(size >= 4 && isPow2(size), errMsg)
+ val b = dataBits * 2
+ val outBits = b + log2Ceil(size) + 1
+ val io = IO(new Bundle {
+ val a = Input(Vec(size, SInt(dataBits.W)))
+ val b = Input(Vec(size, SInt(dataBits.W)))
+ val y = Output(SInt(outBits.W))
+ })
+ val p = log2Ceil(size/2)
+ val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt)
+ val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i)))
+ val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i)))
+ val m = Seq.tabulate(2)(i =>
+ Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1)))
+ )
+ val a = Seq.tabulate(p)(i =>
+ Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3)))
+ )
+
+ for (i <- 0 until log2Ceil(size)) {
+ for (j <- 0 until s(i)) {
+ if (i == 0) {
+ m(i)(j).io.a := io.a(j)
+ m(i)(j).io.b := io.b(j)
+ m(i)(j).io.c := 0.S
+ m(i + 1)(j).io.a := da(j)
+ m(i + 1)(j).io.b := db(j)
+ m(i + 1)(j).io.c := m(i)(j).io.y
+ } else if (i == 1) {
+ a(i - 1)(j).io.a := m(i)(2*j).io.y
+ a(i - 1)(j).io.b := m(i)(2*j + 1).io.y
+ } else {
+ a(i - 1)(j).io.a := a(i - 2)(2*j).io.y
+ a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y
+ }
+ }
+ }
+ io.y := a(p-1)(0).io.y
+}
+
+/** Perform matric-vector-multiplication based on DotProduct */
+class MatrixVectorCore(implicit p: Parameters) extends Module {
+ val accBits = p(CoreKey).accBits
+ val size = p(CoreKey).blockOut
+ val dataBits = p(CoreKey).inpBits
+ val io = IO(new Bundle{
+ val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
+ val inp = new TensorMasterData(tensorType = "inp")
+ val wgt = new TensorMasterData(tensorType = "wgt")
+ val acc_i = new TensorMasterData(tensorType = "acc")
+ val acc_o = new TensorClientData(tensorType = "acc")
+ val out = new TensorClientData(tensorType = "out")
+ })
+ val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size)))
+ val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
+ val add = Seq.fill(size)(Wire(SInt(accBits.W)))
+ val vld = Wire(Vec(size, Bool()))
+
+ for (i <- 0 until size) {
+ acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset
+ acc(i).io.enq.bits := io.acc_i.data.bits(0)(i)
+ for (j <- 0 until size) {
+ dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt
+ dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt
+ }
+ add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y
+ io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt)
+ io.out.data.bits(0)(i) := add(i).asUInt
+ vld(i) := acc(i).io.deq.valid
+ }
+ io.acc_o.data.valid := vld.asUInt.andR | io.reset
+ io.out.data.valid := vld.asUInt.andR
+}
+
+/** TensorGemm.
+ *
+ * This unit instantiate the MatrixVectorCore and go over the
+ * micro-ops (uops) which are used to read inputs, weights and biases,
+ * and writes results back to the acc and out scratchpads.
+ *
+ * Also, the TensorGemm uses the reset field in the Gemm instruction to
+ * clear or zero-out the acc-scratchpad locations based on the micro-ops.
+ */
+class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val uop = new UopMaster
+ val inp = new TensorMaster(tensorType = "inp")
+ val wgt = new TensorMaster(tensorType = "wgt")
+ val acc = new TensorMaster(tensorType = "acc")
+ val out = new TensorMaster(tensorType = "out")
+ })
+ val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
+ val state = RegInit(sIdle)
+ val mvc = Module(new MatrixVectorCore)
+ val dec = io.inst.asTypeOf(new GemmDecode)
+ val uop_idx = Reg(chiselTypeOf(dec.uop_end))
+ val uop_end = dec.uop_end
+ val uop_acc = Reg(chiselTypeOf(dec.uop_end))
+ val uop_inp = Reg(chiselTypeOf(dec.uop_end))
+ val uop_wgt = Reg(chiselTypeOf(dec.uop_end))
+ val cnt_o = Reg(chiselTypeOf(dec.lp_0))
+ val acc_o = Reg(chiselTypeOf(dec.uop_end))
+ val inp_o = Reg(chiselTypeOf(dec.uop_end))
+ val wgt_o = Reg(chiselTypeOf(dec.uop_end))
+ val cnt_i = Reg(chiselTypeOf(dec.lp_1))
+ val acc_i = Reg(chiselTypeOf(dec.uop_end))
+ val inp_i = Reg(chiselTypeOf(dec.uop_end))
+ val wgt_i = Reg(chiselTypeOf(dec.uop_end))
+ val pBits = log2Ceil(p(CoreKey).blockOut) + 1
+ val inflight = Reg(UInt(pBits.W))
+ val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
+ val done = inflight === 0.U &
+ ((state === sExe &
+ cnt_o === dec.lp_0 - 1.U &
+ cnt_i === dec.lp_1 - 1.U &
+ uop_idx === uop_end - 1.U &
+ inflight === 0.U) |
+ state === sWait)
+
+ switch (state) {
+ is (sIdle) {
+ when (io.start) {
+ state := sReadUop
+ }
+ }
+ is (sReadUop) {
+ state := sComputeIdx
+ }
+ is (sComputeIdx) {
+ state := sReadTensor
+ }
+ is (sReadTensor) {
+ state := sExe
+ }
+ is (sExe) {
+ when ((cnt_o === dec.lp_0 - 1.U) &&
+ (cnt_i === dec.lp_1 - 1.U) &&
+ (uop_idx === uop_end - 1.U)) {
+ when (inflight =/= 0.U) {
+ state := sWait
+ } .otherwise {
+ state := sIdle
+ }
+ } .otherwise {
+ state := sReadUop
+ }
+ }
+ is (sWait) {
+ when (inflight === 0.U) {
+ state := sIdle
+ }
+ }
+ }
+
+ when (state === sIdle) {
+ inflight := 0.U
+ } .elsewhen (!dec.reset) {
+ when (state === sExe && inflight =/= ((1 << pBits) - 1).asUInt) { // overflow check
+ inflight := inflight + 1.U
+ } .elsewhen (mvc.io.acc_o.data.valid && inflight =/= 0.U) { // underflow check
+ inflight := inflight - 1.U
+ }
+ }
+
+ when (state === sIdle ||
+ (state === sExe &&
+ uop_idx === uop_end - 1.U)) {
+ uop_idx := dec.uop_begin
+ } .elsewhen (state === sExe) {
+ uop_idx := uop_idx + 1.U
+ }
+
+ when (state === sIdle) {
+ cnt_o := 0.U
+ acc_o := 0.U
+ inp_o := 0.U
+ wgt_o := 0.U
+ } .elsewhen (state === sExe &&
+ uop_idx === uop_end - 1.U &&
+ cnt_i === dec.lp_1 - 1.U) {
+ cnt_o := cnt_o + 1.U
+ acc_o := acc_o + dec.acc_0
+ inp_o := inp_o + dec.inp_0
+ wgt_o := wgt_o + dec.wgt_0
+ }
+
+ when (state === sIdle) {
+ cnt_i := 0.U
+ acc_i := 0.U
+ inp_i := 0.U
+ wgt_i := 0.U
+ } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
+ cnt_i := 0.U
+ acc_i := acc_o
+ inp_i := inp_o
+ wgt_i := wgt_o
+ } .elsewhen (state === sExe &&
+ uop_idx === uop_end - 1.U) {
+ cnt_i := cnt_i + 1.U
+ acc_i := acc_i + dec.acc_1
+ inp_i := inp_i + dec.inp_1
+ wgt_i := wgt_i + dec.wgt_1
+ }
+
+ when (state === sComputeIdx && io.uop.data.valid) {
+ uop_acc := io.uop.data.bits.u0 + acc_i
+ uop_inp := io.uop.data.bits.u1 + inp_i
+ uop_wgt := io.uop.data.bits.u2 + wgt_i
+ }
+
+ wrpipe.io.enq.valid := state === sExe & ~dec.reset
+ wrpipe.io.enq.bits := uop_acc
+
+ // uop
+ io.uop.idx.valid := state === sReadUop
+ io.uop.idx.bits := uop_idx
+
+ // inp
+ io.inp.rd.idx.valid := state === sReadTensor
+ io.inp.rd.idx.bits := uop_inp
+ io.inp.tieoffWrite() // read-only
+
+ // wgt
+ io.wgt.rd.idx.valid := state === sReadTensor
+ io.wgt.rd.idx.bits := uop_wgt
+ io.wgt.tieoffWrite() // read-only
+
+ // acc_i
+ io.acc.rd.idx.valid := state === sReadTensor
+ io.acc.rd.idx.bits := uop_acc
+
+ // mvc
+ mvc.io.reset := dec.reset & state === sExe
+ mvc.io.inp.data <> io.inp.rd.data
+ mvc.io.wgt.data <> io.wgt.rd.data
+ mvc.io.acc_i.data <> io.acc.rd.data
+
+ // acc_o
+ io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid)
+ io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
+ io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
+
+ // out
+ io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
+ io.out.wr.bits.idx := wrpipe.io.deq.bits
+ io.out.wr.bits.data <> mvc.io.out.data.bits
+ io.out.tieoffRead() // write-only
+
+ io.done := done
+
+ if (debug) {
+ when (state === sReadUop && ~dec.reset) {
+ printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
+ }
+
+ when (state === sReadTensor && ~dec.reset) {
+ printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
+ }
+
+ io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
+ when (io.inp.rd.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
+ }
+ }
+
+ io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
+ when (io.wgt.rd.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
+ }
+ }
+
+ io.acc.rd.data.bits.foreach { tensor =>
+ tensor.zipWithIndex.foreach { case(elem, i) =>
+ when (io.acc.rd.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
+ }
+ }
+ }
+
+ mvc.io.acc_o.data.bits.foreach { tensor =>
+ tensor.zipWithIndex.foreach { case(elem, i) =>
+ when (mvc.io.acc_o.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
+ }
+ }
+ }
+
+ mvc.io.out.data.bits.foreach { tensor =>
+ tensor.zipWithIndex.foreach { case(elem, i) =>
+ when (mvc.io.out.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
+ }
+ }
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** TensorStore.
+ *
+ * Load 1D and 2D tensors from main memory (DRAM) to input/weight
+ * scratchpads (SRAM). Also, there is support for zero padding, while
+ * doing the load. Zero-padding works on the y and x axis, and it is
+ * managed by TensorPadCtrl. The TensorDataCtrl is in charge of
+ * handling the way tensors are stored on the scratchpads.
+ */
+class TensorLoad(tensorType: String = "none", debug: Boolean = false)
+ (implicit p: Parameters) extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vme_rd = new VMEReadMaster
+ val tensor = new TensorClient(tensorType)
+ })
+ val sizeFactor = tp.tensorLength * tp.numMemBlock
+ val strideFactor = tp.tensorLength * tp.tensorWidth
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+ val dataCtrl = Module(new TensorDataCtrl(sizeFactor, strideFactor))
+ val dataCtrlDone = RegInit(false.B)
+ val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
+ val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
+ val xPadCtrl0 = Module(new TensorPadCtrl(padType = "XPad0", sizeFactor))
+ val xPadCtrl1 = Module(new TensorPadCtrl(padType = "XPad1", sizeFactor))
+
+ val tag = Reg(UInt(8.W))
+ val set = Reg(UInt(8.W))
+
+ val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7)
+ val state = RegInit(sIdle)
+
+ // control
+ switch (state) {
+ is (sIdle) {
+ when (io.start) {
+ when (dec.ypad_0 =/= 0.U) {
+ state := sYPad0
+ } .elsewhen (dec.xpad_0 =/= 0.U) {
+ state := sXPad0
+ } .otherwise {
+ state := sReadCmd
+ }
+ }
+ }
+ is (sYPad0) {
+ when (yPadCtrl0.io.done) {
+ when (dec.xpad_0 =/= 0.U) {
+ state := sXPad0
+ } .otherwise {
+ state := sReadCmd
+ }
+ }
+ }
+ is (sXPad0) {
+ when (xPadCtrl0.io.done) {
+ state := sReadCmd
+ }
+ }
+ is (sReadCmd) {
+ when (io.vme_rd.cmd.ready) {
+ state := sReadData
+ }
+ }
+ is (sReadData) {
+ when (io.vme_rd.data.valid) {
+ when (dataCtrl.io.done) {
+ when (dec.xpad_1 =/= 0.U) {
+ state := sXPad1
+ } .elsewhen (dec.ypad_1 =/= 0.U) {
+ state := sYPad1
+ } .otherwise {
+ state := sIdle
+ }
+ } .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) {
+ when (dec.xpad_1 =/= 0.U) {
+ state := sXPad1
+ } .elsewhen (dec.xpad_0 =/= 0.U) {
+ state := sXPad0
+ } .otherwise {
+ state := sReadCmd
+ }
+ }
+ }
+ }
+ is (sXPad1) {
+ when (xPadCtrl1.io.done) {
+ when (dataCtrlDone) {
+ when (dec.ypad_1 =/= 0.U) {
+ state := sYPad1
+ } .otherwise {
+ state := sIdle
+ }
+ } .otherwise {
+ when (dec.xpad_0 =/= 0.U) {
+ state := sXPad0
+ } .otherwise {
+ state := sReadCmd
+ }
+ }
+ }
+ }
+ is (sYPad1) {
+ when (yPadCtrl1.io.done && dataCtrlDone) {
+ state := sIdle
+ }
+ }
+ }
+
+ // data controller
+ dataCtrl.io.start := state === sIdle & io.start
+ dataCtrl.io.inst := io.inst
+ dataCtrl.io.baddr := io.baddr
+ dataCtrl.io.xinit := io.vme_rd.cmd.fire()
+ dataCtrl.io.xupdate := io.vme_rd.data.fire()
+ dataCtrl.io.yupdate := io.vme_rd.data.fire()
+
+ when (state === sIdle) {
+ dataCtrlDone := false.B
+ } .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) {
+ dataCtrlDone := true.B
+ }
+
+ // pad
+ yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
+
+ yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
+ ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
+ (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
+
+ xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
+ ((state === sIdle & io.start) |
+ (state === sYPad0 & yPadCtrl0.io.done) |
+ (io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
+ (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
+
+ xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
+ ((dataCtrl.io.done) |
+ (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
+
+ yPadCtrl0.io.inst := io.inst
+ yPadCtrl1.io.inst := io.inst
+ xPadCtrl0.io.inst := io.inst
+ xPadCtrl1.io.inst := io.inst
+
+ // read-from-dram
+ io.vme_rd.cmd.valid := state === sReadCmd
+ io.vme_rd.cmd.bits.addr := dataCtrl.io.addr
+ io.vme_rd.cmd.bits.len := dataCtrl.io.len
+
+ io.vme_rd.data.ready := state === sReadData
+
+ // write-to-sram
+ val isZeroPad = state === sYPad0 |
+ state === sXPad0 |
+ state === sXPad1 |
+ state === sYPad1
+
+ when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
+ tag := 0.U
+ } .elsewhen (io.vme_rd.data.fire() || isZeroPad) {
+ tag := tag + 1.U
+ }
+
+ when (state === sIdle || state === sReadCmd || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
+ set := 0.U
+ } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
+ set := set + 1.U
+ }
+
+ val waddr_cur = Reg(UInt(tp.memAddrBits.W))
+ val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
+ when (state === sIdle) {
+ waddr_cur := dec.sram_offset
+ waddr_nxt := dec.sram_offset
+ } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
+ waddr_cur := waddr_cur + 1.U
+ } .elsewhen (dataCtrl.io.stride) {
+ waddr_cur := waddr_nxt + dec.xsize
+ waddr_nxt := waddr_nxt + dec.xsize
+ }
+
+ val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+ val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
+ val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+ val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
+ no_mask.foreach { m => m := true.B }
+
+ for (i <- 0 until tp.tensorLength) {
+ for (j <- 0 until tp.numMemBlock) {
+ wmask(i)(j) := tag === j.U
+ wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
+ }
+ val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
+ val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
+ val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
+ val muxWdata = Mux(state === sIdle, tdata, wdata(i))
+ val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
+ when (muxWen) {
+ tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
+ }
+ }
+
+ // read-from-sram
+ val rvalid = RegNext(io.tensor.rd.idx.valid)
+ io.tensor.rd.data.valid := rvalid
+
+ val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
+ rdata.zipWithIndex.foreach { case(r, i) =>
+ io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
+ }
+
+ // done
+ val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
+ val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
+ val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
+ io.done := done_no_pad | done_x_pad | done_y_pad
+
+ // debug
+ if (debug) {
+ if (tensorType == "inp") {
+ when (io.vme_rd.cmd.fire()) {
+ printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+ }
+ when (state === sYPad0) {
+ printf("[TensorLoad] [inp] sYPad0\n")
+ }
+ when (state === sYPad1) {
+ printf("[TensorLoad] [inp] sYPad1\n")
+ }
+ when (state === sXPad0) {
+ printf("[TensorLoad] [inp] sXPad0\n")
+ }
+ when (state === sXPad1) {
+ printf("[TensorLoad] [inp] sXPad1\n")
+ }
+ } else if (tensorType == "wgt") {
+ when (io.vme_rd.cmd.fire()) {
+ printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+ }
+ } else if (tensorType == "acc") {
+ when (io.vme_rd.cmd.fire()) {
+ printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+ }
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** TensorStore.
+ *
+ * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
+ */
+class TensorStore(tensorType: String = "true", debug: Boolean = false)
+ (implicit p: Parameters) extends Module {
+ val tp = new TensorParams(tensorType)
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val vme_wr = new VMEWriteMaster
+ val tensor = new TensorClient(tensorType)
+ })
+ val tensorLength = tp.tensorLength
+ val tensorWidth = tp.tensorWidth
+ val tensorElemBits = tp.tensorElemBits
+ val memBlockBits = tp.memBlockBits
+ val memDepth = tp.memDepth
+ val numMemBlock = tp.numMemBlock
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+ val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
+ val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
+ val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
+ val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
+ val xrem = Reg(chiselTypeOf(dec.xsize))
+ val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U
+ val xmax = (1 << mp.lenBits).U
+ val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+ val ycnt = Reg(chiselTypeOf(dec.ysize))
+ val ysize = dec.ysize
+ val tag = Reg(UInt(8.W))
+ val set = Reg(UInt(8.W))
+
+ val sIdle :: sWriteCmd :: sWriteData :: sReadMem :: sWriteAck :: Nil = Enum(5)
+ val state = RegInit(sIdle)
+
+ // control
+ switch (state) {
+ is (sIdle) {
+ when (io.start) {
+ state := sWriteCmd
+ when (xsize < xmax) {
+ xlen := xsize
+ xrem := 0.U
+ } .otherwise {
+ xlen := xmax - 1.U
+ xrem := xsize - xmax
+ }
+ }
+ }
+ is (sWriteCmd) {
+ when (io.vme_wr.cmd.ready) {
+ state := sWriteData
+ }
+ }
+ is (sWriteData) {
+ when (io.vme_wr.data.ready) {
+ when (xcnt === xlen) {
+ state := sWriteAck
+ } .elsewhen (tag === (numMemBlock - 1).U) {
+ state := sReadMem
+ }
+ }
+ }
+ is (sReadMem) {
+ state := sWriteData
+ }
+ is (sWriteAck) {
+ when (io.vme_wr.ack) {
+ when (xrem === 0.U) {
+ when (ycnt === ysize - 1.U) {
+ state := sIdle
+ } .otherwise {
+ state := sWriteCmd
+ when (xsize < xmax) {
+ xlen := xsize
+ xrem := 0.U
+ } .otherwise {
+ xlen := xmax - 1.U
+ xrem := xsize - xmax
+ }
+ }
+ } .elsewhen (xrem < xmax) {
+ state := sWriteCmd
+ xlen := xrem
+ xrem := 0.U
+ } .otherwise {
+ state := sWriteCmd
+ xlen := xmax - 1.U
+ xrem := xrem - xmax
+ }
+ }
+ }
+ }
+
+ // write-to-sram
+ val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) }
+ val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
+ val no_mask = Wire(Vec(numMemBlock, Bool()))
+
+ wdata_t := DontCare
+ no_mask.foreach { m => m := true.B }
+
+ for (i <- 0 until tensorLength) {
+ val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
+ when (io.tensor.wr.valid) {
+ tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
+ }
+ }
+
+ // read-from-sram
+ val stride = state === sWriteAck &
+ io.vme_wr.ack &
+ xcnt === xlen + 1.U &
+ xrem === 0.U &
+ ycnt =/= ysize - 1.U
+
+ when (state === sIdle) {
+ ycnt := 0.U
+ } .elsewhen (stride) {
+ ycnt := ycnt + 1.U
+ }
+
+ when (state === sWriteCmd || tag === (numMemBlock - 1).U) {
+ tag := 0.U
+ } .elsewhen (io.vme_wr.data.fire()) {
+ tag := tag + 1.U
+ }
+
+ when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
+ set := 0.U
+ } .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
+ set := set + 1.U
+ }
+
+ val raddr_cur = Reg(UInt(tp.memAddrBits.W))
+ val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
+ when (state === sIdle) {
+ raddr_cur := dec.sram_offset
+ raddr_nxt := dec.sram_offset
+ } .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
+ raddr_cur := raddr_cur + 1.U
+ } .elsewhen (stride) {
+ raddr_cur := raddr_nxt + dec.xsize
+ raddr_nxt := raddr_nxt + dec.xsize
+ }
+
+ val tread = Seq.tabulate(tensorLength) { i => i.U ->
+ tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) }
+ val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
+
+ // write-to-dram
+ when (state === sIdle) {
+ waddr_cur := io.baddr + dec.dram_offset
+ waddr_nxt := io.baddr + dec.dram_offset
+ } .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
+ waddr_cur := waddr_cur + xmax_bytes
+ } .elsewhen (stride) {
+ waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
+ waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
+ }
+
+ io.vme_wr.cmd.valid := state === sWriteCmd
+ io.vme_wr.cmd.bits.addr := waddr_cur
+ io.vme_wr.cmd.bits.len := xlen
+
+ io.vme_wr.data.valid := state === sWriteData
+ io.vme_wr.data.bits := mdata(tag)
+
+ when (state === sWriteCmd) {
+ xcnt := 0.U
+ } .elsewhen (io.vme_wr.data.fire()) {
+ xcnt := xcnt + 1.U
+ }
+
+ // disable external read-from-sram requests
+ io.tensor.tieoffRead()
+
+ // done
+ io.done := state === sWriteAck & io.vme_wr.ack & xrem === 0.U & ycnt === ysize - 1.U
+
+ // debug
+ if (debug) {
+ when (io.vme_wr.cmd.fire()) {
+ printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
+ }
+ when (io.vme_wr.data.fire()) {
+ printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
+ }
+ when (io.vme_wr.ack) {
+ printf("[TensorStore] ack\n")
+ }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** TensorParams.
+ *
+ * This Bundle derives parameters for each tensorType, including inputs (inp),
+ * weights (wgt), biases (acc), and outputs (out). This is used to avoid
+ * doing the same boring calculations over and over again.
+ */
+class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
+ val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
+
+ require (tensorType == "inp" || tensorType == "wgt"
+ || tensorType == "acc" || tensorType == "out", errorMsg)
+
+ val (tensorLength, tensorWidth, tensorElemBits) =
+ if (tensorType == "inp")
+ (p(CoreKey).batch, p(CoreKey).blockIn, p(CoreKey).inpBits)
+ else if (tensorType == "wgt")
+ (p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits)
+ else if (tensorType == "acc")
+ (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits)
+ else
+ (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits)
+
+ val memBlockBits = p(ShellKey).memParams.dataBits
+ val numMemBlock = (tensorWidth * tensorElemBits) / memBlockBits
+
+ val memDepth =
+ if (tensorType == "inp")
+ p(CoreKey).inpMemDepth
+ else if (tensorType == "wgt")
+ p(CoreKey).wgtMemDepth
+ else if (tensorType == "acc")
+ p(CoreKey).accMemDepth
+ else
+ p(CoreKey).outMemDepth
+
+ val memAddrBits = log2Ceil(memDepth)
+}
+
+/** TensorMaster.
+ *
+ * This interface issue read and write tensor-requests to scratchpads. For example,
+ * The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt),
+ * biases (acc), and outputs (out).
+ *
+ */
+class TensorMaster(tensorType: String = "none")
+ (implicit p: Parameters) extends TensorParams(tensorType) {
+ val rd = new Bundle {
+ val idx = ValidIO(UInt(memAddrBits.W))
+ val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+ }
+ val wr = ValidIO(new Bundle {
+ val idx = UInt(memAddrBits.W)
+ val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+ })
+ def tieoffRead() {
+ rd.idx.valid := false.B
+ rd.idx.bits := 0.U
+ }
+ def tieoffWrite() {
+ wr.valid := false.B
+ wr.bits.idx := 0.U
+ wr.bits.data.foreach { b => b.foreach { c => c := 0.U } }
+ }
+ override def cloneType =
+ new TensorMaster(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorClient.
+ *
+ * This interface receives read and write tensor-requests to scratchpads. For example,
+ * The TensorLoad unit uses this interface for receiving read and write requests from
+ * the TensorGemm unit.
+ */
+class TensorClient(tensorType: String = "none")
+ (implicit p: Parameters) extends TensorParams(tensorType) {
+ val rd = new Bundle {
+ val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
+ val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+ }
+ val wr = Flipped(ValidIO(new Bundle {
+ val idx = UInt(memAddrBits.W)
+ val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+ }))
+ def tieoffRead() {
+ rd.data.valid := false.B
+ rd.data.bits.foreach { b => b.foreach { c => c := 0.U } }
+ }
+ override def cloneType =
+ new TensorClient(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorMasterData.
+ *
+ * This interface is only used for datapath only purposes and the direction convention
+ * is based on the TensorMaster interface, which means this is an input. This interface
+ * is used on datapath only module such MatrixVectorCore or AluVector.
+ */
+class TensorMasterData(tensorType: String = "none")
+ (implicit p: Parameters) extends TensorParams(tensorType) {
+ val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+ override def cloneType =
+ new TensorMasterData(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorClientData.
+ *
+ * This interface is only used for datapath only purposes and the direction convention
+ * is based on the TensorClient interface, which means this is an output. This interface
+ * is used on datapath only module such MatrixVectorCore or AluVector.
+ */
+class TensorClientData(tensorType: String = "none")
+ (implicit p: Parameters) extends TensorParams(tensorType) {
+ val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+ override def cloneType =
+ new TensorClientData(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorPadCtrl. Zero-padding controller for TensorLoad. */
+class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
+ val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
+ require (padType == "YPad0" || padType == "YPad1"
+ || padType == "XPad0" || padType == "XPad1", errorMsg)
+
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ })
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val xmax = Reg(chiselTypeOf(dec.xsize))
+ val ymax = Reg(chiselTypeOf(dec.ypad_0))
+ val xcnt = Reg(chiselTypeOf(dec.xsize))
+ val ycnt = Reg(chiselTypeOf(dec.ypad_0))
+
+ val xval =
+ if (padType == "YPad0" || padType == "YPad1")
+ ((dec.xpad_0 + dec.xsize + dec.xpad_1) << log2Ceil(sizeFactor)) - 1.U
+ else if (padType == "XPad0")
+ (dec.xpad_0 << log2Ceil(sizeFactor)) - 1.U
+ else
+ (dec.xpad_1 << log2Ceil(sizeFactor)) - 1.U
+
+ val yval =
+ if (padType == "YPad0")
+ Mux(dec.ypad_0 =/= 0.U, dec.ypad_0 - 1.U, 0.U)
+ else if (padType == "YPad1")
+ Mux(dec.ypad_1 =/= 0.U, dec.ypad_1 - 1.U, 0.U)
+ else
+ 0.U
+
+ val sIdle :: sActive :: Nil = Enum(2)
+ val state = RegInit(sIdle)
+
+ switch (state) {
+ is (sIdle) {
+ when (io.start) {
+ state := sActive
+ }
+ }
+ is (sActive) {
+ when (ycnt === ymax && xcnt === xmax) {
+ state := sIdle
+ }
+ }
+ }
+
+ when (state === sIdle) {
+ xmax := xval
+ ymax := yval
+ }
+
+ when (state === sIdle || xcnt === xmax) {
+ xcnt := 0.U
+ } .elsewhen (state === sActive) {
+ xcnt := xcnt + 1.U
+ }
+
+ when (state === sIdle || ymax === 0.U) {
+ ycnt := 0.U
+ } .elsewhen (state === sActive && xcnt === xmax) {
+ ycnt := ycnt + 1.U
+ }
+
+ io.done := state === sActive & ycnt === ymax & xcnt === xmax
+}
+
+/** TensorDataCtrl. Data controller for TensorLoad. */
+class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
+ val mp = p(ShellKey).memParams
+ val io = IO(new Bundle {
+ val start = Input(Bool())
+ val done = Output(Bool())
+ val inst = Input(UInt(INST_BITS.W))
+ val baddr = Input(UInt(mp.addrBits.W))
+ val xinit = Input(Bool())
+ val xupdate = Input(Bool())
+ val yupdate = Input(Bool())
+ val stride = Output(Bool())
+ val split = Output(Bool())
+ val commit = Output(Bool())
+ val addr = Output(UInt(mp.addrBits.W))
+ val len = Output(UInt(mp.lenBits.W))
+ })
+
+ val dec = io.inst.asTypeOf(new MemDecode)
+
+ val caddr = Reg(UInt(mp.addrBits.W))
+ val baddr = Reg(UInt(mp.addrBits.W))
+
+ val len = Reg(UInt(mp.lenBits.W))
+
+ val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+ val xcnt = Reg(UInt(mp.lenBits.W))
+ val xrem = Reg(chiselTypeOf(dec.xsize))
+ val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
+ val xmax = (1 << mp.lenBits).U
+ val ycnt = Reg(chiselTypeOf(dec.ysize))
+
+ val stride = xcnt === len &
+ xrem === 0.U &
+ ycnt =/= dec.ysize - 1.U
+
+ val split = xcnt === len & xrem =/= 0.U
+
+ when (io.start || (io.xupdate && stride)) {
+ when (xsize < xmax) {
+ len := xsize
+ xrem := 0.U
+ } .otherwise {
+ len := xmax - 1.U
+ xrem := xsize - xmax
+ }
+ } .elsewhen (io.xupdate && split) {
+ when (xrem < xmax) {
+ len := xrem
+ xrem := 0.U
+ } .otherwise {
+ len := xmax - 1.U
+ xrem := xrem - xmax
+ }
+ }
+
+ when (io.xinit) {
+ xcnt := 0.U
+ } .elsewhen (io.xupdate) {
+ xcnt := xcnt + 1.U
+ }
+
+ when (io.start) {
+ ycnt := 0.U
+ } .elsewhen (io.yupdate && stride) {
+ ycnt := ycnt + 1.U
+ }
+
+ when (io.start) {
+ caddr := io.baddr + dec.dram_offset
+ baddr := io.baddr + dec.dram_offset
+ } .elsewhen (io.yupdate) {
+ when (split) {
+ caddr := caddr + xmax_bytes
+ } .elsewhen (stride) {
+ caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
+ baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
+ }
+ }
+
+ io.stride := stride
+ io.split := split
+ io.commit := xcnt === len
+ io.addr := caddr
+ io.len := len
+ io.done := xcnt === len &
+ xrem === 0.U &
+ ycnt === dec.ysize - 1.U
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta
+
+/** This trick makes ISAConstants globally available */
+package object core extends vta.core.ISAConstants
import chisel3._
import chisel3.util._
+import vta.util.config._
+import vta.interface.axi._
+import vta.shell._
/** Host DPI parameters */
trait VTAHostDPIParams {
})
setResource("/verilog/VTAHostDPI.v")
}
+
+/** Host DPI to AXI Converter.
+ *
+ * Convert Host DPI to AXI for VTAShell
+ */
+
+class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val dpi = new VTAHostDPIClient
+ val axi = new AXILiteMaster(p(ShellKey).hostParams)
+ })
+ val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
+ val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
+ val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+ val state = RegInit(sIdle)
+
+ switch (state) {
+ is (sIdle) {
+ when (io.dpi.req.valid) {
+ when (io.dpi.req.opcode) {
+ state := sWriteAddress
+ } .otherwise {
+ state := sReadAddress
+ }
+ }
+ }
+ is (sReadAddress) {
+ when (io.axi.ar.ready) {
+ state := sReadData
+ }
+ }
+ is (sReadData) {
+ when (io.axi.r.valid) {
+ state := sIdle
+ }
+ }
+ is (sWriteAddress) {
+ when (io.axi.aw.ready) {
+ state := sWriteData
+ }
+ }
+ is (sWriteData) {
+ when (io.axi.w.ready) {
+ state := sWriteResponse
+ }
+ }
+ is (sWriteResponse) {
+ when (io.axi.b.valid) {
+ state := sIdle
+ }
+ }
+ }
+
+ when (state === sIdle && io.dpi.req.valid) {
+ addr := io.dpi.req.addr
+ data := io.dpi.req.value
+ }
+
+ io.axi.aw.valid := state === sWriteAddress
+ io.axi.aw.bits.addr := addr
+ io.axi.w.valid := state === sWriteData
+ io.axi.w.bits.data := data
+ io.axi.w.bits.strb := "h_f".U
+ io.axi.b.ready := state === sWriteResponse
+
+ io.axi.ar.valid := state === sReadAddress
+ io.axi.ar.bits.addr := addr
+ io.axi.r.ready := state === sReadData
+
+ io.dpi.req.deq := (state === sReadAddress & io.axi.ar.ready) | (state === sWriteAddress & io.axi.aw.ready)
+ io.dpi.resp.valid := io.axi.r.valid
+ io.dpi.resp.bits := io.axi.r.bits.data
+
+ if (debug) {
+ when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) }
+ when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) }
+ when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) }
+ when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) }
+ }
+}
import chisel3._
import chisel3.util._
+import vta.util.config._
+import vta.interface.axi._
+import vta.shell._
/** Memory DPI parameters */
trait VTAMemDPIParams {
})
setResource("/verilog/VTAMemDPI.v")
}
+
+class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val dpi = new VTAMemDPIMaster
+ val axi = new AXIClient(p(ShellKey).memParams)
+ })
+ val opcode = RegInit(false.B)
+ val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
+ val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
+ val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+ val state = RegInit(sIdle)
+
+ switch (state) {
+ is (sIdle) {
+ when (io.axi.ar.valid) {
+ state := sReadAddress
+ } .elsewhen (io.axi.aw.valid) {
+ state := sWriteAddress
+ }
+ }
+ is (sReadAddress) {
+ when (io.axi.ar.valid) {
+ state := sReadData
+ }
+ }
+ is (sReadData) {
+ when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
+ state := sIdle
+ }
+ }
+ is (sWriteAddress) {
+ when (io.axi.aw.valid) {
+ state := sWriteData
+ }
+ }
+ is (sWriteData) {
+ when (io.axi.w.valid && io.axi.w.bits.last) {
+ state := sWriteResponse
+ }
+ }
+ is (sWriteResponse) {
+ when (io.axi.b.ready) {
+ state := sIdle
+ }
+ }
+ }
+
+ when (state === sIdle) {
+ when (io.axi.ar.valid) {
+ opcode := false.B
+ len := io.axi.ar.bits.len
+ addr := io.axi.ar.bits.addr
+ } .elsewhen (io.axi.aw.valid) {
+ opcode := true.B
+ len := io.axi.aw.bits.len
+ addr := io.axi.aw.bits.addr
+ }
+ } .elsewhen (state === sReadData) {
+ when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
+ len := len - 1.U
+ }
+ }
+
+ io.dpi.req.valid := (state === sReadAddress & io.axi.ar.valid) | (state === sWriteAddress & io.axi.aw.valid)
+ io.dpi.req.opcode := opcode
+ io.dpi.req.len := len
+ io.dpi.req.addr := addr
+
+ io.axi.ar.ready := state === sReadAddress
+ io.axi.aw.ready := state === sWriteAddress
+
+ io.axi.r.valid := state === sReadData & io.dpi.rd.valid
+ io.axi.r.bits.data := io.dpi.rd.bits
+ io.axi.r.bits.last := len === 0.U
+ io.axi.r.bits.resp := 0.U
+ io.axi.r.bits.user := 0.U
+ io.axi.r.bits.id := 0.U
+ io.dpi.rd.ready := state === sReadData & io.axi.r.ready
+
+ io.dpi.wr.valid := state === sWriteData & io.axi.w.valid
+ io.dpi.wr.bits := io.axi.w.bits.data
+ io.axi.w.ready := state === sWriteData
+
+ io.axi.b.valid := state === sWriteResponse
+ io.axi.b.bits.resp := 0.U
+ io.axi.b.bits.user := 0.U
+ io.axi.b.bits.id := 0.U
+
+ if (debug) {
+ when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) }
+ when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) }
+ when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) }
+ when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) }
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.interface.axi
+
+import chisel3._
+import chisel3.util._
+import vta.util.genericbundle._
+
+case class AXIParams(
+ addrBits: Int = 32,
+ dataBits: Int = 64
+)
+{
+ require (addrBits > 0)
+ require (dataBits >= 8 && dataBits % 2 == 0)
+
+ val idBits = 1
+ val userBits = 1
+ val strbBits = dataBits/8
+ val lenBits = 8
+ val sizeBits = 3
+ val burstBits = 2
+ val lockBits = 2
+ val cacheBits = 4
+ val protBits = 3
+ val qosBits = 4
+ val regionBits = 4
+ val respBits = 2
+ val sizeConst = log2Ceil(dataBits/8)
+ val idConst = 0
+ val userConst = 0
+ val burstConst = 1
+ val lockConst = 0
+ val cacheConst = 3
+ val protConst = 0
+ val qosConst = 0
+ val regionConst = 0
+}
+
+abstract class AXIBase(params: AXIParams)
+ extends GenericParameterizedBundle(params)
+
+// AXILite
+
+class AXILiteAddress(params: AXIParams) extends AXIBase(params) {
+ val addr = UInt(params.addrBits.W)
+}
+
+class AXILiteWriteData(params: AXIParams) extends AXIBase(params) {
+ val data = UInt(params.dataBits.W)
+ val strb = UInt(params.strbBits.W)
+}
+
+class AXILiteWriteResponse(params: AXIParams) extends AXIBase(params) {
+ val resp = UInt(params.respBits.W)
+}
+
+class AXILiteReadData(params: AXIParams) extends AXIBase(params) {
+ val data = UInt(params.dataBits.W)
+ val resp = UInt(params.respBits.W)
+}
+
+class AXILiteMaster(params: AXIParams) extends AXIBase(params) {
+ val aw = Decoupled(new AXILiteAddress(params))
+ val w = Decoupled(new AXILiteWriteData(params))
+ val b = Flipped(Decoupled(new AXILiteWriteResponse(params)))
+ val ar = Decoupled(new AXILiteAddress(params))
+ val r = Flipped(Decoupled(new AXILiteReadData(params)))
+
+ def tieoff() {
+ aw.valid := false.B
+ aw.bits.addr := 0.U
+ w.valid := false.B
+ w.bits.data := 0.U
+ w.bits.strb := 0.U
+ b.ready := false.B
+ ar.valid := false.B
+ ar.bits.addr := 0.U
+ r.ready := false.B
+ }
+}
+
+class AXILiteClient(params: AXIParams) extends AXIBase(params) {
+ val aw = Flipped(Decoupled(new AXILiteAddress(params)))
+ val w = Flipped(Decoupled(new AXILiteWriteData(params)))
+ val b = Decoupled(new AXILiteWriteResponse(params))
+ val ar = Flipped(Decoupled(new AXILiteAddress(params)))
+ val r = Decoupled(new AXILiteReadData(params))
+
+ def tieoff() {
+ aw.ready := false.B
+ w.ready := false.B
+ b.valid := false.B
+ b.bits.resp := 0.U
+ ar.ready := false.B
+ r.valid := false.B
+ r.bits.resp := 0.U
+ r.bits.data := 0.U
+ }
+}
+
+// AXI extends AXILite
+
+class AXIAddress(params: AXIParams) extends AXILiteAddress(params) {
+ val id = UInt(params.idBits.W)
+ val user = UInt(params.userBits.W)
+ val len = UInt(params.lenBits.W)
+ val size = UInt(params.sizeBits.W)
+ val burst = UInt(params.burstBits.W)
+ val lock = UInt(params.lockBits.W)
+ val cache = UInt(params.cacheBits.W)
+ val prot = UInt(params.protBits.W)
+ val qos = UInt(params.qosBits.W)
+ val region = UInt(params.regionBits.W)
+}
+
+class AXIWriteData(params: AXIParams) extends AXILiteWriteData(params) {
+ val last = Bool()
+ val id = UInt(params.idBits.W)
+ val user = UInt(params.userBits.W)
+}
+
+class AXIWriteResponse(params: AXIParams) extends AXILiteWriteResponse(params) {
+ val id = UInt(params.idBits.W)
+ val user = UInt(params.userBits.W)
+}
+
+class AXIReadData(params: AXIParams) extends AXILiteReadData(params) {
+ val last = Bool()
+ val id = UInt(params.idBits.W)
+ val user = UInt(params.userBits.W)
+}
+
+class AXIMaster(params: AXIParams) extends AXIBase(params) {
+ val aw = Decoupled(new AXIAddress(params))
+ val w = Decoupled(new AXIWriteData(params))
+ val b = Flipped(Decoupled(new AXIWriteResponse(params)))
+ val ar = Decoupled(new AXIAddress(params))
+ val r = Flipped(Decoupled(new AXIReadData(params)))
+
+ def tieoff() {
+ aw.valid := false.B
+ aw.bits.addr := 0.U
+ aw.bits.id := 0.U
+ aw.bits.user := 0.U
+ aw.bits.len := 0.U
+ aw.bits.size := 0.U
+ aw.bits.burst := 0.U
+ aw.bits.lock := 0.U
+ aw.bits.cache := 0.U
+ aw.bits.prot := 0.U
+ aw.bits.qos := 0.U
+ aw.bits.region := 0.U
+ w.valid := false.B
+ w.bits.data := 0.U
+ w.bits.strb := 0.U
+ w.bits.last := false.B
+ w.bits.id := 0.U
+ w.bits.user := 0.U
+ b.ready := false.B
+ ar.valid := false.B
+ ar.bits.addr := 0.U
+ ar.bits.id := 0.U
+ ar.bits.user := 0.U
+ ar.bits.len := 0.U
+ ar.bits.size := 0.U
+ ar.bits.burst := 0.U
+ ar.bits.lock := 0.U
+ ar.bits.cache := 0.U
+ ar.bits.prot := 0.U
+ ar.bits.qos := 0.U
+ ar.bits.region := 0.U
+ r.ready := false.B
+ }
+
+ def setConst() {
+ aw.bits.user := params.userConst.U
+ aw.bits.burst := params.burstConst.U
+ aw.bits.lock := params.lockConst.U
+ aw.bits.cache := params.cacheConst.U
+ aw.bits.prot := params.protConst.U
+ aw.bits.qos := params.qosConst.U
+ aw.bits.region := params.regionConst.U
+ aw.bits.size := params.sizeConst.U
+ aw.bits.id := params.idConst.U
+ w.bits.id := params.idConst.U
+ w.bits.user := params.userConst.U
+ w.bits.strb := Fill(params.strbBits, true.B)
+ ar.bits.user := params.userConst.U
+ ar.bits.burst := params.burstConst.U
+ ar.bits.lock := params.lockConst.U
+ ar.bits.cache := params.cacheConst.U
+ ar.bits.prot := params.protConst.U
+ ar.bits.qos := params.qosConst.U
+ ar.bits.region := params.regionConst.U
+ ar.bits.size := params.sizeConst.U
+ ar.bits.id := params.idConst.U
+ }
+}
+
+class AXIClient(params: AXIParams) extends AXIBase(params) {
+ val aw = Flipped(Decoupled(new AXIAddress(params)))
+ val w = Flipped(Decoupled(new AXIWriteData(params)))
+ val b = Decoupled(new AXIWriteResponse(params))
+ val ar = Flipped(Decoupled(new AXIAddress(params)))
+ val r = Decoupled(new AXIReadData(params))
+
+ def tieoff() {
+ aw.ready := false.B
+ w.ready := false.B
+ b.valid := false.B
+ b.bits.resp := 0.U
+ b.bits.user := 0.U
+ b.bits.id := 0.U
+ ar.ready := false.B
+ r.valid := false.B
+ r.bits.resp := 0.U
+ r.bits.data := 0.U
+ r.bits.user := 0.U
+ r.bits.last := false.B
+ r.bits.id := 0.U
+ }
+}
+
+// XilinxAXILiteClient and XilinxAXIMaster bundles are needed
+// for wrapper purposes, because the package RTL tool in Xilinx Vivado
+// only allows certain name formats
+
+class XilinxAXILiteClient(params: AXIParams) extends AXIBase(params) {
+ val AWVALID = Input(Bool())
+ val AWREADY = Output(Bool())
+ val AWADDR = Input(UInt(params.addrBits.W))
+ val WVALID = Input(Bool())
+ val WREADY = Output(Bool())
+ val WDATA = Input(UInt(params.dataBits.W))
+ val WSTRB = Input(UInt(params.strbBits.W))
+ val BVALID = Output(Bool())
+ val BREADY = Input(Bool())
+ val BRESP = Output(UInt(params.respBits.W))
+ val ARVALID = Input(Bool())
+ val ARREADY = Output(Bool())
+ val ARADDR = Input(UInt(params.addrBits.W))
+ val RVALID = Output(Bool())
+ val RREADY = Input(Bool())
+ val RDATA = Output(UInt(params.dataBits.W))
+ val RRESP = Output(UInt(params.respBits.W))
+}
+
+class XilinxAXIMaster(params: AXIParams) extends AXIBase(params) {
+ val AWVALID = Output(Bool())
+ val AWREADY = Input(Bool())
+ val AWADDR = Output(UInt(params.addrBits.W))
+ val AWID = Output(UInt(params.idBits.W))
+ val AWUSER = Output(UInt(params.userBits.W))
+ val AWLEN = Output(UInt(params.lenBits.W))
+ val AWSIZE = Output(UInt(params.sizeBits.W))
+ val AWBURST = Output(UInt(params.burstBits.W))
+ val AWLOCK = Output(UInt(params.lockBits.W))
+ val AWCACHE = Output(UInt(params.cacheBits.W))
+ val AWPROT = Output(UInt(params.protBits.W))
+ val AWQOS = Output(UInt(params.qosBits.W))
+ val AWREGION = Output(UInt(params.regionBits.W))
+ val WVALID = Output(Bool())
+ val WREADY = Input(Bool())
+ val WDATA = Output(UInt(params.dataBits.W))
+ val WSTRB = Output(UInt(params.strbBits.W))
+ val WLAST = Output(Bool())
+ val WID = Output(UInt(params.idBits.W))
+ val WUSER = Output(UInt(params.userBits.W))
+ val BVALID = Input(Bool())
+ val BREADY = Output(Bool())
+ val BRESP = Input(UInt(params.respBits.W))
+ val BID = Input(UInt(params.idBits.W))
+ val BUSER = Input(UInt(params.userBits.W))
+ val ARVALID = Output(Bool())
+ val ARREADY = Input(Bool())
+ val ARADDR = Output(UInt(params.addrBits.W))
+ val ARID = Output(UInt(params.idBits.W))
+ val ARUSER = Output(UInt(params.userBits.W))
+ val ARLEN = Output(UInt(params.lenBits.W))
+ val ARSIZE = Output(UInt(params.sizeBits.W))
+ val ARBURST = Output(UInt(params.burstBits.W))
+ val ARLOCK = Output(UInt(params.lockBits.W))
+ val ARCACHE = Output(UInt(params.cacheBits.W))
+ val ARPROT = Output(UInt(params.protBits.W))
+ val ARQOS = Output(UInt(params.qosBits.W))
+ val ARREGION = Output(UInt(params.regionBits.W))
+ val RVALID = Input(Bool())
+ val RREADY = Output(Bool())
+ val RDATA = Input(UInt(params.dataBits.W))
+ val RRESP = Input(UInt(params.respBits.W))
+ val RLAST = Input(Bool())
+ val RID = Input(UInt(params.idBits.W))
+ val RUSER = Input(UInt(params.userBits.W))
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.interface.axi._
+
+/** PynqConfig. Shell configuration for Pynq */
+class PynqConfig extends Config((site, here, up) => {
+ case ShellKey => ShellParams(
+ hostParams = AXIParams(
+ addrBits = 16,
+ dataBits = 32),
+ memParams = AXIParams(
+ addrBits = 32,
+ dataBits = 64),
+ vcrParams = VCRParams(),
+ vmeParams = VMEParams())
+})
+
+/** F1Config. Shell configuration for F1 */
+class F1Config extends Config((site, here, up) => {
+ case ShellKey => ShellParams(
+ hostParams = AXIParams(
+ addrBits = 16,
+ dataBits = 32),
+ memParams = AXIParams(
+ addrBits = 64,
+ dataBits = 64),
+ vcrParams = VCRParams(),
+ vmeParams = VMEParams())
+})
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import vta.util.config._
+import vta.interface.axi._
+import vta.shell._
+import vta.dpi._
+
+/** VTAHost.
+ *
+ * This module translate the DPI protocol into AXI. This is a simulation only
+ * module and used to test host-to-VTA communication. This module should be updated
+ * for testing hosts using a different bus protocol, other than AXI.
+ */
+class VTAHost(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val axi = new AXILiteMaster(p(ShellKey).hostParams)
+ })
+ val host_dpi = Module(new VTAHostDPI)
+ val host_axi = Module(new VTAHostDPIToAXI)
+ host_dpi.io.reset := reset
+ host_dpi.io.clock := clock
+ host_axi.io.dpi <> host_dpi.io.dpi
+ io.axi <> host_axi.io.axi
+}
+
+/** VTAMem.
+ *
+ * This module translate the DPI protocol into AXI. This is a simulation only
+ * module and used to test VTA-to-memory communication. This module should be updated
+ * for testing memories using a different bus protocol, other than AXI.
+ */
+class VTAMem(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val axi = new AXIClient(p(ShellKey).memParams)
+ })
+ val mem_dpi = Module(new VTAMemDPI)
+ val mem_axi = Module(new VTAMemDPIToAXI)
+ mem_dpi.io.reset := reset
+ mem_dpi.io.clock := clock
+ mem_dpi.io.dpi <> mem_axi.io.dpi
+ mem_axi.io.axi <> io.axi
+}
+
+/** SimShell.
+ *
+ * The simulation shell instantiate a host and memory simulation modules and it is
+ * intended to be connected to the VTAShell.
+ */
+class SimShell(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val mem = new AXIClient(p(ShellKey).memParams)
+ val host = new AXILiteMaster(p(ShellKey).hostParams)
+ })
+ val host = Module(new VTAHost)
+ val mem = Module(new VTAMem)
+ io.mem <> mem.io.axi
+ io.host <> host.io.axi
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.util.genericbundle._
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.LinkedHashMap
+import vta.interface.axi._
+
+/** VCR parameters.
+ *
+ * These parameters are used on VCR interfaces and modules.
+ */
+case class VCRParams()
+{
+ val nValsReg: Int = 1
+ val nPtrsReg: Int = 6
+ val regBits: Int = 32
+ val nCtrlReg: Int = 4
+ val ctrlBaseAddr: Int = 0
+
+ require (nValsReg > 0)
+ require (nPtrsReg > 0)
+}
+
+/** VCRBase. Parametrize base class. */
+abstract class VCRBase(implicit p: Parameters)
+ extends GenericParameterizedBundle(p)
+
+/** VCRMaster.
+ *
+ * This is the master interface used by VCR in the VTAShell to control
+ * the Core unit.
+ */
+class VCRMaster(implicit p: Parameters) extends VCRBase {
+ val vp = p(ShellKey).vcrParams
+ val mp = p(ShellKey).memParams
+ val launch = Output(Bool())
+ val finish = Input(Bool())
+ val irq = Output(Bool())
+ val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
+ val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W)))
+}
+
+/** VCRClient.
+ *
+ * This is the client interface used by the Core module to communicate
+ * to the VCR in the VTAShell.
+ */
+class VCRClient(implicit p: Parameters) extends VCRBase {
+ val vp = p(ShellKey).vcrParams
+ val mp = p(ShellKey).memParams
+ val launch = Input(Bool())
+ val finish = Output(Bool())
+ val irq = Input(Bool())
+ val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
+ val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W)))
+}
+
+/** VTA Control Registers (VCR).
+ *
+ * This unit provides control registers (32 and 64 bits) to be used by a control'
+ * unit, typically a host processor. These registers are read-only by the core
+ * at the moment but this will likely change once we add support to general purpose
+ * registers that could be used as event counters by the Core unit.
+ */
+class VCR(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle{
+ val host = new AXILiteClient(p(ShellKey).hostParams)
+ val vcr = new VCRMaster
+ })
+
+ val vp = p(ShellKey).vcrParams
+ val mp = p(ShellKey).memParams
+ val hp = p(ShellKey).hostParams
+
+ // Write control (AW, W, B)
+ val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address
+ val wdata = io.host.w.bits.data
+ val wstrb = io.host.w.bits.strb
+ val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0)))
+ val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3)
+ val wstate = RegInit(sWriteAddress)
+ switch (wstate) {
+ is (sWriteAddress) {
+ when (io.host.aw.valid) {
+ wstate := sWriteData
+ }
+ }
+ is (sWriteData) {
+ when (io.host.w.valid) {
+ wstate := sWriteResponse
+ }
+ }
+ is (sWriteResponse) {
+ when (io.host.b.ready) {
+ wstate := sWriteAddress
+ }
+ }
+ }
+
+ when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
+
+ io.host.aw.ready := wstate === sWriteAddress
+ io.host.w.ready := wstate === sWriteData
+ io.host.b.valid := wstate === sWriteResponse
+ io.host.b.bits.resp := "h_0".U
+
+ // read control (AR, R)
+ val sReadAddress :: sReadData :: Nil = Enum(2)
+ val rstate = RegInit(sReadAddress)
+
+ switch (rstate) {
+ is (sReadAddress) {
+ when (io.host.ar.valid) {
+ rstate := sReadData
+ }
+ }
+ is (sReadData) {
+ when (io.host.r.ready) {
+ rstate := sReadAddress
+ }
+ }
+ }
+
+ io.host.ar.ready := rstate === sReadAddress
+ io.host.r.valid := rstate === sReadData
+
+ val nPtrsReg = vp.nPtrsReg
+ val nValsReg = vp.nValsReg
+ val regBits = vp.regBits
+ val ptrsBits = mp.addrBits
+ val nCtrlReg = vp.nCtrlReg
+ val rStride = regBits/8
+ val pStride = ptrsBits/8
+ val ctrlBaseAddr = vp.ctrlBaseAddr
+ val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride
+ val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride
+
+ val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr)
+ val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr)
+
+ val ptrsAddr = new ListBuffer[Int]()
+ for (i <- 0 until nPtrsReg) {
+ ptrsAddr += i*pStride + ptrsBaseAddr
+ if (ptrsBits == 64) {
+ ptrsAddr += i*pStride + rStride + ptrsBaseAddr
+ }
+ }
+
+ // AP register
+ val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B)))
+
+ // ap start
+ when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) {
+ c0(0) := true.B
+ } .elsewhen (io.vcr.finish) {
+ c0(0) := false.B
+ }
+
+ // ap done = finish
+ when (io.vcr.finish) {
+ c0(1) := true.B
+ } .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) {
+ c0(1) := false.B
+ }
+
+ val c1 = 0.U
+ val c2 = 0.U
+ val c3 = 0.U
+
+ val ctrlRegList = List(c0, c1, c2, c3)
+
+ io.vcr.launch := c0(0)
+
+ // interrupts not supported atm
+ io.vcr.irq := false.B
+
+ // Write pointer and value registers
+ val pvAddr = valsAddr ++ ptrsAddr
+ val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg
+ val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W))))
+ val pvRegList = new ListBuffer[UInt]()
+
+ for (i <- 0 until pvNumReg) {
+ when (io.host.w.fire() && (waddr === pvAddr(i).U)) {
+ pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask)
+ }
+ pvRegList += pvReg(i)
+ }
+
+ for (i <- 0 until nValsReg) {
+ io.vcr.vals(i) := pvReg(i)
+ }
+
+ for (i <- 0 until nPtrsReg) {
+ if (ptrsBits == 64) {
+ io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2))
+ } else {
+ io.vcr.ptrs(i) := pvReg(nValsReg + i)
+ }
+ }
+
+ // Read pointer and value registers
+ val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr
+ val mapRegList = ctrlRegList ++ pvRegList
+
+ val rdata = RegInit(0.U(regBits.W))
+ val rmap = LinkedHashMap[Int,UInt]()
+
+ val totalReg = mapRegList.length
+ for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt }
+
+ val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) }
+
+ when (io.host.ar.fire()) {
+ rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v)
+ }
+
+ io.host.r.bits.resp := 0.U
+ io.host.r.bits.data := rdata
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.util.genericbundle._
+import vta.interface.axi._
+
+/** VME parameters.
+ *
+ * These parameters are used on VME interfaces and modules.
+ */
+case class VMEParams() {
+ val nReadClients: Int = 5
+ val nWriteClients: Int = 1
+ require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
+ require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
+}
+
+/** VMEBase. Parametrize base class. */
+abstract class VMEBase(implicit p: Parameters)
+ extends GenericParameterizedBundle(p)
+
+/** VMECmd.
+ *
+ * This interface is used for creating write and read requests to memory.
+ */
+class VMECmd(implicit p: Parameters) extends VMEBase {
+ val addrBits = p(ShellKey).memParams.addrBits
+ val lenBits = p(ShellKey).memParams.lenBits
+ val addr = UInt(addrBits.W)
+ val len = UInt(lenBits.W)
+}
+
+/** VMEReadMaster.
+ *
+ * This interface is used by modules inside the core to generate read requests
+ * and receive responses from VME.
+ */
+class VMEReadMaster(implicit p: Parameters) extends Bundle {
+ val dataBits = p(ShellKey).memParams.dataBits
+ val cmd = Decoupled(new VMECmd)
+ val data = Flipped(Decoupled(UInt(dataBits.W)))
+ override def cloneType =
+ new VMEReadMaster().asInstanceOf[this.type]
+}
+
+/** VMEReadClient.
+ *
+ * This interface is used by the VME to receive read requests and generate
+ * responses to modules inside the core.
+ */
+class VMEReadClient(implicit p: Parameters) extends Bundle {
+ val dataBits = p(ShellKey).memParams.dataBits
+ val cmd = Flipped(Decoupled(new VMECmd))
+ val data = Decoupled(UInt(dataBits.W))
+ override def cloneType =
+ new VMEReadClient().asInstanceOf[this.type]
+}
+
+/** VMEWriteMaster.
+ *
+ * This interface is used by modules inside the core to generate write requests
+ * to the VME.
+ */
+class VMEWriteMaster(implicit p: Parameters) extends Bundle {
+ val dataBits = p(ShellKey).memParams.dataBits
+ val cmd = Decoupled(new VMECmd)
+ val data = Decoupled(UInt(dataBits.W))
+ val ack = Input(Bool())
+ override def cloneType =
+ new VMEWriteMaster().asInstanceOf[this.type]
+}
+
+/** VMEWriteClient.
+ *
+ * This interface is used by the VME to handle write requests from modules inside
+ * the core.
+ */
+class VMEWriteClient(implicit p: Parameters) extends Bundle {
+ val dataBits = p(ShellKey).memParams.dataBits
+ val cmd = Flipped(Decoupled(new VMECmd))
+ val data = Flipped(Decoupled(UInt(dataBits.W)))
+ val ack = Output(Bool())
+ override def cloneType =
+ new VMEWriteClient().asInstanceOf[this.type]
+}
+
+/** VMEMaster.
+ *
+ * Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster
+ * interfaces.
+ */
+class VMEMaster(implicit p: Parameters) extends Bundle {
+ val nRd = p(ShellKey).vmeParams.nReadClients
+ val nWr = p(ShellKey).vmeParams.nWriteClients
+ val rd = Vec(nRd, new VMEReadMaster)
+ val wr = Vec(nWr, new VMEWriteMaster)
+}
+
+/** VMEClient.
+ *
+ * Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient
+ * interfaces.
+ */
+class VMEClient(implicit p: Parameters) extends Bundle {
+ val nRd = p(ShellKey).vmeParams.nReadClients
+ val nWr = p(ShellKey).vmeParams.nWriteClients
+ val rd = Vec(nRd, new VMEReadClient)
+ val wr = Vec(nWr, new VMEWriteClient)
+}
+
+/** VTA Memory Engine (VME).
+ *
+ * This unit multiplexes the memory controller interface for the Core. Currently,
+ * it supports single-writer and multiple-reader mode and it is also based on AXI.
+ */
+class VME(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {
+ val mem = new AXIMaster(p(ShellKey).memParams)
+ val vme = new VMEClient
+ })
+
+ val nReadClients = p(ShellKey).vmeParams.nReadClients
+ val rd_arb = Module(new Arbiter(new VMECmd, nReadClients))
+ val rd_arb_chosen = RegEnable(rd_arb.io.chosen, rd_arb.io.out.fire())
+
+ for (i <- 0 until nReadClients) { rd_arb.io.in(i) <> io.vme.rd(i).cmd }
+
+ val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
+ val rstate = RegInit(sReadIdle)
+
+ switch (rstate) {
+ is (sReadIdle) {
+ when (rd_arb.io.out.valid) {
+ rstate := sReadAddr
+ }
+ }
+ is (sReadAddr) {
+ when (io.mem.ar.ready) {
+ rstate := sReadData
+ }
+ }
+ is (sReadData) {
+ when (io.mem.r.fire() && io.mem.r.bits.last) {
+ rstate := sReadIdle
+ }
+ }
+ }
+
+ val sWriteIdle :: sWriteAddr :: sWriteData :: sWriteResp :: Nil = Enum(4)
+ val wstate = RegInit(sWriteIdle)
+ val addrBits = p(ShellKey).memParams.addrBits
+ val lenBits = p(ShellKey).memParams.lenBits
+ val wr_cnt = RegInit(0.U(lenBits.W))
+
+ when (wstate === sWriteIdle) {
+ wr_cnt := 0.U
+ } .elsewhen (io.mem.w.fire()) {
+ wr_cnt := wr_cnt + 1.U
+ }
+
+ switch (wstate) {
+ is (sWriteIdle) {
+ when (io.vme.wr(0).cmd.valid) {
+ wstate := sWriteAddr
+ }
+ }
+ is (sWriteAddr) {
+ when (io.mem.aw.ready) {
+ wstate := sWriteData
+ }
+ }
+ is (sWriteData) {
+ when (io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
+ wstate := sWriteResp
+ }
+ }
+ is (sWriteResp) {
+ when (io.mem.b.valid) {
+ wstate := sWriteIdle
+ }
+ }
+ }
+
+ // registers storing read/write cmds
+
+ val rd_len = RegInit(0.U(lenBits.W))
+ val wr_len = RegInit(0.U(lenBits.W))
+ val rd_addr = RegInit(0.U(addrBits.W))
+ val wr_addr = RegInit(0.U(addrBits.W))
+
+ when (rd_arb.io.out.fire()) {
+ rd_len := rd_arb.io.out.bits.len
+ rd_addr := rd_arb.io.out.bits.addr
+ }
+
+ when (io.vme.wr(0).cmd.fire()) {
+ wr_len := io.vme.wr(0).cmd.bits.len
+ wr_addr := io.vme.wr(0).cmd.bits.addr
+ }
+
+ // rd arb
+ rd_arb.io.out.ready := rstate === sReadIdle
+
+ // vme
+ for (i <- 0 until nReadClients) {
+ io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid
+ io.vme.rd(i).data.bits := io.mem.r.bits.data
+ }
+
+ io.vme.wr(0).cmd.ready := wstate === sWriteIdle
+ io.vme.wr(0).ack := io.mem.b.fire()
+ io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
+
+ // mem
+ io.mem.aw.valid := wstate === sWriteAddr
+ io.mem.aw.bits.addr := wr_addr
+ io.mem.aw.bits.len := wr_len
+
+ io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid
+ io.mem.w.bits.data := io.vme.wr(0).data.bits
+ io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len
+
+ io.mem.b.ready := wstate === sWriteResp
+
+ io.mem.ar.valid := rstate === sReadAddr
+ io.mem.ar.bits.addr := rd_addr
+ io.mem.ar.bits.len := rd_len
+
+ io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready
+
+ // AXI constants - statically defined
+ io.mem.setConst()
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import vta.util.config._
+import vta.interface.axi._
+import vta.core._
+
+/** Shell parameters. */
+case class ShellParams(
+ hostParams: AXIParams,
+ memParams: AXIParams,
+ vcrParams: VCRParams,
+ vmeParams: VMEParams
+)
+
+case object ShellKey extends Field[ShellParams]
+
+/** VTAShell.
+ *
+ * The VTAShell is based on a VME, VCR and core. This creates a complete VTA
+ * system that can be used for simulation or real hardware.
+ */
+class VTAShell(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle{
+ val host = new AXILiteClient(p(ShellKey).hostParams)
+ val mem = new AXIMaster(p(ShellKey).memParams)
+ })
+
+ val vcr = Module(new VCR)
+ val vme = Module(new VME)
+ val core = Module(new Core)
+
+ core.io.vcr <> vcr.io.vcr
+ vme.io.vme <> core.io.vme
+
+ vcr.io.host <> io.host
+ io.mem <> vme.io.mem
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.experimental.{RawModule, withClockAndReset}
+import vta.util.config._
+import vta.interface.axi._
+
+/** XilinxShell.
+ *
+ * This is a wrapper shell mostly used to match Xilinx convention naming,
+ * therefore we can pack VTA as an IP for IPI based flows.
+ */
+class XilinxShell(implicit p: Parameters) extends RawModule {
+
+ val hp = p(ShellKey).hostParams
+ val mp = p(ShellKey).memParams
+
+ val ap_clk = IO(Input(Clock()))
+ val ap_rst_n = IO(Input(Bool()))
+ val m_axi_gmem = IO(new XilinxAXIMaster(mp))
+ val s_axi_control = IO(new XilinxAXILiteClient(hp))
+
+ val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) }
+
+ // memory
+ m_axi_gmem.AWVALID := shell.io.mem.aw.valid
+ shell.io.mem.aw.ready := m_axi_gmem.AWREADY
+ m_axi_gmem.AWADDR := shell.io.mem.aw.bits.addr
+ m_axi_gmem.AWID := shell.io.mem.aw.bits.id
+ m_axi_gmem.AWUSER := shell.io.mem.aw.bits.user
+ m_axi_gmem.AWLEN := shell.io.mem.aw.bits.len
+ m_axi_gmem.AWSIZE := shell.io.mem.aw.bits.size
+ m_axi_gmem.AWBURST := shell.io.mem.aw.bits.burst
+ m_axi_gmem.AWLOCK := shell.io.mem.aw.bits.lock
+ m_axi_gmem.AWCACHE := shell.io.mem.aw.bits.cache
+ m_axi_gmem.AWPROT := shell.io.mem.aw.bits.prot
+ m_axi_gmem.AWQOS := shell.io.mem.aw.bits.qos
+ m_axi_gmem.AWREGION := shell.io.mem.aw.bits.region
+
+ m_axi_gmem.WVALID := shell.io.mem.w.valid
+ shell.io.mem.w.ready := m_axi_gmem.WREADY
+ m_axi_gmem.WDATA := shell.io.mem.w.bits.data
+ m_axi_gmem.WSTRB := shell.io.mem.w.bits.strb
+ m_axi_gmem.WLAST := shell.io.mem.w.bits.last
+ m_axi_gmem.WID := shell.io.mem.w.bits.id
+ m_axi_gmem.WUSER := shell.io.mem.w.bits.user
+
+ shell.io.mem.b.valid := m_axi_gmem.BVALID
+ m_axi_gmem.BREADY := shell.io.mem.b.valid
+ shell.io.mem.b.bits.resp := m_axi_gmem.BRESP
+ shell.io.mem.b.bits.id := m_axi_gmem.BID
+ shell.io.mem.b.bits.user := m_axi_gmem.BUSER
+
+ m_axi_gmem.ARVALID := shell.io.mem.ar.valid
+ shell.io.mem.ar.ready := m_axi_gmem.ARREADY
+ m_axi_gmem.ARADDR := shell.io.mem.ar.bits.addr
+ m_axi_gmem.ARID := shell.io.mem.ar.bits.id
+ m_axi_gmem.ARUSER := shell.io.mem.ar.bits.user
+ m_axi_gmem.ARLEN := shell.io.mem.ar.bits.len
+ m_axi_gmem.ARSIZE := shell.io.mem.ar.bits.size
+ m_axi_gmem.ARBURST := shell.io.mem.ar.bits.burst
+ m_axi_gmem.ARLOCK := shell.io.mem.ar.bits.lock
+ m_axi_gmem.ARCACHE := shell.io.mem.ar.bits.cache
+ m_axi_gmem.ARPROT := shell.io.mem.ar.bits.prot
+ m_axi_gmem.ARQOS := shell.io.mem.ar.bits.qos
+ m_axi_gmem.ARREGION := shell.io.mem.ar.bits.region
+
+ shell.io.mem.r.valid := m_axi_gmem.RVALID
+ m_axi_gmem.RREADY := shell.io.mem.r.ready
+ shell.io.mem.r.bits.data := m_axi_gmem.RDATA
+ shell.io.mem.r.bits.resp := m_axi_gmem.RRESP
+ shell.io.mem.r.bits.last := m_axi_gmem.RLAST
+ shell.io.mem.r.bits.id := m_axi_gmem.RID
+ shell.io.mem.r.bits.user := m_axi_gmem.RUSER
+
+ // host
+ shell.io.host.aw.valid := s_axi_control.AWVALID
+ s_axi_control.AWREADY := shell.io.host.aw.ready
+ shell.io.host.aw.bits.addr := s_axi_control.AWADDR
+
+ shell.io.host.w.valid := s_axi_control.WVALID
+ s_axi_control.WREADY := shell.io.host.w.ready
+ shell.io.host.w.bits.data := s_axi_control.WDATA
+ shell.io.host.w.bits.strb := s_axi_control.WSTRB
+
+ s_axi_control.BVALID := shell.io.host.b.valid
+ shell.io.host.b.ready := s_axi_control.BREADY
+ s_axi_control.BRESP := shell.io.host.b.bits.resp
+
+ shell.io.host.ar.valid := s_axi_control.ARVALID
+ s_axi_control.ARREADY := shell.io.host.ar.ready
+ shell.io.host.ar.bits.addr := s_axi_control.ARADDR
+
+ s_axi_control.RVALID := shell.io.host.r.valid
+ shell.io.host.r.ready := s_axi_control.RREADY
+ s_axi_control.RDATA := shell.io.host.r.bits.data
+ s_axi_control.RRESP := shell.io.host.r.bits.resp
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.test
+
+import chisel3._
+import vta.util.config._
+import vta.shell._
+
+/** Test. This generates a testbench file for simulation */
+class Test(implicit p: Parameters) extends Module {
+ val io = IO(new Bundle {})
+ val sim_shell = Module(new SimShell)
+ val vta_shell = Module(new VTAShell)
+ vta_shell.io.host <> sim_shell.io.host
+ sim_shell.io.mem <> vta_shell.io.mem
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.util.config
+
+// taken from https://github.com/vta.roject/rocket-chip
+
+abstract class Field[T] private (val default: Option[T])
+{
+ def this() = this(None)
+ def this(default: T) = this(Some(default))
+}
+
+abstract class View {
+ final def apply[T](pname: Field[T]): T = apply(pname, this)
+ final def apply[T](pname: Field[T], site: View): T = {
+ val out = find(pname, site)
+ require (out.isDefined, s"Key ${pname} is not defined in Parameters")
+ out.get
+ }
+
+ final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
+ final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T])
+
+ protected[config] def find[T](pname: Field[T], site: View): Option[T]
+}
+
+abstract class Parameters extends View {
+ final def ++ (x: Parameters): Parameters =
+ new ChainParameters(this, x)
+
+ final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters =
+ Parameters(f) ++ this
+
+ final def alterPartial(f: PartialFunction[Any,Any]): Parameters =
+ Parameters((_,_,_) => f) ++ this
+
+ final def alterMap(m: Map[Any,Any]): Parameters =
+ new MapParameters(m) ++ this
+
+ protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T]
+ protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname)
+}
+
+object Parameters {
+ def empty: Parameters = new EmptyParameters
+ def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f)
+}
+
+class Config(p: Parameters) extends Parameters {
+ def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f))
+
+ protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname)
+ override def toString = this.getClass.getSimpleName
+ def toInstance = this
+}
+
+// Internal implementation:
+
+private class TerminalView extends View {
+ def find[T](pname: Field[T], site: View): Option[T] = pname.default
+}
+
+private class ChainView(head: Parameters, tail: View) extends View {
+ def find[T](pname: Field[T], site: View) = head.chain(site, tail, pname)
+}
+
+private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
+ def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname)
+}
+
+private class EmptyParameters extends Parameters {
+ def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
+}
+
+private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters {
+ protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
+ val g = f(site, this, tail)
+ if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site)
+ }
+}
+
+private class MapParameters(map: Map[Any, Any]) extends Parameters {
+ protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
+ val g = map.get(pname)
+ if (g.isDefined) Some(g.get.asInstanceOf[T]) else tail.find(pname, site)
+ }
+}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.util.genericbundle
+
+// taken from https://github.com/vta.roject/rocket-chip
+
+import chisel3._
+
+abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle
+{
+ override def cloneType = {
+ try {
+ this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type]
+ } catch {
+ case e: java.lang.IllegalArgumentException =>
+ throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " +
+ this.getClass + ", probably because " + this.getClass +
+ "() takes more than one argument. Consider overriding " +
+ "cloneType() on " + this.getClass, e)
+ }
+ }
+}
+
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta
+
+import chisel3._
+import vta.util.config._
+import vta.shell._
+import vta.core._
+import vta.test._
+
+/** VTA.
+ *
+ * This file contains all the configurations supported by VTA.
+ * These configurations are built in a mix/match form based on core
+ * and shell configurations.
+ */
+
+class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
+class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
+
+object DefaultPynqConfig extends App {
+ implicit val p: Parameters = new DefaultPynqConfig
+ chisel3.Driver.execute(args, () => new XilinxShell)
+}
+
+object DefaultF1Config extends App {
+ implicit val p: Parameters = new DefaultF1Config
+ chisel3.Driver.execute(args, () => new XilinxShell)
+}
+
+object TestDefaultF1Config extends App {
+ implicit val p: Parameters = new DefaultF1Config
+ chisel3.Driver.execute(args, () => new Test)
+}
_mem_dpi = mem_dpi;
}
+
+// Override Verilator finish definition
+// VL_USER_FINISH needs to be defined when compiling Verilator code
+void vl_finish(const char* filename, int linenum, const char* hier) {
+ Verilated::gotFinish(true);
+ VL_PRINTF("[TSIM] exiting simulation\n");
+}
+
int VTADPISim(uint64_t max_cycles) {
uint64_t trace_count = 0;
+ Verilated::flushCall();
+ Verilated::gotFinish(false);
#if VM_TRACE
uint64_t start = 0;
typedef void * VTADeviceHandle;
/*! \brief physical address */
+#ifdef USE_TSIM
+typedef uint64_t vta_phy_addr_t;
+#else
typedef uint32_t vta_phy_addr_t;
+#endif
/*!
* \brief Allocate a device resource handle
*
* \return 0 if running is successful, 1 if timeout.
*/
+#ifdef USE_TSIM
+int VTADeviceRun(VTADeviceHandle device,
+ vta_phy_addr_t insn_phy_addr,
+ vta_phy_addr_t uop_phy_addr,
+ vta_phy_addr_t inp_phy_addr,
+ vta_phy_addr_t wgt_phy_addr,
+ vta_phy_addr_t acc_phy_addr,
+ vta_phy_addr_t out_phy_addr,
+ uint32_t insn_count,
+ uint32_t wait_cycles);
+#else
int VTADeviceRun(VTADeviceHandle device,
vta_phy_addr_t insn_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles);
+#endif
/*!
* \brief Allocates physically contiguous region in memory (limited by MAX_XFER).
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
- if self.TARGET == "sim":
+ if self.TARGET == "sim" or self.TARGET == "tsim":
return "llvm"
raise ValueError("Unknown target %s" % self.TARGET)
"""Utilities to start simulator."""
import ctypes
import json
+import sys
+import os
import tvm
from ..libinfo import find_libvta
x = tvm.get_global_func("vta.simulator.profiler_status")()
return json.loads(x)
+def tsim_init(hw_lib):
+ """Init hardware shared library for TSIM
+
+ Parameters
+ ------------
+ hw_lib : str
+ Name of hardware shared library
+ """
+ cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+ vta_build_path = os.path.join(cur_path, "..", "..", "..", "build")
+ if not hw_lib.endswith(("dylib", "so")):
+ hw_lib += ".dylib" if sys.platform == "darwin" else ".so"
+ lib = os.path.join(vta_build_path, hw_lib)
+ f = tvm.get_global_func("tvm.vta.tsim.init")
+ m = tvm.module.load(lib, "vta-tsim")
+ f(m)
+
LIBS = _load_lib()
"""
env = get_env()
- if env.TARGET == "sim":
+ if env.TARGET in ["sim", "tsim"]:
# Talk to local RPC if necessary to debug RPC server.
# Compile vta on your host with make at the root.
# Make sure simulation library exists
# If this fails, build vta on host (make)
# with TARGET="sim" in the json.config file.
- assert simulator.enabled()
+ if env.TARGET == "sim":
+ assert simulator.enabled()
run_func(env, rpc.LocalSession())
elif env.TARGET == "pynq":
return data_;
}
/*! \return Physical address of the data. */
- uint32_t phy_addr() const {
+ vta_phy_addr_t phy_addr() const {
return phy_addr_;
}
/*!
/*! \brief The internal data. */
void* data_;
/*! \brief The physical address of the buffer, excluding header. */
- uint32_t phy_addr_;
+ vta_phy_addr_t phy_addr_;
};
/*!
return dram_buffer_;
}
/*! \return Physical address of DRAM. */
- uint32_t dram_phy_addr() const {
+ vta_phy_addr_t dram_phy_addr() const {
return dram_phy_addr_;
}
/*! \return Whether there is pending information. */
// The buffer in DRAM
char* dram_buffer_{nullptr};
// Physics address of the buffer
- uint32_t dram_phy_addr_;
+ vta_phy_addr_t dram_phy_addr_;
};
/*!
CHECK((dram_end_ - dram_begin_) == (sram_end_ - sram_begin_));
insn->memory_type = VTA_MEM_ID_UOP;
insn->sram_base = sram_begin_;
+#ifdef USE_TSIM
+ insn->dram_base = (uint32_t) dram_phy_addr_ + dram_begin_*kElemBytes;
+#else
insn->dram_base = dram_phy_addr_ / kElemBytes + dram_begin_;
+#endif
insn->y_size = 1;
insn->x_size = (dram_end_ - dram_begin_);
insn->x_stride = (dram_end_ - dram_begin_);
insn->memory_type = dst_memory_type;
insn->sram_base = dst_sram_index;
DataBuffer* src = DataBuffer::FromHandle(src_dram_addr);
+#ifdef USE_TSIM
+ insn->dram_base = (uint32_t) src->phy_addr() + src_elem_offset*GetElemBytes(dst_memory_type);
+#else
insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset;
+#endif
insn->y_size = y_size;
insn->x_size = x_size;
insn->x_stride = x_stride;
insn->memory_type = src_memory_type;
insn->sram_base = src_sram_index;
DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr);
+#ifdef USE_TSIM
+ insn->dram_base = (uint32_t) dst->phy_addr() + dst_elem_offset*GetElemBytes(src_memory_type);
+#else
insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset;
+#endif
insn->y_size = y_size;
insn->x_size = x_size;
insn->x_stride = x_stride;
// Make sure that we don't exceed contiguous physical memory limits
CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER);
+#ifdef USE_TSIM
int timeout = VTADeviceRun(
device_,
insn_queue_.dram_phy_addr(),
+ uop_queue_.dram_phy_addr(),
+ inp_phy_addr_,
+ wgt_phy_addr_,
+ acc_phy_addr_,
+ out_phy_addr_,
insn_queue_.count(),
wait_cycles);
+#else
+ int timeout = VTADeviceRun(
+ device_,
+ insn_queue_.dram_phy_addr(),
+ insn_queue_.count(),
+ wait_cycles);
+#endif
CHECK_EQ(timeout, 0);
// Reset buffers
uop_queue_.Reset();
ThreadLocal().reset();
}
+#ifdef USE_TSIM
+ void SetBufPhyAddr(uint32_t type, vta_phy_addr_t addr) {
+ switch (type) {
+ case VTA_MEM_ID_INP: inp_phy_addr_ = addr;
+ case VTA_MEM_ID_WGT: wgt_phy_addr_ = addr;
+ case VTA_MEM_ID_ACC: acc_phy_addr_ = addr;
+ case VTA_MEM_ID_OUT: out_phy_addr_ = addr;
+ default: break;
+ }
+ }
+#endif
+
private:
// Push GEMM uop to the command buffer
void PushGEMMOp(UopKernel* kernel) {
InsnQueue<VTA_MAX_XFER, true, true> insn_queue_;
// Device handle
VTADeviceHandle device_{nullptr};
+#ifdef USE_TSIM
+ // Input phy addr
+ vta_phy_addr_t inp_phy_addr_{0};
+ // Weight phy addr
+ vta_phy_addr_t wgt_phy_addr_{0};
+ // Accumulator phy addr
+ vta_phy_addr_t acc_phy_addr_{0};
+ // Output phy addr
+ vta_phy_addr_t out_phy_addr_{0};
+#endif
};
} // namespace vta
uint32_t y_pad_after,
uint32_t dst_sram_index,
uint32_t dst_memory_type) {
+#ifdef USE_TSIM
+ vta::DataBuffer* src = vta::DataBuffer::FromHandle(src_dram_addr);
+ static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(dst_memory_type, src->phy_addr());
+#endif
static_cast<vta::CommandQueue*>(cmd)->
LoadBuffer2D(src_dram_addr, src_elem_offset,
x_size, y_size, x_stride,
uint32_t x_size,
uint32_t y_size,
uint32_t x_stride) {
+#ifdef USE_TSIM
+ vta::DataBuffer* dst = vta::DataBuffer::FromHandle(dst_dram_addr);
+ static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(src_memory_type, dst->phy_addr());
+#endif
static_cast<vta::CommandQueue*>(cmd)->
StoreBuffer2D(src_sram_index, src_memory_type,
dst_dram_addr, dst_elem_offset,
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <vta/driver.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <vta/dpi/module.h>
+
+namespace vta {
+namespace tsim {
+
+using vta::dpi::DPIModuleNode;
+using tvm::runtime::Module;
+
+class DPILoader {
+ public:
+ void Init(Module module) {
+ mod_ = module;
+ }
+
+ DPIModuleNode* Get() {
+ return static_cast<DPIModuleNode*>(mod_.operator->());
+ }
+
+ static DPILoader* Global() {
+ static DPILoader inst;
+ return &inst;
+ }
+
+ Module mod_;
+};
+
+class Device {
+ public:
+ Device() {
+ dpi_ = DPILoader::Global();
+ }
+
+ int Run(vta_phy_addr_t insn_phy_addr,
+ vta_phy_addr_t uop_phy_addr,
+ vta_phy_addr_t inp_phy_addr,
+ vta_phy_addr_t wgt_phy_addr,
+ vta_phy_addr_t acc_phy_addr,
+ vta_phy_addr_t out_phy_addr,
+ uint32_t insn_count,
+ uint32_t wait_cycles) {
+ this->Init();
+ this->Launch(insn_phy_addr,
+ uop_phy_addr,
+ inp_phy_addr,
+ wgt_phy_addr,
+ acc_phy_addr,
+ out_phy_addr,
+ insn_count,
+ wait_cycles);
+ this->WaitForCompletion(wait_cycles);
+ dev_->Finish();
+ return 0;
+ }
+
+ private:
+ void Init() {
+ dev_ = dpi_->Get();
+ }
+
+ void Launch(vta_phy_addr_t insn_phy_addr,
+ vta_phy_addr_t uop_phy_addr,
+ vta_phy_addr_t inp_phy_addr,
+ vta_phy_addr_t wgt_phy_addr,
+ vta_phy_addr_t acc_phy_addr,
+ vta_phy_addr_t out_phy_addr,
+ uint32_t insn_count,
+ uint32_t wait_cycles) {
+ // launch simulation thread
+ dev_->Launch(wait_cycles);
+ dev_->WriteReg(0x10, insn_count);
+ dev_->WriteReg(0x14, insn_phy_addr);
+ dev_->WriteReg(0x18, insn_phy_addr >> 32);
+ dev_->WriteReg(0x1c, 0);
+ dev_->WriteReg(0x20, uop_phy_addr >> 32);
+ dev_->WriteReg(0x24, 0);
+ dev_->WriteReg(0x28, inp_phy_addr >> 32);
+ dev_->WriteReg(0x2c, 0);
+ dev_->WriteReg(0x30, wgt_phy_addr >> 32);
+ dev_->WriteReg(0x34, 0);
+ dev_->WriteReg(0x38, acc_phy_addr >> 32);
+ dev_->WriteReg(0x3c, 0);
+ dev_->WriteReg(0x40, out_phy_addr >> 32);
+ // start
+ dev_->WriteReg(0x00, 0x1);
+ }
+
+ void WaitForCompletion(uint32_t wait_cycles) {
+ uint32_t i, val;
+ for (i = 0; i < wait_cycles; i++) {
+ val = dev_->ReadReg(0x00);
+ val &= 0x2;
+ if (val == 0x2) break; // finish
+ }
+ }
+
+ DPILoader* dpi_;
+ DPIModuleNode* dev_;
+};
+
+using tvm::runtime::TVMRetValue;
+using tvm::runtime::TVMArgs;
+
+TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ Module m = args[0];
+ DPILoader::Global()->Init(m);
+ });
+
+} // namespace tsim
+} // namespace vta
+
+void* VTAMemAlloc(size_t size, int cached) {
+ void *p = malloc(size);
+ return p;
+}
+
+void VTAMemFree(void* buf) {
+ free(buf);
+}
+
+vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
+ return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
+}
+
+void VTAFlushCache(vta_phy_addr_t buf, int size) {
+}
+
+void VTAInvalidateCache(vta_phy_addr_t buf, int size) {
+}
+
+VTADeviceHandle VTADeviceAlloc() {
+ return new vta::tsim::Device();
+}
+
+void VTADeviceFree(VTADeviceHandle handle) {
+ delete static_cast<vta::tsim::Device*>(handle);
+}
+
+int VTADeviceRun(VTADeviceHandle handle,
+ vta_phy_addr_t insn_phy_addr,
+ vta_phy_addr_t uop_phy_addr,
+ vta_phy_addr_t inp_phy_addr,
+ vta_phy_addr_t wgt_phy_addr,
+ vta_phy_addr_t acc_phy_addr,
+ vta_phy_addr_t out_phy_addr,
+ uint32_t insn_count,
+ uint32_t wait_cycles) {
+ return static_cast<vta::tsim::Device*>(handle)->Run(
+ insn_phy_addr,
+ uop_phy_addr,
+ inp_phy_addr,
+ wgt_phy_addr,
+ acc_phy_addr,
+ out_phy_addr,
+ insn_count,
+ wait_cycles);
+}
y_np = x_np.astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+
+ if env.TARGET == "tsim":
+ simulator.tsim_init("libvta_hw")
+
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
:] = x_np
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+
+ if env.TARGET == "tsim":
+ simulator.tsim_init("libvta_hw")
+
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
+ if env.TARGET == "tsim":
+ simulator.tsim_init("libvta_hw")
+
if env.TARGET == "sim":
simulator.clear_stats()
f(x_nd, w_nd, y_nd)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+
+ if env.TARGET == "tsim":
+ simulator.tsim_init("libvta_hw")
+
if use_imm:
f(a_nd, res_nd)
else:
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+
+ if env.TARGET == "tsim":
+ simulator.tsim_init("libvta_hw")
+
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+
+ if env.TARGET == "tsim":
+ simulator.tsim_init("libvta_hw")
+
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("Load/store test")
test_save_load_out()
print("Padded load test")
- #test_padded_load()
+ test_padded_load()
print("GEMM test")
test_gemm()
- test_alu()
print("ALU test")
+ test_alu()
+ print("Relu test")
test_relu()
print("Shift and scale")
test_shift_and_scale()