[VTA] [Hardware] Chisel implementation (#3258)
authorLuis Vega <vegaluisjose@users.noreply.github.com>
Wed, 5 Jun 2019 17:17:11 +0000 (10:17 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 5 Jun 2019 17:17:11 +0000 (10:17 -0700)
43 files changed:
cmake/config.cmake
cmake/modules/VTA.cmake
vta/apps/tsim_example/README.md
vta/apps/tsim_example/cmake/modules/hw.cmake
vta/hardware/chisel/Makefile
vta/hardware/chisel/src/main/resources/verilog/VTAHostDPI.v
vta/hardware/chisel/src/main/scala/core/Compute.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/Configs.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/Core.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/Decode.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/Fetch.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/ISA.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/Load.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/LoadUop.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/Semaphore.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/Store.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/TensorAlu.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/TensorGemm.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/TensorLoad.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/TensorStore.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/TensorUtil.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/core/package.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala
vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala
vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/shell/Configs.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/shell/SimShell.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/shell/VCR.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/shell/VME.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/shell/VTAShell.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/test/Test.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/util/Config.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala [new file with mode: 0644]
vta/hardware/chisel/src/main/scala/vta/Configs.scala [new file with mode: 0644]
vta/hardware/dpi/tsim_device.cc
vta/include/vta/driver.h
vta/python/vta/environment.py
vta/python/vta/testing/simulator.py
vta/python/vta/testing/util.py
vta/src/runtime.cc
vta/src/tsim/tsim_driver.cc [new file with mode: 0644]
vta/tests/python/unittest/test_vta_insn.py

index 7c5add5..6239bc4 100644 (file)
@@ -132,9 +132,6 @@ set(USE_SORT ON)
 # Build ANTLR parser for Relay text format
 set(USE_ANTLR OFF)
 
-# Build TSIM for VTA
-set(USE_VTA_TSIM OFF)
-
 # Whether use Relay debug mode
 set(USE_RELAY_DEBUG OFF)
 
index 1df6c66..6d5ea00 100644 (file)
@@ -29,8 +29,7 @@ elseif(PYTHON)
       --use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json)
   endif()
 
-  execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target)
-  string(STRIP ${__vta_target} VTA_TARGET)
+  execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE)
 
   message(STATUS "Build VTA runtime with target: " ${VTA_TARGET})
 
@@ -44,6 +43,13 @@ elseif(PYTHON)
 
   add_library(vta SHARED ${VTA_RUNTIME_SRCS})
 
+  if(${VTA_TARGET} STREQUAL "tsim")
+    target_compile_definitions(vta PUBLIC USE_TSIM)
+    include_directories("vta/include")
+    file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
+    list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
+  endif()
+
   target_include_directories(vta PUBLIC vta/include)
 
   foreach(__def ${VTA_DEFINITIONS})
@@ -61,12 +67,6 @@ elseif(PYTHON)
     target_link_libraries(vta ${__cma_lib})
   endif()
 
-  if(NOT USE_VTA_TSIM STREQUAL "OFF")
-    include_directories("vta/include")
-    file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
-    list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
-  endif()
-
 else()
   message(STATUS "Cannot found python in env, VTA build is skipped..")
 endif()
index b557b24..dc06a92 100644 (file)
@@ -49,7 +49,7 @@ sudo apt install verilator sbt
 ## Setup in TVM
 
 1. Install `verilator` and `sbt` as described above
-2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake
+2. Set the VTA TARGET to `tsim` on `<tvm-root>/vta/config/vta_config.json`
 3. Build tvm
 
 ## How to run VTA TSIM examples
index 87dd72b..e016ea0 100644 (file)
@@ -124,7 +124,7 @@ else()
       file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
       add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})
 
-      set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
+      set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
       if (NOT TSIM_USE_TRACE STREQUAL "OFF")
         list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
       else()
index 65a9ed1..7371dd1 100644 (file)
 # specific language governing permissions and limitations
 # under the License.
 
+CONFIG = DefaultF1Config
+TOP = VTA
+TOP_TEST = Test
+BUILD_NAME = build
+USE_TRACE = 0
+VTA_LIBNAME = libvta_hw
+
+config_test = $(TOP_TEST)$(CONFIG)
+vta_dir = $(abspath ../../)
+tvm_dir = $(abspath ../../../)
+verilator_inc_dir = /usr/local/share/verilator/include
+verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator
+chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel
+
+verilator_opt = --cc
+verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
+verilator_opt += +define+RANDOMIZE_REG_INIT
+verilator_opt += +define+RANDOMIZE_MEM_INIT
+verilator_opt += --x-assign unique
+verilator_opt += --output-split 20000
+verilator_opt += --output-split-cfuncs 20000
+verilator_opt += --top-module ${TOP_TEST}
+verilator_opt += -Mdir ${verilator_build_dir}
+verilator_opt += -I$(chisel_build_dir)
+
+cxx_flags = -O2 -Wall -fPIC -shared
+cxx_flags += -fvisibility=hidden -std=c++11
+cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST)
+cxx_flags += -DVL_PRINTF=printf
+cxx_flags += -DVL_USER_FINISH
+cxx_flags += -DVM_COVERAGE=0
+cxx_flags += -DVM_SC=0
+cxx_flags += -Wno-sign-compare
+cxx_flags += -include V$(TOP_TEST).h
+cxx_flags += -I$(verilator_build_dir)
+cxx_flags += -I$(verilator_inc_dir)
+cxx_flags += -I$(verilator_inc_dir)/vltstd
+cxx_flags += -I$(vta_dir)/include
+cxx_flags += -I$(tvm_dir)/include
+cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
+
+cxx_files = $(verilator_inc_dir)/verilated.cpp
+cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp
+cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
+cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
+
+ifneq ($(USE_TRACE), 0)
+  verilator_opt += --trace
+  cxx_flags += -DVM_TRACE=1
+  cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd
+  cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp
+else
+  cxx_flags += -DVM_TRACE=0
+endif
+
+default: lib
+
+lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
+$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp
+       g++ $(cxx_flags) $(cxx_files) -o $@
+
+verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp
+$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
+       verilator $(verilator_opt) $<
+
+verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v
+$(chisel_build_dir)/$(TOP).$(CONFIG).v:
+       sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)'
+
+verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
+$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v:
+       sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)'
+
 clean:
        -rm -rf target project/target project/project
+
+cleanall:
+       -rm -rf $(vta_dir)/$(BUILD_NAME)
index 02fcf0d..8ab85f6 100644 (file)
@@ -112,7 +112,7 @@ module VTAHostDPI #
 
   always_ff @(posedge clock) begin
     if (__exit == 'd1) begin
-      $display("[DONE] at cycle:%016d", cycles);
+      $display("[TSIM] Verilog $finish called at cycle:%016d", cycles);
       $finish;
     end
   end
diff --git a/vta/hardware/chisel/src/main/scala/core/Compute.scala b/vta/hardware/chisel/src/main/scala/core/Compute.scala
new file mode 100644 (file)
index 0000000..ef56c3d
--- /dev/null
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Compute.
+  *
+  * The compute unit is in charge of the following:
+  * - Loading micro-ops from memory (loadUop module)
+  * - Loading biases (acc) from memory (tensorAcc module)
+  * - Compute ALU instructions (tensorAlu module)
+  * - Compute GEMM instructions (tensorGemm module)
+  */
+class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val i_post = Vec(2, Input(Bool()))
+    val o_post = Vec(2, Output(Bool()))
+    val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
+    val uop_baddr = Input(UInt(mp.addrBits.W))
+    val acc_baddr = Input(UInt(mp.addrBits.W))
+    val vme_rd = Vec(2, new VMEReadMaster)
+    val inp = new TensorMaster(tensorType = "inp")
+    val wgt = new TensorMaster(tensorType = "wgt")
+    val out = new TensorMaster(tensorType = "out")
+    val finish = Output(Bool())
+  })
+  val sIdle :: sSync :: sExe :: Nil = Enum(3)
+  val state = RegInit(sIdle)
+
+  val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
+
+  val loadUop = Module(new LoadUop)
+  val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
+  val tensorGemm = Module(new TensorGemm)
+  val tensorAlu = Module(new TensorAlu)
+
+  val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
+
+  // decode
+  val dec = Module(new ComputeDecode)
+  dec.io.inst := inst_q.io.deq.bits
+
+  val inst_type = Cat(dec.io.isFinish,
+                      dec.io.isAlu,
+                     dec.io.isGemm,
+                     dec.io.isLoadAcc,
+                     dec.io.isLoadUop).asUInt
+
+  val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
+  val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
+  val start = snext & sprev
+  val done =
+    MuxLookup(inst_type,
+               false.B, // default
+      Array(
+        "h_01".U -> loadUop.io.done,
+        "h_02".U -> tensorAcc.io.done,
+        "h_04".U -> tensorGemm.io.done,
+        "h_08".U -> tensorAlu.io.done,
+        "h_10".U -> true.B // Finish
+      )
+    )
+
+  // control
+  switch (state) {
+    is (sIdle) {
+      when (start) {
+       when (dec.io.isSync) {
+          state := sSync
+       } .elsewhen (inst_type.orR) {
+          state := sExe
+       }
+      }
+    }
+    is (sSync) {
+      state := sIdle
+    }
+    is (sExe) {
+      when (done) {
+        state := sIdle
+      }
+    }
+  }
+
+  // instructions
+  inst_q.io.enq <> io.inst
+  inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
+
+  // uop
+  loadUop.io.start :=  state === sIdle & start & dec.io.isLoadUop
+  loadUop.io.inst := inst_q.io.deq.bits
+  loadUop.io.baddr := io.uop_baddr
+  io.vme_rd(0) <> loadUop.io.vme_rd
+  loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
+
+  // acc
+  tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
+  tensorAcc.io.inst := inst_q.io.deq.bits
+  tensorAcc.io.baddr := io.acc_baddr
+  tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
+  tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
+  io.vme_rd(1) <> tensorAcc.io.vme_rd
+
+  // gemm
+  tensorGemm.io.start :=  state === sIdle & start & dec.io.isGemm
+  tensorGemm.io.inst := inst_q.io.deq.bits
+  tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
+  tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
+  tensorGemm.io.inp <> io.inp
+  tensorGemm.io.wgt <> io.wgt
+  tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
+  tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
+  tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
+  tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
+
+  // alu
+  tensorAlu.io.start :=  state === sIdle & start & dec.io.isAlu
+  tensorAlu.io.inst := inst_q.io.deq.bits
+  tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
+  tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
+  tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
+  tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
+  tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
+  tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
+
+  // out
+  io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
+  io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
+
+  // semaphore
+  s(0).io.spost := io.i_post(0)
+  s(1).io.spost := io.i_post(1)
+  s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
+  s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
+  io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
+  io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))
+
+  // finish
+  io.finish := state === sExe & done & dec.io.isFinish
+
+  // debug
+  if (debug) {
+    // start
+    when (state === sIdle && start) {
+      when (dec.io.isSync) {
+        printf("[Compute] start sync\n")
+      } .elsewhen (dec.io.isLoadUop) {
+        printf("[Compute] start load uop\n")
+      } .elsewhen (dec.io.isLoadAcc) {
+        printf("[Compute] start load acc\n")
+      } .elsewhen (dec.io.isGemm) {
+        printf("[Compute] start gemm\n")
+      } .elsewhen (dec.io.isAlu) {
+        printf("[Compute] start alu\n")
+      } .elsewhen (dec.io.isFinish) {
+        printf("[Compute] start finish\n")
+      }
+    }
+    // done
+    when (state === sSync) {
+      printf("[Compute] done sync\n")
+    }
+    when (state === sExe) {
+      when (done) {
+        when (dec.io.isLoadUop) {
+          printf("[Compute] done load uop\n")
+        } .elsewhen (dec.io.isLoadAcc) {
+          printf("[Compute] done load acc\n")
+        } .elsewhen (dec.io.isGemm) {
+          printf("[Compute] done gemm\n")
+        } .elsewhen (dec.io.isAlu) {
+          printf("[Compute] done alu\n")
+        } .elsewhen (dec.io.isFinish) {
+          printf("[Compute] done finish\n")
+        }
+      }
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/Configs.scala b/vta/hardware/chisel/src/main/scala/core/Configs.scala
new file mode 100644 (file)
index 0000000..b4e764b
--- /dev/null
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import vta.util.config._
+
+/** CoreConfig.
+  *
+  * This is one supported configuration for VTA. This file will
+  * be eventually filled out with class configurations that can be
+  * mixed/matched with Shell configurations for different backends.
+  */
+class CoreConfig extends Config((site, here, up) => {
+  case CoreKey => CoreParams(
+    batch = 1,
+    blockOut = 16,
+    blockIn = 16,
+    inpBits = 8,
+    wgtBits = 8,
+    uopBits = 32,
+    accBits = 32,
+    outBits = 8,
+    uopMemDepth = 2048,
+    inpMemDepth = 2048,
+    wgtMemDepth = 1024,
+    accMemDepth = 2048,
+    outMemDepth = 2048,
+    instQueueEntries = 512)
+})
diff --git a/vta/hardware/chisel/src/main/scala/core/Core.scala b/vta/hardware/chisel/src/main/scala/core/Core.scala
new file mode 100644 (file)
index 0000000..2a2d4e0
--- /dev/null
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import vta.util.config._
+import vta.shell._
+
+/** Core parameters */
+case class CoreParams (
+  batch: Int = 1,
+  blockOut: Int = 16,
+  blockIn: Int = 16,
+  inpBits: Int = 8,
+  wgtBits: Int = 8,
+  uopBits: Int = 32,
+  accBits: Int = 32,
+  outBits: Int = 8,
+  uopMemDepth: Int = 512,
+  inpMemDepth: Int = 512,
+  wgtMemDepth: Int = 512,
+  accMemDepth: Int = 512,
+  outMemDepth: Int = 512,
+  instQueueEntries: Int = 32
+)
+
+case object CoreKey extends Field[CoreParams]
+
+/** Core.
+  *
+  * The core defines the current VTA architecture by connecting memory and
+  * compute modules together such as load/store and compute. Most of the
+  * connections in the core are bulk (<>), and we should try to keep it this
+  * way, because it is easier to understand what is going on.
+  *
+  * Also, the core must be instantiated by a shell using the
+  * VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces.
+  * More info about these interfaces and modules can be found in the shell
+  * directory.
+  */
+class Core(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val vcr = new VCRClient
+    val vme = new VMEMaster
+  })
+  val fetch = Module(new Fetch)
+  val load = Module(new Load)
+  val compute = Module(new Compute)
+  val store = Module(new Store)
+
+  // Read(rd) and write(wr) from/to memory (i.e. DRAM)
+  io.vme.rd(0) <> fetch.io.vme_rd
+  io.vme.rd(1) <> compute.io.vme_rd(0)
+  io.vme.rd(2) <> load.io.vme_rd(0)
+  io.vme.rd(3) <> load.io.vme_rd(1)
+  io.vme.rd(4) <> compute.io.vme_rd(1)
+  io.vme.wr(0) <> store.io.vme_wr
+
+  // Fetch instructions (tasks) from memory (DRAM) into queues (SRAMs)
+  fetch.io.launch := io.vcr.launch
+  fetch.io.ins_baddr := io.vcr.ptrs(0)
+  fetch.io.ins_count := io.vcr.vals(0)
+
+  // Load inputs and weights from memory (DRAM) into scratchpads (SRAMs)
+  load.io.i_post := compute.io.o_post(0)
+  load.io.inst <> fetch.io.inst.ld
+  load.io.inp_baddr := io.vcr.ptrs(2)
+  load.io.wgt_baddr := io.vcr.ptrs(3)
+
+  // The compute module performs the following:
+  // - Load micro-ops (uops) and accumulations (acc)
+  // - Compute dense and ALU instructions (tasks)
+  compute.io.i_post(0) := load.io.o_post
+  compute.io.i_post(1) := store.io.o_post
+  compute.io.inst <> fetch.io.inst.co
+  compute.io.uop_baddr := io.vcr.ptrs(1)
+  compute.io.acc_baddr := io.vcr.ptrs(4)
+  compute.io.inp <> load.io.inp
+  compute.io.wgt <> load.io.wgt
+
+  // The store module performs the following:
+  // - Writes results from compute into scratchpads (SRAMs)
+  // - Store results from scratchpads (SRAMs) to memory (DRAM)
+  store.io.i_post := compute.io.o_post(1)
+  store.io.inst <> fetch.io.inst.st
+  store.io.out_baddr := io.vcr.ptrs(5)
+  store.io.out <> compute.io.out
+
+  // Finish instruction is executed and asserts the VCR finish flag
+  val finish = RegNext(compute.io.finish)
+  io.vcr.finish := finish
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/Decode.scala b/vta/hardware/chisel/src/main/scala/core/Decode.scala
new file mode 100644 (file)
index 0000000..f5bf340
--- /dev/null
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+
+import ISA._
+
+/** MemDecode.
+  *
+  * Decode memory instructions with a Bundle. This is similar to an union,
+  * therefore order matters when declaring fields. These are the instructions
+  * decoded with this bundle:
+  *   - LUOP
+  *   - LWGT
+  *   - LINP
+  *   - LACC
+  *   - SOUT
+  */
+class MemDecode extends Bundle {
+  val xpad_1 = UInt(M_PAD_BITS.W)
+  val xpad_0 = UInt(M_PAD_BITS.W)
+  val ypad_1 = UInt(M_PAD_BITS.W)
+  val ypad_0 = UInt(M_PAD_BITS.W)
+  val xstride = UInt(M_STRIDE_BITS.W)
+  val xsize = UInt(M_SIZE_BITS.W)
+  val ysize = UInt(M_SIZE_BITS.W)
+  val empty_0 = UInt(7.W) // derive this
+  val dram_offset = UInt(M_DRAM_OFFSET_BITS.W)
+  val sram_offset = UInt(M_SRAM_OFFSET_BITS.W)
+  val id = UInt(M_ID_BITS.W)
+  val push_next = Bool()
+  val push_prev = Bool()
+  val pop_next = Bool()
+  val pop_prev = Bool()
+  val op = UInt(OP_BITS.W)
+}
+
+/** GemmDecode.
+  *
+  * Decode GEMM instruction with a Bundle. This is similar to an union,
+  * therefore order matters when declaring fields.
+  */
+class GemmDecode extends Bundle {
+  val wgt_1 = UInt(C_WIDX_BITS.W)
+  val wgt_0 = UInt(C_WIDX_BITS.W)
+  val inp_1 = UInt(C_IIDX_BITS.W)
+  val inp_0 = UInt(C_IIDX_BITS.W)
+  val acc_1 = UInt(C_AIDX_BITS.W)
+  val acc_0 = UInt(C_AIDX_BITS.W)
+  val empty_0 = Bool()
+  val lp_1 = UInt(C_ITER_BITS.W)
+  val lp_0 = UInt(C_ITER_BITS.W)
+  val uop_end = UInt(C_UOP_END_BITS.W)
+  val uop_begin = UInt(C_UOP_BGN_BITS.W)
+  val reset = Bool()
+  val push_next = Bool()
+  val push_prev = Bool()
+  val pop_next = Bool()
+  val pop_prev = Bool()
+  val op = UInt(OP_BITS.W)
+}
+
+/** AluDecode.
+  *
+  * Decode ALU instructions with a Bundle. This is similar to an union,
+  * therefore order matters when declaring fields. These are the instructions
+  * decoded with this bundle:
+  *   - VMIN
+  *   - VMAX
+  *   - VADD
+  *   - VSHX
+  */
+class AluDecode extends Bundle {
+  val empty_1 = Bool()
+  val alu_imm = UInt(C_ALU_IMM_BITS.W)
+  val alu_use_imm = Bool()
+  val alu_op = UInt(C_ALU_DEC_BITS.W)
+  val src_1 = UInt(C_IIDX_BITS.W)
+  val src_0 = UInt(C_IIDX_BITS.W)
+  val dst_1 = UInt(C_AIDX_BITS.W)
+  val dst_0 = UInt(C_AIDX_BITS.W)
+  val empty_0 = Bool()
+  val lp_1 = UInt(C_ITER_BITS.W)
+  val lp_0 = UInt(C_ITER_BITS.W)
+  val uop_end = UInt(C_UOP_END_BITS.W)
+  val uop_begin = UInt(C_UOP_BGN_BITS.W)
+  val reset = Bool()
+  val push_next = Bool()
+  val push_prev = Bool()
+  val pop_next = Bool()
+  val pop_prev = Bool()
+  val op = UInt(OP_BITS.W)
+}
+
+/** UopDecode.
+  *
+  * Decode micro-ops (uops).
+  */
+class UopDecode extends Bundle {
+  val u2 = UInt(10.W)
+  val u1 = UInt(11.W)
+  val u0 = UInt(11.W)
+}
+
+/** FetchDecode.
+  *
+  * Partial decoding for dispatching instructions to Load, Compute, and Store.
+  */
+class FetchDecode extends Module {
+  val io = IO(new Bundle {
+    val inst = Input(UInt(INST_BITS.W))
+    val isLoad = Output(Bool())
+    val isCompute = Output(Bool())
+    val isStore = Output(Bool())
+  })
+  val csignals =
+    ListLookup(io.inst,
+        List(N, OP_X),
+      Array(
+        LUOP -> List(Y, OP_G),
+        LWGT -> List(Y, OP_L),
+        LINP -> List(Y, OP_L),
+        LACC -> List(Y, OP_G),
+        SOUT -> List(Y, OP_S),
+        GEMM -> List(Y, OP_G),
+        FNSH -> List(Y, OP_G),
+        VMIN -> List(Y, OP_G),
+        VMAX -> List(Y, OP_G),
+        VADD -> List(Y, OP_G),
+        VSHX -> List(Y, OP_G)
+      )
+    )
+
+  val (cs_val_inst: Bool) :: cs_op_type :: Nil = csignals
+
+  io.isLoad := cs_val_inst & cs_op_type === OP_L
+  io.isCompute := cs_val_inst & cs_op_type === OP_G
+  io.isStore := cs_val_inst & cs_op_type === OP_S
+}
+
+/** LoadDecode.
+  *
+  * Decode dependencies, type and sync for Load module.
+  */
+class LoadDecode extends Module {
+  val io = IO(new Bundle {
+    val inst = Input(UInt(INST_BITS.W))
+    val push_next = Output(Bool())
+    val pop_next = Output(Bool())
+    val isInput = Output(Bool())
+    val isWeight = Output(Bool())
+    val isSync = Output(Bool())
+  })
+  val dec = io.inst.asTypeOf(new MemDecode)
+  io.push_next := dec.push_next
+  io.pop_next := dec.pop_next
+  io.isInput := io.inst === LINP & dec.xsize =/= 0.U
+  io.isWeight := io.inst === LWGT & dec.xsize =/= 0.U
+  io.isSync := (io.inst === LINP | io.inst === LWGT) & dec.xsize === 0.U
+}
+
+/** ComputeDecode.
+  *
+  * Decode dependencies, type and sync for Compute module.
+  */
+class ComputeDecode extends Module {
+  val io = IO(new Bundle {
+    val inst = Input(UInt(INST_BITS.W))
+    val push_next = Output(Bool())
+    val push_prev = Output(Bool())
+    val pop_next = Output(Bool())
+    val pop_prev = Output(Bool())
+    val isLoadAcc = Output(Bool())
+    val isLoadUop = Output(Bool())
+    val isSync = Output(Bool())
+    val isAlu = Output(Bool())
+    val isGemm = Output(Bool())
+    val isFinish = Output(Bool())
+  })
+  val dec = io.inst.asTypeOf(new MemDecode)
+  io.push_next := dec.push_next
+  io.push_prev := dec.push_prev
+  io.pop_next := dec.pop_next
+  io.pop_prev := dec.pop_prev
+  io.isLoadAcc := io.inst === LACC & dec.xsize =/= 0.U
+  io.isLoadUop := io.inst === LUOP & dec.xsize =/= 0.U
+  io.isSync := (io.inst === LACC | io.inst === LUOP) & dec.xsize === 0.U
+  io.isAlu := io.inst === VMIN | io.inst === VMAX | io.inst === VADD | io.inst === VSHX
+  io.isGemm := io.inst === GEMM
+  io.isFinish := io.inst === FNSH
+}
+
+/** StoreDecode.
+  *
+  * Decode dependencies, type and sync for Store module.
+  */
+class StoreDecode extends Module {
+  val io = IO(new Bundle {
+    val inst = Input(UInt(INST_BITS.W))
+    val push_prev = Output(Bool())
+    val pop_prev = Output(Bool())
+    val isStore = Output(Bool())
+    val isSync = Output(Bool())
+  })
+  val dec = io.inst.asTypeOf(new MemDecode)
+  io.push_prev := dec.push_prev
+  io.pop_prev := dec.pop_prev
+  io.isStore := io.inst === SOUT & dec.xsize =/= 0.U
+  io.isSync := io.inst === SOUT & dec.xsize === 0.U
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/Fetch.scala b/vta/hardware/chisel/src/main/scala/core/Fetch.scala
new file mode 100644 (file)
index 0000000..bcc164a
--- /dev/null
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Fetch.
+  *
+  * The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
+  * VTA Memory Engine (VME), and push them into an instruction queue called
+  * inst_q. Once the instruction queue is full, instructions are dispatched to
+  * the Load, Compute and Store module queues based on the instruction opcode.
+  * After draining the queue, the fetch unit checks if there are more instructions
+  * via the ins_count register which is written by the host.
+  *
+  * Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
+  * because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
+  * This should be configurable for larger payloads, i.e. 64-bytes, which can load
+  * more than one instruction at the time. Finally, the instruction queue is
+  * sized (entries_q), depending on the maximum burst allowed in the memory.
+  */
+class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val vp = p(ShellKey).vcrParams
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val launch = Input(Bool())
+    val ins_baddr = Input(UInt(mp.addrBits.W))
+    val ins_count = Input(UInt(vp.regBits.W))
+    val vme_rd = new VMEReadMaster
+    val inst = new Bundle {
+      val ld = Decoupled(UInt(INST_BITS.W))
+      val co = Decoupled(UInt(INST_BITS.W))
+      val st = Decoupled(UInt(INST_BITS.W))
+    }
+  })
+  val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word
+  val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q))
+  val dec = Module(new FetchDecode)
+
+  val s1_launch = RegNext(io.launch)
+  val pulse = io.launch & ~s1_launch
+
+  val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
+  val rlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+  val ilen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+
+  val xrem = Reg(chiselTypeOf(io.ins_count))
+  val xsize = (io.ins_count << 1.U) - 1.U
+  val xmax = (1 << mp.lenBits).U
+  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+
+  val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
+  val state = RegInit(sIdle)
+
+  // control
+  switch (state) {
+    is (sIdle) {
+      when (pulse) {
+        state := sReadCmd
+       when (xsize < xmax) {
+          rlen := xsize
+         ilen := xsize >> 1.U
+          xrem := 0.U
+       } .otherwise {
+          rlen := xmax - 1.U
+         ilen := (xmax >> 1.U) - 1.U
+          xrem := xsize - xmax
+       }
+      }
+    }
+    is (sReadCmd) {
+      when (io.vme_rd.cmd.ready) {
+        state := sReadLSB
+      }
+    }
+    is (sReadLSB) {
+      when (io.vme_rd.data.valid) {
+          state := sReadMSB
+      }
+    }
+    is (sReadMSB) {
+      when (io.vme_rd.data.valid) {
+        when (inst_q.io.count === ilen) {
+          state := sDrain
+        } .otherwise {
+          state := sReadLSB
+       }
+      }
+    }
+    is (sDrain) {
+      when (inst_q.io.count === 0.U) {
+        when (xrem === 0.U) {
+          state := sIdle
+        } .elsewhen (xrem < xmax) {
+          state := sReadCmd
+          rlen := xrem
+         ilen := xrem >> 1.U
+          xrem := 0.U
+        } .otherwise {
+          state := sReadCmd
+          rlen := xmax - 1.U
+         ilen := (xmax >> 1.U) - 1.U
+          xrem := xrem - xmax
+        }
+      }
+    }
+  }
+
+  // read instructions from dram
+  when (state === sIdle) {
+    raddr := io.ins_baddr
+  } .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
+    raddr := raddr + xmax_bytes
+  }
+
+  io.vme_rd.cmd.valid := state === sReadCmd
+  io.vme_rd.cmd.bits.addr := raddr
+  io.vme_rd.cmd.bits.len := rlen
+
+  io.vme_rd.data.ready := inst_q.io.enq.ready
+
+  val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits))
+  val msb = io.vme_rd.data.bits
+  val inst = Cat(msb, lsb)
+
+  when (state === sReadLSB) { lsb := io.vme_rd.data.bits }
+
+  inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
+  inst_q.io.enq.bits := inst
+
+  // decode
+  dec.io.inst := inst_q.io.deq.bits
+
+  // instruction queues
+  io.inst.ld.valid := dec.io.isLoad & inst_q.io.deq.valid & state === sDrain
+  io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain
+  io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain
+
+  io.inst.ld.bits := inst_q.io.deq.bits
+  io.inst.co.bits := inst_q.io.deq.bits
+  io.inst.st.bits := inst_q.io.deq.bits
+
+  // check if selected queue is ready
+  val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
+  val deq_ready =
+    MuxLookup(deq_sel,
+               false.B, // default
+      Array(
+        "h_01".U -> io.inst.ld.ready,
+        "h_02".U -> io.inst.st.ready,
+        "h_04".U -> io.inst.co.ready
+      )
+    )
+
+  // dequeue instruction
+  inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
+
+
+  // debug
+  if (debug) {
+    when (state === sIdle && pulse) {
+      printf("[Fetch] Launch\n")
+    }
+    // instruction
+    when (inst_q.io.deq.fire()) {
+      when (dec.io.isLoad) {
+        printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
+      }
+      when (dec.io.isCompute) {
+        printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
+      }
+      when (dec.io.isStore) {
+        printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
+      }
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/ISA.scala b/vta/hardware/chisel/src/main/scala/core/ISA.scala
new file mode 100644 (file)
index 0000000..c3bf609
--- /dev/null
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+
+/** ISAConstants.
+  *
+  * These constants are used for decoding (parsing) fields on instructions.
+  */
+trait ISAConstants
+{
+   val INST_BITS = 128
+
+   val OP_BITS = 3
+
+   val M_DEP_BITS = 4
+   val M_ID_BITS = 2
+   val M_SRAM_OFFSET_BITS = 16
+   val M_DRAM_OFFSET_BITS = 32
+   val M_SIZE_BITS = 16
+   val M_STRIDE_BITS = 16
+   val M_PAD_BITS = 4
+
+   val C_UOP_BGN_BITS = 13
+   val C_UOP_END_BITS = 14
+   val C_ITER_BITS = 14
+   val C_AIDX_BITS = 11
+   val C_IIDX_BITS = 11
+   val C_WIDX_BITS = 10
+   val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
+   val C_ALU_OP_BITS = 3
+   val C_ALU_IMM_BITS = 16
+
+   val Y = true.B
+   val N = false.B
+
+   val OP_L = 0.asUInt(OP_BITS.W)
+   val OP_S = 1.asUInt(OP_BITS.W)
+   val OP_G = 2.asUInt(OP_BITS.W)
+   val OP_F = 3.asUInt(OP_BITS.W)
+   val OP_A = 4.asUInt(OP_BITS.W)
+   val OP_X = 5.asUInt(OP_BITS.W)
+
+   val ALU_OP_NUM = 5
+   val ALU_OP = Enum(ALU_OP_NUM)
+
+   val M_ID_U = 0.asUInt(M_ID_BITS.W)
+   val M_ID_W = 1.asUInt(M_ID_BITS.W)
+   val M_ID_I = 2.asUInt(M_ID_BITS.W)
+   val M_ID_A = 3.asUInt(M_ID_BITS.W)
+}
+
+/** ISA.
+  *
+  * This is the VTA ISA, here we specify the cares and dont-cares that makes
+  * decoding easier. Since instructions are quite long 128-bit, we could generate
+  * these based on ISAConstants.
+  *
+  * FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler
+  * TODO: Add VXOR to clear accumulator
+  */
+object ISA {
+  def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
+  def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
+  def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
+  def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
+  def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
+  def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
+  def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/Load.scala b/vta/hardware/chisel/src/main/scala/core/Load.scala
new file mode 100644 (file)
index 0000000..6479513
--- /dev/null
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Load.
+  *
+  * Load inputs and weights from memory (DRAM) into scratchpads (SRAMs).
+  * This module instantiate the TensorLoad unit which is in charge of
+  * loading 1D and 2D tensors to scratchpads, so it can be used by
+  * other modules such as Compute.
+  */
+class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val i_post = Input(Bool())
+    val o_post = Output(Bool())
+    val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
+    val inp_baddr = Input(UInt(mp.addrBits.W))
+    val wgt_baddr = Input(UInt(mp.addrBits.W))
+    val vme_rd = Vec(2, new VMEReadMaster)
+    val inp = new TensorClient(tensorType = "inp")
+    val wgt = new TensorClient(tensorType = "wgt")
+  })
+  val sIdle :: sSync :: sExe :: Nil = Enum(3)
+  val state = RegInit(sIdle)
+
+  val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
+  val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
+
+  val dec = Module(new LoadDecode)
+  dec.io.inst := inst_q.io.deq.bits
+
+  val tensorType = Seq("inp", "wgt")
+  val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
+  val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
+
+  val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
+  val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
+
+  // control
+  switch (state) {
+    is (sIdle) {
+      when (start) {
+        when (dec.io.isSync) {
+          state := sSync
+       } .elsewhen (dec.io.isInput || dec.io.isWeight) {
+          state := sExe
+       }
+      }
+    }
+    is (sSync) {
+      state := sIdle
+    }
+    is (sExe) {
+      when (done) {
+        state := sIdle
+      }
+    }
+  }
+
+  // instructions
+  inst_q.io.enq <> io.inst
+  inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
+
+  // load tensor
+  // [0] input (inp)
+  // [1] weight (wgt)
+  val ptr = Seq(io.inp_baddr, io.wgt_baddr)
+  val tsor = Seq(io.inp, io.wgt)
+  for (i <- 0 until 2) {
+    tensorLoad(i).io.start := state === sIdle & start & tensorDec(i)
+    tensorLoad(i).io.inst := inst_q.io.deq.bits
+    tensorLoad(i).io.baddr := ptr(i)
+    tensorLoad(i).io.tensor <> tsor(i)
+    io.vme_rd(i) <> tensorLoad(i).io.vme_rd
+  }
+
+  // semaphore
+  s.io.spost := io.i_post
+  s.io.swait := dec.io.pop_next & (state === sIdle & start)
+  io.o_post := dec.io.push_next & ((state === sExe & done) | (state === sSync))
+
+  // debug
+  if (debug) {
+    // start
+    when (state === sIdle && start) {
+      when (dec.io.isSync) {
+        printf("[Load] start sync\n")
+      } .elsewhen (dec.io.isInput) {
+        printf("[Load] start input\n")
+      } .elsewhen (dec.io.isWeight) {
+        printf("[Load] start weight\n")
+      }
+    }
+    // done
+    when (state === sSync) {
+      printf("[Load] done sync\n")
+    }
+    when (state === sExe) {
+      when (done) {
+        when (dec.io.isInput) {
+         printf("[Load] done input\n")
+       } .elsewhen (dec.io.isWeight) {
+         printf("[Load] done weight\n")
+       }
+      }
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/LoadUop.scala b/vta/hardware/chisel/src/main/scala/core/LoadUop.scala
new file mode 100644 (file)
index 0000000..0729652
--- /dev/null
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** UopMaster.
+  *
+  * Uop interface used by a master module, i.e. TensorAlu or TensorGemm,
+  * to request a micro-op (uop) from the uop-scratchpad. The index (idx) is
+  * used as an address to find the uop in the uop-scratchpad.
+  */
+class UopMaster(implicit p: Parameters) extends Bundle {
+  val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
+  val idx = ValidIO(UInt(addrBits.W))
+  val data = Flipped(ValidIO(new UopDecode))
+  override def cloneType = new UopMaster().asInstanceOf[this.type]
+}
+
+/** UopClient.
+  *
+  * Uop interface used by a client module, i.e. LoadUop, to receive
+  * a request from a master module, i.e. TensorAlu or TensorGemm.
+  * The index (idx) is used as an address to find the uop in the uop-scratchpad.
+  */
+class UopClient(implicit p: Parameters) extends Bundle {
+  val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
+  val idx = Flipped(ValidIO(UInt(addrBits.W)))
+  val data = ValidIO(new UopDecode)
+  override def cloneType = new UopClient().asInstanceOf[this.type]
+}
+
+/** LoadUop.
+  *
+  * Load micro-ops (uops) from memory, i.e. DRAM, and store them in the
+  * uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in
+  * group of 2 given the fact that the DRAM payload is 8-bytes. This module
+  * should be modified later on to support different DRAM sizes efficiently.
+  */
+class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val baddr = Input(UInt(mp.addrBits.W))
+    val vme_rd = new VMEReadMaster
+    val uop = new UopClient
+  })
+  val numUop = 2 // store two uops per sram word
+  val uopBits = p(CoreKey).uopBits
+  val uopDepth = p(CoreKey).uopMemDepth / numUop
+
+  val dec = io.inst.asTypeOf(new MemDecode)
+  val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
+  val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+  val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
+  val xrem = Reg(chiselTypeOf(dec.xsize))
+  val xsize = dec.xsize(0) + (dec.xsize >> log2Ceil(numUop)) - 1.U
+  val xmax = (1 << mp.lenBits).U
+  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+
+  val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
+  val sizeIsEven = (dec.xsize % 2.U) === 0.U
+
+  val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3)
+  val state = RegInit(sIdle)
+
+  // control
+  switch (state) {
+    is (sIdle) {
+      when (io.start) {
+        state := sReadCmd
+       when (xsize < xmax) {
+          xlen := xsize
+          xrem := 0.U
+       } .otherwise {
+          xlen := xmax - 1.U
+          xrem := xsize - xmax
+       }
+      }
+    }
+    is (sReadCmd) {
+      when (io.vme_rd.cmd.ready) {
+        state := sReadData
+      }
+    }
+    is (sReadData) {
+      when (io.vme_rd.data.valid) {
+        when(xcnt === xlen) {
+          when (xrem === 0.U) {
+            state := sIdle
+          } .elsewhen (xrem < xmax) {
+            state := sReadCmd
+            xlen := xrem
+            xrem := 0.U
+          } .otherwise {
+            state := sReadCmd
+            xlen := xmax - 1.U
+            xrem := xrem - xmax
+          }
+        }
+      }
+    }
+  }
+
+  // read-from-dram
+  when (state === sIdle) {
+    when (offsetIsEven) {
+      raddr := io.baddr + dec.dram_offset
+    } .otherwise {
+      raddr := io.baddr + dec.dram_offset - 4.U
+    }
+  } .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
+    raddr := raddr + xmax_bytes
+  }
+
+  io.vme_rd.cmd.valid := state === sReadCmd
+  io.vme_rd.cmd.bits.addr := raddr
+  io.vme_rd.cmd.bits.len := xlen
+
+  io.vme_rd.data.ready := state === sReadData
+
+  when (state =/= sReadData) {
+    xcnt := 0.U
+  } .elsewhen (io.vme_rd.data.fire()) {
+    xcnt := xcnt + 1.U
+  }
+
+  val waddr = Reg(UInt(log2Ceil(uopDepth).W))
+  when (state === sIdle) {
+    waddr := dec.sram_offset >> log2Ceil(numUop)
+  } .elsewhen (io.vme_rd.data.fire()) {
+    waddr := waddr + 1.U
+  }
+
+  val wdata = Wire(Vec(numUop, UInt(uopBits.W)))
+  val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
+  val wmask = Reg(Vec(numUop, Bool()))
+
+  when (offsetIsEven) {
+    when (sizeIsEven) {
+      wmask := "b_11".U.asTypeOf(wmask)
+    } .elsewhen (io.vme_rd.cmd.fire()) {
+      when (dec.xsize === 1.U) {
+        wmask := "b_01".U.asTypeOf(wmask)
+      } .otherwise {
+        wmask := "b_11".U.asTypeOf(wmask)
+      }
+    } .elsewhen (io.vme_rd.data.fire()) {
+      when (xcnt === xlen - 1.U) {
+        wmask := "b_01".U.asTypeOf(wmask)
+      } .otherwise {
+        wmask := "b_11".U.asTypeOf(wmask)
+      }
+    }
+  } .otherwise {
+    when (io.vme_rd.cmd.fire()) {
+      wmask := "b_10".U.asTypeOf(wmask)
+    } .elsewhen (io.vme_rd.data.fire()) {
+      when (sizeIsEven && xcnt === xlen - 1.U) {
+        wmask := "b_01".U.asTypeOf(wmask)
+      } .otherwise {
+        wmask := "b_11".U.asTypeOf(wmask)
+      }
+    }
+  }
+
+  wdata := io.vme_rd.data.bits.asTypeOf(wdata)
+  when (io.vme_rd.data.fire()) {
+    mem.write(waddr, wdata, wmask)
+  }
+
+  // read-from-sram
+  io.uop.data.valid := RegNext(io.uop.idx.valid)
+
+  val sIdx = io.uop.idx.bits % numUop.U
+  val rIdx = io.uop.idx.bits >> log2Ceil(numUop)
+  val memRead = mem.read(rIdx, io.uop.idx.valid)
+  val sWord = memRead.asUInt.asTypeOf(wdata)
+  val sUop = sWord(sIdx).asTypeOf(io.uop.data.bits)
+
+  io.uop.data.bits <> sUop
+
+  // done
+  io.done := state === sReadData & io.vme_rd.data.valid & xcnt === xlen & xrem === 0.U
+
+  // debug
+  if (debug) {
+    when (io.vme_rd.cmd.fire()) {
+      printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/Semaphore.scala b/vta/hardware/chisel/src/main/scala/core/Semaphore.scala
new file mode 100644 (file)
index 0000000..06df51e
--- /dev/null
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+
+/** Semaphore.
+  *
+  * This semaphore is used instead of push/pop fifo, used in the initial
+  * version of VTA. This semaphore is incremented (spost) or decremented (swait)
+  * depending on the push and pop fields on instructions to prevent RAW and WAR
+  * hazards.
+  */
+class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
+  val io = IO(new Bundle {
+    val spost = Input(Bool())
+    val swait = Input(Bool())
+    val sready = Output(Bool())
+  })
+  val cnt = RegInit(counterInitValue.U(counterBits.W))
+  when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U }
+  when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
+  io.sready := cnt =/= 0.U
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/Store.scala b/vta/hardware/chisel/src/main/scala/core/Store.scala
new file mode 100644 (file)
index 0000000..5d89871
--- /dev/null
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** Store.
+  *
+  * Store results back to memory (DRAM) from scratchpads (SRAMs).
+  * This module instantiate the TensorStore unit which is in charge
+  * of storing 1D and 2D tensors to main memory.
+  */
+class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val i_post = Input(Bool())
+    val o_post = Output(Bool())
+    val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
+    val out_baddr = Input(UInt(mp.addrBits.W))
+    val vme_wr = new VMEWriteMaster
+    val out = new TensorClient(tensorType = "out")
+  })
+  val sIdle :: sSync :: sExe :: Nil = Enum(3)
+  val state = RegInit(sIdle)
+
+  val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
+  val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
+
+  val dec = Module(new StoreDecode)
+  dec.io.inst := inst_q.io.deq.bits
+
+  val tensorStore = Module(new TensorStore(tensorType = "out"))
+
+  val start = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s.io.sready, true.B)
+  val done = tensorStore.io.done
+
+  // control
+  switch (state) {
+    is (sIdle) {
+      when (start) {
+        when (dec.io.isSync) {
+          state := sSync
+       } .elsewhen (dec.io.isStore) {
+          state := sExe
+       }
+      }
+    }
+    is (sSync) {
+      state := sIdle
+    }
+    is (sExe) {
+      when (done) {
+        state := sIdle
+      }
+    }
+  }
+
+  // instructions
+  inst_q.io.enq <> io.inst
+  inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
+
+  // store
+  tensorStore.io.start := state === sIdle & start & dec.io.isStore
+  tensorStore.io.inst := inst_q.io.deq.bits
+  tensorStore.io.baddr := io.out_baddr
+  io.vme_wr <> tensorStore.io.vme_wr
+  tensorStore.io.tensor <> io.out
+
+  // semaphore
+  s.io.spost := io.i_post
+  s.io.swait := dec.io.pop_prev & (state === sIdle & start)
+  io.o_post := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
+
+  // debug
+  if (debug) {
+    // start
+    when (state === sIdle && start) {
+      when (dec.io.isSync) {
+        printf("[Store] start sync\n")
+      } .elsewhen (dec.io.isStore) {
+        printf("[Store] start\n")
+      }
+    }
+    // done
+    when (state === sSync) {
+      printf("[Store] done sync\n")
+    }
+    when (state === sExe) {
+      when (done) {
+       printf("[Store] done\n")
+      }
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala b/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala
new file mode 100644 (file)
index 0000000..7f429be
--- /dev/null
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+
+/** ALU datapath */
+class Alu(implicit p: Parameters) extends Module {
+  val aluBits = p(CoreKey).accBits
+  val io = IO(new Bundle {
+    val opcode = Input(UInt(C_ALU_OP_BITS.W))
+    val a = Input(SInt(aluBits.W))
+    val b = Input(SInt(aluBits.W))
+    val y = Output(SInt(aluBits.W))
+  })
+
+  // FIXME: the following three will change once we support properly SHR and SHL
+  val ub = io.b.asUInt
+  val width = log2Ceil(aluBits)
+  val m = ~ub(width - 1, 0) + 1.U
+
+  val n = ub(width - 1, 0)
+  val fop = Seq(Mux(io.a < io.b, io.a, io.b),
+                Mux(io.a < io.b, io.b, io.a),
+                io.a + io.b,
+                io.a >> n,
+               io.a << m)
+
+  val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i))
+  io.y := MuxLookup(io.opcode, io.a, opmux)
+}
+
+/** Pipelined ALU */
+class AluReg(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val opcode = Input(UInt(C_ALU_OP_BITS.W))
+    val a = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
+    val b = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
+    val y = ValidIO(UInt(p(CoreKey).accBits.W))
+  })
+  val alu = Module(new Alu)
+  val rA = RegEnable(io.a.bits, io.a.valid)
+  val rB = RegEnable(io.b.bits, io.b.valid)
+  val valid = RegNext(io.b.valid)
+
+  alu.io.opcode := io.opcode
+
+  // register input
+  alu.io.a := rA.asSInt
+  alu.io.b := rB.asSInt
+
+  // output
+  io.y.valid := valid
+  io.y.bits := alu.io.y.asUInt
+}
+
+/** Vector of pipeline ALUs */
+class AluVector(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val opcode = Input(UInt(C_ALU_OP_BITS.W))
+    val acc_a = new TensorMasterData(tensorType = "acc")
+    val acc_b = new TensorMasterData(tensorType = "acc")
+    val acc_y = new TensorClientData(tensorType = "acc")
+    val out = new TensorClientData(tensorType = "out")
+  })
+  val blockOut = p(CoreKey).blockOut
+  val f = Seq.fill(blockOut)(Module(new AluReg))
+  val valid = Wire(Vec(blockOut, Bool()))
+  for (i <- 0 until blockOut) {
+    f(i).io.opcode := io.opcode
+    f(i).io.a.valid := io.acc_a.data.valid
+    f(i).io.a.bits := io.acc_a.data.bits(0)(i)
+    f(i).io.b.valid := io.acc_b.data.valid
+    f(i).io.b.bits := io.acc_b.data.bits(0)(i)
+    valid(i) := f(i).io.y.valid
+    io.acc_y.data.bits(0)(i) := f(i).io.y.bits
+    io.out.data.bits(0)(i) := f(i).io.y.bits
+  }
+  io.acc_y.data.valid := valid.asUInt.andR
+  io.out.data.valid := valid.asUInt.andR
+}
+
+/** TensorAlu.
+  *
+  * This unit instantiate the ALU vector unit (AluVector) and go over the
+  * micro-ops (uops) which are used to read the source operands (vectors)
+  * from the acc-scratchpad and then they are written back the same
+  * acc-scratchpad.
+  */
+class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val uop = new UopMaster
+    val acc = new TensorMaster(tensorType = "acc")
+    val out = new TensorMaster(tensorType = "out")
+  })
+  val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6)
+  val state = RegInit(sIdle)
+  val alu = Module(new AluVector)
+  val dec = io.inst.asTypeOf(new AluDecode)
+  val uop_idx = Reg(chiselTypeOf(dec.uop_end))
+  val uop_end = dec.uop_end
+  val uop_dst = Reg(chiselTypeOf(dec.uop_end))
+  val uop_src = Reg(chiselTypeOf(dec.uop_end))
+  val cnt_o = Reg(chiselTypeOf(dec.lp_0))
+  val dst_o = Reg(chiselTypeOf(dec.uop_end))
+  val src_o = Reg(chiselTypeOf(dec.uop_end))
+  val cnt_i = Reg(chiselTypeOf(dec.lp_1))
+  val dst_i = Reg(chiselTypeOf(dec.uop_end))
+  val src_i = Reg(chiselTypeOf(dec.uop_end))
+  val done =
+    state === sExe &
+    alu.io.out.data.valid &
+    (cnt_o === dec.lp_0 - 1.U) &
+    (cnt_i === dec.lp_1 - 1.U) &
+    (uop_idx === uop_end - 1.U)
+
+  switch (state) {
+    is (sIdle) {
+      when (io.start) {
+        state := sReadUop
+      }
+    }
+    is (sReadUop) {
+      state := sComputeIdx
+    }
+    is (sComputeIdx) {
+      state := sReadTensorA
+    }
+    is (sReadTensorA) {
+      state := sReadTensorB
+    }
+    is (sReadTensorB) {
+      state := sExe
+    }
+    is (sExe) {
+      when (alu.io.out.data.valid) {
+        when ((cnt_o === dec.lp_0 - 1.U) &&
+             (cnt_i === dec.lp_1 - 1.U) &&
+             (uop_idx === uop_end - 1.U)) {
+          state := sIdle
+        } .otherwise {
+          state := sReadUop
+        }
+      }
+    }
+  }
+
+  when (state === sIdle ||
+         (state === sExe &&
+         alu.io.out.data.valid &&
+         uop_idx === uop_end - 1.U)) {
+    uop_idx := dec.uop_begin
+  } .elsewhen (state === sExe && alu.io.out.data.valid) {
+    uop_idx := uop_idx + 1.U
+  }
+
+  when (state === sIdle) {
+    cnt_o := 0.U
+    dst_o := 0.U
+    src_o := 0.U
+  } .elsewhen (state === sExe &&
+               alu.io.out.data.valid &&
+               uop_idx === uop_end - 1.U &&
+              cnt_i === dec.lp_1 - 1.U) {
+    cnt_o := cnt_o + 1.U
+    dst_o := dst_o + dec.dst_0
+    src_o := src_o + dec.src_0
+  }
+
+  when (state === sIdle) {
+    cnt_i := 0.U
+    dst_i := 0.U
+    src_i := 0.U
+  } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
+    cnt_i := 0.U
+    dst_i := dst_o
+    src_i := src_o
+  } .elsewhen (state === sExe &&
+               alu.io.out.data.valid &&
+              uop_idx === uop_end - 1.U) {
+    cnt_i := cnt_i + 1.U
+    dst_i := dst_i + dec.dst_1
+    src_i := src_i + dec.src_1
+  }
+
+  when (state === sComputeIdx && io.uop.data.valid) {
+    uop_dst := io.uop.data.bits.u0 + dst_i
+    uop_src := io.uop.data.bits.u1 + src_i
+  }
+
+  // uop
+  io.uop.idx.valid := state === sReadUop
+  io.uop.idx.bits := uop_idx
+
+  // acc_i
+  io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
+  io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
+
+  // imm
+  val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
+  tensorImm.data.valid := state === sReadTensorB
+  tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } }
+
+  // alu
+  val isSHR = dec.alu_op === ALU_OP(3)
+  val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1)
+  val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
+  alu.io.opcode := fixme_alu_op
+  alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
+  alu.io.acc_a.data.bits <> io.acc.rd.data.bits
+  alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe)
+  alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits)
+
+  // acc_o
+  io.acc.wr.valid := alu.io.acc_y.data.valid
+  io.acc.wr.bits.idx := uop_dst
+  io.acc.wr.bits.data <> alu.io.acc_y.data.bits
+
+  // out
+  io.out.wr.valid := alu.io.out.data.valid
+  io.out.wr.bits.idx := uop_dst
+  io.out.wr.bits.data <> alu.io.out.data.bits
+  io.out.tieoffRead() // write-only
+
+  io.done := done
+
+  if (debug) {
+
+    when (state === sReadUop) {
+      printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
+    }
+
+    when (state === sReadTensorA) {
+      printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
+    }
+
+    when (state === sIdle && io.start) {
+      printf(p"[TensorAlu] decode:$dec\n")
+    }
+
+    alu.io.acc_a.data.bits.foreach { tensor =>
+      tensor.zipWithIndex.foreach { case(elem, i) =>
+        when (alu.io.acc_a.data.valid) {
+          printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
+        }
+      }
+    }
+
+    alu.io.acc_b.data.bits.foreach { tensor =>
+      tensor.zipWithIndex.foreach { case(elem, i) =>
+        when (alu.io.acc_b.data.valid) {
+          printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
+        }
+      }
+    }
+
+    alu.io.acc_y.data.bits.foreach { tensor =>
+      tensor.zipWithIndex.foreach { case(elem, i) =>
+        when (alu.io.acc_y.data.valid) {
+          printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
+        }
+      }
+    }
+
+    alu.io.out.data.bits.foreach { tensor =>
+      tensor.zipWithIndex.foreach { case(elem, i) =>
+        when (alu.io.out.data.valid) {
+          printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
+        }
+      }
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala
new file mode 100644 (file)
index 0000000..2dd8c33
--- /dev/null
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import chisel3.experimental._
+import vta.util.config._
+import scala.math.pow
+
+/** Pipelined multiply and accumulate */
+class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module {
+  require (cBits >= dataBits * 2)
+  require (outBits >= dataBits * 2)
+  val io = IO(new Bundle {
+    val a = Input(SInt(dataBits.W))
+    val b = Input(SInt(dataBits.W))
+    val c = Input(SInt(cBits.W))
+    val y = Output(SInt(outBits.W))
+  })
+  val mult = Wire(SInt(cBits.W))
+  val add = Wire(SInt(outBits.W))
+  val rA = RegNext(io.a)
+  val rB = RegNext(io.b)
+  val rC = RegNext(io.c)
+  mult := rA * rB
+  add := rC + mult
+  io.y := add
+}
+
+/** Pipelined adder */
+class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module {
+  require (outBits >= dataBits)
+  val io = IO(new Bundle {
+    val a = Input(SInt(dataBits.W))
+    val b = Input(SInt(dataBits.W))
+    val y = Output(SInt(outBits.W))
+  })
+  val add = Wire(SInt(outBits.W))
+  val rA = RegNext(io.a)
+  val rB = RegNext(io.b)
+  add := rA + rB
+  io.y := add
+}
+
+/** Pipelined DotProduct based on MAC and Adder */
+class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module {
+  val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
+  require(size >= 4 && isPow2(size), errMsg)
+  val b = dataBits * 2
+  val outBits = b + log2Ceil(size) + 1
+  val io = IO(new Bundle {
+    val a = Input(Vec(size, SInt(dataBits.W)))
+    val b = Input(Vec(size, SInt(dataBits.W)))
+    val y = Output(SInt(outBits.W))
+  })
+  val p = log2Ceil(size/2)
+  val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt)
+  val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i)))
+  val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i)))
+  val m = Seq.tabulate(2)(i =>
+    Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1)))
+  )
+  val a = Seq.tabulate(p)(i =>
+    Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3)))
+  )
+
+  for (i <- 0 until log2Ceil(size)) {
+    for (j <- 0 until s(i)) {
+      if (i == 0) {
+        m(i)(j).io.a := io.a(j)
+        m(i)(j).io.b := io.b(j)
+        m(i)(j).io.c := 0.S
+        m(i + 1)(j).io.a := da(j)
+        m(i + 1)(j).io.b := db(j)
+        m(i + 1)(j).io.c := m(i)(j).io.y
+      } else if (i == 1) {
+        a(i - 1)(j).io.a := m(i)(2*j).io.y
+        a(i - 1)(j).io.b := m(i)(2*j + 1).io.y
+      } else {
+        a(i - 1)(j).io.a := a(i - 2)(2*j).io.y
+        a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y
+      }
+    }
+  }
+  io.y := a(p-1)(0).io.y
+}
+
+/** Perform matric-vector-multiplication based on DotProduct */
+class MatrixVectorCore(implicit p: Parameters) extends Module {
+  val accBits = p(CoreKey).accBits
+  val size = p(CoreKey).blockOut
+  val dataBits = p(CoreKey).inpBits
+  val io = IO(new Bundle{
+    val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
+    val inp = new TensorMasterData(tensorType = "inp")
+    val wgt = new TensorMasterData(tensorType = "wgt")
+    val acc_i = new TensorMasterData(tensorType = "acc")
+    val acc_o = new TensorClientData(tensorType = "acc")
+    val out = new TensorClientData(tensorType = "out")
+  })
+  val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size)))
+  val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
+  val add = Seq.fill(size)(Wire(SInt(accBits.W)))
+  val vld = Wire(Vec(size, Bool()))
+
+  for (i <- 0 until size) {
+    acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset
+    acc(i).io.enq.bits := io.acc_i.data.bits(0)(i)
+    for (j <- 0 until size) {
+      dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt
+      dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt
+    }
+    add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y
+    io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt)
+    io.out.data.bits(0)(i) := add(i).asUInt
+    vld(i) := acc(i).io.deq.valid
+  }
+  io.acc_o.data.valid := vld.asUInt.andR | io.reset
+  io.out.data.valid := vld.asUInt.andR
+}
+
+/** TensorGemm.
+  *
+  * This unit instantiate the MatrixVectorCore and go over the
+  * micro-ops (uops) which are used to read inputs, weights and biases,
+  * and writes results back to the acc and out scratchpads.
+  *
+  * Also, the TensorGemm uses the reset field in the Gemm instruction to
+  * clear or zero-out the acc-scratchpad locations based on the micro-ops.
+  */
+class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val uop = new UopMaster
+    val inp = new TensorMaster(tensorType = "inp")
+    val wgt = new TensorMaster(tensorType = "wgt")
+    val acc = new TensorMaster(tensorType = "acc")
+    val out = new TensorMaster(tensorType = "out")
+  })
+  val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
+  val state = RegInit(sIdle)
+  val mvc = Module(new MatrixVectorCore)
+  val dec = io.inst.asTypeOf(new GemmDecode)
+  val uop_idx = Reg(chiselTypeOf(dec.uop_end))
+  val uop_end = dec.uop_end
+  val uop_acc = Reg(chiselTypeOf(dec.uop_end))
+  val uop_inp = Reg(chiselTypeOf(dec.uop_end))
+  val uop_wgt = Reg(chiselTypeOf(dec.uop_end))
+  val cnt_o = Reg(chiselTypeOf(dec.lp_0))
+  val acc_o = Reg(chiselTypeOf(dec.uop_end))
+  val inp_o = Reg(chiselTypeOf(dec.uop_end))
+  val wgt_o = Reg(chiselTypeOf(dec.uop_end))
+  val cnt_i = Reg(chiselTypeOf(dec.lp_1))
+  val acc_i = Reg(chiselTypeOf(dec.uop_end))
+  val inp_i = Reg(chiselTypeOf(dec.uop_end))
+  val wgt_i = Reg(chiselTypeOf(dec.uop_end))
+  val pBits = log2Ceil(p(CoreKey).blockOut) + 1
+  val inflight = Reg(UInt(pBits.W))
+  val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
+  val done = inflight === 0.U &
+             ((state === sExe &
+              cnt_o === dec.lp_0 - 1.U &
+             cnt_i === dec.lp_1 - 1.U &
+             uop_idx === uop_end - 1.U &
+             inflight === 0.U) |
+            state === sWait)
+
+  switch (state) {
+    is (sIdle) {
+      when (io.start) {
+        state := sReadUop
+      }
+    }
+    is (sReadUop) {
+      state := sComputeIdx
+    }
+    is (sComputeIdx) {
+      state := sReadTensor
+    }
+    is (sReadTensor) {
+      state := sExe
+    }
+    is (sExe) {
+      when ((cnt_o === dec.lp_0 - 1.U) &&
+            (cnt_i === dec.lp_1 - 1.U) &&
+            (uop_idx === uop_end - 1.U)) {
+       when (inflight =/= 0.U) {
+          state := sWait
+       } .otherwise {
+          state := sIdle
+       }
+      } .otherwise {
+        state := sReadUop
+      }
+    }
+    is (sWait) {
+      when (inflight === 0.U) {
+        state := sIdle
+      }
+    }
+  }
+
+  when (state === sIdle) {
+    inflight := 0.U
+  } .elsewhen (!dec.reset) {
+    when (state === sExe && inflight =/= ((1 << pBits) - 1).asUInt) { // overflow check
+      inflight := inflight + 1.U
+    } .elsewhen (mvc.io.acc_o.data.valid && inflight =/= 0.U) { // underflow check
+      inflight := inflight - 1.U
+    }
+  }
+
+  when (state === sIdle ||
+         (state === sExe &&
+         uop_idx === uop_end - 1.U)) {
+    uop_idx := dec.uop_begin
+  } .elsewhen (state === sExe) {
+    uop_idx := uop_idx + 1.U
+  }
+
+  when (state === sIdle) {
+    cnt_o := 0.U
+    acc_o := 0.U
+    inp_o := 0.U
+    wgt_o := 0.U
+  } .elsewhen (state === sExe &&
+              uop_idx === uop_end - 1.U &&
+              cnt_i === dec.lp_1 - 1.U) {
+    cnt_o := cnt_o + 1.U
+    acc_o := acc_o + dec.acc_0
+    inp_o := inp_o + dec.inp_0
+    wgt_o := wgt_o + dec.wgt_0
+  }
+
+  when (state === sIdle) {
+    cnt_i := 0.U
+    acc_i := 0.U
+    inp_i := 0.U
+    wgt_i := 0.U
+  } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
+    cnt_i := 0.U
+    acc_i := acc_o
+    inp_i := inp_o
+    wgt_i := wgt_o
+  } .elsewhen (state === sExe &&
+              uop_idx === uop_end - 1.U) {
+    cnt_i := cnt_i + 1.U
+    acc_i := acc_i + dec.acc_1
+    inp_i := inp_i + dec.inp_1
+    wgt_i := wgt_i + dec.wgt_1
+  }
+
+  when (state === sComputeIdx && io.uop.data.valid) {
+    uop_acc := io.uop.data.bits.u0 + acc_i
+    uop_inp := io.uop.data.bits.u1 + inp_i
+    uop_wgt := io.uop.data.bits.u2 + wgt_i
+  }
+
+  wrpipe.io.enq.valid := state === sExe & ~dec.reset
+  wrpipe.io.enq.bits := uop_acc
+
+  // uop
+  io.uop.idx.valid := state === sReadUop
+  io.uop.idx.bits := uop_idx
+
+  // inp
+  io.inp.rd.idx.valid := state === sReadTensor
+  io.inp.rd.idx.bits := uop_inp
+  io.inp.tieoffWrite() // read-only
+
+  // wgt
+  io.wgt.rd.idx.valid := state === sReadTensor
+  io.wgt.rd.idx.bits := uop_wgt
+  io.wgt.tieoffWrite() // read-only
+
+  // acc_i
+  io.acc.rd.idx.valid := state === sReadTensor
+  io.acc.rd.idx.bits := uop_acc
+
+  // mvc
+  mvc.io.reset := dec.reset & state === sExe
+  mvc.io.inp.data <> io.inp.rd.data
+  mvc.io.wgt.data <> io.wgt.rd.data
+  mvc.io.acc_i.data <> io.acc.rd.data
+
+  // acc_o
+  io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid)
+  io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
+  io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
+
+  // out
+  io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
+  io.out.wr.bits.idx := wrpipe.io.deq.bits
+  io.out.wr.bits.data <> mvc.io.out.data.bits
+  io.out.tieoffRead() // write-only
+
+  io.done := done
+
+  if (debug) {
+    when (state === sReadUop && ~dec.reset) {
+      printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
+    }
+
+    when (state === sReadTensor && ~dec.reset) {
+      printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
+    }
+
+    io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
+      when (io.inp.rd.data.valid && ~dec.reset) {
+        printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
+      }
+    }
+
+    io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
+      when (io.wgt.rd.data.valid && ~dec.reset) {
+        printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
+      }
+    }
+
+    io.acc.rd.data.bits.foreach { tensor =>
+      tensor.zipWithIndex.foreach { case(elem, i) =>
+        when (io.acc.rd.data.valid && ~dec.reset) {
+          printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
+        }
+      }
+    }
+
+    mvc.io.acc_o.data.bits.foreach { tensor =>
+      tensor.zipWithIndex.foreach { case(elem, i) =>
+        when (mvc.io.acc_o.data.valid && ~dec.reset) {
+          printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
+        }
+      }
+    }
+
+    mvc.io.out.data.bits.foreach { tensor =>
+      tensor.zipWithIndex.foreach { case(elem, i) =>
+        when (mvc.io.out.data.valid && ~dec.reset) {
+          printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
+        }
+      }
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
new file mode 100644 (file)
index 0000000..d96a681
--- /dev/null
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** TensorStore.
+  *
+  * Load 1D and 2D tensors from main memory (DRAM) to input/weight
+  * scratchpads (SRAM). Also, there is support for zero padding, while
+  * doing the load. Zero-padding works on the y and x axis, and it is
+  * managed by TensorPadCtrl. The TensorDataCtrl is in charge of
+  * handling the way tensors are stored on the scratchpads.
+  */
+class TensorLoad(tensorType: String = "none", debug: Boolean = false)
+  (implicit p: Parameters) extends Module {
+  val tp = new TensorParams(tensorType)
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val baddr = Input(UInt(mp.addrBits.W))
+    val vme_rd = new VMEReadMaster
+    val tensor = new TensorClient(tensorType)
+  })
+  val sizeFactor = tp.tensorLength * tp.numMemBlock
+  val strideFactor = tp.tensorLength * tp.tensorWidth
+
+  val dec = io.inst.asTypeOf(new MemDecode)
+  val dataCtrl = Module(new TensorDataCtrl(sizeFactor, strideFactor))
+  val dataCtrlDone = RegInit(false.B)
+  val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
+  val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
+  val xPadCtrl0 = Module(new TensorPadCtrl(padType = "XPad0", sizeFactor))
+  val xPadCtrl1 = Module(new TensorPadCtrl(padType = "XPad1", sizeFactor))
+
+  val tag = Reg(UInt(8.W))
+  val set = Reg(UInt(8.W))
+
+  val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7)
+  val state = RegInit(sIdle)
+
+  // control
+  switch (state) {
+    is (sIdle) {
+      when (io.start) {
+        when (dec.ypad_0 =/= 0.U) {
+          state := sYPad0
+       } .elsewhen (dec.xpad_0 =/= 0.U) {
+          state := sXPad0
+       } .otherwise {
+          state := sReadCmd
+       }
+      }
+    }
+    is (sYPad0) {
+      when (yPadCtrl0.io.done) {
+        when (dec.xpad_0 =/= 0.U) {
+          state := sXPad0
+       } .otherwise {
+          state := sReadCmd
+       }
+      }
+    }
+    is (sXPad0) {
+      when (xPadCtrl0.io.done) {
+        state := sReadCmd
+      }
+    }
+    is (sReadCmd) {
+      when (io.vme_rd.cmd.ready) {
+        state := sReadData
+      }
+    }
+    is (sReadData) {
+      when (io.vme_rd.data.valid) {
+        when (dataCtrl.io.done) {
+         when (dec.xpad_1 =/= 0.U) {
+           state := sXPad1
+         } .elsewhen (dec.ypad_1 =/= 0.U) {
+           state := sYPad1
+         } .otherwise  {
+           state := sIdle
+         }
+       } .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) {
+          when (dec.xpad_1 =/= 0.U) {
+           state := sXPad1
+         } .elsewhen (dec.xpad_0 =/= 0.U) {
+            state := sXPad0
+         } .otherwise {
+              state := sReadCmd
+         }
+       }
+      }
+    }
+    is (sXPad1) {
+      when (xPadCtrl1.io.done) {
+        when (dataCtrlDone) {
+          when (dec.ypad_1 =/= 0.U) {
+            state := sYPad1
+          } .otherwise {
+            state := sIdle
+          }
+        } .otherwise {
+          when (dec.xpad_0 =/= 0.U) {
+            state := sXPad0
+          } .otherwise {
+            state := sReadCmd
+          }
+        }
+      }
+    }
+    is (sYPad1) {
+      when (yPadCtrl1.io.done && dataCtrlDone) {
+        state := sIdle
+      }
+    }
+  }
+
+  // data controller
+  dataCtrl.io.start := state === sIdle & io.start
+  dataCtrl.io.inst := io.inst
+  dataCtrl.io.baddr := io.baddr
+  dataCtrl.io.xinit := io.vme_rd.cmd.fire()
+  dataCtrl.io.xupdate := io.vme_rd.data.fire()
+  dataCtrl.io.yupdate := io.vme_rd.data.fire()
+
+  when (state === sIdle) {
+    dataCtrlDone := false.B
+  } .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) {
+    dataCtrlDone := true.B
+  }
+
+  // pad
+  yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
+
+  yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
+                          ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
+                           (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
+
+  xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
+                          ((state === sIdle & io.start) |
+                         (state === sYPad0 & yPadCtrl0.io.done) |
+                          (io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
+                         (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
+
+  xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
+                          ((dataCtrl.io.done) |
+                          (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
+
+  yPadCtrl0.io.inst := io.inst
+  yPadCtrl1.io.inst := io.inst
+  xPadCtrl0.io.inst := io.inst
+  xPadCtrl1.io.inst := io.inst
+
+  // read-from-dram
+  io.vme_rd.cmd.valid := state === sReadCmd
+  io.vme_rd.cmd.bits.addr := dataCtrl.io.addr
+  io.vme_rd.cmd.bits.len := dataCtrl.io.len
+
+  io.vme_rd.data.ready := state === sReadData
+
+  // write-to-sram
+  val isZeroPad = state === sYPad0 |
+                  state === sXPad0 |
+                 state === sXPad1 |
+                 state === sYPad1
+
+  when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
+    tag := 0.U
+  } .elsewhen (io.vme_rd.data.fire() || isZeroPad) {
+    tag := tag + 1.U
+  }
+
+  when (state === sIdle || state === sReadCmd || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
+    set := 0.U
+  } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
+    set := set + 1.U
+  }
+
+  val waddr_cur = Reg(UInt(tp.memAddrBits.W))
+  val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
+  when (state === sIdle) {
+    waddr_cur := dec.sram_offset
+    waddr_nxt := dec.sram_offset
+  } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
+    waddr_cur := waddr_cur + 1.U
+  } .elsewhen (dataCtrl.io.stride) {
+    waddr_cur := waddr_nxt + dec.xsize
+    waddr_nxt := waddr_nxt + dec.xsize
+  }
+
+  val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+  val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
+  val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+  val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
+  no_mask.foreach { m => m := true.B }
+
+  for (i <- 0 until tp.tensorLength) {
+    for (j <- 0 until tp.numMemBlock) {
+      wmask(i)(j) := tag === j.U
+      wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
+    }
+    val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
+    val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
+    val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
+    val muxWdata = Mux(state === sIdle, tdata, wdata(i))
+    val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
+    when (muxWen) {
+      tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
+    }
+  }
+
+  // read-from-sram
+  val rvalid = RegNext(io.tensor.rd.idx.valid)
+  io.tensor.rd.data.valid := rvalid
+
+  val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
+  rdata.zipWithIndex.foreach { case(r, i) =>
+    io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
+  }
+
+  // done
+  val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
+  val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
+  val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
+  io.done := done_no_pad | done_x_pad | done_y_pad
+
+  // debug
+  if (debug) {
+    if (tensorType == "inp") {
+      when (io.vme_rd.cmd.fire()) {
+        printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+      }
+      when (state === sYPad0) {
+        printf("[TensorLoad] [inp] sYPad0\n")
+      }
+      when (state === sYPad1) {
+        printf("[TensorLoad] [inp] sYPad1\n")
+      }
+      when (state === sXPad0) {
+        printf("[TensorLoad] [inp] sXPad0\n")
+      }
+      when (state === sXPad1) {
+        printf("[TensorLoad] [inp] sXPad1\n")
+      }
+    } else if (tensorType == "wgt") {
+      when (io.vme_rd.cmd.fire()) {
+        printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+      }
+    } else if (tensorType == "acc") {
+      when (io.vme_rd.cmd.fire()) {
+        printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+      }
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/TensorStore.scala b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala
new file mode 100644 (file)
index 0000000..0012e47
--- /dev/null
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** TensorStore.
+  *
+  * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
+  */
+class TensorStore(tensorType: String = "true", debug: Boolean = false)
+  (implicit p: Parameters) extends Module {
+  val tp = new TensorParams(tensorType)
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val baddr = Input(UInt(mp.addrBits.W))
+    val vme_wr = new VMEWriteMaster
+    val tensor = new TensorClient(tensorType)
+  })
+  val tensorLength = tp.tensorLength
+  val tensorWidth = tp.tensorWidth
+  val tensorElemBits = tp.tensorElemBits
+  val memBlockBits = tp.memBlockBits
+  val memDepth = tp.memDepth
+  val numMemBlock = tp.numMemBlock
+
+  val dec = io.inst.asTypeOf(new MemDecode)
+  val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
+  val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
+  val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
+  val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
+  val xrem = Reg(chiselTypeOf(dec.xsize))
+  val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U
+  val xmax = (1 << mp.lenBits).U
+  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+  val ycnt = Reg(chiselTypeOf(dec.ysize))
+  val ysize = dec.ysize
+  val tag = Reg(UInt(8.W))
+  val set = Reg(UInt(8.W))
+
+  val sIdle :: sWriteCmd :: sWriteData :: sReadMem :: sWriteAck :: Nil = Enum(5)
+  val state = RegInit(sIdle)
+
+  // control
+  switch (state) {
+    is (sIdle) {
+      when (io.start) {
+        state := sWriteCmd
+       when (xsize < xmax) {
+          xlen := xsize
+          xrem := 0.U
+       } .otherwise {
+          xlen := xmax - 1.U
+          xrem := xsize - xmax
+       }
+      }
+    }
+    is (sWriteCmd) {
+      when (io.vme_wr.cmd.ready) {
+        state := sWriteData
+      }
+    }
+    is (sWriteData) {
+      when (io.vme_wr.data.ready) {
+        when (xcnt === xlen) {
+          state := sWriteAck
+        } .elsewhen (tag === (numMemBlock - 1).U) {
+          state := sReadMem
+       }
+      }
+    }
+    is (sReadMem) {
+      state := sWriteData
+    }
+    is (sWriteAck) {
+      when (io.vme_wr.ack) {
+        when (xrem === 0.U) {
+         when (ycnt === ysize - 1.U) {
+            state := sIdle
+         } .otherwise {
+            state := sWriteCmd
+           when (xsize < xmax) {
+              xlen := xsize
+              xrem := 0.U
+           } .otherwise {
+              xlen := xmax - 1.U
+              xrem := xsize - xmax
+           }
+         }
+       } .elsewhen (xrem < xmax) {
+          state := sWriteCmd
+          xlen := xrem
+          xrem := 0.U
+       } .otherwise {
+          state := sWriteCmd
+          xlen := xmax - 1.U
+          xrem := xrem - xmax
+       }
+      }
+    }
+  }
+
+  // write-to-sram
+  val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) }
+  val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
+  val no_mask = Wire(Vec(numMemBlock, Bool()))
+
+  wdata_t := DontCare
+  no_mask.foreach { m => m := true.B }
+
+  for (i <- 0 until tensorLength) {
+    val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
+    when (io.tensor.wr.valid) {
+      tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
+    }
+  }
+
+  // read-from-sram
+  val stride = state === sWriteAck &
+              io.vme_wr.ack &
+              xcnt === xlen + 1.U &
+             xrem === 0.U &
+             ycnt =/= ysize - 1.U
+
+  when (state === sIdle) {
+    ycnt := 0.U
+  } .elsewhen (stride) {
+    ycnt := ycnt + 1.U
+  }
+
+  when (state === sWriteCmd || tag === (numMemBlock - 1).U) {
+    tag := 0.U
+  } .elsewhen (io.vme_wr.data.fire()) {
+    tag := tag + 1.U
+  }
+
+  when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
+    set := 0.U
+  } .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
+    set := set + 1.U
+  }
+
+  val raddr_cur = Reg(UInt(tp.memAddrBits.W))
+  val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
+  when (state === sIdle) {
+    raddr_cur := dec.sram_offset
+    raddr_nxt := dec.sram_offset
+  } .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
+    raddr_cur := raddr_cur + 1.U
+  } .elsewhen (stride) {
+    raddr_cur := raddr_nxt + dec.xsize
+    raddr_nxt := raddr_nxt + dec.xsize
+  }
+
+  val tread = Seq.tabulate(tensorLength) { i => i.U ->
+    tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) }
+  val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
+
+  // write-to-dram
+  when (state === sIdle) {
+    waddr_cur := io.baddr + dec.dram_offset
+    waddr_nxt := io.baddr + dec.dram_offset
+  } .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
+    waddr_cur := waddr_cur + xmax_bytes
+  } .elsewhen (stride) {
+    waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
+    waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
+  }
+
+  io.vme_wr.cmd.valid := state === sWriteCmd
+  io.vme_wr.cmd.bits.addr := waddr_cur
+  io.vme_wr.cmd.bits.len := xlen
+
+  io.vme_wr.data.valid := state === sWriteData
+  io.vme_wr.data.bits := mdata(tag)
+
+  when (state === sWriteCmd) {
+    xcnt := 0.U
+  } .elsewhen (io.vme_wr.data.fire()) {
+    xcnt := xcnt + 1.U
+  }
+
+  // disable external read-from-sram requests
+  io.tensor.tieoffRead()
+
+  // done
+  io.done := state === sWriteAck & io.vme_wr.ack & xrem === 0.U & ycnt === ysize - 1.U
+
+  // debug
+  if (debug) {
+    when (io.vme_wr.cmd.fire()) {
+      printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
+    }
+    when (io.vme_wr.data.fire()) {
+      printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
+    }
+    when (io.vme_wr.ack) {
+      printf("[TensorStore] ack\n")
+    }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala b/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala
new file mode 100644 (file)
index 0000000..e41a2c5
--- /dev/null
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.core
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.shell._
+
+/** TensorParams.
+  *
+  * This Bundle derives parameters for each tensorType, including inputs (inp),
+  * weights (wgt), biases (acc), and outputs (out). This is used to avoid
+  * doing the same boring calculations over and over again.
+  */
+class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
+  val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
+
+  require (tensorType == "inp" || tensorType == "wgt"
+    || tensorType == "acc" || tensorType == "out", errorMsg)
+
+  val (tensorLength, tensorWidth, tensorElemBits) =
+    if (tensorType == "inp")
+      (p(CoreKey).batch, p(CoreKey).blockIn, p(CoreKey).inpBits)
+    else if (tensorType == "wgt")
+      (p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits)
+    else if (tensorType == "acc")
+      (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits)
+    else
+      (p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits)
+
+  val memBlockBits = p(ShellKey).memParams.dataBits
+  val numMemBlock = (tensorWidth * tensorElemBits) / memBlockBits
+
+  val memDepth =
+    if (tensorType == "inp")
+      p(CoreKey).inpMemDepth
+    else if (tensorType == "wgt")
+      p(CoreKey).wgtMemDepth
+    else if (tensorType == "acc")
+      p(CoreKey).accMemDepth
+    else
+      p(CoreKey).outMemDepth
+
+  val memAddrBits = log2Ceil(memDepth)
+}
+
+/** TensorMaster.
+  *
+  * This interface issue read and write tensor-requests to scratchpads. For example,
+  * The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt),
+  * biases (acc), and outputs (out).
+  *
+  */
+class TensorMaster(tensorType: String = "none")
+  (implicit p: Parameters) extends TensorParams(tensorType) {
+    val rd = new Bundle {
+      val idx = ValidIO(UInt(memAddrBits.W))
+      val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+    }
+    val wr = ValidIO(new Bundle {
+      val idx = UInt(memAddrBits.W)
+      val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+    })
+    def tieoffRead() {
+      rd.idx.valid := false.B
+      rd.idx.bits := 0.U
+    }
+    def tieoffWrite() {
+      wr.valid := false.B
+      wr.bits.idx := 0.U
+      wr.bits.data.foreach { b => b.foreach { c => c := 0.U } }
+    }
+  override def cloneType =
+    new TensorMaster(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorClient.
+  *
+  * This interface receives read and write tensor-requests to scratchpads. For example,
+  * The TensorLoad unit uses this interface for receiving read and write requests from
+  * the TensorGemm unit.
+  */
+class TensorClient(tensorType: String = "none")
+  (implicit p: Parameters) extends TensorParams(tensorType) {
+    val rd = new Bundle {
+      val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
+      val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+    }
+    val wr = Flipped(ValidIO(new Bundle {
+      val idx = UInt(memAddrBits.W)
+      val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+    }))
+    def tieoffRead() {
+      rd.data.valid := false.B
+      rd.data.bits.foreach { b => b.foreach { c => c := 0.U } }
+    }
+  override def cloneType =
+    new TensorClient(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorMasterData.
+  *
+  * This interface is only used for datapath only purposes and the direction convention
+  * is based on the TensorMaster interface, which means this is an input. This interface
+  * is used on datapath only module such MatrixVectorCore or AluVector.
+  */
+class TensorMasterData(tensorType: String = "none")
+  (implicit p: Parameters) extends TensorParams(tensorType) {
+  val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+  override def cloneType =
+    new TensorMasterData(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorClientData.
+  *
+  * This interface is only used for datapath only purposes and the direction convention
+  * is based on the TensorClient interface, which means this is an output. This interface
+  * is used on datapath only module such MatrixVectorCore or AluVector.
+  */
+class TensorClientData(tensorType: String = "none")
+  (implicit p: Parameters) extends TensorParams(tensorType) {
+  val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+  override def cloneType =
+    new TensorClientData(tensorType).asInstanceOf[this.type]
+}
+
+/** TensorPadCtrl. Zero-padding controller for TensorLoad. */
+class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
+  val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
+  require (padType == "YPad0" || padType == "YPad1"
+    || padType == "XPad0" || padType == "XPad1", errorMsg)
+
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+  })
+
+  val dec = io.inst.asTypeOf(new MemDecode)
+
+  val xmax = Reg(chiselTypeOf(dec.xsize))
+  val ymax = Reg(chiselTypeOf(dec.ypad_0))
+  val xcnt = Reg(chiselTypeOf(dec.xsize))
+  val ycnt = Reg(chiselTypeOf(dec.ypad_0))
+
+  val xval =
+    if (padType == "YPad0" || padType == "YPad1")
+      ((dec.xpad_0 + dec.xsize + dec.xpad_1) << log2Ceil(sizeFactor)) - 1.U
+    else if (padType == "XPad0")
+      (dec.xpad_0 << log2Ceil(sizeFactor)) - 1.U
+    else
+      (dec.xpad_1 << log2Ceil(sizeFactor)) - 1.U
+
+  val yval =
+    if (padType == "YPad0")
+      Mux(dec.ypad_0 =/= 0.U, dec.ypad_0 - 1.U, 0.U)
+    else if (padType == "YPad1")
+      Mux(dec.ypad_1 =/= 0.U, dec.ypad_1 - 1.U, 0.U)
+    else
+      0.U
+
+  val sIdle :: sActive :: Nil = Enum(2)
+  val state = RegInit(sIdle)
+
+  switch (state) {
+    is (sIdle) {
+      when (io.start) {
+        state := sActive
+      }
+    }
+    is (sActive) {
+      when (ycnt === ymax && xcnt === xmax) {
+        state := sIdle
+      }
+    }
+  }
+
+  when (state === sIdle) {
+    xmax := xval
+    ymax := yval
+  }
+
+  when (state === sIdle || xcnt === xmax) {
+    xcnt := 0.U
+  } .elsewhen (state === sActive) {
+    xcnt := xcnt + 1.U
+  }
+
+  when (state === sIdle || ymax === 0.U) {
+    ycnt := 0.U
+  } .elsewhen (state === sActive && xcnt === xmax) {
+    ycnt := ycnt + 1.U
+  }
+
+  io.done := state === sActive & ycnt === ymax & xcnt === xmax
+}
+
+/** TensorDataCtrl. Data controller for TensorLoad. */
+class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
+  val mp = p(ShellKey).memParams
+  val io = IO(new Bundle {
+    val start = Input(Bool())
+    val done = Output(Bool())
+    val inst = Input(UInt(INST_BITS.W))
+    val baddr = Input(UInt(mp.addrBits.W))
+    val xinit = Input(Bool())
+    val xupdate = Input(Bool())
+    val yupdate = Input(Bool())
+    val stride = Output(Bool())
+    val split = Output(Bool())
+    val commit = Output(Bool())
+    val addr = Output(UInt(mp.addrBits.W))
+    val len = Output(UInt(mp.lenBits.W))
+  })
+
+  val dec = io.inst.asTypeOf(new MemDecode)
+
+  val caddr = Reg(UInt(mp.addrBits.W))
+  val baddr = Reg(UInt(mp.addrBits.W))
+
+  val len = Reg(UInt(mp.lenBits.W))
+
+  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+  val xcnt = Reg(UInt(mp.lenBits.W))
+  val xrem = Reg(chiselTypeOf(dec.xsize))
+  val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
+  val xmax = (1 << mp.lenBits).U
+  val ycnt = Reg(chiselTypeOf(dec.ysize))
+
+  val stride = xcnt === len &
+              xrem === 0.U &
+              ycnt =/= dec.ysize - 1.U
+
+  val split = xcnt === len & xrem =/= 0.U
+
+  when (io.start || (io.xupdate && stride)) {
+    when (xsize < xmax) {
+      len := xsize
+      xrem := 0.U
+    } .otherwise {
+      len := xmax - 1.U
+      xrem := xsize - xmax
+    }
+  } .elsewhen (io.xupdate && split) {
+    when (xrem < xmax) {
+      len := xrem
+      xrem := 0.U
+    } .otherwise {
+      len := xmax - 1.U
+      xrem := xrem - xmax
+    }
+  }
+
+  when (io.xinit) {
+    xcnt := 0.U
+  } .elsewhen (io.xupdate) {
+    xcnt := xcnt + 1.U
+  }
+
+  when (io.start) {
+    ycnt := 0.U
+  } .elsewhen (io.yupdate && stride) {
+    ycnt := ycnt + 1.U
+  }
+
+  when (io.start) {
+    caddr := io.baddr + dec.dram_offset
+    baddr := io.baddr + dec.dram_offset
+  } .elsewhen (io.yupdate) {
+    when (split) {
+      caddr := caddr + xmax_bytes
+    } .elsewhen (stride) {
+      caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
+      baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
+    }
+  }
+
+  io.stride := stride
+  io.split := split
+  io.commit := xcnt === len
+  io.addr := caddr
+  io.len := len
+  io.done := xcnt === len &
+            xrem === 0.U &
+            ycnt === dec.ysize - 1.U
+}
diff --git a/vta/hardware/chisel/src/main/scala/core/package.scala b/vta/hardware/chisel/src/main/scala/core/package.scala
new file mode 100644 (file)
index 0000000..673d390
--- /dev/null
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta
+
+/** This trick makes ISAConstants globally available */
+package object core extends vta.core.ISAConstants
index aab2d63..115bcbc 100644 (file)
@@ -21,6 +21,9 @@ package vta.dpi
 
 import chisel3._
 import chisel3.util._
+import vta.util.config._
+import vta.interface.axi._
+import vta.shell._
 
 /** Host DPI parameters */
 trait VTAHostDPIParams {
@@ -70,3 +73,83 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
   })
   setResource("/verilog/VTAHostDPI.v")
 }
+
+/** Host DPI to AXI Converter.
+  *
+  * Convert Host DPI to AXI for VTAShell
+  */
+
+class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val dpi = new VTAHostDPIClient
+    val axi = new AXILiteMaster(p(ShellKey).hostParams)
+  })
+  val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
+  val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
+  val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+  val state = RegInit(sIdle)
+
+  switch (state) {
+    is (sIdle) {
+      when (io.dpi.req.valid) {
+        when (io.dpi.req.opcode) {
+          state := sWriteAddress
+        } .otherwise {
+          state := sReadAddress
+        }
+      }
+    }
+    is (sReadAddress) {
+      when (io.axi.ar.ready) {
+        state := sReadData
+      }
+    }
+    is (sReadData) {
+      when (io.axi.r.valid) {
+        state := sIdle
+      }
+    }
+    is (sWriteAddress) {
+      when (io.axi.aw.ready) {
+        state := sWriteData
+      }
+    }
+    is (sWriteData) {
+      when (io.axi.w.ready) {
+        state := sWriteResponse
+      }
+    }
+    is (sWriteResponse) {
+      when (io.axi.b.valid) {
+        state := sIdle
+      }
+    }
+  }
+
+  when (state === sIdle && io.dpi.req.valid) {
+    addr := io.dpi.req.addr
+    data := io.dpi.req.value
+  }
+
+  io.axi.aw.valid := state === sWriteAddress
+  io.axi.aw.bits.addr := addr
+  io.axi.w.valid := state === sWriteData
+  io.axi.w.bits.data := data
+  io.axi.w.bits.strb := "h_f".U
+  io.axi.b.ready := state === sWriteResponse
+
+  io.axi.ar.valid := state === sReadAddress
+  io.axi.ar.bits.addr := addr
+  io.axi.r.ready := state === sReadData
+
+  io.dpi.req.deq := (state === sReadAddress & io.axi.ar.ready) | (state === sWriteAddress & io.axi.aw.ready)
+  io.dpi.resp.valid := io.axi.r.valid
+  io.dpi.resp.bits := io.axi.r.bits.data
+
+  if (debug) {
+    when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) }
+    when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) }
+    when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) }
+    when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) }
+  }
+}
index 090f045..5e2fa74 100644 (file)
@@ -21,6 +21,9 @@ package vta.dpi
 
 import chisel3._
 import chisel3.util._
+import vta.util.config._
+import vta.interface.axi._
+import vta.shell._
 
 /** Memory DPI parameters */
 trait VTAMemDPIParams {
@@ -71,3 +74,98 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
   })
   setResource("/verilog/VTAMemDPI.v")
 }
+
+class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val dpi = new VTAMemDPIMaster
+    val axi = new AXIClient(p(ShellKey).memParams)
+  })
+  val opcode = RegInit(false.B)
+  val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
+  val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
+  val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+  val state = RegInit(sIdle)
+
+  switch (state) {
+    is (sIdle) {
+      when (io.axi.ar.valid) {
+        state := sReadAddress
+      } .elsewhen (io.axi.aw.valid) {
+        state := sWriteAddress
+      }
+    }
+    is (sReadAddress) {
+      when (io.axi.ar.valid) {
+        state := sReadData
+      }
+    }
+    is (sReadData) {
+      when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
+        state := sIdle
+      }
+    }
+    is (sWriteAddress) {
+      when (io.axi.aw.valid) {
+        state := sWriteData
+      }
+    }
+    is (sWriteData) {
+      when (io.axi.w.valid && io.axi.w.bits.last) {
+        state := sWriteResponse
+      }
+    }
+    is (sWriteResponse) {
+      when (io.axi.b.ready) {
+        state := sIdle
+      }
+    }
+  }
+
+  when (state === sIdle) {
+    when (io.axi.ar.valid) {
+      opcode := false.B
+      len := io.axi.ar.bits.len
+      addr := io.axi.ar.bits.addr
+    } .elsewhen (io.axi.aw.valid) {
+      opcode := true.B
+      len := io.axi.aw.bits.len
+      addr := io.axi.aw.bits.addr
+    }
+  } .elsewhen (state === sReadData) {
+    when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
+      len := len - 1.U
+    }
+  }
+
+  io.dpi.req.valid := (state === sReadAddress & io.axi.ar.valid) | (state === sWriteAddress & io.axi.aw.valid)
+  io.dpi.req.opcode := opcode
+  io.dpi.req.len := len
+  io.dpi.req.addr := addr
+
+  io.axi.ar.ready := state === sReadAddress
+  io.axi.aw.ready := state === sWriteAddress
+
+  io.axi.r.valid := state === sReadData & io.dpi.rd.valid
+  io.axi.r.bits.data := io.dpi.rd.bits
+  io.axi.r.bits.last := len === 0.U
+  io.axi.r.bits.resp := 0.U
+  io.axi.r.bits.user := 0.U
+  io.axi.r.bits.id := 0.U
+  io.dpi.rd.ready := state === sReadData & io.axi.r.ready
+
+  io.dpi.wr.valid := state === sWriteData & io.axi.w.valid
+  io.dpi.wr.bits := io.axi.w.bits.data
+  io.axi.w.ready := state === sWriteData
+
+  io.axi.b.valid := state === sWriteResponse
+  io.axi.b.bits.resp := 0.U
+  io.axi.b.bits.user := 0.U
+  io.axi.b.bits.id := 0.U
+
+  if (debug) {
+    when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) }
+    when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) }
+    when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) }
+    when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) }
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala b/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala
new file mode 100644 (file)
index 0000000..a853e85
--- /dev/null
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.interface.axi
+
+import chisel3._
+import chisel3.util._
+import vta.util.genericbundle._
+
+case class AXIParams(
+  addrBits: Int = 32,
+  dataBits: Int = 64
+)
+{
+  require (addrBits > 0)
+  require (dataBits >= 8 && dataBits % 2 == 0)
+
+  val idBits = 1
+  val userBits = 1
+  val strbBits = dataBits/8
+  val lenBits = 8
+  val sizeBits = 3
+  val burstBits = 2
+  val lockBits = 2
+  val cacheBits = 4
+  val protBits = 3
+  val qosBits = 4
+  val regionBits = 4
+  val respBits = 2
+  val sizeConst = log2Ceil(dataBits/8)
+  val idConst = 0
+  val userConst = 0
+  val burstConst = 1
+  val lockConst = 0
+  val cacheConst = 3
+  val protConst = 0
+  val qosConst = 0
+  val regionConst = 0
+}
+
+abstract class AXIBase(params: AXIParams)
+  extends GenericParameterizedBundle(params)
+
+// AXILite
+
+class AXILiteAddress(params: AXIParams) extends AXIBase(params) {
+  val addr = UInt(params.addrBits.W)
+}
+
+class AXILiteWriteData(params: AXIParams) extends AXIBase(params) {
+  val data = UInt(params.dataBits.W)
+  val strb = UInt(params.strbBits.W)
+}
+
+class AXILiteWriteResponse(params: AXIParams) extends AXIBase(params) {
+  val resp = UInt(params.respBits.W)
+}
+
+class AXILiteReadData(params: AXIParams) extends AXIBase(params) {
+  val data = UInt(params.dataBits.W)
+  val resp = UInt(params.respBits.W)
+}
+
+class AXILiteMaster(params: AXIParams) extends AXIBase(params) {
+  val aw = Decoupled(new AXILiteAddress(params))
+  val w = Decoupled(new AXILiteWriteData(params))
+  val b = Flipped(Decoupled(new AXILiteWriteResponse(params)))
+  val ar = Decoupled(new AXILiteAddress(params))
+  val r = Flipped(Decoupled(new AXILiteReadData(params)))
+
+  def tieoff() {
+    aw.valid := false.B
+    aw.bits.addr := 0.U
+    w.valid := false.B
+    w.bits.data := 0.U
+    w.bits.strb := 0.U
+    b.ready := false.B
+    ar.valid := false.B
+    ar.bits.addr := 0.U
+    r.ready := false.B
+  }
+}
+
+class AXILiteClient(params: AXIParams) extends AXIBase(params) {
+  val aw = Flipped(Decoupled(new AXILiteAddress(params)))
+  val w = Flipped(Decoupled(new AXILiteWriteData(params)))
+  val b = Decoupled(new AXILiteWriteResponse(params))
+  val ar = Flipped(Decoupled(new AXILiteAddress(params)))
+  val r = Decoupled(new AXILiteReadData(params))
+
+  def tieoff() {
+    aw.ready := false.B
+    w.ready := false.B
+    b.valid := false.B
+    b.bits.resp := 0.U
+    ar.ready := false.B
+    r.valid := false.B
+    r.bits.resp := 0.U
+    r.bits.data := 0.U
+  }
+}
+
+// AXI extends AXILite
+
+class AXIAddress(params: AXIParams) extends AXILiteAddress(params) {
+  val id = UInt(params.idBits.W)
+  val user = UInt(params.userBits.W)
+  val len = UInt(params.lenBits.W)
+  val size = UInt(params.sizeBits.W)
+  val burst = UInt(params.burstBits.W)
+  val lock = UInt(params.lockBits.W)
+  val cache = UInt(params.cacheBits.W)
+  val prot = UInt(params.protBits.W)
+  val qos = UInt(params.qosBits.W)
+  val region = UInt(params.regionBits.W)
+}
+
+class AXIWriteData(params: AXIParams) extends AXILiteWriteData(params) {
+  val last = Bool()
+  val id = UInt(params.idBits.W)
+  val user = UInt(params.userBits.W)
+}
+
+class AXIWriteResponse(params: AXIParams) extends AXILiteWriteResponse(params) {
+  val id = UInt(params.idBits.W)
+  val user = UInt(params.userBits.W)
+}
+
+class AXIReadData(params: AXIParams) extends AXILiteReadData(params) {
+  val last = Bool()
+  val id = UInt(params.idBits.W)
+  val user = UInt(params.userBits.W)
+}
+
+class AXIMaster(params: AXIParams) extends AXIBase(params) {
+  val aw = Decoupled(new AXIAddress(params))
+  val w = Decoupled(new AXIWriteData(params))
+  val b = Flipped(Decoupled(new AXIWriteResponse(params)))
+  val ar = Decoupled(new AXIAddress(params))
+  val r = Flipped(Decoupled(new AXIReadData(params)))
+
+  def tieoff() {
+    aw.valid := false.B
+    aw.bits.addr := 0.U
+    aw.bits.id := 0.U
+    aw.bits.user := 0.U
+    aw.bits.len := 0.U
+    aw.bits.size := 0.U
+    aw.bits.burst := 0.U
+    aw.bits.lock := 0.U
+    aw.bits.cache := 0.U
+    aw.bits.prot := 0.U
+    aw.bits.qos := 0.U
+    aw.bits.region := 0.U
+    w.valid := false.B
+    w.bits.data := 0.U
+    w.bits.strb := 0.U
+    w.bits.last := false.B
+    w.bits.id := 0.U
+    w.bits.user := 0.U
+    b.ready := false.B
+    ar.valid := false.B
+    ar.bits.addr := 0.U
+    ar.bits.id := 0.U
+    ar.bits.user := 0.U
+    ar.bits.len := 0.U
+    ar.bits.size := 0.U
+    ar.bits.burst := 0.U
+    ar.bits.lock := 0.U
+    ar.bits.cache := 0.U
+    ar.bits.prot := 0.U
+    ar.bits.qos := 0.U
+    ar.bits.region := 0.U
+    r.ready := false.B
+  }
+
+  def setConst() {
+    aw.bits.user := params.userConst.U
+    aw.bits.burst := params.burstConst.U
+    aw.bits.lock := params.lockConst.U
+    aw.bits.cache := params.cacheConst.U
+    aw.bits.prot := params.protConst.U
+    aw.bits.qos := params.qosConst.U
+    aw.bits.region := params.regionConst.U
+    aw.bits.size := params.sizeConst.U
+    aw.bits.id := params.idConst.U
+    w.bits.id := params.idConst.U
+    w.bits.user := params.userConst.U
+    w.bits.strb := Fill(params.strbBits, true.B)
+    ar.bits.user := params.userConst.U
+    ar.bits.burst := params.burstConst.U
+    ar.bits.lock := params.lockConst.U
+    ar.bits.cache := params.cacheConst.U
+    ar.bits.prot := params.protConst.U
+    ar.bits.qos := params.qosConst.U
+    ar.bits.region := params.regionConst.U
+    ar.bits.size := params.sizeConst.U
+    ar.bits.id := params.idConst.U
+  }
+}
+
+class AXIClient(params: AXIParams) extends AXIBase(params) {
+  val aw = Flipped(Decoupled(new AXIAddress(params)))
+  val w = Flipped(Decoupled(new AXIWriteData(params)))
+  val b = Decoupled(new AXIWriteResponse(params))
+  val ar = Flipped(Decoupled(new AXIAddress(params)))
+  val r = Decoupled(new AXIReadData(params))
+
+  def tieoff() {
+    aw.ready := false.B
+    w.ready := false.B
+    b.valid := false.B
+    b.bits.resp := 0.U
+    b.bits.user := 0.U
+    b.bits.id := 0.U
+    ar.ready := false.B
+    r.valid := false.B
+    r.bits.resp := 0.U
+    r.bits.data := 0.U
+    r.bits.user := 0.U
+    r.bits.last := false.B
+    r.bits.id := 0.U
+  }
+}
+
+// XilinxAXILiteClient and XilinxAXIMaster bundles are needed
+// for wrapper purposes, because the package RTL tool in Xilinx Vivado
+// only allows certain name formats
+
+class XilinxAXILiteClient(params: AXIParams) extends AXIBase(params) {
+  val AWVALID = Input(Bool())
+  val AWREADY = Output(Bool())
+  val AWADDR = Input(UInt(params.addrBits.W))
+  val WVALID = Input(Bool())
+  val WREADY = Output(Bool())
+  val WDATA = Input(UInt(params.dataBits.W))
+  val WSTRB = Input(UInt(params.strbBits.W))
+  val BVALID = Output(Bool())
+  val BREADY = Input(Bool())
+  val BRESP = Output(UInt(params.respBits.W))
+  val ARVALID = Input(Bool())
+  val ARREADY = Output(Bool())
+  val ARADDR = Input(UInt(params.addrBits.W))
+  val RVALID = Output(Bool())
+  val RREADY = Input(Bool())
+  val RDATA = Output(UInt(params.dataBits.W))
+  val RRESP = Output(UInt(params.respBits.W))
+}
+
+class XilinxAXIMaster(params: AXIParams) extends AXIBase(params) {
+  val AWVALID = Output(Bool())
+  val AWREADY = Input(Bool())
+  val AWADDR = Output(UInt(params.addrBits.W))
+  val AWID = Output(UInt(params.idBits.W))
+  val AWUSER = Output(UInt(params.userBits.W))
+  val AWLEN = Output(UInt(params.lenBits.W))
+  val AWSIZE = Output(UInt(params.sizeBits.W))
+  val AWBURST = Output(UInt(params.burstBits.W))
+  val AWLOCK = Output(UInt(params.lockBits.W))
+  val AWCACHE = Output(UInt(params.cacheBits.W))
+  val AWPROT = Output(UInt(params.protBits.W))
+  val AWQOS = Output(UInt(params.qosBits.W))
+  val AWREGION = Output(UInt(params.regionBits.W))
+  val WVALID = Output(Bool())
+  val WREADY = Input(Bool())
+  val WDATA = Output(UInt(params.dataBits.W))
+  val WSTRB = Output(UInt(params.strbBits.W))
+  val WLAST = Output(Bool())
+  val WID = Output(UInt(params.idBits.W))
+  val WUSER = Output(UInt(params.userBits.W))
+  val BVALID = Input(Bool())
+  val BREADY = Output(Bool())
+  val BRESP = Input(UInt(params.respBits.W))
+  val BID = Input(UInt(params.idBits.W))
+  val BUSER = Input(UInt(params.userBits.W))
+  val ARVALID = Output(Bool())
+  val ARREADY = Input(Bool())
+  val ARADDR = Output(UInt(params.addrBits.W))
+  val ARID = Output(UInt(params.idBits.W))
+  val ARUSER = Output(UInt(params.userBits.W))
+  val ARLEN = Output(UInt(params.lenBits.W))
+  val ARSIZE = Output(UInt(params.sizeBits.W))
+  val ARBURST = Output(UInt(params.burstBits.W))
+  val ARLOCK = Output(UInt(params.lockBits.W))
+  val ARCACHE = Output(UInt(params.cacheBits.W))
+  val ARPROT = Output(UInt(params.protBits.W))
+  val ARQOS = Output(UInt(params.qosBits.W))
+  val ARREGION = Output(UInt(params.regionBits.W))
+  val RVALID = Input(Bool())
+  val RREADY = Output(Bool())
+  val RDATA = Input(UInt(params.dataBits.W))
+  val RRESP = Input(UInt(params.respBits.W))
+  val RLAST = Input(Bool())
+  val RID = Input(UInt(params.idBits.W))
+  val RUSER = Input(UInt(params.userBits.W))
+}
diff --git a/vta/hardware/chisel/src/main/scala/shell/Configs.scala b/vta/hardware/chisel/src/main/scala/shell/Configs.scala
new file mode 100644 (file)
index 0000000..1d1d522
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.interface.axi._
+
+/** PynqConfig. Shell configuration for Pynq */
+class PynqConfig extends Config((site, here, up) => {
+  case ShellKey => ShellParams(
+    hostParams = AXIParams(
+      addrBits = 16,
+      dataBits = 32),
+    memParams = AXIParams(
+      addrBits = 32,
+      dataBits = 64),
+    vcrParams = VCRParams(),
+    vmeParams = VMEParams())
+})
+
+/** F1Config. Shell configuration for F1 */
+class F1Config extends Config((site, here, up) => {
+  case ShellKey => ShellParams(
+    hostParams = AXIParams(
+      addrBits = 16,
+      dataBits = 32),
+    memParams = AXIParams(
+      addrBits = 64,
+      dataBits = 64),
+    vcrParams = VCRParams(),
+    vmeParams = VMEParams())
+})
diff --git a/vta/hardware/chisel/src/main/scala/shell/SimShell.scala b/vta/hardware/chisel/src/main/scala/shell/SimShell.scala
new file mode 100644 (file)
index 0000000..3ad4b65
--- /dev/null
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import vta.util.config._
+import vta.interface.axi._
+import vta.shell._
+import vta.dpi._
+
+/** VTAHost.
+  *
+  * This module translate the DPI protocol into AXI. This is a simulation only
+  * module and used to test host-to-VTA communication. This module should be updated
+  * for testing hosts using a different bus protocol, other than AXI.
+  */
+class VTAHost(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val axi = new AXILiteMaster(p(ShellKey).hostParams)
+  })
+  val host_dpi = Module(new VTAHostDPI)
+  val host_axi = Module(new VTAHostDPIToAXI)
+  host_dpi.io.reset := reset
+  host_dpi.io.clock := clock
+  host_axi.io.dpi <> host_dpi.io.dpi
+  io.axi <> host_axi.io.axi
+}
+
+/** VTAMem.
+  *
+  * This module translate the DPI protocol into AXI. This is a simulation only
+  * module and used to test VTA-to-memory communication. This module should be updated
+  * for testing memories using a different bus protocol, other than AXI.
+  */
+class VTAMem(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val axi = new AXIClient(p(ShellKey).memParams)
+  })
+  val mem_dpi = Module(new VTAMemDPI)
+  val mem_axi = Module(new VTAMemDPIToAXI)
+  mem_dpi.io.reset := reset
+  mem_dpi.io.clock := clock
+  mem_dpi.io.dpi <> mem_axi.io.dpi
+  mem_axi.io.axi <> io.axi
+}
+
+/** SimShell.
+  *
+  * The simulation shell instantiate a host and memory simulation modules and it is
+  * intended to be connected to the VTAShell.
+  */
+class SimShell(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val mem = new AXIClient(p(ShellKey).memParams)
+    val host = new AXILiteMaster(p(ShellKey).hostParams)
+  })
+  val host = Module(new VTAHost)
+  val mem = Module(new VTAMem)
+  io.mem <> mem.io.axi
+  io.host <> host.io.axi
+}
diff --git a/vta/hardware/chisel/src/main/scala/shell/VCR.scala b/vta/hardware/chisel/src/main/scala/shell/VCR.scala
new file mode 100644 (file)
index 0000000..463f55b
--- /dev/null
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.util.genericbundle._
+import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.LinkedHashMap
+import vta.interface.axi._
+
+/** VCR parameters.
+  *
+  * These parameters are used on VCR interfaces and modules.
+  */
+case class VCRParams()
+{
+  val nValsReg: Int = 1
+  val nPtrsReg: Int = 6
+  val regBits: Int = 32
+  val nCtrlReg: Int = 4
+  val ctrlBaseAddr: Int = 0
+
+  require (nValsReg > 0)
+  require (nPtrsReg > 0)
+}
+
+/** VCRBase. Parametrize base class. */
+abstract class VCRBase(implicit p: Parameters)
+  extends GenericParameterizedBundle(p)
+
+/** VCRMaster.
+  *
+  * This is the master interface used by VCR in the VTAShell to control
+  * the Core unit.
+  */
+class VCRMaster(implicit p: Parameters) extends VCRBase {
+  val vp = p(ShellKey).vcrParams
+  val mp = p(ShellKey).memParams
+  val launch = Output(Bool())
+  val finish = Input(Bool())
+  val irq = Output(Bool())
+  val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
+  val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W)))
+}
+
+/** VCRClient.
+  *
+  * This is the client interface used by the Core module to communicate
+  * to the VCR in the VTAShell.
+  */
+class VCRClient(implicit p: Parameters) extends VCRBase {
+  val vp = p(ShellKey).vcrParams
+  val mp = p(ShellKey).memParams
+  val launch = Input(Bool())
+  val finish = Output(Bool())
+  val irq = Input(Bool())
+  val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
+  val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W)))
+}
+
+/** VTA Control Registers (VCR).
+  *
+  * This unit provides control registers (32 and 64 bits) to be used by a control'
+  * unit, typically a host processor. These registers are read-only by the core
+  * at the moment but this will likely change once we add support to general purpose
+  * registers that could be used as event counters by the Core unit.
+  */
+class VCR(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle{
+    val host = new AXILiteClient(p(ShellKey).hostParams)
+    val vcr = new VCRMaster
+  })
+
+  val vp = p(ShellKey).vcrParams
+  val mp = p(ShellKey).memParams
+  val hp = p(ShellKey).hostParams
+
+  // Write control (AW, W, B)
+  val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address
+  val wdata = io.host.w.bits.data
+  val wstrb = io.host.w.bits.strb
+  val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0)))
+  val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3)
+  val wstate = RegInit(sWriteAddress)
+  switch (wstate) {
+    is (sWriteAddress) {
+      when (io.host.aw.valid) {
+        wstate := sWriteData
+      }
+    }
+    is (sWriteData) {
+      when (io.host.w.valid) {
+        wstate := sWriteResponse
+      }
+    }
+    is (sWriteResponse) {
+      when (io.host.b.ready) {
+        wstate := sWriteAddress
+      }
+    }
+  }
+
+  when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
+
+  io.host.aw.ready := wstate === sWriteAddress
+  io.host.w.ready := wstate === sWriteData
+  io.host.b.valid := wstate === sWriteResponse
+  io.host.b.bits.resp := "h_0".U
+
+  // read control (AR, R)
+  val sReadAddress :: sReadData :: Nil = Enum(2)
+  val rstate = RegInit(sReadAddress)
+
+  switch (rstate) {
+    is (sReadAddress) {
+      when (io.host.ar.valid) {
+        rstate := sReadData
+      }
+    }
+    is (sReadData) {
+      when (io.host.r.ready) {
+        rstate := sReadAddress
+      }
+    }
+  }
+
+  io.host.ar.ready := rstate === sReadAddress
+  io.host.r.valid := rstate === sReadData
+
+  val nPtrsReg = vp.nPtrsReg
+  val nValsReg = vp.nValsReg
+  val regBits = vp.regBits
+  val ptrsBits = mp.addrBits
+  val nCtrlReg = vp.nCtrlReg
+  val rStride = regBits/8
+  val pStride = ptrsBits/8
+  val ctrlBaseAddr = vp.ctrlBaseAddr
+  val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride
+  val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride
+
+  val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr)
+  val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr)
+
+  val ptrsAddr = new ListBuffer[Int]()
+  for (i <- 0 until nPtrsReg) {
+    ptrsAddr += i*pStride + ptrsBaseAddr
+    if (ptrsBits == 64) {
+      ptrsAddr += i*pStride + rStride + ptrsBaseAddr
+    }
+  }
+
+  // AP register
+  val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B)))
+
+  // ap start
+  when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) {
+    c0(0) := true.B
+  } .elsewhen (io.vcr.finish) {
+    c0(0) := false.B
+  }
+
+  // ap done = finish
+  when (io.vcr.finish) {
+    c0(1) := true.B
+  } .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) {
+    c0(1) := false.B
+  }
+
+  val c1 = 0.U
+  val c2 = 0.U
+  val c3 = 0.U
+
+  val ctrlRegList = List(c0, c1, c2, c3)
+
+  io.vcr.launch := c0(0)
+
+  // interrupts not supported atm
+  io.vcr.irq := false.B
+
+  // Write pointer and value registers
+  val pvAddr = valsAddr ++ ptrsAddr
+  val pvNumReg =  if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg
+  val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W))))
+  val pvRegList = new ListBuffer[UInt]()
+
+  for (i <- 0 until pvNumReg) {
+    when (io.host.w.fire() && (waddr === pvAddr(i).U)) {
+      pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask)
+    }
+    pvRegList += pvReg(i)
+  }
+
+  for (i <- 0 until nValsReg) {
+    io.vcr.vals(i) := pvReg(i)
+  }
+
+  for (i <- 0 until nPtrsReg) {
+    if (ptrsBits == 64) {
+      io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2))
+    } else {
+      io.vcr.ptrs(i) := pvReg(nValsReg + i)
+    }
+  }
+
+  // Read pointer and value registers
+  val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr
+  val mapRegList = ctrlRegList ++ pvRegList
+
+  val rdata = RegInit(0.U(regBits.W))
+  val rmap = LinkedHashMap[Int,UInt]()
+
+  val totalReg = mapRegList.length
+  for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt }
+
+  val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) }
+
+  when (io.host.ar.fire()) {
+    rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v)
+  }
+
+  io.host.r.bits.resp := 0.U
+  io.host.r.bits.data := rdata
+}
diff --git a/vta/hardware/chisel/src/main/scala/shell/VME.scala b/vta/hardware/chisel/src/main/scala/shell/VME.scala
new file mode 100644 (file)
index 0000000..862e981
--- /dev/null
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.util._
+import vta.util.config._
+import vta.util.genericbundle._
+import vta.interface.axi._
+
+/** VME parameters.
+  *
+  * These parameters are used on VME interfaces and modules.
+  */
+case class VMEParams() {
+  val nReadClients: Int = 5
+  val nWriteClients: Int = 1
+  require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
+  require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
+}
+
+/** VMEBase. Parametrize base class. */
+abstract class VMEBase(implicit p: Parameters)
+  extends GenericParameterizedBundle(p)
+
+/** VMECmd.
+  *
+  * This interface is used for creating write and read requests to memory.
+  */
+class VMECmd(implicit p: Parameters) extends VMEBase {
+  val addrBits = p(ShellKey).memParams.addrBits
+  val lenBits = p(ShellKey).memParams.lenBits
+  val addr = UInt(addrBits.W)
+  val len = UInt(lenBits.W)
+}
+
+/** VMEReadMaster.
+  *
+  * This interface is used by modules inside the core to generate read requests
+  * and receive responses from VME.
+  */
+class VMEReadMaster(implicit p: Parameters) extends Bundle {
+  val dataBits = p(ShellKey).memParams.dataBits
+  val cmd = Decoupled(new VMECmd)
+  val data = Flipped(Decoupled(UInt(dataBits.W)))
+  override def cloneType =
+    new VMEReadMaster().asInstanceOf[this.type]
+}
+
+/** VMEReadClient.
+  *
+  * This interface is used by the VME to receive read requests and generate
+  * responses to modules inside the core.
+  */
+class VMEReadClient(implicit p: Parameters) extends Bundle {
+  val dataBits = p(ShellKey).memParams.dataBits
+  val cmd = Flipped(Decoupled(new VMECmd))
+  val data = Decoupled(UInt(dataBits.W))
+  override def cloneType =
+    new VMEReadClient().asInstanceOf[this.type]
+}
+
+/** VMEWriteMaster.
+  *
+  * This interface is used by modules inside the core to generate write requests
+  * to the VME.
+  */
+class VMEWriteMaster(implicit p: Parameters) extends Bundle {
+  val dataBits = p(ShellKey).memParams.dataBits
+  val cmd = Decoupled(new VMECmd)
+  val data = Decoupled(UInt(dataBits.W))
+  val ack = Input(Bool())
+  override def cloneType =
+    new VMEWriteMaster().asInstanceOf[this.type]
+}
+
+/** VMEWriteClient.
+  *
+  * This interface is used by the VME to handle write requests from modules inside
+  * the core.
+  */
+class VMEWriteClient(implicit p: Parameters) extends Bundle {
+  val dataBits = p(ShellKey).memParams.dataBits
+  val cmd = Flipped(Decoupled(new VMECmd))
+  val data = Flipped(Decoupled(UInt(dataBits.W)))
+  val ack = Output(Bool())
+  override def cloneType =
+    new VMEWriteClient().asInstanceOf[this.type]
+}
+
+/** VMEMaster.
+  *
+  * Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster
+  * interfaces.
+  */
+class VMEMaster(implicit p: Parameters) extends Bundle {
+  val nRd = p(ShellKey).vmeParams.nReadClients
+  val nWr = p(ShellKey).vmeParams.nWriteClients
+  val rd = Vec(nRd, new VMEReadMaster)
+  val wr = Vec(nWr, new VMEWriteMaster)
+}
+
+/** VMEClient.
+  *
+  * Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient
+  * interfaces.
+  */
+class VMEClient(implicit p: Parameters) extends Bundle {
+  val nRd = p(ShellKey).vmeParams.nReadClients
+  val nWr = p(ShellKey).vmeParams.nWriteClients
+  val rd = Vec(nRd, new VMEReadClient)
+  val wr = Vec(nWr, new VMEWriteClient)
+}
+
+/** VTA Memory Engine (VME).
+  *
+  * This unit multiplexes the memory controller interface for the Core. Currently,
+  * it supports single-writer and multiple-reader mode and it is also based on AXI.
+  */
+class VME(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {
+    val mem = new AXIMaster(p(ShellKey).memParams)
+    val vme = new VMEClient
+  })
+
+  val nReadClients = p(ShellKey).vmeParams.nReadClients
+  val rd_arb = Module(new Arbiter(new VMECmd, nReadClients))
+  val rd_arb_chosen = RegEnable(rd_arb.io.chosen, rd_arb.io.out.fire())
+
+  for (i <- 0 until nReadClients) { rd_arb.io.in(i) <> io.vme.rd(i).cmd }
+
+  val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
+  val rstate = RegInit(sReadIdle)
+
+  switch (rstate) {
+    is (sReadIdle) {
+      when (rd_arb.io.out.valid) {
+        rstate := sReadAddr
+      }
+    }
+    is (sReadAddr) {
+      when (io.mem.ar.ready) {
+        rstate := sReadData
+      }
+    }
+    is (sReadData) {
+      when (io.mem.r.fire() && io.mem.r.bits.last) {
+        rstate := sReadIdle
+      }
+    }
+  }
+
+  val sWriteIdle :: sWriteAddr :: sWriteData :: sWriteResp :: Nil = Enum(4)
+  val wstate = RegInit(sWriteIdle)
+  val addrBits = p(ShellKey).memParams.addrBits
+  val lenBits = p(ShellKey).memParams.lenBits
+  val wr_cnt = RegInit(0.U(lenBits.W))
+
+  when (wstate === sWriteIdle) {
+    wr_cnt := 0.U
+  } .elsewhen (io.mem.w.fire()) {
+    wr_cnt := wr_cnt + 1.U
+  }
+
+  switch (wstate) {
+    is (sWriteIdle) {
+      when (io.vme.wr(0).cmd.valid) {
+        wstate := sWriteAddr
+      }
+    }
+    is (sWriteAddr) {
+      when (io.mem.aw.ready) {
+        wstate := sWriteData
+      }
+    }
+    is (sWriteData) {
+      when (io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
+        wstate := sWriteResp
+      }
+    }
+    is (sWriteResp) {
+      when (io.mem.b.valid) {
+        wstate := sWriteIdle
+      }
+    }
+  }
+
+  // registers storing read/write cmds
+
+  val rd_len = RegInit(0.U(lenBits.W))
+  val wr_len = RegInit(0.U(lenBits.W))
+  val rd_addr = RegInit(0.U(addrBits.W))
+  val wr_addr = RegInit(0.U(addrBits.W))
+
+  when (rd_arb.io.out.fire()) {
+    rd_len := rd_arb.io.out.bits.len
+    rd_addr := rd_arb.io.out.bits.addr
+  }
+
+  when (io.vme.wr(0).cmd.fire()) {
+    wr_len := io.vme.wr(0).cmd.bits.len
+    wr_addr := io.vme.wr(0).cmd.bits.addr
+  }
+
+  // rd arb
+  rd_arb.io.out.ready := rstate === sReadIdle
+
+  // vme
+  for (i <- 0 until nReadClients) {
+    io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid
+    io.vme.rd(i).data.bits := io.mem.r.bits.data
+  }
+
+  io.vme.wr(0).cmd.ready := wstate === sWriteIdle
+  io.vme.wr(0).ack := io.mem.b.fire()
+  io.vme.wr(0).data.ready := wstate === sWriteData &  io.mem.w.ready
+
+  // mem
+  io.mem.aw.valid := wstate === sWriteAddr
+  io.mem.aw.bits.addr := wr_addr
+  io.mem.aw.bits.len := wr_len
+
+  io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid
+  io.mem.w.bits.data := io.vme.wr(0).data.bits
+  io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len
+
+  io.mem.b.ready := wstate === sWriteResp
+
+  io.mem.ar.valid := rstate === sReadAddr
+  io.mem.ar.bits.addr := rd_addr
+  io.mem.ar.bits.len := rd_len
+
+  io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready
+
+  // AXI constants - statically defined
+  io.mem.setConst()
+}
diff --git a/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala b/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala
new file mode 100644 (file)
index 0000000..c809311
--- /dev/null
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import vta.util.config._
+import vta.interface.axi._
+import vta.core._
+
+/** Shell parameters. */
+case class ShellParams(
+  hostParams: AXIParams,
+  memParams: AXIParams,
+  vcrParams: VCRParams,
+  vmeParams: VMEParams
+)
+
+case object ShellKey extends Field[ShellParams]
+
+/** VTAShell.
+  *
+  * The VTAShell is based on a VME, VCR and core. This creates a complete VTA
+  * system that can be used for simulation or real hardware.
+  */
+class VTAShell(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle{
+    val host = new AXILiteClient(p(ShellKey).hostParams)
+    val mem = new AXIMaster(p(ShellKey).memParams)
+  })
+
+  val vcr = Module(new VCR)
+  val vme = Module(new VME)
+  val core = Module(new Core)
+
+  core.io.vcr <> vcr.io.vcr
+  vme.io.vme <> core.io.vme
+
+  vcr.io.host <> io.host
+  io.mem <> vme.io.mem
+}
diff --git a/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala b/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala
new file mode 100644 (file)
index 0000000..db72137
--- /dev/null
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.shell
+
+import chisel3._
+import chisel3.experimental.{RawModule, withClockAndReset}
+import vta.util.config._
+import vta.interface.axi._
+
+/** XilinxShell.
+  *
+  * This is a wrapper shell mostly used to match Xilinx convention naming,
+  * therefore we can pack VTA as an IP for IPI based flows.
+  */
+class XilinxShell(implicit p: Parameters) extends RawModule {
+
+  val hp = p(ShellKey).hostParams
+  val mp = p(ShellKey).memParams
+
+  val ap_clk = IO(Input(Clock()))
+  val ap_rst_n = IO(Input(Bool()))
+  val m_axi_gmem = IO(new XilinxAXIMaster(mp))
+  val s_axi_control = IO(new XilinxAXILiteClient(hp))
+
+  val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) }
+
+  // memory
+  m_axi_gmem.AWVALID := shell.io.mem.aw.valid
+  shell.io.mem.aw.ready := m_axi_gmem.AWREADY
+  m_axi_gmem.AWADDR := shell.io.mem.aw.bits.addr
+  m_axi_gmem.AWID := shell.io.mem.aw.bits.id
+  m_axi_gmem.AWUSER := shell.io.mem.aw.bits.user
+  m_axi_gmem.AWLEN := shell.io.mem.aw.bits.len
+  m_axi_gmem.AWSIZE := shell.io.mem.aw.bits.size
+  m_axi_gmem.AWBURST := shell.io.mem.aw.bits.burst
+  m_axi_gmem.AWLOCK := shell.io.mem.aw.bits.lock
+  m_axi_gmem.AWCACHE := shell.io.mem.aw.bits.cache
+  m_axi_gmem.AWPROT := shell.io.mem.aw.bits.prot
+  m_axi_gmem.AWQOS := shell.io.mem.aw.bits.qos
+  m_axi_gmem.AWREGION := shell.io.mem.aw.bits.region
+
+  m_axi_gmem.WVALID := shell.io.mem.w.valid
+  shell.io.mem.w.ready := m_axi_gmem.WREADY
+  m_axi_gmem.WDATA := shell.io.mem.w.bits.data
+  m_axi_gmem.WSTRB := shell.io.mem.w.bits.strb
+  m_axi_gmem.WLAST := shell.io.mem.w.bits.last
+  m_axi_gmem.WID := shell.io.mem.w.bits.id
+  m_axi_gmem.WUSER := shell.io.mem.w.bits.user
+
+  shell.io.mem.b.valid := m_axi_gmem.BVALID
+  m_axi_gmem.BREADY := shell.io.mem.b.valid
+  shell.io.mem.b.bits.resp := m_axi_gmem.BRESP
+  shell.io.mem.b.bits.id := m_axi_gmem.BID
+  shell.io.mem.b.bits.user := m_axi_gmem.BUSER
+
+  m_axi_gmem.ARVALID := shell.io.mem.ar.valid
+  shell.io.mem.ar.ready := m_axi_gmem.ARREADY
+  m_axi_gmem.ARADDR := shell.io.mem.ar.bits.addr
+  m_axi_gmem.ARID := shell.io.mem.ar.bits.id
+  m_axi_gmem.ARUSER := shell.io.mem.ar.bits.user
+  m_axi_gmem.ARLEN := shell.io.mem.ar.bits.len
+  m_axi_gmem.ARSIZE := shell.io.mem.ar.bits.size
+  m_axi_gmem.ARBURST := shell.io.mem.ar.bits.burst
+  m_axi_gmem.ARLOCK := shell.io.mem.ar.bits.lock
+  m_axi_gmem.ARCACHE := shell.io.mem.ar.bits.cache
+  m_axi_gmem.ARPROT := shell.io.mem.ar.bits.prot
+  m_axi_gmem.ARQOS := shell.io.mem.ar.bits.qos
+  m_axi_gmem.ARREGION := shell.io.mem.ar.bits.region
+
+  shell.io.mem.r.valid := m_axi_gmem.RVALID
+  m_axi_gmem.RREADY := shell.io.mem.r.ready
+  shell.io.mem.r.bits.data := m_axi_gmem.RDATA
+  shell.io.mem.r.bits.resp := m_axi_gmem.RRESP
+  shell.io.mem.r.bits.last := m_axi_gmem.RLAST
+  shell.io.mem.r.bits.id := m_axi_gmem.RID
+  shell.io.mem.r.bits.user := m_axi_gmem.RUSER
+
+  // host
+  shell.io.host.aw.valid := s_axi_control.AWVALID
+  s_axi_control.AWREADY := shell.io.host.aw.ready
+  shell.io.host.aw.bits.addr := s_axi_control.AWADDR
+
+  shell.io.host.w.valid := s_axi_control.WVALID
+  s_axi_control.WREADY := shell.io.host.w.ready
+  shell.io.host.w.bits.data := s_axi_control.WDATA
+  shell.io.host.w.bits.strb := s_axi_control.WSTRB
+
+  s_axi_control.BVALID := shell.io.host.b.valid
+  shell.io.host.b.ready := s_axi_control.BREADY
+  s_axi_control.BRESP := shell.io.host.b.bits.resp
+
+  shell.io.host.ar.valid := s_axi_control.ARVALID
+  s_axi_control.ARREADY := shell.io.host.ar.ready
+  shell.io.host.ar.bits.addr := s_axi_control.ARADDR
+
+  s_axi_control.RVALID := shell.io.host.r.valid
+  shell.io.host.r.ready := s_axi_control.RREADY
+  s_axi_control.RDATA := shell.io.host.r.bits.data
+  s_axi_control.RRESP := shell.io.host.r.bits.resp
+}
diff --git a/vta/hardware/chisel/src/main/scala/test/Test.scala b/vta/hardware/chisel/src/main/scala/test/Test.scala
new file mode 100644 (file)
index 0000000..db06073
--- /dev/null
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.test
+
+import chisel3._
+import vta.util.config._
+import vta.shell._
+
+/** Test. This generates a testbench file for simulation */
+class Test(implicit p: Parameters) extends Module {
+  val io = IO(new Bundle {})
+  val sim_shell = Module(new SimShell)
+  val vta_shell = Module(new VTAShell)
+  vta_shell.io.host <> sim_shell.io.host
+  sim_shell.io.mem <> vta_shell.io.mem
+}
diff --git a/vta/hardware/chisel/src/main/scala/util/Config.scala b/vta/hardware/chisel/src/main/scala/util/Config.scala
new file mode 100644 (file)
index 0000000..6699507
--- /dev/null
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.util.config
+
+// taken from https://github.com/vta.roject/rocket-chip
+
+abstract class Field[T] private (val default: Option[T])
+{
+  def this() = this(None)
+  def this(default: T) = this(Some(default))
+}
+
+abstract class View {
+  final def apply[T](pname: Field[T]): T = apply(pname, this)
+  final def apply[T](pname: Field[T], site: View): T = {
+    val out = find(pname, site)
+    require (out.isDefined, s"Key ${pname} is not defined in Parameters")
+    out.get
+  }
+
+  final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
+  final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T])
+
+  protected[config] def find[T](pname: Field[T], site: View): Option[T]
+}
+
+abstract class Parameters extends View {
+  final def ++ (x: Parameters): Parameters =
+    new ChainParameters(this, x)
+
+  final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters =
+    Parameters(f) ++ this
+
+  final def alterPartial(f: PartialFunction[Any,Any]): Parameters =
+    Parameters((_,_,_) => f) ++ this
+
+  final def alterMap(m: Map[Any,Any]): Parameters =
+    new MapParameters(m) ++ this
+
+  protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T]
+  protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname)
+}
+
+object Parameters {
+  def empty: Parameters = new EmptyParameters
+  def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f)
+}
+
+class Config(p: Parameters) extends Parameters {
+  def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f))
+
+  protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname)
+  override def toString = this.getClass.getSimpleName
+  def toInstance = this
+}
+
+// Internal implementation:
+
+private class TerminalView extends View {
+  def find[T](pname: Field[T], site: View): Option[T] = pname.default
+}
+
+private class ChainView(head: Parameters, tail: View) extends View {
+  def find[T](pname: Field[T], site: View) = head.chain(site, tail, pname)
+}
+
+private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
+  def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname)
+}
+
+private class EmptyParameters extends Parameters {
+  def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
+}
+
+private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters {
+  protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
+    val g = f(site, this, tail)
+    if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site)
+  }
+}
+
+private class MapParameters(map: Map[Any, Any]) extends Parameters {
+  protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
+    val g = map.get(pname)
+    if (g.isDefined) Some(g.get.asInstanceOf[T]) else tail.find(pname, site)
+  }
+}
diff --git a/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala b/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala
new file mode 100644 (file)
index 0000000..db19635
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta.util.genericbundle
+
+// taken from https://github.com/vta.roject/rocket-chip
+
+import chisel3._
+
+abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle
+{
+  override def cloneType = {
+    try {
+      this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type]
+    } catch {
+      case e: java.lang.IllegalArgumentException =>
+        throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " +
+                       this.getClass + ", probably because " + this.getClass +
+                       "() takes more than one argument.  Consider overriding " +
+                       "cloneType() on " + this.getClass, e)
+    }
+  }
+}
+
diff --git a/vta/hardware/chisel/src/main/scala/vta/Configs.scala b/vta/hardware/chisel/src/main/scala/vta/Configs.scala
new file mode 100644 (file)
index 0000000..d5aa127
--- /dev/null
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package vta
+
+import chisel3._
+import vta.util.config._
+import vta.shell._
+import vta.core._
+import vta.test._
+
+/** VTA.
+  *
+  * This file contains all the configurations supported by VTA.
+  * These configurations are built in a mix/match form based on core
+  * and shell configurations.
+  */
+
+class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
+class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
+
+object DefaultPynqConfig extends App {
+  implicit val p: Parameters = new DefaultPynqConfig
+  chisel3.Driver.execute(args, () => new XilinxShell)
+}
+
+object DefaultF1Config extends App {
+  implicit val p: Parameters = new DefaultF1Config
+  chisel3.Driver.execute(args, () => new XilinxShell)
+}
+
+object TestDefaultF1Config extends App {
+  implicit val p: Parameters = new DefaultF1Config
+  chisel3.Driver.execute(args, () => new Test)
+}
index 0895417..0b315e4 100644 (file)
@@ -70,8 +70,18 @@ void VTADPIInit(VTAContextHandle handle,
   _mem_dpi = mem_dpi;
 }
 
+
+// Override Verilator finish definition
+// VL_USER_FINISH needs to be defined when compiling Verilator code
+void vl_finish(const char* filename, int linenum, const char* hier) {
+  Verilated::gotFinish(true);
+  VL_PRINTF("[TSIM] exiting simulation\n");
+}
+
 int VTADPISim(uint64_t max_cycles) {
   uint64_t trace_count = 0;
+  Verilated::flushCall();
+  Verilated::gotFinish(false);
 
 #if VM_TRACE
   uint64_t start = 0;
index d583051..eca9e4d 100644 (file)
@@ -53,7 +53,11 @@ extern "C" {
 typedef void * VTADeviceHandle;
 
 /*! \brief physical address */
+#ifdef USE_TSIM
+typedef uint64_t vta_phy_addr_t;
+#else
 typedef uint32_t vta_phy_addr_t;
+#endif
 
 /*!
  * \brief Allocate a device resource handle
@@ -76,10 +80,22 @@ void VTADeviceFree(VTADeviceHandle handle);
  *
  * \return 0 if running is successful, 1 if timeout.
  */
+#ifdef USE_TSIM
+int VTADeviceRun(VTADeviceHandle device,
+                 vta_phy_addr_t insn_phy_addr,
+                 vta_phy_addr_t uop_phy_addr,
+                 vta_phy_addr_t inp_phy_addr,
+                 vta_phy_addr_t wgt_phy_addr,
+                 vta_phy_addr_t acc_phy_addr,
+                 vta_phy_addr_t out_phy_addr,
+                 uint32_t insn_count,
+                 uint32_t wait_cycles);
+#else
 int VTADeviceRun(VTADeviceHandle device,
                  vta_phy_addr_t insn_phy_addr,
                  uint32_t insn_count,
                  uint32_t wait_cycles);
+#endif
 
 /*!
  * \brief Allocates physically contiguous region in memory (limited by MAX_XFER).
index d5400d8..4c2200d 100644 (file)
@@ -239,7 +239,7 @@ class Environment(object):
         """The target host"""
         if self.TARGET == "pynq":
             return "llvm -target=armv7-none-linux-gnueabihf"
-        if self.TARGET == "sim":
+        if self.TARGET == "sim" or self.TARGET == "tsim":
             return "llvm"
         raise ValueError("Unknown target %s" % self.TARGET)
 
index a1e15ba..858e115 100644 (file)
@@ -17,6 +17,8 @@
 """Utilities to start simulator."""
 import ctypes
 import json
+import sys
+import os
 import tvm
 from ..libinfo import find_libvta
 
@@ -55,5 +57,22 @@ def stats():
     x = tvm.get_global_func("vta.simulator.profiler_status")()
     return json.loads(x)
 
+def tsim_init(hw_lib):
+    """Init hardware shared library for TSIM
+
+     Parameters
+     ------------
+     hw_lib : str
+        Name of hardware shared library
+    """
+    cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
+    vta_build_path = os.path.join(cur_path, "..", "..", "..", "build")
+    if not hw_lib.endswith(("dylib", "so")):
+        hw_lib += ".dylib" if sys.platform == "darwin" else ".so"
+    lib = os.path.join(vta_build_path, hw_lib)
+    f = tvm.get_global_func("tvm.vta.tsim.init")
+    m = tvm.module.load(lib, "vta-tsim")
+    f(m)
+
 
 LIBS = _load_lib()
index 48dd085..06c700c 100644 (file)
@@ -31,7 +31,7 @@ def run(run_func):
     """
     env = get_env()
 
-    if env.TARGET == "sim":
+    if env.TARGET in ["sim", "tsim"]:
 
         # Talk to local RPC if necessary to debug RPC server.
         # Compile vta on your host with make at the root.
@@ -48,7 +48,8 @@ def run(run_func):
             # Make sure simulation library exists
             # If this fails, build vta on host (make)
             # with TARGET="sim" in the json.config file.
-            assert simulator.enabled()
+            if env.TARGET == "sim":
+                assert simulator.enabled()
             run_func(env, rpc.LocalSession())
 
     elif env.TARGET == "pynq":
index 79a407f..06b3474 100644 (file)
@@ -56,7 +56,7 @@ struct DataBuffer {
     return data_;
   }
   /*! \return Physical address of the data. */
-  uint32_t phy_addr() const {
+  vta_phy_addr_t phy_addr() const {
     return phy_addr_;
   }
   /*!
@@ -113,7 +113,7 @@ struct DataBuffer {
   /*! \brief The internal data. */
   void* data_;
   /*! \brief The physical address of the buffer, excluding header. */
-  uint32_t phy_addr_;
+  vta_phy_addr_t phy_addr_;
 };
 
 /*!
@@ -302,7 +302,7 @@ class BaseQueue {
     return dram_buffer_;
   }
   /*! \return Physical address of DRAM. */
-  uint32_t dram_phy_addr() const {
+  vta_phy_addr_t dram_phy_addr() const {
     return dram_phy_addr_;
   }
   /*! \return Whether there is pending information. */
@@ -367,7 +367,7 @@ class BaseQueue {
   // The buffer in DRAM
   char* dram_buffer_{nullptr};
   // Physics address of the buffer
-  uint32_t dram_phy_addr_;
+  vta_phy_addr_t dram_phy_addr_;
 };
 
 /*!
@@ -424,7 +424,11 @@ class UopQueue : public BaseQueue {
       CHECK((dram_end_ - dram_begin_) == (sram_end_ - sram_begin_));
       insn->memory_type = VTA_MEM_ID_UOP;
       insn->sram_base = sram_begin_;
+#ifdef USE_TSIM
+      insn->dram_base = (uint32_t) dram_phy_addr_ + dram_begin_*kElemBytes;
+#else
       insn->dram_base = dram_phy_addr_ / kElemBytes + dram_begin_;
+#endif
       insn->y_size = 1;
       insn->x_size = (dram_end_ - dram_begin_);
       insn->x_stride = (dram_end_ - dram_begin_);
@@ -958,7 +962,11 @@ class CommandQueue {
     insn->memory_type = dst_memory_type;
     insn->sram_base = dst_sram_index;
     DataBuffer* src = DataBuffer::FromHandle(src_dram_addr);
+#ifdef USE_TSIM
+    insn->dram_base = (uint32_t) src->phy_addr() + src_elem_offset*GetElemBytes(dst_memory_type);
+#else
     insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset;
+#endif
     insn->y_size = y_size;
     insn->x_size = x_size;
     insn->x_stride = x_stride;
@@ -981,7 +989,11 @@ class CommandQueue {
     insn->memory_type = src_memory_type;
     insn->sram_base = src_sram_index;
     DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr);
+#ifdef USE_TSIM
+    insn->dram_base = (uint32_t) dst->phy_addr() + dst_elem_offset*GetElemBytes(src_memory_type);
+#else
     insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset;
+#endif
     insn->y_size = y_size;
     insn->x_size = x_size;
     insn->x_stride = x_stride;
@@ -1046,11 +1058,24 @@ class CommandQueue {
 
     // Make sure that we don't exceed contiguous physical memory limits
     CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER);
+#ifdef USE_TSIM
     int timeout = VTADeviceRun(
         device_,
         insn_queue_.dram_phy_addr(),
+        uop_queue_.dram_phy_addr(),
+        inp_phy_addr_,
+        wgt_phy_addr_,
+        acc_phy_addr_,
+        out_phy_addr_,
         insn_queue_.count(),
         wait_cycles);
+#else
+    int timeout = VTADeviceRun(
+        device_,
+        insn_queue_.dram_phy_addr(),
+        insn_queue_.count(),
+        wait_cycles);
+#endif
     CHECK_EQ(timeout, 0);
     // Reset buffers
     uop_queue_.Reset();
@@ -1125,6 +1150,18 @@ class CommandQueue {
     ThreadLocal().reset();
   }
 
+#ifdef USE_TSIM
+  void SetBufPhyAddr(uint32_t type, vta_phy_addr_t addr) {
+    switch (type) {
+      case VTA_MEM_ID_INP: inp_phy_addr_ = addr;
+      case VTA_MEM_ID_WGT: wgt_phy_addr_ = addr;
+      case VTA_MEM_ID_ACC: acc_phy_addr_ = addr;
+      case VTA_MEM_ID_OUT: out_phy_addr_ = addr;
+      default: break;
+    }
+  }
+#endif
+
  private:
   // Push GEMM uop to the command buffer
   void PushGEMMOp(UopKernel* kernel) {
@@ -1229,6 +1266,16 @@ class CommandQueue {
   InsnQueue<VTA_MAX_XFER, true, true> insn_queue_;
   // Device handle
   VTADeviceHandle device_{nullptr};
+#ifdef USE_TSIM
+  // Input phy addr
+  vta_phy_addr_t inp_phy_addr_{0};
+  // Weight phy addr
+  vta_phy_addr_t wgt_phy_addr_{0};
+  // Accumulator phy addr
+  vta_phy_addr_t acc_phy_addr_{0};
+  // Output phy addr
+  vta_phy_addr_t out_phy_addr_{0};
+#endif
 };
 
 }  // namespace vta
@@ -1317,6 +1364,10 @@ void VTALoadBuffer2D(VTACommandHandle cmd,
                      uint32_t y_pad_after,
                      uint32_t dst_sram_index,
                      uint32_t dst_memory_type) {
+#ifdef USE_TSIM
+  vta::DataBuffer* src = vta::DataBuffer::FromHandle(src_dram_addr);
+  static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(dst_memory_type, src->phy_addr());
+#endif
   static_cast<vta::CommandQueue*>(cmd)->
       LoadBuffer2D(src_dram_addr, src_elem_offset,
                    x_size, y_size, x_stride,
@@ -1333,6 +1384,10 @@ void VTAStoreBuffer2D(VTACommandHandle cmd,
                       uint32_t x_size,
                       uint32_t y_size,
                       uint32_t x_stride) {
+#ifdef USE_TSIM
+  vta::DataBuffer* dst = vta::DataBuffer::FromHandle(dst_dram_addr);
+  static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(src_memory_type, dst->phy_addr());
+#endif
   static_cast<vta::CommandQueue*>(cmd)->
       StoreBuffer2D(src_sram_index, src_memory_type,
                     dst_dram_addr, dst_elem_offset,
diff --git a/vta/src/tsim/tsim_driver.cc b/vta/src/tsim/tsim_driver.cc
new file mode 100644 (file)
index 0000000..e0ceb90
--- /dev/null
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <vta/driver.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <vta/dpi/module.h>
+
+namespace vta {
+namespace tsim {
+
+using vta::dpi::DPIModuleNode;
+using tvm::runtime::Module;
+
+class DPILoader {
+ public:
+  void Init(Module module) {
+    mod_ = module;
+  }
+
+  DPIModuleNode* Get() {
+    return static_cast<DPIModuleNode*>(mod_.operator->());
+  }
+
+  static DPILoader* Global() {
+    static DPILoader inst;
+    return &inst;
+  }
+
+  Module mod_;
+};
+
+class Device {
+ public:
+  Device() {
+    dpi_ = DPILoader::Global();
+  }
+
+  int Run(vta_phy_addr_t insn_phy_addr,
+          vta_phy_addr_t uop_phy_addr,
+          vta_phy_addr_t inp_phy_addr,
+          vta_phy_addr_t wgt_phy_addr,
+          vta_phy_addr_t acc_phy_addr,
+          vta_phy_addr_t out_phy_addr,
+          uint32_t insn_count,
+          uint32_t wait_cycles) {
+    this->Init();
+    this->Launch(insn_phy_addr,
+                 uop_phy_addr,
+                 inp_phy_addr,
+                 wgt_phy_addr,
+                 acc_phy_addr,
+                 out_phy_addr,
+                 insn_count,
+                 wait_cycles);
+    this->WaitForCompletion(wait_cycles);
+    dev_->Finish();
+    return 0;
+  }
+
+ private:
+  void Init() {
+    dev_ = dpi_->Get();
+  }
+
+  void Launch(vta_phy_addr_t insn_phy_addr,
+              vta_phy_addr_t uop_phy_addr,
+              vta_phy_addr_t inp_phy_addr,
+              vta_phy_addr_t wgt_phy_addr,
+              vta_phy_addr_t acc_phy_addr,
+              vta_phy_addr_t out_phy_addr,
+              uint32_t insn_count,
+              uint32_t wait_cycles) {
+    // launch simulation thread
+    dev_->Launch(wait_cycles);
+    dev_->WriteReg(0x10, insn_count);
+    dev_->WriteReg(0x14, insn_phy_addr);
+    dev_->WriteReg(0x18, insn_phy_addr >> 32);
+    dev_->WriteReg(0x1c, 0);
+    dev_->WriteReg(0x20, uop_phy_addr >> 32);
+    dev_->WriteReg(0x24, 0);
+    dev_->WriteReg(0x28, inp_phy_addr >> 32);
+    dev_->WriteReg(0x2c, 0);
+    dev_->WriteReg(0x30, wgt_phy_addr >> 32);
+    dev_->WriteReg(0x34, 0);
+    dev_->WriteReg(0x38, acc_phy_addr >> 32);
+    dev_->WriteReg(0x3c, 0);
+    dev_->WriteReg(0x40, out_phy_addr >> 32);
+    // start
+    dev_->WriteReg(0x00, 0x1);
+  }
+
+  void WaitForCompletion(uint32_t wait_cycles) {
+    uint32_t i, val;
+    for (i = 0; i < wait_cycles; i++) {
+      val = dev_->ReadReg(0x00);
+      val &= 0x2;
+      if (val == 0x2) break;  // finish
+    }
+  }
+
+  DPILoader* dpi_;
+  DPIModuleNode* dev_;
+};
+
+using tvm::runtime::TVMRetValue;
+using tvm::runtime::TVMArgs;
+
+TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    Module m = args[0];
+    DPILoader::Global()->Init(m);
+  });
+
+}  // namespace tsim
+}  // namespace vta
+
+void* VTAMemAlloc(size_t size, int cached) {
+  void *p = malloc(size);
+  return p;
+}
+
+void VTAMemFree(void* buf) {
+  free(buf);
+}
+
+vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
+  return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
+}
+
+void VTAFlushCache(vta_phy_addr_t buf, int size) {
+}
+
+void VTAInvalidateCache(vta_phy_addr_t buf, int size) {
+}
+
+VTADeviceHandle VTADeviceAlloc() {
+  return new vta::tsim::Device();
+}
+
+void VTADeviceFree(VTADeviceHandle handle) {
+  delete static_cast<vta::tsim::Device*>(handle);
+}
+
+int VTADeviceRun(VTADeviceHandle handle,
+                 vta_phy_addr_t insn_phy_addr,
+                 vta_phy_addr_t uop_phy_addr,
+                 vta_phy_addr_t inp_phy_addr,
+                 vta_phy_addr_t wgt_phy_addr,
+                 vta_phy_addr_t acc_phy_addr,
+                 vta_phy_addr_t out_phy_addr,
+                 uint32_t insn_count,
+                 uint32_t wait_cycles) {
+  return static_cast<vta::tsim::Device*>(handle)->Run(
+      insn_phy_addr,
+      uop_phy_addr,
+      inp_phy_addr,
+      wgt_phy_addr,
+      acc_phy_addr,
+      out_phy_addr,
+      insn_count,
+      wait_cycles);
+}
index 58835bb..2cedcea 100644 (file)
@@ -68,6 +68,10 @@ def test_save_load_out():
         y_np = x_np.astype(y.dtype)
         x_nd = tvm.nd.array(x_np, ctx)
         y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+
+        if env.TARGET == "tsim":
+            simulator.tsim_init("libvta_hw")
+
         f(x_nd, y_nd)
         np.testing.assert_equal(y_np, y_nd.asnumpy())
 
@@ -126,6 +130,10 @@ def test_padded_load():
              :] = x_np
         x_nd = tvm.nd.array(x_np, ctx)
         y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+
+        if env.TARGET == "tsim":
+            simulator.tsim_init("libvta_hw")
+
         f(x_nd, y_nd)
         np.testing.assert_equal(y_np, y_nd.asnumpy())
 
@@ -197,6 +205,9 @@ def test_gemm():
             y_np = np.right_shift(y_np, 8)
             y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
 
+            if env.TARGET == "tsim":
+                simulator.tsim_init("libvta_hw")
+
             if env.TARGET == "sim":
                 simulator.clear_stats()
                 f(x_nd, w_nd, y_nd)
@@ -351,6 +362,10 @@ def test_alu():
             a_nd = tvm.nd.array(a_np, ctx)
             res_nd = tvm.nd.array(
                 np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+
+            if env.TARGET == "tsim":
+                simulator.tsim_init("libvta_hw")
+
             if use_imm:
                 f(a_nd, res_nd)
             else:
@@ -420,6 +435,10 @@ def test_relu():
         a_nd = tvm.nd.array(a_np, ctx)
         res_nd = tvm.nd.array(
             np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+
+        if env.TARGET == "tsim":
+            simulator.tsim_init("libvta_hw")
+
         f(a_nd, res_nd)
         np.testing.assert_equal(res_np, res_nd.asnumpy())
 
@@ -479,6 +498,10 @@ def test_shift_and_scale():
         a_nd = tvm.nd.array(a_np, ctx)
         res_nd = tvm.nd.array(
             np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
+
+        if env.TARGET == "tsim":
+            simulator.tsim_init("libvta_hw")
+
         f(a_nd, res_nd)
         np.testing.assert_equal(res_np, res_nd.asnumpy())
 
@@ -503,11 +526,12 @@ if __name__ == "__main__":
     print("Load/store test")
     test_save_load_out()
     print("Padded load test")
-    #test_padded_load()
+    test_padded_load()
     print("GEMM test")
     test_gemm()
-    test_alu()
     print("ALU test")
+    test_alu()
+    print("Relu test")
     test_relu()
     print("Shift and scale")
     test_shift_and_scale()