".clang-format",
".gitmodules",
"CODEOWNERS",
+ ".scalafmt.conf",
}
# List of specific files allowed in relpath to <proj_root>
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+maxColumn = 100
+rewrite.rules = [SortModifiers, SortImports]
lib_path = $(build_dir)/$(LIBNAME).so
endif
-default: lib
+default: lint lib
+
+lint:
+ sbt scalafmt
lib: $(lib_path)
$(lib_path): $(verilator_build_dir)/V$(TOP).cpp
*/
logLevel := Level.Warn
+addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1")
val nVals = 2
val nPtrs = 2
val regBits = 32
- val ptrBits = 2*regBits
+ val ptrBits = 2 * regBits
}
class Accel extends Module {
val raddr = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
- switch (state) {
- is (sIdle) {
- when (io.launch) {
+ switch(state) {
+ is(sIdle) {
+ when(io.launch) {
state := sReadReq
}
}
- is (sReadReq) {
+ is(sReadReq) {
state := sReadData
}
- is (sReadData) {
- when (io.mem.rd.valid) {
+ is(sReadData) {
+ when(io.mem.rd.valid) {
state := sWriteReq
}
}
- is (sWriteReq) {
+ is(sWriteReq) {
state := sWriteData
}
- is (sWriteData) {
- when (cnt === (length - 1.U)) {
+ is(sWriteData) {
+ when(cnt === (length - 1.U)) {
state := sIdle
- } .otherwise {
+ }.otherwise {
state := sReadReq
}
}
val last = state === sWriteData && cnt === (length - 1.U)
// cycle counter
- when (state === sIdle) {
+ when(state === sIdle) {
cycles := 0.U
- } .otherwise {
+ }.otherwise {
cycles := cycles + 1.U
}
io.ecnt(0).bits := cycles
// calculate next address
- when (state === sIdle) {
+ when(state === sIdle) {
raddr := io.ptrs(0)
waddr := io.ptrs(1)
- } .elsewhen (state === sWriteData) { // increment by 8-bytes
+ }.elsewhen(state === sWriteData) { // increment by 8-bytes
raddr := raddr + 8.U
waddr := waddr + 8.U
}
io.mem.req.addr := Mux(state === sReadReq, raddr, waddr)
// read
- when (state === sReadData && io.mem.rd.valid) {
+ when(state === sReadData && io.mem.rd.valid) {
reg := io.mem.rd.bits + const
}
io.mem.rd.ready := state === sReadData
io.mem.wr.bits := reg
// count read/write
- when (state === sIdle) {
+ when(state === sIdle) {
cnt := 0.U
- } .elsewhen (state === sWriteData) {
+ }.elsewhen(state === sWriteData) {
cnt := cnt + 1.U
}
val sIdle :: sRead :: Nil = Enum(2)
val state = RegInit(sIdle)
- switch (state) {
- is (sIdle) {
- when (io.host.req.valid && !io.host.req.opcode) {
+ switch(state) {
+ is(sIdle) {
+ when(io.host.req.valid && !io.host.req.opcode) {
state := sRead
}
}
- is (sRead) {
+ is(sRead) {
state := sIdle
}
}
io.host.req.deq := state === sIdle & io.host.req.valid
- val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs)
- val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
+ val nTotal = config.nCtrl + config.nECnt + config.nVals + (2 * config.nPtrs)
+ val reg =
+ Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(nTotal)(_ * 4)
- val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
+ val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = config.nCtrl
val vo = eo + config.nECnt
val po = vo + config.nVals
- when (io.finish) {
+ when(io.finish) {
reg(0) := "b_10".U
- } .elsewhen (state === sIdle && io.host.req.valid &&
- io.host.req.opcode && addr(0).U === io.host.req.addr) {
+ }.elsewhen(state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && addr(0).U === io.host.req.addr) {
reg(0) := io.host.req.value
}
for (i <- 0 until config.nECnt) {
- when (io.ecnt(i).valid) {
+ when(io.ecnt(i).valid) {
reg(eo + i) := io.ecnt(i).bits
- } .elsewhen (state === sIdle && io.host.req.valid &&
- io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
+ }.elsewhen(state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
reg(eo + i) := io.host.req.value
}
}
- for (i <- 0 until (config.nVals + (2*config.nPtrs))) {
- when (state === sIdle && io.host.req.valid &&
- io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
+ for (i <- 0 until (config.nVals + (2 * config.nPtrs))) {
+ when(
+ state === sIdle && io.host.req.valid &&
+ io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value
}
}
val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
- when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
+ when(state === sIdle && io.host.req.valid && !io.host.req.opcode) {
rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
}
}
for (i <- 0 until config.nPtrs) {
- io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
+ io.ptrs(i) := Cat(reg(po + (2 * i) + 1), reg(po + (2 * i)))
}
}
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+maxColumn = 100
+rewrite.rules = [SortModifiers, SortImports]
lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
endif
-default: lib
+default: lint lib
+
+lint:
+ sbt scalafmt
lib: $(lib_path)
*/
logLevel := Level.Warn
+addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1")
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)
- val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
+ val s = Seq.tabulate(2)(_ =>
+ Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
val loadUop = Module(new LoadUop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
val dec = Module(new ComputeDecode)
dec.io.inst := inst_q.io.deq.bits
- val inst_type = Cat(dec.io.isFinish,
- dec.io.isAlu,
- dec.io.isGemm,
- dec.io.isLoadAcc,
- dec.io.isLoadUop).asUInt
+ val inst_type =
+ Cat(dec.io.isFinish,
+ dec.io.isAlu,
+ dec.io.isGemm,
+ dec.io.isLoadAcc,
+ dec.io.isLoadUop).asUInt
val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
val start = snext & sprev
val done =
- MuxLookup(inst_type,
- false.B, // default
+ MuxLookup(
+ inst_type,
+ false.B, // default
Array(
"h_01".U -> loadUop.io.done,
"h_02".U -> tensorAcc.io.done,
)
// control
- switch (state) {
- is (sIdle) {
- when (start) {
- when (dec.io.isSync) {
+ switch(state) {
+ is(sIdle) {
+ when(start) {
+ when(dec.io.isSync) {
state := sSync
- } .elsewhen (inst_type.orR) {
+ }.elsewhen(inst_type.orR) {
state := sExe
}
}
}
- is (sSync) {
+ is(sSync) {
state := sIdle
}
- is (sExe) {
- when (done) {
+ is(sExe) {
+ when(done) {
state := sIdle
}
}
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
// uop
- loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
+ loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
- loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
+ loadUop.io.uop.idx <> Mux(dec.io.isGemm,
+ tensorGemm.io.uop.idx,
+ tensorAlu.io.uop.idx)
// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
- tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
- tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
+ tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm,
+ tensorGemm.io.acc.rd.idx,
+ tensorAlu.io.acc.rd.idx)
+ tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm,
+ tensorGemm.io.acc.wr,
+ tensorAlu.io.acc.wr)
io.vme_rd(1) <> tensorAcc.io.vme_rd
// gemm
- tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
+ tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
tensorGemm.io.inst := inst_q.io.deq.bits
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
// alu
- tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
+ tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
tensorAlu.io.inst := inst_q.io.deq.bits
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
// out
- io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
+ io.out.rd.idx <> Mux(dec.io.isGemm,
+ tensorGemm.io.out.rd.idx,
+ tensorAlu.io.out.rd.idx)
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
// semaphore
// debug
if (debug) {
// start
- when (state === sIdle && start) {
- when (dec.io.isSync) {
+ when(state === sIdle && start) {
+ when(dec.io.isSync) {
printf("[Compute] start sync\n")
- } .elsewhen (dec.io.isLoadUop) {
- printf("[Compute] start load uop\n")
- } .elsewhen (dec.io.isLoadAcc) {
- printf("[Compute] start load acc\n")
- } .elsewhen (dec.io.isGemm) {
- printf("[Compute] start gemm\n")
- } .elsewhen (dec.io.isAlu) {
- printf("[Compute] start alu\n")
- } .elsewhen (dec.io.isFinish) {
- printf("[Compute] start finish\n")
- }
+ }.elsewhen(dec.io.isLoadUop) {
+ printf("[Compute] start load uop\n")
+ }
+ .elsewhen(dec.io.isLoadAcc) {
+ printf("[Compute] start load acc\n")
+ }
+ .elsewhen(dec.io.isGemm) {
+ printf("[Compute] start gemm\n")
+ }
+ .elsewhen(dec.io.isAlu) {
+ printf("[Compute] start alu\n")
+ }
+ .elsewhen(dec.io.isFinish) {
+ printf("[Compute] start finish\n")
+ }
}
// done
- when (state === sSync) {
+ when(state === sSync) {
printf("[Compute] done sync\n")
}
- when (state === sExe) {
- when (done) {
- when (dec.io.isLoadUop) {
+ when(state === sExe) {
+ when(done) {
+ when(dec.io.isLoadUop) {
printf("[Compute] done load uop\n")
- } .elsewhen (dec.io.isLoadAcc) {
- printf("[Compute] done load acc\n")
- } .elsewhen (dec.io.isGemm) {
- printf("[Compute] done gemm\n")
- } .elsewhen (dec.io.isAlu) {
- printf("[Compute] done alu\n")
- } .elsewhen (dec.io.isFinish) {
- printf("[Compute] done finish\n")
- }
+ }.elsewhen(dec.io.isLoadAcc) {
+ printf("[Compute] done load acc\n")
+ }
+ .elsewhen(dec.io.isGemm) {
+ printf("[Compute] done gemm\n")
+ }
+ .elsewhen(dec.io.isAlu) {
+ printf("[Compute] done alu\n")
+ }
+ .elsewhen(dec.io.isFinish) {
+ printf("[Compute] done finish\n")
+ }
}
}
}
* be eventually filled out with class configurations that can be
* mixed/matched with Shell configurations for different backends.
*/
-class CoreConfig extends Config((site, here, up) => {
- case CoreKey => CoreParams(
- batch = 1,
- blockOut = 16,
- blockIn = 16,
- inpBits = 8,
- wgtBits = 8,
- uopBits = 32,
- accBits = 32,
- outBits = 8,
- uopMemDepth = 2048,
- inpMemDepth = 2048,
- wgtMemDepth = 1024,
- accMemDepth = 2048,
- outMemDepth = 2048,
- instQueueEntries = 512)
-})
+class CoreConfig
+ extends Config((site, here, up) => {
+ case CoreKey =>
+ CoreParams(
+ batch = 1,
+ blockOut = 16,
+ blockIn = 16,
+ inpBits = 8,
+ wgtBits = 8,
+ uopBits = 32,
+ accBits = 32,
+ outBits = 8,
+ uopMemDepth = 2048,
+ inpMemDepth = 2048,
+ wgtMemDepth = 1024,
+ accMemDepth = 2048,
+ outMemDepth = 2048,
+ instQueueEntries = 512
+ )
+ })
import vta.shell._
/** Core parameters */
-case class CoreParams (
- batch: Int = 1,
- blockOut: Int = 16,
- blockIn: Int = 16,
- inpBits: Int = 8,
- wgtBits: Int = 8,
- uopBits: Int = 32,
- accBits: Int = 32,
- outBits: Int = 8,
- uopMemDepth: Int = 512,
- inpMemDepth: Int = 512,
- wgtMemDepth: Int = 512,
- accMemDepth: Int = 512,
- outMemDepth: Int = 512,
- instQueueEntries: Int = 32
-)
-{
- require (uopBits % 8 == 0, s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
+case class CoreParams(
+ batch: Int = 1,
+ blockOut: Int = 16,
+ blockIn: Int = 16,
+ inpBits: Int = 8,
+ wgtBits: Int = 8,
+ uopBits: Int = 32,
+ accBits: Int = 32,
+ outBits: Int = 8,
+ uopMemDepth: Int = 512,
+ inpMemDepth: Int = 512,
+ wgtMemDepth: Int = 512,
+ accMemDepth: Int = 512,
+ outMemDepth: Int = 512,
+ instQueueEntries: Int = 32
+) {
+ require(uopBits % 8 == 0,
+ s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
}
case object CoreKey extends Field[CoreParams]
val isStore = Output(Bool())
})
val csignals =
- ListLookup(io.inst,
- List(N, OP_X),
+ ListLookup(
+ io.inst,
+ List(N, OP_X),
Array(
LUOP -> List(Y, OP_G),
LWGT -> List(Y, OP_L),
* If one would like to add an event counter, then the value of nECnt must be
* changed in VCRParams together with the corresponding counting logic here.
*/
-class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class EventCounters(debug: Boolean = false)(implicit p: Parameters)
+ extends Module {
val vp = p(ShellKey).vcrParams
- val io = IO(new Bundle{
+ val io = IO(new Bundle {
val launch = Input(Bool())
val finish = Input(Bool())
val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
})
val cycle_cnt = RegInit(0.U(vp.regBits.W))
- when (io.launch && !io.finish) {
+ when(io.launch && !io.finish) {
cycle_cnt := cycle_cnt + 1.U
- } .otherwise {
+ }.otherwise {
cycle_cnt := 0.U
}
io.ecnt(0).valid := io.finish
val xrem = Reg(chiselTypeOf(io.ins_count))
val xsize = (io.ins_count << 1.U) - 1.U
val xmax = (1 << mp.lenBits).U
- val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+ val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
val state = RegInit(sIdle)
// control
- switch (state) {
- is (sIdle) {
- when (pulse) {
+ switch(state) {
+ is(sIdle) {
+ when(pulse) {
state := sReadCmd
- when (xsize < xmax) {
+ when(xsize < xmax) {
rlen := xsize
ilen := xsize >> 1.U
xrem := 0.U
- } .otherwise {
+ }.otherwise {
rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U
xrem := xsize - xmax
}
}
}
- is (sReadCmd) {
- when (io.vme_rd.cmd.ready) {
+ is(sReadCmd) {
+ when(io.vme_rd.cmd.ready) {
state := sReadLSB
}
}
- is (sReadLSB) {
- when (io.vme_rd.data.valid) {
+ is(sReadLSB) {
+ when(io.vme_rd.data.valid) {
state := sReadMSB
}
}
- is (sReadMSB) {
- when (io.vme_rd.data.valid) {
- when (inst_q.io.count === ilen) {
+ is(sReadMSB) {
+ when(io.vme_rd.data.valid) {
+ when(inst_q.io.count === ilen) {
state := sDrain
- } .otherwise {
+ }.otherwise {
state := sReadLSB
}
}
}
- is (sDrain) {
- when (inst_q.io.count === 0.U) {
- when (xrem === 0.U) {
+ is(sDrain) {
+ when(inst_q.io.count === 0.U) {
+ when(xrem === 0.U) {
state := sIdle
- } .elsewhen (xrem < xmax) {
- state := sReadCmd
- rlen := xrem
- ilen := xrem >> 1.U
- xrem := 0.U
- } .otherwise {
- state := sReadCmd
- rlen := xmax - 1.U
- ilen := (xmax >> 1.U) - 1.U
- xrem := xrem - xmax
- }
+ }.elsewhen(xrem < xmax) {
+ state := sReadCmd
+ rlen := xrem
+ ilen := xrem >> 1.U
+ xrem := 0.U
+ }
+ .otherwise {
+ state := sReadCmd
+ rlen := xmax - 1.U
+ ilen := (xmax >> 1.U) - 1.U
+ xrem := xrem - xmax
+ }
}
}
}
// read instructions from dram
- when (state === sIdle) {
+ when(state === sIdle) {
raddr := io.ins_baddr
- } .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
+ }.elsewhen(state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
raddr := raddr + xmax_bytes
}
val msb = io.vme_rd.data.bits
val inst = Cat(msb, lsb)
- when (state === sReadLSB) { lsb := io.vme_rd.data.bits }
+ when(state === sReadLSB) { lsb := io.vme_rd.data.bits }
inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
inst_q.io.enq.bits := inst
val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
val deq_ready =
MuxLookup(deq_sel,
- false.B, // default
- Array(
- "h_01".U -> io.inst.ld.ready,
- "h_02".U -> io.inst.st.ready,
- "h_04".U -> io.inst.co.ready
- )
- )
+ false.B, // default
+ Array(
+ "h_01".U -> io.inst.ld.ready,
+ "h_02".U -> io.inst.st.ready,
+ "h_04".U -> io.inst.co.ready
+ ))
// dequeue instruction
inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
-
// debug
if (debug) {
- when (state === sIdle && pulse) {
+ when(state === sIdle && pulse) {
printf("[Fetch] Launch\n")
}
// instruction
- when (inst_q.io.deq.fire()) {
- when (dec.io.isLoad) {
+ when(inst_q.io.deq.fire()) {
+ when(dec.io.isLoad) {
printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
}
- when (dec.io.isCompute) {
+ when(dec.io.isCompute) {
printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
}
- when (dec.io.isStore) {
+ when(dec.io.isStore) {
printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
}
}
*
* These constants are used for decoding (parsing) fields on instructions.
*/
-trait ISAConstants
-{
- val INST_BITS = 128
+trait ISAConstants {
+ val INST_BITS = 128
- val OP_BITS = 3
+ val OP_BITS = 3
- val M_DEP_BITS = 4
- val M_ID_BITS = 2
- val M_SRAM_OFFSET_BITS = 16
- val M_DRAM_OFFSET_BITS = 32
- val M_SIZE_BITS = 16
- val M_STRIDE_BITS = 16
- val M_PAD_BITS = 4
+ val M_DEP_BITS = 4
+ val M_ID_BITS = 2
+ val M_SRAM_OFFSET_BITS = 16
+ val M_DRAM_OFFSET_BITS = 32
+ val M_SIZE_BITS = 16
+ val M_STRIDE_BITS = 16
+ val M_PAD_BITS = 4
- val C_UOP_BGN_BITS = 13
- val C_UOP_END_BITS = 14
- val C_ITER_BITS = 14
- val C_AIDX_BITS = 11
- val C_IIDX_BITS = 11
- val C_WIDX_BITS = 10
- val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
- val C_ALU_OP_BITS = 3
- val C_ALU_IMM_BITS = 16
+ val C_UOP_BGN_BITS = 13
+ val C_UOP_END_BITS = 14
+ val C_ITER_BITS = 14
+ val C_AIDX_BITS = 11
+ val C_IIDX_BITS = 11
+ val C_WIDX_BITS = 10
+ val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
+ val C_ALU_OP_BITS = 3
+ val C_ALU_IMM_BITS = 16
- val Y = true.B
- val N = false.B
+ val Y = true.B
+ val N = false.B
- val OP_L = 0.asUInt(OP_BITS.W)
- val OP_S = 1.asUInt(OP_BITS.W)
- val OP_G = 2.asUInt(OP_BITS.W)
- val OP_F = 3.asUInt(OP_BITS.W)
- val OP_A = 4.asUInt(OP_BITS.W)
- val OP_X = 5.asUInt(OP_BITS.W)
+ val OP_L = 0.asUInt(OP_BITS.W)
+ val OP_S = 1.asUInt(OP_BITS.W)
+ val OP_G = 2.asUInt(OP_BITS.W)
+ val OP_F = 3.asUInt(OP_BITS.W)
+ val OP_A = 4.asUInt(OP_BITS.W)
+ val OP_X = 5.asUInt(OP_BITS.W)
- val ALU_OP_NUM = 5
- val ALU_OP = Enum(ALU_OP_NUM)
+ val ALU_OP_NUM = 5
+ val ALU_OP = Enum(ALU_OP_NUM)
- val M_ID_U = 0.asUInt(M_ID_BITS.W)
- val M_ID_W = 1.asUInt(M_ID_BITS.W)
- val M_ID_I = 2.asUInt(M_ID_BITS.W)
- val M_ID_A = 3.asUInt(M_ID_BITS.W)
+ val M_ID_U = 0.asUInt(M_ID_BITS.W)
+ val M_ID_W = 1.asUInt(M_ID_BITS.W)
+ val M_ID_I = 2.asUInt(M_ID_BITS.W)
+ val M_ID_A = 3.asUInt(M_ID_BITS.W)
}
/** ISA.
* TODO: Add VXOR to clear accumulator
*/
object ISA {
- def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
- def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
- def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
- def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
- def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
- def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
- def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
- def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
- def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
- def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
- def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
+ def LUOP =
+ BitPat(
+ "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
+ def LWGT =
+ BitPat(
+ "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
+ def LINP =
+ BitPat(
+ "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
+ def LACC =
+ BitPat(
+ "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
+ def SOUT =
+ BitPat(
+ "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
+ def GEMM =
+ BitPat(
+ "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
+ def VMIN =
+ BitPat(
+ "b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def VMAX =
+ BitPat(
+ "b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def VADD =
+ BitPat(
+ "b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def VSHX =
+ BitPat(
+ "b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+ def FNSH =
+ BitPat(
+ "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
}
val tensorType = Seq("inp", "wgt")
val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
- val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
+ val tensorLoad =
+ Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
// control
- switch (state) {
- is (sIdle) {
- when (start) {
- when (dec.io.isSync) {
+ switch(state) {
+ is(sIdle) {
+ when(start) {
+ when(dec.io.isSync) {
state := sSync
- } .elsewhen (dec.io.isInput || dec.io.isWeight) {
+ }.elsewhen(dec.io.isInput || dec.io.isWeight) {
state := sExe
}
}
}
- is (sSync) {
+ is(sSync) {
state := sIdle
}
- is (sExe) {
- when (done) {
+ is(sExe) {
+ when(done) {
state := sIdle
}
}
// debug
if (debug) {
// start
- when (state === sIdle && start) {
- when (dec.io.isSync) {
+ when(state === sIdle && start) {
+ when(dec.io.isSync) {
printf("[Load] start sync\n")
- } .elsewhen (dec.io.isInput) {
- printf("[Load] start input\n")
- } .elsewhen (dec.io.isWeight) {
- printf("[Load] start weight\n")
- }
+ }.elsewhen(dec.io.isInput) {
+ printf("[Load] start input\n")
+ }
+ .elsewhen(dec.io.isWeight) {
+ printf("[Load] start weight\n")
+ }
}
// done
- when (state === sSync) {
+ when(state === sSync) {
printf("[Load] done sync\n")
}
- when (state === sExe) {
- when (done) {
- when (dec.io.isInput) {
+ when(state === sExe) {
+ when(done) {
+ when(dec.io.isInput) {
printf("[Load] done input\n")
- } .elsewhen (dec.io.isWeight) {
+ }.elsewhen(dec.io.isWeight) {
printf("[Load] done weight\n")
}
}
val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize))
- val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
+ val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
val xmax = (1 << mp.lenBits).U
- val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+ val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
val sizeIsEven = (dec.xsize % 2.U) === 0.U
val state = RegInit(sIdle)
// control
- switch (state) {
- is (sIdle) {
- when (io.start) {
+ switch(state) {
+ is(sIdle) {
+ when(io.start) {
state := sReadCmd
- when (xsize < xmax) {
+ when(xsize < xmax) {
xlen := xsize
xrem := 0.U
- } .otherwise {
+ }.otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
}
- is (sReadCmd) {
- when (io.vme_rd.cmd.ready) {
+ is(sReadCmd) {
+ when(io.vme_rd.cmd.ready) {
state := sReadData
}
}
- is (sReadData) {
- when (io.vme_rd.data.valid) {
+ is(sReadData) {
+ when(io.vme_rd.data.valid) {
when(xcnt === xlen) {
- when (xrem === 0.U) {
+ when(xrem === 0.U) {
state := sIdle
- } .elsewhen (xrem < xmax) {
- state := sReadCmd
- xlen := xrem
- xrem := 0.U
- } .otherwise {
- state := sReadCmd
- xlen := xmax - 1.U
- xrem := xrem - xmax
- }
+ }.elsewhen(xrem < xmax) {
+ state := sReadCmd
+ xlen := xrem
+ xrem := 0.U
+ }
+ .otherwise {
+ state := sReadCmd
+ xlen := xmax - 1.U
+ xrem := xrem - xmax
+ }
}
}
}
// read-from-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
- when (state === sIdle) {
- when (offsetIsEven) {
+ when(state === sIdle) {
+ when(offsetIsEven) {
raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))
- } .otherwise {
- raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))) - uopBytes.U
+ }.otherwise {
+ raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
+ uopBytes)))) - uopBytes.U
}
- } .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
+ }.elsewhen(state === sReadData && xcnt === xlen && xrem =/= 0.U) {
raddr := raddr + xmax_bytes
}
io.vme_rd.data.ready := state === sReadData
- when (state =/= sReadData) {
+ when(state =/= sReadData) {
xcnt := 0.U
- } .elsewhen (io.vme_rd.data.fire()) {
+ }.elsewhen(io.vme_rd.data.fire()) {
xcnt := xcnt + 1.U
}
val waddr = Reg(UInt(log2Ceil(uopDepth).W))
- when (state === sIdle) {
+ when(state === sIdle) {
waddr := dec.sram_offset >> log2Ceil(numUop)
- } .elsewhen (io.vme_rd.data.fire()) {
+ }.elsewhen(io.vme_rd.data.fire()) {
waddr := waddr + 1.U
}
val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
val wmask = Reg(Vec(numUop, Bool()))
- when (offsetIsEven) {
- when (sizeIsEven) {
+ when(offsetIsEven) {
+ when(sizeIsEven) {
wmask := "b_11".U.asTypeOf(wmask)
- } .elsewhen (io.vme_rd.cmd.fire()) {
- when (dec.xsize === 1.U) {
- wmask := "b_01".U.asTypeOf(wmask)
- } .otherwise {
- wmask := "b_11".U.asTypeOf(wmask)
+ }.elsewhen(io.vme_rd.cmd.fire()) {
+ when(dec.xsize === 1.U) {
+ wmask := "b_01".U.asTypeOf(wmask)
+ }.otherwise {
+ wmask := "b_11".U.asTypeOf(wmask)
+ }
}
- } .elsewhen (io.vme_rd.data.fire()) {
- when (xcnt === xlen - 1.U) {
- wmask := "b_01".U.asTypeOf(wmask)
- } .otherwise {
- wmask := "b_11".U.asTypeOf(wmask)
+ .elsewhen(io.vme_rd.data.fire()) {
+ when(xcnt === xlen - 1.U) {
+ wmask := "b_01".U.asTypeOf(wmask)
+ }.otherwise {
+ wmask := "b_11".U.asTypeOf(wmask)
+ }
}
- }
- } .otherwise {
- when (io.vme_rd.cmd.fire()) {
+ }.otherwise {
+ when(io.vme_rd.cmd.fire()) {
wmask := "b_10".U.asTypeOf(wmask)
- } .elsewhen (io.vme_rd.data.fire()) {
- when (sizeIsEven && xcnt === xlen - 1.U) {
+ }.elsewhen(io.vme_rd.data.fire()) {
+ when(sizeIsEven && xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
- } .otherwise {
+ }.otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}
}
}
wdata := io.vme_rd.data.bits.asTypeOf(wdata)
- when (io.vme_rd.data.fire()) {
+ when(io.vme_rd.data.fire()) {
mem.write(waddr, wdata, wmask)
}
// debug
if (debug) {
- when (io.vme_rd.cmd.fire()) {
+ when(io.vme_rd.cmd.fire()) {
printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
}
}
* depending on the push and pop fields on instructions to prevent RAW and WAR
* hazards.
*/
-class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
+class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1)
+ extends Module {
val io = IO(new Bundle {
val spost = Input(Bool())
val swait = Input(Bool())
val sready = Output(Bool())
})
val cnt = RegInit(counterInitValue.U(counterBits.W))
- when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U }
- when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
+ when(io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) {
+ cnt := cnt + 1.U
+ }
+ when(!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
io.sready := cnt =/= 0.U
}
val done = tensorStore.io.done
// control
- switch (state) {
- is (sIdle) {
- when (start) {
- when (dec.io.isSync) {
+ switch(state) {
+ is(sIdle) {
+ when(start) {
+ when(dec.io.isSync) {
state := sSync
- } .elsewhen (dec.io.isStore) {
+ }.elsewhen(dec.io.isStore) {
state := sExe
}
}
}
- is (sSync) {
+ is(sSync) {
state := sIdle
}
- is (sExe) {
- when (done) {
+ is(sExe) {
+ when(done) {
state := sIdle
}
}
// debug
if (debug) {
// start
- when (state === sIdle && start) {
- when (dec.io.isSync) {
+ when(state === sIdle && start) {
+ when(dec.io.isSync) {
printf("[Store] start sync\n")
- } .elsewhen (dec.io.isStore) {
+ }.elsewhen(dec.io.isStore) {
printf("[Store] start\n")
}
}
// done
- when (state === sSync) {
+ when(state === sSync) {
printf("[Store] done sync\n")
}
- when (state === sExe) {
- when (done) {
+ when(state === sExe) {
+ when(done) {
printf("[Store] done\n")
}
}
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
- val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6)
+ val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil =
+ Enum(6)
val state = RegInit(sIdle)
val alu = Module(new AluVector)
val dec = io.inst.asTypeOf(new AluDecode)
val src_i = Reg(chiselTypeOf(dec.uop_end))
val done =
state === sExe &
- alu.io.out.data.valid &
- (cnt_o === dec.lp_0 - 1.U) &
- (cnt_i === dec.lp_1 - 1.U) &
- (uop_idx === uop_end - 1.U)
-
- switch (state) {
- is (sIdle) {
- when (io.start) {
+ alu.io.out.data.valid &
+ (cnt_o === dec.lp_0 - 1.U) &
+ (cnt_i === dec.lp_1 - 1.U) &
+ (uop_idx === uop_end - 1.U)
+
+ switch(state) {
+ is(sIdle) {
+ when(io.start) {
state := sReadUop
}
}
- is (sReadUop) {
+ is(sReadUop) {
state := sComputeIdx
}
- is (sComputeIdx) {
+ is(sComputeIdx) {
state := sReadTensorA
}
- is (sReadTensorA) {
+ is(sReadTensorA) {
state := sReadTensorB
}
- is (sReadTensorB) {
+ is(sReadTensorB) {
state := sExe
}
- is (sExe) {
- when (alu.io.out.data.valid) {
- when ((cnt_o === dec.lp_0 - 1.U) &&
- (cnt_i === dec.lp_1 - 1.U) &&
- (uop_idx === uop_end - 1.U)) {
+ is(sExe) {
+ when(alu.io.out.data.valid) {
+ when(
+ (cnt_o === dec.lp_0 - 1.U) &&
+ (cnt_i === dec.lp_1 - 1.U) &&
+ (uop_idx === uop_end - 1.U)) {
state := sIdle
- } .otherwise {
+ }.otherwise {
state := sReadUop
}
}
}
}
- when (state === sIdle ||
- (state === sExe &&
- alu.io.out.data.valid &&
- uop_idx === uop_end - 1.U)) {
+ when(
+ state === sIdle ||
+ (state === sExe &&
+ alu.io.out.data.valid &&
+ uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
- } .elsewhen (state === sExe && alu.io.out.data.valid) {
+ }.elsewhen(state === sExe && alu.io.out.data.valid) {
uop_idx := uop_idx + 1.U
}
- when (state === sIdle) {
+ when(state === sIdle) {
cnt_o := 0.U
dst_o := 0.U
src_o := 0.U
- } .elsewhen (state === sExe &&
- alu.io.out.data.valid &&
- uop_idx === uop_end - 1.U &&
- cnt_i === dec.lp_1 - 1.U) {
+ }.elsewhen(
+ state === sExe &&
+ alu.io.out.data.valid &&
+ uop_idx === uop_end - 1.U &&
+ cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
dst_o := dst_o + dec.dst_0
src_o := src_o + dec.src_0
}
- when (state === sIdle) {
+ when(state === sIdle) {
cnt_i := 0.U
dst_i := 0.U
src_i := 0.U
- } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
- cnt_i := 0.U
- dst_i := dst_o
- src_i := src_o
- } .elsewhen (state === sExe &&
- alu.io.out.data.valid &&
- uop_idx === uop_end - 1.U) {
- cnt_i := cnt_i + 1.U
- dst_i := dst_i + dec.dst_1
- src_i := src_i + dec.src_1
- }
+ }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
+ cnt_i := 0.U
+ dst_i := dst_o
+ src_i := src_o
+ }
+ .elsewhen(
+ state === sExe &&
+ alu.io.out.data.valid &&
+ uop_idx === uop_end - 1.U) {
+ cnt_i := cnt_i + 1.U
+ dst_i := dst_i + dec.dst_1
+ src_i := src_i + dec.src_1
+ }
- when (state === sComputeIdx && io.uop.data.valid) {
+ when(state === sComputeIdx && io.uop.data.valid) {
uop_dst := io.uop.data.bits.u0 + dst_i
uop_src := io.uop.data.bits.u1 + src_i
}
// imm
val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
tensorImm.data.valid := state === sReadTensorB
- tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } }
+ tensorImm.data.bits.foreach { b =>
+ b.foreach { c =>
+ c := dec.alu_imm
+ }
+ }
// alu
val isSHR = dec.alu_op === ALU_OP(3)
- val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1)
+ val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
alu.io.opcode := fixme_alu_op
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
alu.io.acc_a.data.bits <> io.acc.rd.data.bits
- alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe)
- alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits)
+ alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
+ tensorImm.data.valid,
+ io.acc.rd.data.valid & state === sExe)
+ alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
+ tensorImm.data.bits,
+ io.acc.rd.data.bits)
// acc_o
io.acc.wr.valid := alu.io.acc_y.data.valid
if (debug) {
- when (state === sReadUop) {
+ when(state === sReadUop) {
printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
}
- when (state === sReadTensorA) {
+ when(state === sReadTensorA) {
printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
}
- when (state === sIdle && io.start) {
+ when(state === sIdle && io.start) {
printf(p"[TensorAlu] decode:$dec\n")
}
alu.io.acc_a.data.bits.foreach { tensor =>
- tensor.zipWithIndex.foreach { case(elem, i) =>
- when (alu.io.acc_a.data.valid) {
- printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
- }
+ tensor.zipWithIndex.foreach {
+ case (elem, i) =>
+ when(alu.io.acc_a.data.valid) {
+ printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
+ }
}
}
alu.io.acc_b.data.bits.foreach { tensor =>
- tensor.zipWithIndex.foreach { case(elem, i) =>
- when (alu.io.acc_b.data.valid) {
- printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
- }
+ tensor.zipWithIndex.foreach {
+ case (elem, i) =>
+ when(alu.io.acc_b.data.valid) {
+ printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
+ }
}
}
alu.io.acc_y.data.bits.foreach { tensor =>
- tensor.zipWithIndex.foreach { case(elem, i) =>
- when (alu.io.acc_y.data.valid) {
- printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
- }
+ tensor.zipWithIndex.foreach {
+ case (elem, i) =>
+ when(alu.io.acc_y.data.valid) {
+ printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
+ }
}
}
alu.io.out.data.bits.foreach { tensor =>
- tensor.zipWithIndex.foreach { case(elem, i) =>
- when (alu.io.out.data.valid) {
- printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
- }
+ tensor.zipWithIndex.foreach {
+ case (elem, i) =>
+ when(alu.io.out.data.valid) {
+ printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
+ }
}
}
}
}
/** Pipelined DotProduct based on MAC and PipeAdder */
-class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module {
- val errorMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
+class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
+ extends Module {
+ val errorMsg =
+ s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require(size >= 2 && isPow2(size), errorMsg)
val b = aBits + bBits
val outBits = b + log2Ceil(size) + 1
val b = Input(Vec(size, SInt(bBits.W)))
val y = Output(SInt(outBits.W))
})
- val s = Seq.tabulate(log2Ceil(size + 1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers
+ val s = Seq.tabulate(log2Ceil(size + 1))(i =>
+ pow(2, log2Ceil(size) - i).toInt) // # of total layers
val p = log2Ceil(size / 2) + 1 // # of adder layers
val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
- val a = Seq.tabulate(p)(i =>
- Seq.fill(s(i + 1))(Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1))))
- ) // # adders within each layer
+ val a = Seq.tabulate(p)(
+ i =>
+ Seq.fill(s(i + 1))(Module(new PipeAdder(
+ aBits = (b + i + 1),
+ bBits = (b + i + 1))))) // # adders within each layer
// Vector MACs
for (i <- 0 until s(0)) {
val inpBits = p(CoreKey).inpBits
val wgtBits = p(CoreKey).wgtBits
val outBits = p(CoreKey).outBits
- val io = IO(new Bundle{
+ val io = IO(new Bundle {
val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
val inp = new TensorMasterData(tensorType = "inp")
val wgt = new TensorMasterData(tensorType = "wgt")
val acc_o = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out")
})
- val dot = Seq.fill(size)(Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
- val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
+ val dot = Seq.fill(size)(
+ Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
+ val acc = Seq.fill(size)(
+ Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
val add = Seq.fill(size)(Wire(SInt(accBits.W)))
val vld = Wire(Vec(size, Bool()))
* Also, the TensorGemm uses the reset field in the Gemm instruction to
* clear or zero-out the acc-scratchpad locations based on the micro-ops.
*/
-class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
+ extends Module {
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
- val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
+ val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
+ Enum(6)
val state = RegInit(sIdle)
val mvc = Module(new MatrixVectorMultiplication)
val dec = io.inst.asTypeOf(new GemmDecode)
val inflight = Reg(UInt(pBits.W))
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
val done = inflight === 0.U &
- ((state === sExe &
- cnt_o === dec.lp_0 - 1.U &
- cnt_i === dec.lp_1 - 1.U &
- uop_idx === uop_end - 1.U &
- inflight === 0.U) |
- state === sWait)
-
- switch (state) {
- is (sIdle) {
- when (io.start) {
+ ((state === sExe &
+ cnt_o === dec.lp_0 - 1.U &
+ cnt_i === dec.lp_1 - 1.U &
+ uop_idx === uop_end - 1.U &
+ inflight === 0.U) |
+ state === sWait)
+
+ switch(state) {
+ is(sIdle) {
+ when(io.start) {
state := sReadUop
}
}
- is (sReadUop) {
+ is(sReadUop) {
state := sComputeIdx
}
- is (sComputeIdx) {
+ is(sComputeIdx) {
state := sReadTensor
}
- is (sReadTensor) {
+ is(sReadTensor) {
state := sExe
}
- is (sExe) {
- when ((cnt_o === dec.lp_0 - 1.U) &&
- (cnt_i === dec.lp_1 - 1.U) &&
- (uop_idx === uop_end - 1.U)) {
- when (inflight =/= 0.U) {
+ is(sExe) {
+ when(
+ (cnt_o === dec.lp_0 - 1.U) &&
+ (cnt_i === dec.lp_1 - 1.U) &&
+ (uop_idx === uop_end - 1.U)) {
+ when(inflight =/= 0.U) {
state := sWait
- } .otherwise {
+ }.otherwise {
state := sIdle
}
- } .otherwise {
+ }.otherwise {
state := sReadUop
}
}
- is (sWait) {
- when (inflight === 0.U) {
+ is(sWait) {
+ when(inflight === 0.U) {
state := sIdle
}
}
}
- when (state === sIdle) {
+ when(state === sIdle) {
inflight := 0.U
- } .elsewhen (!dec.reset) {
- when (state === sReadTensor) { // issue a tensor
+ }.elsewhen(!dec.reset) {
+ when(state === sReadTensor) { // issue a tensor
inflight := inflight + 1.U
- } .elsewhen (mvc.io.acc_o.data.valid) { // commit a tensor
+ }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
inflight := inflight - 1.U
}
}
- when (state === sIdle ||
- (state === sExe &&
- uop_idx === uop_end - 1.U)) {
+ when(
+ state === sIdle ||
+ (state === sExe &&
+ uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
- } .elsewhen (state === sExe) {
+ }.elsewhen(state === sExe) {
uop_idx := uop_idx + 1.U
}
- when (state === sIdle) {
+ when(state === sIdle) {
cnt_o := 0.U
acc_o := 0.U
inp_o := 0.U
wgt_o := 0.U
- } .elsewhen (state === sExe &&
- uop_idx === uop_end - 1.U &&
- cnt_i === dec.lp_1 - 1.U) {
+ }.elsewhen(
+ state === sExe &&
+ uop_idx === uop_end - 1.U &&
+ cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
acc_o := acc_o + dec.acc_0
inp_o := inp_o + dec.inp_0
wgt_o := wgt_o + dec.wgt_0
}
- when (state === sIdle) {
+ when(state === sIdle) {
cnt_i := 0.U
acc_i := 0.U
inp_i := 0.U
wgt_i := 0.U
- } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
- cnt_i := 0.U
- acc_i := acc_o
- inp_i := inp_o
- wgt_i := wgt_o
- } .elsewhen (state === sExe &&
- uop_idx === uop_end - 1.U) {
- cnt_i := cnt_i + 1.U
- acc_i := acc_i + dec.acc_1
- inp_i := inp_i + dec.inp_1
- wgt_i := wgt_i + dec.wgt_1
- }
+ }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
+ cnt_i := 0.U
+ acc_i := acc_o
+ inp_i := inp_o
+ wgt_i := wgt_o
+ }
+ .elsewhen(state === sExe &&
+ uop_idx === uop_end - 1.U) {
+ cnt_i := cnt_i + 1.U
+ acc_i := acc_i + dec.acc_1
+ inp_i := inp_i + dec.inp_1
+ wgt_i := wgt_i + dec.wgt_1
+ }
- when (state === sComputeIdx && io.uop.data.valid) {
+ when(state === sComputeIdx && io.uop.data.valid) {
uop_acc := io.uop.data.bits.u0 + acc_i
uop_inp := io.uop.data.bits.u1 + inp_i
uop_wgt := io.uop.data.bits.u2 + wgt_i
mvc.io.acc_i.data <> io.acc.rd.data
// acc_o
- io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid)
+ io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset,
+ true.B,
+ wrpipe.io.deq.valid)
io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
io.done := done
if (debug) {
- when (state === sReadUop && ~dec.reset) {
+ when(state === sReadUop && ~dec.reset) {
printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
}
- when (state === sReadTensor && ~dec.reset) {
- printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
+ when(state === sReadTensor && ~dec.reset) {
+ printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n",
+ uop_acc,
+ uop_inp,
+ uop_wgt)
}
- io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
- when (io.inp.rd.data.valid && ~dec.reset) {
- printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
- }
+ io.inp.rd.data.bits.zipWithIndex.foreach {
+ case (r, i) =>
+ when(io.inp.rd.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
+ }
}
- io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
- when (io.wgt.rd.data.valid && ~dec.reset) {
- printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
- }
+ io.wgt.rd.data.bits.zipWithIndex.foreach {
+ case (r, i) =>
+ when(io.wgt.rd.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
+ }
}
io.acc.rd.data.bits.foreach { tensor =>
- tensor.zipWithIndex.foreach { case(elem, i) =>
- when (io.acc.rd.data.valid && ~dec.reset) {
- printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
- }
+ tensor.zipWithIndex.foreach {
+ case (elem, i) =>
+ when(io.acc.rd.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
+ }
}
}
mvc.io.acc_o.data.bits.foreach { tensor =>
- tensor.zipWithIndex.foreach { case(elem, i) =>
- when (mvc.io.acc_o.data.valid && ~dec.reset) {
- printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
- }
+ tensor.zipWithIndex.foreach {
+ case (elem, i) =>
+ when(mvc.io.acc_o.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
+ }
}
}
mvc.io.out.data.bits.foreach { tensor =>
- tensor.zipWithIndex.foreach { case(elem, i) =>
- when (mvc.io.out.data.valid && ~dec.reset) {
- printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
- }
+ tensor.zipWithIndex.foreach {
+ case (elem, i) =>
+ when(mvc.io.out.data.valid && ~dec.reset) {
+ printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
+ }
}
}
}
* managed by TensorPadCtrl. The TensorDataCtrl is in charge of
* handling the way tensors are stored on the scratchpads.
*/
-class TensorLoad(tensorType: String = "none", debug: Boolean = false)
- (implicit p: Parameters) extends Module {
+class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val strideFactor = tp.tensorLength * tp.tensorWidth
val dec = io.inst.asTypeOf(new MemDecode)
- val dataCtrl = Module(new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
+ val dataCtrl = Module(
+ new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
val dataCtrlDone = RegInit(false.B)
val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
val tag = Reg(UInt(log2Ceil(tp.numMemBlock).W))
val set = Reg(UInt(log2Ceil(tp.tensorLength).W))
- val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7)
+ val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil =
+ Enum(7)
val state = RegInit(sIdle)
// control
- switch (state) {
- is (sIdle) {
- when (io.start) {
- when (dec.ypad_0 =/= 0.U) {
+ switch(state) {
+ is(sIdle) {
+ when(io.start) {
+ when(dec.ypad_0 =/= 0.U) {
state := sYPad0
- } .elsewhen (dec.xpad_0 =/= 0.U) {
- state := sXPad0
- } .otherwise {
- state := sReadCmd
- }
+ }.elsewhen(dec.xpad_0 =/= 0.U) {
+ state := sXPad0
+ }
+ .otherwise {
+ state := sReadCmd
+ }
}
}
- is (sYPad0) {
- when (yPadCtrl0.io.done) {
- when (dec.xpad_0 =/= 0.U) {
+ is(sYPad0) {
+ when(yPadCtrl0.io.done) {
+ when(dec.xpad_0 =/= 0.U) {
state := sXPad0
- } .otherwise {
+ }.otherwise {
state := sReadCmd
}
}
}
- is (sXPad0) {
- when (xPadCtrl0.io.done) {
+ is(sXPad0) {
+ when(xPadCtrl0.io.done) {
state := sReadCmd
}
}
- is (sReadCmd) {
- when (io.vme_rd.cmd.ready) {
+ is(sReadCmd) {
+ when(io.vme_rd.cmd.ready) {
state := sReadData
}
}
- is (sReadData) {
- when (io.vme_rd.data.valid) {
- when (dataCtrl.io.done) {
- when (dec.xpad_1 =/= 0.U) {
+ is(sReadData) {
+ when(io.vme_rd.data.valid) {
+ when(dataCtrl.io.done) {
+ when(dec.xpad_1 =/= 0.U) {
state := sXPad1
- } .elsewhen (dec.ypad_1 =/= 0.U) {
- state := sYPad1
- } .otherwise {
- state := sIdle
- }
- } .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) {
- when (dec.xpad_1 =/= 0.U) {
+ }.elsewhen(dec.ypad_1 =/= 0.U) {
+ state := sYPad1
+ }
+ .otherwise {
+ state := sIdle
+ }
+ }.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) {
+ when(dec.xpad_1 =/= 0.U) {
state := sXPad1
- } .elsewhen (dec.xpad_0 =/= 0.U) {
- state := sXPad0
- } .otherwise {
+ }.elsewhen(dec.xpad_0 =/= 0.U) {
+ state := sXPad0
+ }
+ .otherwise {
state := sReadCmd
- }
+ }
}
}
}
- is (sXPad1) {
- when (xPadCtrl1.io.done) {
- when (dataCtrlDone) {
- when (dec.ypad_1 =/= 0.U) {
+ is(sXPad1) {
+ when(xPadCtrl1.io.done) {
+ when(dataCtrlDone) {
+ when(dec.ypad_1 =/= 0.U) {
state := sYPad1
- } .otherwise {
+ }.otherwise {
state := sIdle
}
- } .otherwise {
- when (dec.xpad_0 =/= 0.U) {
+ }.otherwise {
+ when(dec.xpad_0 =/= 0.U) {
state := sXPad0
- } .otherwise {
+ }.otherwise {
state := sReadCmd
}
}
}
}
- is (sYPad1) {
- when (yPadCtrl1.io.done && dataCtrlDone) {
+ is(sYPad1) {
+ when(yPadCtrl1.io.done && dataCtrlDone) {
state := sIdle
}
}
dataCtrl.io.xupdate := io.vme_rd.data.fire()
dataCtrl.io.yupdate := io.vme_rd.data.fire()
- when (state === sIdle) {
+ when(state === sIdle) {
dataCtrlDone := false.B
- } .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) {
+ }.elsewhen(io.vme_rd.data.fire() && dataCtrl.io.done) {
dataCtrlDone := true.B
}
yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
- ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
- (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
+ ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
+ (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
- ((state === sIdle & io.start) |
- (state === sYPad0 & yPadCtrl0.io.done) |
- (io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
- (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
+ ((state === sIdle & io.start) |
+ (state === sYPad0 & yPadCtrl0.io.done) |
+ (io.vme_rd.data
+ .fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
+ (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
- ((dataCtrl.io.done) |
- (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
+ ((dataCtrl.io.done) |
+ (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
yPadCtrl0.io.inst := io.inst
yPadCtrl1.io.inst := io.inst
// write-to-sram
val isZeroPad = state === sYPad0 |
- state === sXPad0 |
- state === sXPad1 |
- state === sYPad1
+ state === sXPad0 |
+ state === sXPad1 |
+ state === sYPad1
- when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
+ when(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
tag := 0.U
- } .elsewhen (io.vme_rd.data.fire() || isZeroPad) {
+ }.elsewhen(io.vme_rd.data.fire() || isZeroPad) {
tag := tag + 1.U
}
- when (state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
+ when(
+ state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
set := 0.U
- } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
+ }.elsewhen(
+ (io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
set := set + 1.U
}
val waddr_cur = Reg(UInt(tp.memAddrBits.W))
val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
- when (state === sIdle) {
+ when(state === sIdle) {
waddr_cur := dec.sram_offset
waddr_nxt := dec.sram_offset
- } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
- waddr_cur := waddr_cur + 1.U
- } .elsewhen (dataCtrl.io.stride) {
- waddr_cur := waddr_nxt + dec.xsize
- waddr_nxt := waddr_nxt + dec.xsize
- }
+ }.elsewhen((io.vme_rd.data
+ .fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
+ waddr_cur := waddr_cur + 1.U
+ }
+ .elsewhen(dataCtrl.io.stride) {
+ waddr_cur := waddr_nxt + dec.xsize
+ waddr_nxt := waddr_nxt + dec.xsize
+ }
- val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+ val tensorFile = Seq.fill(tp.tensorLength) {
+ SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
+ }
val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
- val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+ val wdata = Seq.fill(tp.tensorLength) {
+ Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
+ }
val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
- no_mask.foreach { m => m := true.B }
+ no_mask.foreach { m =>
+ m := true.B
+ }
for (i <- 0 until tp.tensorLength) {
for (j <- 0 until tp.numMemBlock) {
wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
}
val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
- val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
+ val muxWen =
+ Mux(state === sIdle,
+ io.tensor.wr.valid,
+ (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
val muxWdata = Mux(state === sIdle, tdata, wdata(i))
val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
- when (muxWen) {
+ when(muxWen) {
tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
}
}
val rvalid = RegNext(io.tensor.rd.idx.valid)
io.tensor.rd.data.valid := rvalid
- val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
- rdata.zipWithIndex.foreach { case(r, i) =>
- io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
+ val rdata =
+ tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
+ rdata.zipWithIndex.foreach {
+ case (r, i) =>
+ io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
}
// done
- val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
+ val done_no_pad = io.vme_rd.data
+ .fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
io.done := done_no_pad | done_x_pad | done_y_pad
// debug
if (debug) {
if (tensorType == "inp") {
- when (io.vme_rd.cmd.fire()) {
- printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+ when(io.vme_rd.cmd.fire()) {
+ printf("[TensorLoad] [inp] cmd addr:%x len:%x\n",
+ dataCtrl.io.addr,
+ dataCtrl.io.len)
}
- when (state === sYPad0) {
+ when(state === sYPad0) {
printf("[TensorLoad] [inp] sYPad0\n")
}
- when (state === sYPad1) {
+ when(state === sYPad1) {
printf("[TensorLoad] [inp] sYPad1\n")
}
- when (state === sXPad0) {
+ when(state === sXPad0) {
printf("[TensorLoad] [inp] sXPad0\n")
}
- when (state === sXPad1) {
+ when(state === sXPad1) {
printf("[TensorLoad] [inp] sXPad1\n")
}
} else if (tensorType == "wgt") {
- when (io.vme_rd.cmd.fire()) {
- printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+ when(io.vme_rd.cmd.fire()) {
+ printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n",
+ dataCtrl.io.addr,
+ dataCtrl.io.len)
}
} else if (tensorType == "acc") {
- when (io.vme_rd.cmd.fire()) {
- printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+ when(io.vme_rd.cmd.fire()) {
+ printf("[TensorLoad] [acc] cmd addr:%x len:%x\n",
+ dataCtrl.io.addr,
+ dataCtrl.io.len)
}
}
}
*
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/
-class TensorStore(tensorType: String = "none", debug: Boolean = false)
- (implicit p: Parameters) extends Module {
+class TensorStore(tensorType: String = "none", debug: Boolean = false)(
+ implicit p: Parameters)
+ extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize))
- val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U
+ val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock)) - 1.U
val xmax = (1 << mp.lenBits).U
- val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+ val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val ycnt = Reg(chiselTypeOf(dec.ysize))
val ysize = dec.ysize
val tag = Reg(UInt(8.W))
val state = RegInit(sIdle)
// control
- switch (state) {
- is (sIdle) {
- when (io.start) {
+ switch(state) {
+ is(sIdle) {
+ when(io.start) {
state := sWriteCmd
- when (xsize < xmax) {
+ when(xsize < xmax) {
xlen := xsize
xrem := 0.U
- } .otherwise {
+ }.otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
}
- is (sWriteCmd) {
- when (io.vme_wr.cmd.ready) {
+ is(sWriteCmd) {
+ when(io.vme_wr.cmd.ready) {
state := sWriteData
}
}
- is (sWriteData) {
- when (io.vme_wr.data.ready) {
- when (xcnt === xlen) {
+ is(sWriteData) {
+ when(io.vme_wr.data.ready) {
+ when(xcnt === xlen) {
state := sWriteAck
- } .elsewhen (tag === (numMemBlock - 1).U) {
+ }.elsewhen(tag === (numMemBlock - 1).U) {
state := sReadMem
}
}
}
- is (sReadMem) {
+ is(sReadMem) {
state := sWriteData
}
- is (sWriteAck) {
- when (io.vme_wr.ack) {
- when (xrem === 0.U) {
- when (ycnt === ysize - 1.U) {
+ is(sWriteAck) {
+ when(io.vme_wr.ack) {
+ when(xrem === 0.U) {
+ when(ycnt === ysize - 1.U) {
state := sIdle
- } .otherwise {
+ }.otherwise {
state := sWriteCmd
- when (xsize < xmax) {
+ when(xsize < xmax) {
xlen := xsize
xrem := 0.U
- } .otherwise {
+ }.otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
- } .elsewhen (xrem < xmax) {
- state := sWriteCmd
- xlen := xrem
- xrem := 0.U
- } .otherwise {
- state := sWriteCmd
- xlen := xmax - 1.U
- xrem := xrem - xmax
- }
+ }.elsewhen(xrem < xmax) {
+ state := sWriteCmd
+ xlen := xrem
+ xrem := 0.U
+ }
+ .otherwise {
+ state := sWriteCmd
+ xlen := xmax - 1.U
+ xrem := xrem - xmax
+ }
}
}
}
// write-to-sram
- val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) }
+ val tensorFile = Seq.fill(tensorLength) {
+ SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W)))
+ }
val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
val no_mask = Wire(Vec(numMemBlock, Bool()))
wdata_t := DontCare
- no_mask.foreach { m => m := true.B }
+ no_mask.foreach { m =>
+ m := true.B
+ }
for (i <- 0 until tensorLength) {
val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
- when (io.tensor.wr.valid) {
+ when(io.tensor.wr.valid) {
tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
}
}
// read-from-sram
val stride = state === sWriteAck &
- io.vme_wr.ack &
- xcnt === xlen + 1.U &
- xrem === 0.U &
- ycnt =/= ysize - 1.U
+ io.vme_wr.ack &
+ xcnt === xlen + 1.U &
+ xrem === 0.U &
+ ycnt =/= ysize - 1.U
- when (state === sIdle) {
+ when(state === sIdle) {
ycnt := 0.U
- } .elsewhen (stride) {
+ }.elsewhen(stride) {
ycnt := ycnt + 1.U
}
- when (state === sWriteCmd || tag === (numMemBlock - 1).U) {
+ when(state === sWriteCmd || tag === (numMemBlock - 1).U) {
tag := 0.U
- } .elsewhen (io.vme_wr.data.fire()) {
+ }.elsewhen(io.vme_wr.data.fire()) {
tag := tag + 1.U
}
- when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
+ when(
+ state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
set := 0.U
- } .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
+ }.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
set := set + 1.U
}
val raddr_cur = Reg(UInt(tp.memAddrBits.W))
val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
- when (state === sIdle) {
+ when(state === sIdle) {
raddr_cur := dec.sram_offset
raddr_nxt := dec.sram_offset
- } .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
- raddr_cur := raddr_cur + 1.U
- } .elsewhen (stride) {
- raddr_cur := raddr_nxt + dec.xsize
- raddr_nxt := raddr_nxt + dec.xsize
- }
+ }.elsewhen(io.vme_wr.data
+ .fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
+ raddr_cur := raddr_cur + 1.U
+ }
+ .elsewhen(stride) {
+ raddr_cur := raddr_nxt + dec.xsize
+ raddr_nxt := raddr_nxt + dec.xsize
+ }
- val tread = Seq.tabulate(tensorLength) { i => i.U ->
- tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) }
+ val tread = Seq.tabulate(tensorLength) { i =>
+ i.U ->
+ tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem)
+ }
val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
// write-to-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8
- when (state === sIdle) {
- waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
- waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
- } .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
- waddr_cur := waddr_cur + xmax_bytes
- } .elsewhen (stride) {
- waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
- waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
- }
+ when(state === sIdle) {
+ waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
+ elemBytes)))
+ waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
+ elemBytes)))
+ }.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
+ waddr_cur := waddr_cur + xmax_bytes
+ }
+ .elsewhen(stride) {
+ waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(
+ tensorLength * tensorWidth))
+ waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(
+ tensorLength * tensorWidth))
+ }
io.vme_wr.cmd.valid := state === sWriteCmd
io.vme_wr.cmd.bits.addr := waddr_cur
io.vme_wr.data.valid := state === sWriteData
io.vme_wr.data.bits := mdata(tag)
- when (state === sWriteCmd) {
+ when(state === sWriteCmd) {
xcnt := 0.U
- } .elsewhen (io.vme_wr.data.fire()) {
+ }.elsewhen(io.vme_wr.data.fire()) {
xcnt := xcnt + 1.U
}
// debug
if (debug) {
- when (io.vme_wr.cmd.fire()) {
- printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
+ when(io.vme_wr.cmd.fire()) {
+ printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n",
+ ysize,
+ ycnt,
+ raddr_cur,
+ waddr_cur,
+ xlen,
+ xrem)
}
- when (io.vme_wr.data.fire()) {
+ when(io.vme_wr.data.fire()) {
printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
}
- when (io.vme_wr.ack) {
+ when(io.vme_wr.ack) {
printf("[TensorStore] ack\n")
}
}
* weights (wgt), biases (acc), and outputs (out). This is used to avoid
* doing the same boring calculations over and over again.
*/
-class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
- val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
+class TensorParams(tensorType: String = "none")(implicit p: Parameters)
+ extends Bundle {
+ val errorMsg =
+ s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
- require (tensorType == "inp" || tensorType == "wgt"
- || tensorType == "acc" || tensorType == "out", errorMsg)
+ require(tensorType == "inp" || tensorType == "wgt"
+ || tensorType == "acc" || tensorType == "out",
+ errorMsg)
val (tensorLength, tensorWidth, tensorElemBits) =
if (tensorType == "inp")
* biases (acc), and outputs (out).
*
*/
-class TensorMaster(tensorType: String = "none")
- (implicit p: Parameters) extends TensorParams(tensorType) {
- val rd = new Bundle {
- val idx = ValidIO(UInt(memAddrBits.W))
- val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
- }
- val wr = ValidIO(new Bundle {
- val idx = UInt(memAddrBits.W)
- val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
- })
- def tieoffRead() {
- rd.idx.valid := false.B
- rd.idx.bits := 0.U
- }
- def tieoffWrite() {
- wr.valid := false.B
- wr.bits.idx := 0.U
- wr.bits.data.foreach { b => b.foreach { c => c := 0.U } }
+class TensorMaster(tensorType: String = "none")(implicit p: Parameters)
+ extends TensorParams(tensorType) {
+ val rd = new Bundle {
+ val idx = ValidIO(UInt(memAddrBits.W))
+ val data = Flipped(
+ ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+ }
+ val wr = ValidIO(new Bundle {
+ val idx = UInt(memAddrBits.W)
+ val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+ })
+ def tieoffRead() {
+ rd.idx.valid := false.B
+ rd.idx.bits := 0.U
+ }
+ def tieoffWrite() {
+ wr.valid := false.B
+ wr.bits.idx := 0.U
+ wr.bits.data.foreach { b =>
+ b.foreach { c =>
+ c := 0.U
+ }
}
+ }
override def cloneType =
new TensorMaster(tensorType).asInstanceOf[this.type]
}
* The TensorLoad unit uses this interface for receiving read and write requests from
* the TensorGemm unit.
*/
-class TensorClient(tensorType: String = "none")
- (implicit p: Parameters) extends TensorParams(tensorType) {
- val rd = new Bundle {
- val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
- val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
- }
- val wr = Flipped(ValidIO(new Bundle {
- val idx = UInt(memAddrBits.W)
- val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
- }))
- def tieoffRead() {
- rd.data.valid := false.B
- rd.data.bits.foreach { b => b.foreach { c => c := 0.U } }
+class TensorClient(tensorType: String = "none")(implicit p: Parameters)
+ extends TensorParams(tensorType) {
+ val rd = new Bundle {
+ val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
+ val data = ValidIO(
+ Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+ }
+ val wr = Flipped(ValidIO(new Bundle {
+ val idx = UInt(memAddrBits.W)
+ val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+ }))
+ def tieoffRead() {
+ rd.data.valid := false.B
+ rd.data.bits.foreach { b =>
+ b.foreach { c =>
+ c := 0.U
+ }
}
+ }
override def cloneType =
new TensorClient(tensorType).asInstanceOf[this.type]
}
* is based on the TensorMaster interface, which means this is an input. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
-class TensorMasterData(tensorType: String = "none")
- (implicit p: Parameters) extends TensorParams(tensorType) {
- val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+class TensorMasterData(tensorType: String = "none")(implicit p: Parameters)
+ extends TensorParams(tensorType) {
+ val data = Flipped(
+ ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
override def cloneType =
new TensorMasterData(tensorType).asInstanceOf[this.type]
}
* is based on the TensorClient interface, which means this is an output. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
-class TensorClientData(tensorType: String = "none")
- (implicit p: Parameters) extends TensorParams(tensorType) {
- val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+class TensorClientData(tensorType: String = "none")(implicit p: Parameters)
+ extends TensorParams(tensorType) {
+ val data = ValidIO(
+ Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
override def cloneType =
new TensorClientData(tensorType).asInstanceOf[this.type]
}
/** TensorPadCtrl. Zero-padding controller for TensorLoad. */
-class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
- val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
- require (padType == "YPad0" || padType == "YPad1"
- || padType == "XPad0" || padType == "XPad1", errorMsg)
+class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1)
+ extends Module {
+ val errorMsg =
+ s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
+ require(padType == "YPad0" || padType == "YPad1"
+ || padType == "XPad0" || padType == "XPad1",
+ errorMsg)
val io = IO(new Bundle {
val start = Input(Bool())
val sIdle :: sActive :: Nil = Enum(2)
val state = RegInit(sIdle)
- switch (state) {
- is (sIdle) {
- when (io.start) {
+ switch(state) {
+ is(sIdle) {
+ when(io.start) {
state := sActive
}
}
- is (sActive) {
- when (ycnt === ymax && xcnt === xmax) {
+ is(sActive) {
+ when(ycnt === ymax && xcnt === xmax) {
state := sIdle
}
}
}
- when (state === sIdle) {
+ when(state === sIdle) {
xmax := xval
ymax := yval
}
- when (state === sIdle || xcnt === xmax) {
+ when(state === sIdle || xcnt === xmax) {
xcnt := 0.U
- } .elsewhen (state === sActive) {
+ }.elsewhen(state === sActive) {
xcnt := xcnt + 1.U
}
- when (state === sIdle || ymax === 0.U) {
+ when(state === sIdle || ymax === 0.U) {
ycnt := 0.U
- } .elsewhen (state === sActive && xcnt === xmax) {
+ }.elsewhen(state === sActive && xcnt === xmax) {
ycnt := ycnt + 1.U
}
}
/** TensorDataCtrl. Data controller for TensorLoad. */
-class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
+class TensorDataCtrl(tensorType: String = "none",
+ sizeFactor: Int = 1,
+ strideFactor: Int = 1)(implicit p: Parameters)
+ extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val len = Reg(UInt(mp.lenBits.W))
- val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+ val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val xcnt = Reg(UInt(mp.lenBits.W))
val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
val ycnt = Reg(chiselTypeOf(dec.ysize))
val stride = xcnt === len &
- xrem === 0.U &
- ycnt =/= dec.ysize - 1.U
+ xrem === 0.U &
+ ycnt =/= dec.ysize - 1.U
val split = xcnt === len & xrem =/= 0.U
- when (io.start || (io.xupdate && stride)) {
- when (xsize < xmax) {
+ when(io.start || (io.xupdate && stride)) {
+ when(xsize < xmax) {
len := xsize
xrem := 0.U
- } .otherwise {
+ }.otherwise {
len := xmax - 1.U
xrem := xsize - xmax
}
- } .elsewhen (io.xupdate && split) {
- when (xrem < xmax) {
+ }.elsewhen(io.xupdate && split) {
+ when(xrem < xmax) {
len := xrem
xrem := 0.U
- } .otherwise {
+ }.otherwise {
len := xmax - 1.U
xrem := xrem - xmax
}
}
- when (io.xinit) {
+ when(io.xinit) {
xcnt := 0.U
- } .elsewhen (io.xupdate) {
+ }.elsewhen(io.xupdate) {
xcnt := xcnt + 1.U
}
- when (io.start) {
+ when(io.start) {
ycnt := 0.U
- } .elsewhen (io.yupdate && stride) {
+ }.elsewhen(io.yupdate && stride) {
ycnt := ycnt + 1.U
}
(p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).accBits) / 8
}
- when (io.start) {
+ when(io.start) {
caddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
baddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
- } .elsewhen (io.yupdate) {
- when (split) {
+ }.elsewhen(io.yupdate) {
+ when(split) {
caddr := caddr + xmax_bytes
- } .elsewhen (stride) {
+ }.elsewhen(stride) {
caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
}
io.addr := caddr
io.len := len
io.done := xcnt === len &
- xrem === 0.U &
- ycnt === dec.ysize - 1.U
+ xrem === 0.U &
+ ycnt === dec.ysize - 1.U
}
*
* Convert Host DPI to AXI for VTAShell
*/
-
-class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
+ extends Module {
val io = IO(new Bundle {
val dpi = new VTAHostDPIClient
val axi = new AXILiteMaster(p(ShellKey).hostParams)
})
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
- val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+ val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
+ Enum(6)
val state = RegInit(sIdle)
- switch (state) {
- is (sIdle) {
- when (io.dpi.req.valid) {
- when (io.dpi.req.opcode) {
+ switch(state) {
+ is(sIdle) {
+ when(io.dpi.req.valid) {
+ when(io.dpi.req.opcode) {
state := sWriteAddress
- } .otherwise {
+ }.otherwise {
state := sReadAddress
}
}
}
- is (sReadAddress) {
- when (io.axi.ar.ready) {
+ is(sReadAddress) {
+ when(io.axi.ar.ready) {
state := sReadData
}
}
- is (sReadData) {
- when (io.axi.r.valid) {
+ is(sReadData) {
+ when(io.axi.r.valid) {
state := sIdle
}
}
- is (sWriteAddress) {
- when (io.axi.aw.ready) {
+ is(sWriteAddress) {
+ when(io.axi.aw.ready) {
state := sWriteData
}
}
- is (sWriteData) {
- when (io.axi.w.ready) {
+ is(sWriteData) {
+ when(io.axi.w.ready) {
state := sWriteResponse
}
}
- is (sWriteResponse) {
- when (io.axi.b.valid) {
+ is(sWriteResponse) {
+ when(io.axi.b.valid) {
state := sIdle
}
}
}
- when (state === sIdle && io.dpi.req.valid) {
+ when(state === sIdle && io.dpi.req.valid) {
addr := io.dpi.req.addr
data := io.dpi.req.value
}
io.dpi.resp.bits := io.axi.r.bits.data
if (debug) {
- when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) }
- when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) }
- when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) }
- when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) }
+ when(state === sWriteAddress && io.axi.aw.ready) {
+ printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr)
+ }
+ when(state === sReadAddress && io.axi.ar.ready) {
+ printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr)
+ }
+ when(io.axi.r.fire()) {
+ printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data)
+ }
+ when(io.axi.w.fire()) {
+ printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data)
+ }
}
}
setResource("/verilog/VTAMemDPI.v")
}
-class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
+ extends Module {
val io = IO(new Bundle {
val dpi = new VTAMemDPIMaster
val axi = new AXIClient(p(ShellKey).memParams)
val opcode = RegInit(false.B)
val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
- val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+ val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
+ Enum(6)
val state = RegInit(sIdle)
- switch (state) {
- is (sIdle) {
- when (io.axi.ar.valid) {
+ switch(state) {
+ is(sIdle) {
+ when(io.axi.ar.valid) {
state := sReadAddress
- } .elsewhen (io.axi.aw.valid) {
+ }.elsewhen(io.axi.aw.valid) {
state := sWriteAddress
}
}
- is (sReadAddress) {
- when (io.axi.ar.valid) {
+ is(sReadAddress) {
+ when(io.axi.ar.valid) {
state := sReadData
}
}
- is (sReadData) {
- when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
+ is(sReadData) {
+ when(io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
state := sIdle
}
}
- is (sWriteAddress) {
- when (io.axi.aw.valid) {
+ is(sWriteAddress) {
+ when(io.axi.aw.valid) {
state := sWriteData
}
}
- is (sWriteData) {
- when (io.axi.w.valid && io.axi.w.bits.last) {
+ is(sWriteData) {
+ when(io.axi.w.valid && io.axi.w.bits.last) {
state := sWriteResponse
}
}
- is (sWriteResponse) {
- when (io.axi.b.ready) {
+ is(sWriteResponse) {
+ when(io.axi.b.ready) {
state := sIdle
}
}
}
- when (state === sIdle) {
- when (io.axi.ar.valid) {
+ when(state === sIdle) {
+ when(io.axi.ar.valid) {
opcode := false.B
len := io.axi.ar.bits.len
addr := io.axi.ar.bits.addr
- } .elsewhen (io.axi.aw.valid) {
+ }.elsewhen(io.axi.aw.valid) {
opcode := true.B
len := io.axi.aw.bits.len
addr := io.axi.aw.bits.addr
}
- } .elsewhen (state === sReadData) {
- when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
+ }.elsewhen(state === sReadData) {
+ when(io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
len := len - 1.U
}
}
io.axi.b.bits.id := 0.U
if (debug) {
- when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) }
- when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) }
- when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) }
- when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) }
+ when(state === sReadAddress && io.axi.ar.valid) {
+ printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len)
+ }
+ when(state === sWriteAddress && io.axi.aw.valid) {
+ printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len)
+ }
+ when(io.axi.r.fire()) {
+ printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n",
+ io.axi.r.bits.last,
+ io.axi.r.bits.data)
+ }
+ when(io.axi.w.fire()) {
+ printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n",
+ io.axi.w.bits.last,
+ io.axi.w.bits.data)
+ }
}
}
import vta.util.genericbundle._
case class AXIParams(
- coherent: Boolean = false,
- idBits: Int = 1,
- addrBits: Int = 32,
- dataBits: Int = 64,
- lenBits: Int = 8,
- userBits: Int = 1
-)
-{
- require (addrBits > 0)
- require (dataBits >= 8 && dataBits % 2 == 0)
+ coherent: Boolean = false,
+ idBits: Int = 1,
+ addrBits: Int = 32,
+ dataBits: Int = 64,
+ lenBits: Int = 8,
+ userBits: Int = 1
+) {
+ require(addrBits > 0)
+ require(dataBits >= 8 && dataBits % 2 == 0)
- val strbBits = dataBits/8
+ val strbBits = dataBits / 8
val sizeBits = 3
val burstBits = 2
val lockBits = 2
val qosBits = 4
val regionBits = 4
val respBits = 2
- val sizeConst = log2Ceil(dataBits/8)
+ val sizeConst = log2Ceil(dataBits / 8)
val idConst = 0
val userConst = if (coherent) 1 else 0
val burstConst = 1
}
abstract class AXIBase(params: AXIParams)
- extends GenericParameterizedBundle(params)
+ extends GenericParameterizedBundle(params)
// AXILite
import vta.interface.axi._
/** PynqConfig. Shell configuration for Pynq */
-class PynqConfig extends Config((site, here, up) => {
- case ShellKey => ShellParams(
- hostParams = AXIParams(
- coherent = false,
- addrBits = 16,
- dataBits = 32,
- lenBits = 8,
- userBits = 1),
- memParams = AXIParams(
- coherent = true,
- addrBits = 32,
- dataBits = 64,
- lenBits = 8,
- userBits = 1),
- vcrParams = VCRParams(),
- vmeParams = VMEParams())
-})
+class PynqConfig
+ extends Config((site, here, up) => {
+ case ShellKey =>
+ ShellParams(
+ hostParams = AXIParams(coherent = false,
+ addrBits = 16,
+ dataBits = 32,
+ lenBits = 8,
+ userBits = 1),
+ memParams = AXIParams(coherent = true,
+ addrBits = 32,
+ dataBits = 64,
+ lenBits = 8,
+ userBits = 1),
+ vcrParams = VCRParams(),
+ vmeParams = VMEParams()
+ )
+ })
/** F1Config. Shell configuration for F1 */
-class F1Config extends Config((site, here, up) => {
- case ShellKey => ShellParams(
- hostParams = AXIParams(
- coherent = false,
- addrBits = 16,
- dataBits = 32,
- lenBits = 8,
- userBits = 1),
- memParams = AXIParams(
- coherent = false,
- addrBits = 64,
- dataBits = 64,
- lenBits = 8,
- userBits = 1),
- vcrParams = VCRParams(),
- vmeParams = VMEParams())
-})
+class F1Config
+ extends Config((site, here, up) => {
+ case ShellKey =>
+ ShellParams(
+ hostParams = AXIParams(coherent = false,
+ addrBits = 16,
+ dataBits = 32,
+ lenBits = 8,
+ userBits = 1),
+ memParams = AXIParams(coherent = false,
+ addrBits = 64,
+ dataBits = 64,
+ lenBits = 8,
+ userBits = 1),
+ vcrParams = VCRParams(),
+ vmeParams = VMEParams()
+ )
+ })
/** De10Config. Shell configuration for De10 */
-class De10Config extends Config((site, here, up) => {
- case ShellKey => ShellParams(
- hostParams = AXIParams(
- addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
- memParams = AXIParams(
- addrBits = 32, dataBits = 64, userBits = 5,
- lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
- coherent = true),
- vcrParams = VCRParams(),
- vmeParams = VMEParams())
-})
+class De10Config
+ extends Config((site, here, up) => {
+ case ShellKey =>
+ ShellParams(
+ hostParams =
+ AXIParams(addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
+ memParams = AXIParams(
+ addrBits = 32,
+ dataBits = 64,
+ userBits = 5,
+ lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
+ coherent = true),
+ vcrParams = VCRParams(),
+ vmeParams = VMEParams()
+ )
+ })
* system that can be used for simulation or real hardware.
*/
class IntelShell(implicit p: Parameters) extends Module {
- val io = IO(new Bundle{
+ val io = IO(new Bundle {
val host = new AXIClient(p(ShellKey).hostParams)
val mem = new AXIMaster(p(ShellKey).memParams)
})
sim.io.clock := clock
sim_wait := sim.io.dpi_wait
}
+
/** SimShell.
*
* The simulation shell instantiate the sim, host and memory DPI modules that
*
* These parameters are used on VCR interfaces and modules.
*/
-case class VCRParams()
-{
+case class VCRParams() {
val nCtrl = 1
val nECnt = 1
val nVals = 1
/** VCRBase. Parametrize base class. */
abstract class VCRBase(implicit p: Parameters)
- extends GenericParameterizedBundle(p)
+ extends GenericParameterizedBundle(p)
/** VCRMaster.
*
* registers that could be used as event counters by the Core unit.
*/
class VCR(implicit p: Parameters) extends Module {
- val io = IO(new Bundle{
+ val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams)
val vcr = new VCRMaster
})
val rdata = RegInit(0.U(vp.regBits.W))
// registers
- val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2*vp.nPtrs
+ val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2 * vp.nPtrs
val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + nPtrs
val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W)))
val addr = Seq.tabulate(nTotal)(_ * 4)
- val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
+ val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = vp.nCtrl
val vo = eo + vp.nECnt
val po = vo + vp.nVals
- switch (wstate) {
- is (sWriteAddress) {
- when (io.host.aw.valid) {
+ switch(wstate) {
+ is(sWriteAddress) {
+ when(io.host.aw.valid) {
wstate := sWriteData
}
}
- is (sWriteData) {
- when (io.host.w.valid) {
+ is(sWriteData) {
+ when(io.host.w.valid) {
wstate := sWriteResponse
}
}
- is (sWriteResponse) {
- when (io.host.b.ready) {
+ is(sWriteResponse) {
+ when(io.host.b.ready) {
wstate := sWriteAddress
}
}
}
- when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
+ when(io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
io.host.aw.ready := wstate === sWriteAddress
io.host.w.ready := wstate === sWriteData
io.host.b.valid := wstate === sWriteResponse
io.host.b.bits.resp := 0.U
-
- switch (rstate) {
- is (sReadAddress) {
- when (io.host.ar.valid) {
+ switch(rstate) {
+ is(sReadAddress) {
+ when(io.host.ar.valid) {
rstate := sReadData
}
}
- is (sReadData) {
- when (io.host.r.ready) {
+ is(sReadData) {
+ when(io.host.r.ready) {
rstate := sReadAddress
}
}
io.host.r.bits.data := rdata
io.host.r.bits.resp := 0.U
- when (io.vcr.finish) {
+ when(io.vcr.finish) {
reg(0) := "b_10".U
- } .elsewhen (io.host.w.fire() && addr(0).U === waddr) {
+ }.elsewhen(io.host.w.fire() && addr(0).U === waddr) {
reg(0) := wdata
}
for (i <- 0 until vp.nECnt) {
- when (io.vcr.ecnt(i).valid) {
+ when(io.vcr.ecnt(i).valid) {
reg(eo + i) := io.vcr.ecnt(i).bits
- } .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) {
+ }.elsewhen(io.host.w.fire() && addr(eo + i).U === waddr) {
reg(eo + i) := wdata
}
}
for (i <- 0 until (vp.nVals + nPtrs)) {
- when (io.host.w.fire() && addr(vo + i).U === waddr) {
+ when(io.host.w.fire() && addr(vo + i).U === waddr) {
reg(vo + i) := wdata
}
}
- when (io.host.ar.fire()) {
+ when(io.host.ar.fire()) {
rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map)
}
io.vcr.vals(i) := reg(vo + i)
}
- if (mp.addrBits == 32) { // 32-bit pointers
+ if (mp.addrBits == 32) { // 32-bit pointers
for (i <- 0 until nPtrs) {
io.vcr.ptrs(i) := reg(po + i)
}
- } else { // 64-bits pointers
- for (i <- 0 until (nPtrs/2)) {
- io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
+ } else { // 64-bits pointers
+ for (i <- 0 until (nPtrs / 2)) {
+ io.vcr.ptrs(i) := Cat(reg(po + 2 * i + 1), reg(po + 2 * i))
}
}
}
case class VMEParams() {
val nReadClients: Int = 5
val nWriteClients: Int = 1
- require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
- require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
+ require(nReadClients > 0,
+ s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
+ require(
+ nWriteClients == 1,
+ s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
}
/** VMEBase. Parametrize base class. */
abstract class VMEBase(implicit p: Parameters)
- extends GenericParameterizedBundle(p)
+ extends GenericParameterizedBundle(p)
/** VMECmd.
*
val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
val rstate = RegInit(sReadIdle)
- switch (rstate) {
- is (sReadIdle) {
- when (rd_arb.io.out.valid) {
+ switch(rstate) {
+ is(sReadIdle) {
+ when(rd_arb.io.out.valid) {
rstate := sReadAddr
}
}
- is (sReadAddr) {
- when (io.mem.ar.ready) {
+ is(sReadAddr) {
+ when(io.mem.ar.ready) {
rstate := sReadData
}
}
- is (sReadData) {
- when (io.mem.r.fire() && io.mem.r.bits.last) {
+ is(sReadData) {
+ when(io.mem.r.fire() && io.mem.r.bits.last) {
rstate := sReadIdle
}
}
val lenBits = p(ShellKey).memParams.lenBits
val wr_cnt = RegInit(0.U(lenBits.W))
- when (wstate === sWriteIdle) {
+ when(wstate === sWriteIdle) {
wr_cnt := 0.U
- } .elsewhen (io.mem.w.fire()) {
+ }.elsewhen(io.mem.w.fire()) {
wr_cnt := wr_cnt + 1.U
}
- switch (wstate) {
- is (sWriteIdle) {
- when (io.vme.wr(0).cmd.valid) {
+ switch(wstate) {
+ is(sWriteIdle) {
+ when(io.vme.wr(0).cmd.valid) {
wstate := sWriteAddr
}
}
- is (sWriteAddr) {
- when (io.mem.aw.ready) {
+ is(sWriteAddr) {
+ when(io.mem.aw.ready) {
wstate := sWriteData
}
}
- is (sWriteData) {
- when (io.vme.wr(0).data.valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
+ is(sWriteData) {
+ when(
+ io.vme
+ .wr(0)
+ .data
+ .valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
wstate := sWriteResp
}
}
- is (sWriteResp) {
- when (io.mem.b.valid) {
+ is(sWriteResp) {
+ when(io.mem.b.valid) {
wstate := sWriteIdle
}
}
val rd_addr = RegInit(0.U(addrBits.W))
val wr_addr = RegInit(0.U(addrBits.W))
- when (rd_arb.io.out.fire()) {
+ when(rd_arb.io.out.fire()) {
rd_len := rd_arb.io.out.bits.len
rd_addr := rd_arb.io.out.bits.addr
}
- when (io.vme.wr(0).cmd.fire()) {
+ when(io.vme.wr(0).cmd.fire()) {
wr_len := io.vme.wr(0).cmd.bits.len
wr_addr := io.vme.wr(0).cmd.bits.addr
}
io.vme.wr(0).cmd.ready := wstate === sWriteIdle
io.vme.wr(0).ack := io.mem.b.fire()
- io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
+ io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
// mem
io.mem.aw.valid := wstate === sWriteAddr
/** Shell parameters. */
case class ShellParams(
- hostParams: AXIParams,
- memParams: AXIParams,
- vcrParams: VCRParams,
- vmeParams: VMEParams
+ hostParams: AXIParams,
+ memParams: AXIParams,
+ vcrParams: VCRParams,
+ vmeParams: VMEParams
)
case object ShellKey extends Field[ShellParams]
* system that can be used for simulation or real hardware.
*/
class VTAShell(implicit p: Parameters) extends Module {
- val io = IO(new Bundle{
+ val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams)
val mem = new AXIMaster(p(ShellKey).memParams)
})
package vta.shell
import chisel3._
-import chisel3.experimental.{RawModule, withClockAndReset}
+import chisel3.experimental.{withClockAndReset, RawModule}
import vta.util.config._
import vta.interface.axi._
val m_axi_gmem = IO(new XilinxAXIMaster(mp))
val s_axi_control = IO(new XilinxAXILiteClient(hp))
- val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) }
+ val shell = withClockAndReset(clock = ap_clk, reset = ~ap_rst_n) {
+ Module(new VTAShell)
+ }
// memory
m_axi_gmem.AWVALID := shell.io.mem.aw.valid
// taken from https://github.com/vta.roject/rocket-chip
-abstract class Field[T] private (val default: Option[T])
-{
+abstract class Field[T] private (val default: Option[T]) {
def this() = this(None)
def this(default: T) = this(Some(default))
}
final def apply[T](pname: Field[T]): T = apply(pname, this)
final def apply[T](pname: Field[T], site: View): T = {
val out = find(pname, site)
- require (out.isDefined, s"Key ${pname} is not defined in Parameters")
+ require(out.isDefined, s"Key ${pname} is not defined in Parameters")
out.get
}
final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
- final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T])
+ final def lift[T](pname: Field[T], site: View): Option[T] =
+ find(pname, site).map(_.asInstanceOf[T])
protected[config] def find[T](pname: Field[T], site: View): Option[T]
}
abstract class Parameters extends View {
- final def ++ (x: Parameters): Parameters =
+ final def ++(x: Parameters): Parameters =
new ChainParameters(this, x)
- final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters =
+ final def alter(
+ f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
Parameters(f) ++ this
- final def alterPartial(f: PartialFunction[Any,Any]): Parameters =
- Parameters((_,_,_) => f) ++ this
+ final def alterPartial(f: PartialFunction[Any, Any]): Parameters =
+ Parameters((_, _, _) => f) ++ this
- final def alterMap(m: Map[Any,Any]): Parameters =
+ final def alterMap(m: Map[Any, Any]): Parameters =
new MapParameters(m) ++ this
- protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T]
- protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname)
+ protected[config] def chain[T](site: View,
+ tail: View,
+ pname: Field[T]): Option[T]
+ protected[config] def find[T](pname: Field[T], site: View) =
+ chain(site, new TerminalView, pname)
}
object Parameters {
def empty: Parameters = new EmptyParameters
- def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f)
+ def apply(f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
+ new PartialParameters(f)
}
class Config(p: Parameters) extends Parameters {
- def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f))
+ def this(f: (View, View, View) => PartialFunction[Any, Any]) =
+ this(Parameters(f))
- protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname)
+ protected[config] def chain[T](site: View, tail: View, pname: Field[T]) =
+ p.chain(site, tail, pname)
override def toString = this.getClass.getSimpleName
def toInstance = this
}
}
private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
- def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname)
+ def chain[T](site: View, tail: View, pname: Field[T]) =
+ x.chain(site, new ChainView(y, tail), pname)
}
private class EmptyParameters extends Parameters {
def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
}
-private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters {
+private class PartialParameters(
+ f: (View, View, View) => PartialFunction[Any, Any])
+ extends Parameters {
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
val g = f(site, this, tail)
- if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site)
+ if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T])
+ else tail.find(pname, site)
}
}
import chisel3._
-abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle
-{
+abstract class GenericParameterizedBundle[+T <: Object](val params: T)
+ extends Bundle {
override def cloneType = {
try {
- this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type]
+ this.getClass.getConstructors.head
+ .newInstance(params)
+ .asInstanceOf[this.type]
} catch {
case e: java.lang.IllegalArgumentException =>
- throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " +
- this.getClass + ", probably because " + this.getClass +
- "() takes more than one argument. Consider overriding " +
- "cloneType() on " + this.getClass, e)
+ throw new Exception(
+ "Unable to use GenericParameterizedBundle.cloneType on " +
+ this.getClass + ", probably because " + this.getClass +
+ "() takes more than one argument. Consider overriding " +
+ "cloneType() on " + this.getClass,
+ e
+ )
}
}
}
-
* These configurations are built in a mix/match form based on core
* and shell configurations.
*/
-
class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config)