From 5fe61fd1b4357dfa111a165d8a0367aca9132b7b Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Tue, 3 Sep 2019 22:19:01 -0700 Subject: [PATCH] [VTA][Chisel] add scalafmt and format existing scala codebase (#3880) * [VTA][Chisel] add scalafmt and format existing scala codebase * change column width to 100 * add scalafmt conf file as a valid file type * add asf header to scalafmt conf file and rerun formatter --- tests/lint/check_file_type.py | 1 + .../tsim_example/hardware/chisel/.scalafmt.conf | 21 +++ vta/apps/tsim_example/hardware/chisel/Makefile | 5 +- .../hardware/chisel/project/plugins.sbt | 1 + .../chisel/src/main/scala/accel/Accel.scala | 2 +- .../chisel/src/main/scala/accel/Compute.scala | 34 ++-- .../chisel/src/main/scala/accel/RegFile.scala | 38 ++-- vta/hardware/chisel/.scalafmt.conf | 21 +++ vta/hardware/chisel/Makefile | 5 +- vta/hardware/chisel/project/plugins.sbt | 1 + .../chisel/src/main/scala/core/Compute.scala | 116 +++++++----- .../chisel/src/main/scala/core/Configs.scala | 37 ++-- vta/hardware/chisel/src/main/scala/core/Core.scala | 36 ++-- .../chisel/src/main/scala/core/Decode.scala | 5 +- .../chisel/src/main/scala/core/EventCounters.scala | 9 +- .../chisel/src/main/scala/core/Fetch.scala | 87 +++++---- vta/hardware/chisel/src/main/scala/core/ISA.scala | 111 ++++++----- vta/hardware/chisel/src/main/scala/core/Load.scala | 44 ++--- .../chisel/src/main/scala/core/LoadUop.scala | 103 +++++----- .../chisel/src/main/scala/core/Semaphore.scala | 9 +- .../chisel/src/main/scala/core/Store.scala | 28 +-- .../chisel/src/main/scala/core/TensorAlu.scala | 148 ++++++++------- .../chisel/src/main/scala/core/TensorGemm.scala | 189 ++++++++++--------- .../chisel/src/main/scala/core/TensorLoad.scala | 207 ++++++++++++--------- .../chisel/src/main/scala/core/TensorStore.scala | 160 +++++++++------- .../chisel/src/main/scala/core/TensorUtil.scala | 171 +++++++++-------- .../chisel/src/main/scala/dpi/VTAHostDPI.scala | 55 +++--- .../chisel/src/main/scala/dpi/VTAMemDPI.scala | 64 ++++--- .../chisel/src/main/scala/interface/axi/AXI.scala | 25 ++- .../chisel/src/main/scala/shell/Configs.scala | 97 +++++----- .../chisel/src/main/scala/shell/IntelShell.scala | 2 +- .../chisel/src/main/scala/shell/SimShell.scala | 1 + vta/hardware/chisel/src/main/scala/shell/VCR.scala | 58 +++--- vta/hardware/chisel/src/main/scala/shell/VME.scala | 55 +++--- .../chisel/src/main/scala/shell/VTAShell.scala | 10 +- .../chisel/src/main/scala/shell/XilinxShell.scala | 6 +- .../chisel/src/main/scala/util/Config.scala | 45 +++-- .../scala/util/GenericParameterizedBundle.scala | 20 +- .../chisel/src/main/scala/vta/Configs.scala | 1 - 39 files changed, 1148 insertions(+), 880 deletions(-) create mode 100644 vta/apps/tsim_example/hardware/chisel/.scalafmt.conf create mode 100644 vta/hardware/chisel/.scalafmt.conf diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index c6691bb..e5f2dc7 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -87,6 +87,7 @@ ALLOW_FILE_NAME = { ".clang-format", ".gitmodules", "CODEOWNERS", + ".scalafmt.conf", } # List of specific files allowed in relpath to diff --git a/vta/apps/tsim_example/hardware/chisel/.scalafmt.conf b/vta/apps/tsim_example/hardware/chisel/.scalafmt.conf new file mode 100644 index 0000000..9172d5e --- /dev/null +++ b/vta/apps/tsim_example/hardware/chisel/.scalafmt.conf @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +maxColumn = 100 +rewrite.rules = [SortModifiers, SortImports] diff --git a/vta/apps/tsim_example/hardware/chisel/Makefile b/vta/apps/tsim_example/hardware/chisel/Makefile index 4f555ba..0f97945 100644 --- a/vta/apps/tsim_example/hardware/chisel/Makefile +++ b/vta/apps/tsim_example/hardware/chisel/Makefile @@ -91,7 +91,10 @@ else lib_path = $(build_dir)/$(LIBNAME).so endif -default: lib +default: lint lib + +lint: + sbt scalafmt lib: $(lib_path) $(lib_path): $(verilator_build_dir)/V$(TOP).cpp diff --git a/vta/apps/tsim_example/hardware/chisel/project/plugins.sbt b/vta/apps/tsim_example/hardware/chisel/project/plugins.sbt index 79ffb22..e14e694 100644 --- a/vta/apps/tsim_example/hardware/chisel/project/plugins.sbt +++ b/vta/apps/tsim_example/hardware/chisel/project/plugins.sbt @@ -18,3 +18,4 @@ */ logLevel := Level.Warn +addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1") diff --git a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala index d654a7f..b90c729 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala @@ -41,7 +41,7 @@ case class AccelConfig() { val nVals = 2 val nPtrs = 2 val regBits = 32 - val ptrBits = 2*regBits + val ptrBits = 2 * regBits } class Accel extends Module { diff --git a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala index f24cbdd..7ad965c 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala @@ -54,27 +54,27 @@ class Compute(implicit config: AccelConfig) extends Module { val raddr = Reg(UInt(config.ptrBits.W)) val waddr = Reg(UInt(config.ptrBits.W)) - switch (state) { - is (sIdle) { - when (io.launch) { + switch(state) { + is(sIdle) { + when(io.launch) { state := sReadReq } } - is (sReadReq) { + is(sReadReq) { state := sReadData } - is (sReadData) { - when (io.mem.rd.valid) { + is(sReadData) { + when(io.mem.rd.valid) { state := sWriteReq } } - is (sWriteReq) { + is(sWriteReq) { state := sWriteData } - is (sWriteData) { - when (cnt === (length - 1.U)) { + is(sWriteData) { + when(cnt === (length - 1.U)) { state := sIdle - } .otherwise { + }.otherwise { state := sReadReq } } @@ -83,9 +83,9 @@ class Compute(implicit config: AccelConfig) extends Module { val last = state === sWriteData && cnt === (length - 1.U) // cycle counter - when (state === sIdle) { + when(state === sIdle) { cycles := 0.U - } .otherwise { + }.otherwise { cycles := cycles + 1.U } @@ -93,10 +93,10 @@ class Compute(implicit config: AccelConfig) extends Module { io.ecnt(0).bits := cycles // calculate next address - when (state === sIdle) { + when(state === sIdle) { raddr := io.ptrs(0) waddr := io.ptrs(1) - } .elsewhen (state === sWriteData) { // increment by 8-bytes + }.elsewhen(state === sWriteData) { // increment by 8-bytes raddr := raddr + 8.U waddr := waddr + 8.U } @@ -108,7 +108,7 @@ class Compute(implicit config: AccelConfig) extends Module { io.mem.req.addr := Mux(state === sReadReq, raddr, waddr) // read - when (state === sReadData && io.mem.rd.valid) { + when(state === sReadData && io.mem.rd.valid) { reg := io.mem.rd.bits + const } io.mem.rd.ready := state === sReadData @@ -118,9 +118,9 @@ class Compute(implicit config: AccelConfig) extends Module { io.mem.wr.bits := reg // count read/write - when (state === sIdle) { + when(state === sIdle) { cnt := 0.U - } .elsewhen (state === sWriteData) { + }.elsewhen(state === sWriteData) { cnt := cnt + 1.U } diff --git a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala index 92a9833..1982f18 100644 --- a/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala +++ b/vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala @@ -59,52 +59,54 @@ class RegFile(implicit config: AccelConfig) extends Module { val sIdle :: sRead :: Nil = Enum(2) val state = RegInit(sIdle) - switch (state) { - is (sIdle) { - when (io.host.req.valid && !io.host.req.opcode) { + switch(state) { + is(sIdle) { + when(io.host.req.valid && !io.host.req.opcode) { state := sRead } } - is (sRead) { + is(sRead) { state := sIdle } } io.host.req.deq := state === sIdle & io.host.req.valid - val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs) - val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))) + val nTotal = config.nCtrl + config.nECnt + config.nVals + (2 * config.nPtrs) + val reg = + Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))) val addr = Seq.tabulate(nTotal)(_ * 4) - val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } + val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } val eo = config.nCtrl val vo = eo + config.nECnt val po = vo + config.nVals - when (io.finish) { + when(io.finish) { reg(0) := "b_10".U - } .elsewhen (state === sIdle && io.host.req.valid && - io.host.req.opcode && addr(0).U === io.host.req.addr) { + }.elsewhen(state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(0).U === io.host.req.addr) { reg(0) := io.host.req.value } for (i <- 0 until config.nECnt) { - when (io.ecnt(i).valid) { + when(io.ecnt(i).valid) { reg(eo + i) := io.ecnt(i).bits - } .elsewhen (state === sIdle && io.host.req.valid && - io.host.req.opcode && addr(eo + i).U === io.host.req.addr) { + }.elsewhen(state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(eo + i).U === io.host.req.addr) { reg(eo + i) := io.host.req.value } } - for (i <- 0 until (config.nVals + (2*config.nPtrs))) { - when (state === sIdle && io.host.req.valid && - io.host.req.opcode && addr(vo + i).U === io.host.req.addr) { + for (i <- 0 until (config.nVals + (2 * config.nPtrs))) { + when( + state === sIdle && io.host.req.valid && + io.host.req.opcode && addr(vo + i).U === io.host.req.addr) { reg(vo + i) := io.host.req.value } } val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))) - when (state === sIdle && io.host.req.valid && !io.host.req.opcode) { + when(state === sIdle && io.host.req.valid && !io.host.req.opcode) { rdata := MuxLookup(io.host.req.addr, 0.U, reg_map) } @@ -118,6 +120,6 @@ class RegFile(implicit config: AccelConfig) extends Module { } for (i <- 0 until config.nPtrs) { - io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i)) + io.ptrs(i) := Cat(reg(po + (2 * i) + 1), reg(po + (2 * i))) } } diff --git a/vta/hardware/chisel/.scalafmt.conf b/vta/hardware/chisel/.scalafmt.conf new file mode 100644 index 0000000..9172d5e --- /dev/null +++ b/vta/hardware/chisel/.scalafmt.conf @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +maxColumn = 100 +rewrite.rules = [SortModifiers, SortImports] diff --git a/vta/hardware/chisel/Makefile b/vta/hardware/chisel/Makefile index 6cd2802..7c88915 100644 --- a/vta/hardware/chisel/Makefile +++ b/vta/hardware/chisel/Makefile @@ -102,7 +102,10 @@ else lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so endif -default: lib +default: lint lib + +lint: + sbt scalafmt lib: $(lib_path) diff --git a/vta/hardware/chisel/project/plugins.sbt b/vta/hardware/chisel/project/plugins.sbt index 79ffb22..e14e694 100644 --- a/vta/hardware/chisel/project/plugins.sbt +++ b/vta/hardware/chisel/project/plugins.sbt @@ -18,3 +18,4 @@ */ logLevel := Level.Warn +addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1") diff --git a/vta/hardware/chisel/src/main/scala/core/Compute.scala b/vta/hardware/chisel/src/main/scala/core/Compute.scala index 01fa9d6..7751bf7 100644 --- a/vta/hardware/chisel/src/main/scala/core/Compute.scala +++ b/vta/hardware/chisel/src/main/scala/core/Compute.scala @@ -49,7 +49,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { val sIdle :: sSync :: sExe :: Nil = Enum(3) val state = RegInit(sIdle) - val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0))) + val s = Seq.tabulate(2)(_ => + Module(new Semaphore(counterBits = 8, counterInitValue = 0))) val loadUop = Module(new LoadUop) val tensorAcc = Module(new TensorLoad(tensorType = "acc")) @@ -62,18 +63,20 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { val dec = Module(new ComputeDecode) dec.io.inst := inst_q.io.deq.bits - val inst_type = Cat(dec.io.isFinish, - dec.io.isAlu, - dec.io.isGemm, - dec.io.isLoadAcc, - dec.io.isLoadUop).asUInt + val inst_type = + Cat(dec.io.isFinish, + dec.io.isAlu, + dec.io.isGemm, + dec.io.isLoadAcc, + dec.io.isLoadUop).asUInt val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B) val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B) val start = snext & sprev val done = - MuxLookup(inst_type, - false.B, // default + MuxLookup( + inst_type, + false.B, // default Array( "h_01".U -> loadUop.io.done, "h_02".U -> tensorAcc.io.done, @@ -84,21 +87,21 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ) // control - switch (state) { - is (sIdle) { - when (start) { - when (dec.io.isSync) { + switch(state) { + is(sIdle) { + when(start) { + when(dec.io.isSync) { state := sSync - } .elsewhen (inst_type.orR) { + }.elsewhen(inst_type.orR) { state := sExe } } } - is (sSync) { + is(sSync) { state := sIdle } - is (sExe) { - when (done) { + is(sExe) { + when(done) { state := sIdle } } @@ -109,22 +112,28 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) // uop - loadUop.io.start := state === sIdle & start & dec.io.isLoadUop + loadUop.io.start := state === sIdle & start & dec.io.isLoadUop loadUop.io.inst := inst_q.io.deq.bits loadUop.io.baddr := io.uop_baddr io.vme_rd(0) <> loadUop.io.vme_rd - loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx) + loadUop.io.uop.idx <> Mux(dec.io.isGemm, + tensorGemm.io.uop.idx, + tensorAlu.io.uop.idx) // acc tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc tensorAcc.io.inst := inst_q.io.deq.bits tensorAcc.io.baddr := io.acc_baddr - tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx) - tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr) + tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, + tensorGemm.io.acc.rd.idx, + tensorAlu.io.acc.rd.idx) + tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, + tensorGemm.io.acc.wr, + tensorAlu.io.acc.wr) io.vme_rd(1) <> tensorAcc.io.vme_rd // gemm - tensorGemm.io.start := state === sIdle & start & dec.io.isGemm + tensorGemm.io.start := state === sIdle & start & dec.io.isGemm tensorGemm.io.inst := inst_q.io.deq.bits tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits @@ -136,7 +145,7 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits // alu - tensorAlu.io.start := state === sIdle & start & dec.io.isAlu + tensorAlu.io.start := state === sIdle & start & dec.io.isAlu tensorAlu.io.inst := inst_q.io.deq.bits tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits @@ -146,7 +155,9 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits // out - io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx) + io.out.rd.idx <> Mux(dec.io.isGemm, + tensorGemm.io.out.rd.idx, + tensorAlu.io.out.rd.idx) io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr) // semaphore @@ -163,38 +174,45 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { // debug if (debug) { // start - when (state === sIdle && start) { - when (dec.io.isSync) { + when(state === sIdle && start) { + when(dec.io.isSync) { printf("[Compute] start sync\n") - } .elsewhen (dec.io.isLoadUop) { - printf("[Compute] start load uop\n") - } .elsewhen (dec.io.isLoadAcc) { - printf("[Compute] start load acc\n") - } .elsewhen (dec.io.isGemm) { - printf("[Compute] start gemm\n") - } .elsewhen (dec.io.isAlu) { - printf("[Compute] start alu\n") - } .elsewhen (dec.io.isFinish) { - printf("[Compute] start finish\n") - } + }.elsewhen(dec.io.isLoadUop) { + printf("[Compute] start load uop\n") + } + .elsewhen(dec.io.isLoadAcc) { + printf("[Compute] start load acc\n") + } + .elsewhen(dec.io.isGemm) { + printf("[Compute] start gemm\n") + } + .elsewhen(dec.io.isAlu) { + printf("[Compute] start alu\n") + } + .elsewhen(dec.io.isFinish) { + printf("[Compute] start finish\n") + } } // done - when (state === sSync) { + when(state === sSync) { printf("[Compute] done sync\n") } - when (state === sExe) { - when (done) { - when (dec.io.isLoadUop) { + when(state === sExe) { + when(done) { + when(dec.io.isLoadUop) { printf("[Compute] done load uop\n") - } .elsewhen (dec.io.isLoadAcc) { - printf("[Compute] done load acc\n") - } .elsewhen (dec.io.isGemm) { - printf("[Compute] done gemm\n") - } .elsewhen (dec.io.isAlu) { - printf("[Compute] done alu\n") - } .elsewhen (dec.io.isFinish) { - printf("[Compute] done finish\n") - } + }.elsewhen(dec.io.isLoadAcc) { + printf("[Compute] done load acc\n") + } + .elsewhen(dec.io.isGemm) { + printf("[Compute] done gemm\n") + } + .elsewhen(dec.io.isAlu) { + printf("[Compute] done alu\n") + } + .elsewhen(dec.io.isFinish) { + printf("[Compute] done finish\n") + } } } } diff --git a/vta/hardware/chisel/src/main/scala/core/Configs.scala b/vta/hardware/chisel/src/main/scala/core/Configs.scala index b4e764b..de7012b 100644 --- a/vta/hardware/chisel/src/main/scala/core/Configs.scala +++ b/vta/hardware/chisel/src/main/scala/core/Configs.scala @@ -27,20 +27,23 @@ import vta.util.config._ * be eventually filled out with class configurations that can be * mixed/matched with Shell configurations for different backends. */ -class CoreConfig extends Config((site, here, up) => { - case CoreKey => CoreParams( - batch = 1, - blockOut = 16, - blockIn = 16, - inpBits = 8, - wgtBits = 8, - uopBits = 32, - accBits = 32, - outBits = 8, - uopMemDepth = 2048, - inpMemDepth = 2048, - wgtMemDepth = 1024, - accMemDepth = 2048, - outMemDepth = 2048, - instQueueEntries = 512) -}) +class CoreConfig + extends Config((site, here, up) => { + case CoreKey => + CoreParams( + batch = 1, + blockOut = 16, + blockIn = 16, + inpBits = 8, + wgtBits = 8, + uopBits = 32, + accBits = 32, + outBits = 8, + uopMemDepth = 2048, + inpMemDepth = 2048, + wgtMemDepth = 1024, + accMemDepth = 2048, + outMemDepth = 2048, + instQueueEntries = 512 + ) + }) diff --git a/vta/hardware/chisel/src/main/scala/core/Core.scala b/vta/hardware/chisel/src/main/scala/core/Core.scala index e63a112..a7228ee 100644 --- a/vta/hardware/chisel/src/main/scala/core/Core.scala +++ b/vta/hardware/chisel/src/main/scala/core/Core.scala @@ -24,24 +24,24 @@ import vta.util.config._ import vta.shell._ /** Core parameters */ -case class CoreParams ( - batch: Int = 1, - blockOut: Int = 16, - blockIn: Int = 16, - inpBits: Int = 8, - wgtBits: Int = 8, - uopBits: Int = 32, - accBits: Int = 32, - outBits: Int = 8, - uopMemDepth: Int = 512, - inpMemDepth: Int = 512, - wgtMemDepth: Int = 512, - accMemDepth: Int = 512, - outMemDepth: Int = 512, - instQueueEntries: Int = 32 -) -{ - require (uopBits % 8 == 0, s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n") +case class CoreParams( + batch: Int = 1, + blockOut: Int = 16, + blockIn: Int = 16, + inpBits: Int = 8, + wgtBits: Int = 8, + uopBits: Int = 32, + accBits: Int = 32, + outBits: Int = 8, + uopMemDepth: Int = 512, + inpMemDepth: Int = 512, + wgtMemDepth: Int = 512, + accMemDepth: Int = 512, + outMemDepth: Int = 512, + instQueueEntries: Int = 32 +) { + require(uopBits % 8 == 0, + s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n") } case object CoreKey extends Field[CoreParams] diff --git a/vta/hardware/chisel/src/main/scala/core/Decode.scala b/vta/hardware/chisel/src/main/scala/core/Decode.scala index f5bf340..a49ddce 100644 --- a/vta/hardware/chisel/src/main/scala/core/Decode.scala +++ b/vta/hardware/chisel/src/main/scala/core/Decode.scala @@ -133,8 +133,9 @@ class FetchDecode extends Module { val isStore = Output(Bool()) }) val csignals = - ListLookup(io.inst, - List(N, OP_X), + ListLookup( + io.inst, + List(N, OP_X), Array( LUOP -> List(Y, OP_G), LWGT -> List(Y, OP_L), diff --git a/vta/hardware/chisel/src/main/scala/core/EventCounters.scala b/vta/hardware/chisel/src/main/scala/core/EventCounters.scala index 5a5b095..8990aef 100644 --- a/vta/hardware/chisel/src/main/scala/core/EventCounters.scala +++ b/vta/hardware/chisel/src/main/scala/core/EventCounters.scala @@ -38,17 +38,18 @@ import vta.shell._ * If one would like to add an event counter, then the value of nECnt must be * changed in VCRParams together with the corresponding counting logic here. */ -class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module { +class EventCounters(debug: Boolean = false)(implicit p: Parameters) + extends Module { val vp = p(ShellKey).vcrParams - val io = IO(new Bundle{ + val io = IO(new Bundle { val launch = Input(Bool()) val finish = Input(Bool()) val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W))) }) val cycle_cnt = RegInit(0.U(vp.regBits.W)) - when (io.launch && !io.finish) { + when(io.launch && !io.finish) { cycle_cnt := cycle_cnt + 1.U - } .otherwise { + }.otherwise { cycle_cnt := 0.U } io.ecnt(0).valid := io.finish diff --git a/vta/hardware/chisel/src/main/scala/core/Fetch.scala b/vta/hardware/chisel/src/main/scala/core/Fetch.scala index c7a6d50..9baf1cc 100644 --- a/vta/hardware/chisel/src/main/scala/core/Fetch.scala +++ b/vta/hardware/chisel/src/main/scala/core/Fetch.scala @@ -67,69 +67,70 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { val xrem = Reg(chiselTypeOf(io.ins_count)) val xsize = (io.ins_count << 1.U) - 1.U val xmax = (1 << mp.lenBits).U - val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5) val state = RegInit(sIdle) // control - switch (state) { - is (sIdle) { - when (pulse) { + switch(state) { + is(sIdle) { + when(pulse) { state := sReadCmd - when (xsize < xmax) { + when(xsize < xmax) { rlen := xsize ilen := xsize >> 1.U xrem := 0.U - } .otherwise { + }.otherwise { rlen := xmax - 1.U ilen := (xmax >> 1.U) - 1.U xrem := xsize - xmax } } } - is (sReadCmd) { - when (io.vme_rd.cmd.ready) { + is(sReadCmd) { + when(io.vme_rd.cmd.ready) { state := sReadLSB } } - is (sReadLSB) { - when (io.vme_rd.data.valid) { + is(sReadLSB) { + when(io.vme_rd.data.valid) { state := sReadMSB } } - is (sReadMSB) { - when (io.vme_rd.data.valid) { - when (inst_q.io.count === ilen) { + is(sReadMSB) { + when(io.vme_rd.data.valid) { + when(inst_q.io.count === ilen) { state := sDrain - } .otherwise { + }.otherwise { state := sReadLSB } } } - is (sDrain) { - when (inst_q.io.count === 0.U) { - when (xrem === 0.U) { + is(sDrain) { + when(inst_q.io.count === 0.U) { + when(xrem === 0.U) { state := sIdle - } .elsewhen (xrem < xmax) { - state := sReadCmd - rlen := xrem - ilen := xrem >> 1.U - xrem := 0.U - } .otherwise { - state := sReadCmd - rlen := xmax - 1.U - ilen := (xmax >> 1.U) - 1.U - xrem := xrem - xmax - } + }.elsewhen(xrem < xmax) { + state := sReadCmd + rlen := xrem + ilen := xrem >> 1.U + xrem := 0.U + } + .otherwise { + state := sReadCmd + rlen := xmax - 1.U + ilen := (xmax >> 1.U) - 1.U + xrem := xrem - xmax + } } } } // read instructions from dram - when (state === sIdle) { + when(state === sIdle) { raddr := io.ins_baddr - } .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) { + }.elsewhen(state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) { raddr := raddr + xmax_bytes } @@ -143,7 +144,7 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { val msb = io.vme_rd.data.bits val inst = Cat(msb, lsb) - when (state === sReadLSB) { lsb := io.vme_rd.data.bits } + when(state === sReadLSB) { lsb := io.vme_rd.data.bits } inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB inst_q.io.enq.bits := inst @@ -164,32 +165,30 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt val deq_ready = MuxLookup(deq_sel, - false.B, // default - Array( - "h_01".U -> io.inst.ld.ready, - "h_02".U -> io.inst.st.ready, - "h_04".U -> io.inst.co.ready - ) - ) + false.B, // default + Array( + "h_01".U -> io.inst.ld.ready, + "h_02".U -> io.inst.st.ready, + "h_04".U -> io.inst.co.ready + )) // dequeue instruction inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain - // debug if (debug) { - when (state === sIdle && pulse) { + when(state === sIdle && pulse) { printf("[Fetch] Launch\n") } // instruction - when (inst_q.io.deq.fire()) { - when (dec.io.isLoad) { + when(inst_q.io.deq.fire()) { + when(dec.io.isLoad) { printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits) } - when (dec.io.isCompute) { + when(dec.io.isCompute) { printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits) } - when (dec.io.isStore) { + when(dec.io.isStore) { printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits) } } diff --git a/vta/hardware/chisel/src/main/scala/core/ISA.scala b/vta/hardware/chisel/src/main/scala/core/ISA.scala index c3bf609..f08b23b 100644 --- a/vta/hardware/chisel/src/main/scala/core/ISA.scala +++ b/vta/hardware/chisel/src/main/scala/core/ISA.scala @@ -26,47 +26,46 @@ import chisel3.util._ * * These constants are used for decoding (parsing) fields on instructions. */ -trait ISAConstants -{ - val INST_BITS = 128 +trait ISAConstants { + val INST_BITS = 128 - val OP_BITS = 3 + val OP_BITS = 3 - val M_DEP_BITS = 4 - val M_ID_BITS = 2 - val M_SRAM_OFFSET_BITS = 16 - val M_DRAM_OFFSET_BITS = 32 - val M_SIZE_BITS = 16 - val M_STRIDE_BITS = 16 - val M_PAD_BITS = 4 + val M_DEP_BITS = 4 + val M_ID_BITS = 2 + val M_SRAM_OFFSET_BITS = 16 + val M_DRAM_OFFSET_BITS = 32 + val M_SIZE_BITS = 16 + val M_STRIDE_BITS = 16 + val M_PAD_BITS = 4 - val C_UOP_BGN_BITS = 13 - val C_UOP_END_BITS = 14 - val C_ITER_BITS = 14 - val C_AIDX_BITS = 11 - val C_IIDX_BITS = 11 - val C_WIDX_BITS = 10 - val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction - val C_ALU_OP_BITS = 3 - val C_ALU_IMM_BITS = 16 + val C_UOP_BGN_BITS = 13 + val C_UOP_END_BITS = 14 + val C_ITER_BITS = 14 + val C_AIDX_BITS = 11 + val C_IIDX_BITS = 11 + val C_WIDX_BITS = 10 + val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction + val C_ALU_OP_BITS = 3 + val C_ALU_IMM_BITS = 16 - val Y = true.B - val N = false.B + val Y = true.B + val N = false.B - val OP_L = 0.asUInt(OP_BITS.W) - val OP_S = 1.asUInt(OP_BITS.W) - val OP_G = 2.asUInt(OP_BITS.W) - val OP_F = 3.asUInt(OP_BITS.W) - val OP_A = 4.asUInt(OP_BITS.W) - val OP_X = 5.asUInt(OP_BITS.W) + val OP_L = 0.asUInt(OP_BITS.W) + val OP_S = 1.asUInt(OP_BITS.W) + val OP_G = 2.asUInt(OP_BITS.W) + val OP_F = 3.asUInt(OP_BITS.W) + val OP_A = 4.asUInt(OP_BITS.W) + val OP_X = 5.asUInt(OP_BITS.W) - val ALU_OP_NUM = 5 - val ALU_OP = Enum(ALU_OP_NUM) + val ALU_OP_NUM = 5 + val ALU_OP = Enum(ALU_OP_NUM) - val M_ID_U = 0.asUInt(M_ID_BITS.W) - val M_ID_W = 1.asUInt(M_ID_BITS.W) - val M_ID_I = 2.asUInt(M_ID_BITS.W) - val M_ID_A = 3.asUInt(M_ID_BITS.W) + val M_ID_U = 0.asUInt(M_ID_BITS.W) + val M_ID_W = 1.asUInt(M_ID_BITS.W) + val M_ID_I = 2.asUInt(M_ID_BITS.W) + val M_ID_A = 3.asUInt(M_ID_BITS.W) } /** ISA. @@ -79,15 +78,37 @@ trait ISAConstants * TODO: Add VXOR to clear accumulator */ object ISA { - def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000") - def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000") - def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000") - def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000") - def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001") - def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010") - def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") - def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") - def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") - def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") - def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011") + def LUOP = + BitPat( + "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000") + def LWGT = + BitPat( + "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000") + def LINP = + BitPat( + "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000") + def LACC = + BitPat( + "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000") + def SOUT = + BitPat( + "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001") + def GEMM = + BitPat( + "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010") + def VMIN = + BitPat( + "b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VMAX = + BitPat( + "b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VADD = + BitPat( + "b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def VSHX = + BitPat( + "b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") + def FNSH = + BitPat( + "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011") } diff --git a/vta/hardware/chisel/src/main/scala/core/Load.scala b/vta/hardware/chisel/src/main/scala/core/Load.scala index bbc6600..7c79498 100644 --- a/vta/hardware/chisel/src/main/scala/core/Load.scala +++ b/vta/hardware/chisel/src/main/scala/core/Load.scala @@ -54,27 +54,28 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { val tensorType = Seq("inp", "wgt") val tensorDec = Seq(dec.io.isInput, dec.io.isWeight) - val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i)))) + val tensorLoad = + Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i)))) val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B) val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done) // control - switch (state) { - is (sIdle) { - when (start) { - when (dec.io.isSync) { + switch(state) { + is(sIdle) { + when(start) { + when(dec.io.isSync) { state := sSync - } .elsewhen (dec.io.isInput || dec.io.isWeight) { + }.elsewhen(dec.io.isInput || dec.io.isWeight) { state := sExe } } } - is (sSync) { + is(sSync) { state := sIdle } - is (sExe) { - when (done) { + is(sExe) { + when(done) { state := sIdle } } @@ -105,24 +106,25 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { // debug if (debug) { // start - when (state === sIdle && start) { - when (dec.io.isSync) { + when(state === sIdle && start) { + when(dec.io.isSync) { printf("[Load] start sync\n") - } .elsewhen (dec.io.isInput) { - printf("[Load] start input\n") - } .elsewhen (dec.io.isWeight) { - printf("[Load] start weight\n") - } + }.elsewhen(dec.io.isInput) { + printf("[Load] start input\n") + } + .elsewhen(dec.io.isWeight) { + printf("[Load] start weight\n") + } } // done - when (state === sSync) { + when(state === sSync) { printf("[Load] done sync\n") } - when (state === sExe) { - when (done) { - when (dec.io.isInput) { + when(state === sExe) { + when(done) { + when(dec.io.isInput) { printf("[Load] done input\n") - } .elsewhen (dec.io.isWeight) { + }.elsewhen(dec.io.isWeight) { printf("[Load] done weight\n") } } diff --git a/vta/hardware/chisel/src/main/scala/core/LoadUop.scala b/vta/hardware/chisel/src/main/scala/core/LoadUop.scala index bbf8cf1..fcde836 100644 --- a/vta/hardware/chisel/src/main/scala/core/LoadUop.scala +++ b/vta/hardware/chisel/src/main/scala/core/LoadUop.scala @@ -77,9 +77,9 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len)) val xrem = Reg(chiselTypeOf(dec.xsize)) - val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U + val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U val xmax = (1 << mp.lenBits).U - val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U val offsetIsEven = (dec.sram_offset % 2.U) === 0.U val sizeIsEven = (dec.xsize % 2.U) === 0.U @@ -88,38 +88,39 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { val state = RegInit(sIdle) // control - switch (state) { - is (sIdle) { - when (io.start) { + switch(state) { + is(sIdle) { + when(io.start) { state := sReadCmd - when (xsize < xmax) { + when(xsize < xmax) { xlen := xsize xrem := 0.U - } .otherwise { + }.otherwise { xlen := xmax - 1.U xrem := xsize - xmax } } } - is (sReadCmd) { - when (io.vme_rd.cmd.ready) { + is(sReadCmd) { + when(io.vme_rd.cmd.ready) { state := sReadData } } - is (sReadData) { - when (io.vme_rd.data.valid) { + is(sReadData) { + when(io.vme_rd.data.valid) { when(xcnt === xlen) { - when (xrem === 0.U) { + when(xrem === 0.U) { state := sIdle - } .elsewhen (xrem < xmax) { - state := sReadCmd - xlen := xrem - xrem := 0.U - } .otherwise { - state := sReadCmd - xlen := xmax - 1.U - xrem := xrem - xmax - } + }.elsewhen(xrem < xmax) { + state := sReadCmd + xlen := xrem + xrem := 0.U + } + .otherwise { + state := sReadCmd + xlen := xmax - 1.U + xrem := xrem - xmax + } } } } @@ -127,13 +128,14 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { // read-from-dram val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt - when (state === sIdle) { - when (offsetIsEven) { + when(state === sIdle) { + when(offsetIsEven) { raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes))) - } .otherwise { - raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))) - uopBytes.U + }.otherwise { + raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil( + uopBytes)))) - uopBytes.U } - } .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) { + }.elsewhen(state === sReadData && xcnt === xlen && xrem =/= 0.U) { raddr := raddr + xmax_bytes } @@ -143,16 +145,16 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { io.vme_rd.data.ready := state === sReadData - when (state =/= sReadData) { + when(state =/= sReadData) { xcnt := 0.U - } .elsewhen (io.vme_rd.data.fire()) { + }.elsewhen(io.vme_rd.data.fire()) { xcnt := xcnt + 1.U } val waddr = Reg(UInt(log2Ceil(uopDepth).W)) - when (state === sIdle) { + when(state === sIdle) { waddr := dec.sram_offset >> log2Ceil(numUop) - } .elsewhen (io.vme_rd.data.fire()) { + }.elsewhen(io.vme_rd.data.fire()) { waddr := waddr + 1.U } @@ -160,36 +162,37 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata)) val wmask = Reg(Vec(numUop, Bool())) - when (offsetIsEven) { - when (sizeIsEven) { + when(offsetIsEven) { + when(sizeIsEven) { wmask := "b_11".U.asTypeOf(wmask) - } .elsewhen (io.vme_rd.cmd.fire()) { - when (dec.xsize === 1.U) { - wmask := "b_01".U.asTypeOf(wmask) - } .otherwise { - wmask := "b_11".U.asTypeOf(wmask) + }.elsewhen(io.vme_rd.cmd.fire()) { + when(dec.xsize === 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + }.otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } } - } .elsewhen (io.vme_rd.data.fire()) { - when (xcnt === xlen - 1.U) { - wmask := "b_01".U.asTypeOf(wmask) - } .otherwise { - wmask := "b_11".U.asTypeOf(wmask) + .elsewhen(io.vme_rd.data.fire()) { + when(xcnt === xlen - 1.U) { + wmask := "b_01".U.asTypeOf(wmask) + }.otherwise { + wmask := "b_11".U.asTypeOf(wmask) + } } - } - } .otherwise { - when (io.vme_rd.cmd.fire()) { + }.otherwise { + when(io.vme_rd.cmd.fire()) { wmask := "b_10".U.asTypeOf(wmask) - } .elsewhen (io.vme_rd.data.fire()) { - when (sizeIsEven && xcnt === xlen - 1.U) { + }.elsewhen(io.vme_rd.data.fire()) { + when(sizeIsEven && xcnt === xlen - 1.U) { wmask := "b_01".U.asTypeOf(wmask) - } .otherwise { + }.otherwise { wmask := "b_11".U.asTypeOf(wmask) } } } wdata := io.vme_rd.data.bits.asTypeOf(wdata) - when (io.vme_rd.data.fire()) { + when(io.vme_rd.data.fire()) { mem.write(waddr, wdata, wmask) } @@ -209,7 +212,7 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { // debug if (debug) { - when (io.vme_rd.cmd.fire()) { + when(io.vme_rd.cmd.fire()) { printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem) } } diff --git a/vta/hardware/chisel/src/main/scala/core/Semaphore.scala b/vta/hardware/chisel/src/main/scala/core/Semaphore.scala index 06df51e..f268e79 100644 --- a/vta/hardware/chisel/src/main/scala/core/Semaphore.scala +++ b/vta/hardware/chisel/src/main/scala/core/Semaphore.scala @@ -29,14 +29,17 @@ import chisel3.util._ * depending on the push and pop fields on instructions to prevent RAW and WAR * hazards. */ -class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module { +class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) + extends Module { val io = IO(new Bundle { val spost = Input(Bool()) val swait = Input(Bool()) val sready = Output(Bool()) }) val cnt = RegInit(counterInitValue.U(counterBits.W)) - when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U } - when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U } + when(io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { + cnt := cnt + 1.U + } + when(!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U } io.sready := cnt =/= 0.U } diff --git a/vta/hardware/chisel/src/main/scala/core/Store.scala b/vta/hardware/chisel/src/main/scala/core/Store.scala index 71d9208..04bc7f5 100644 --- a/vta/hardware/chisel/src/main/scala/core/Store.scala +++ b/vta/hardware/chisel/src/main/scala/core/Store.scala @@ -55,21 +55,21 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module { val done = tensorStore.io.done // control - switch (state) { - is (sIdle) { - when (start) { - when (dec.io.isSync) { + switch(state) { + is(sIdle) { + when(start) { + when(dec.io.isSync) { state := sSync - } .elsewhen (dec.io.isStore) { + }.elsewhen(dec.io.isStore) { state := sExe } } } - is (sSync) { + is(sSync) { state := sIdle } - is (sExe) { - when (done) { + is(sExe) { + when(done) { state := sIdle } } @@ -94,19 +94,19 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module { // debug if (debug) { // start - when (state === sIdle && start) { - when (dec.io.isSync) { + when(state === sIdle && start) { + when(dec.io.isSync) { printf("[Store] start sync\n") - } .elsewhen (dec.io.isStore) { + }.elsewhen(dec.io.isStore) { printf("[Store] start\n") } } // done - when (state === sSync) { + when(state === sSync) { printf("[Store] done sync\n") } - when (state === sExe) { - when (done) { + when(state === sExe) { + when(done) { printf("[Store] done\n") } } diff --git a/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala b/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala index fbb0578..b438641 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorAlu.scala @@ -116,7 +116,8 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { val acc = new TensorMaster(tensorType = "acc") val out = new TensorMaster(tensorType = "out") }) - val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6) + val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = + Enum(6) val state = RegInit(sIdle) val alu = Module(new AluVector) val dec = io.inst.asTypeOf(new AluDecode) @@ -132,81 +133,86 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { val src_i = Reg(chiselTypeOf(dec.uop_end)) val done = state === sExe & - alu.io.out.data.valid & - (cnt_o === dec.lp_0 - 1.U) & - (cnt_i === dec.lp_1 - 1.U) & - (uop_idx === uop_end - 1.U) - - switch (state) { - is (sIdle) { - when (io.start) { + alu.io.out.data.valid & + (cnt_o === dec.lp_0 - 1.U) & + (cnt_i === dec.lp_1 - 1.U) & + (uop_idx === uop_end - 1.U) + + switch(state) { + is(sIdle) { + when(io.start) { state := sReadUop } } - is (sReadUop) { + is(sReadUop) { state := sComputeIdx } - is (sComputeIdx) { + is(sComputeIdx) { state := sReadTensorA } - is (sReadTensorA) { + is(sReadTensorA) { state := sReadTensorB } - is (sReadTensorB) { + is(sReadTensorB) { state := sExe } - is (sExe) { - when (alu.io.out.data.valid) { - when ((cnt_o === dec.lp_0 - 1.U) && - (cnt_i === dec.lp_1 - 1.U) && - (uop_idx === uop_end - 1.U)) { + is(sExe) { + when(alu.io.out.data.valid) { + when( + (cnt_o === dec.lp_0 - 1.U) && + (cnt_i === dec.lp_1 - 1.U) && + (uop_idx === uop_end - 1.U)) { state := sIdle - } .otherwise { + }.otherwise { state := sReadUop } } } } - when (state === sIdle || - (state === sExe && - alu.io.out.data.valid && - uop_idx === uop_end - 1.U)) { + when( + state === sIdle || + (state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U)) { uop_idx := dec.uop_begin - } .elsewhen (state === sExe && alu.io.out.data.valid) { + }.elsewhen(state === sExe && alu.io.out.data.valid) { uop_idx := uop_idx + 1.U } - when (state === sIdle) { + when(state === sIdle) { cnt_o := 0.U dst_o := 0.U src_o := 0.U - } .elsewhen (state === sExe && - alu.io.out.data.valid && - uop_idx === uop_end - 1.U && - cnt_i === dec.lp_1 - 1.U) { + }.elsewhen( + state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U && + cnt_i === dec.lp_1 - 1.U) { cnt_o := cnt_o + 1.U dst_o := dst_o + dec.dst_0 src_o := src_o + dec.src_0 } - when (state === sIdle) { + when(state === sIdle) { cnt_i := 0.U dst_i := 0.U src_i := 0.U - } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { - cnt_i := 0.U - dst_i := dst_o - src_i := src_o - } .elsewhen (state === sExe && - alu.io.out.data.valid && - uop_idx === uop_end - 1.U) { - cnt_i := cnt_i + 1.U - dst_i := dst_i + dec.dst_1 - src_i := src_i + dec.src_1 - } + }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) { + cnt_i := 0.U + dst_i := dst_o + src_i := src_o + } + .elsewhen( + state === sExe && + alu.io.out.data.valid && + uop_idx === uop_end - 1.U) { + cnt_i := cnt_i + 1.U + dst_i := dst_i + dec.dst_1 + src_i := src_i + dec.src_1 + } - when (state === sComputeIdx && io.uop.data.valid) { + when(state === sComputeIdx && io.uop.data.valid) { uop_dst := io.uop.data.bits.u0 + dst_i uop_src := io.uop.data.bits.u1 + src_i } @@ -222,17 +228,25 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { // imm val tensorImm = Wire(new TensorClientData(tensorType = "acc")) tensorImm.data.valid := state === sReadTensorB - tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } } + tensorImm.data.bits.foreach { b => + b.foreach { c => + c := dec.alu_imm + } + } // alu val isSHR = dec.alu_op === ALU_OP(3) - val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1) + val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1) val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op)) alu.io.opcode := fixme_alu_op alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB alu.io.acc_a.data.bits <> io.acc.rd.data.bits - alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe) - alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits) + alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, + tensorImm.data.valid, + io.acc.rd.data.valid & state === sExe) + alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, + tensorImm.data.bits, + io.acc.rd.data.bits) // acc_o io.acc.wr.valid := alu.io.acc_y.data.valid @@ -249,47 +263,51 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { if (debug) { - when (state === sReadUop) { + when(state === sReadUop) { printf("[TensorAlu] [uop] idx:%x\n", uop_idx) } - when (state === sReadTensorA) { + when(state === sReadTensorA) { printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src) } - when (state === sIdle && io.start) { + when(state === sIdle && io.start) { printf(p"[TensorAlu] decode:$dec\n") } alu.io.acc_a.data.bits.foreach { tensor => - tensor.zipWithIndex.foreach { case(elem, i) => - when (alu.io.acc_a.data.valid) { - printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem) - } + tensor.zipWithIndex.foreach { + case (elem, i) => + when(alu.io.acc_a.data.valid) { + printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem) + } } } alu.io.acc_b.data.bits.foreach { tensor => - tensor.zipWithIndex.foreach { case(elem, i) => - when (alu.io.acc_b.data.valid) { - printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem) - } + tensor.zipWithIndex.foreach { + case (elem, i) => + when(alu.io.acc_b.data.valid) { + printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem) + } } } alu.io.acc_y.data.bits.foreach { tensor => - tensor.zipWithIndex.foreach { case(elem, i) => - when (alu.io.acc_y.data.valid) { - printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem) - } + tensor.zipWithIndex.foreach { + case (elem, i) => + when(alu.io.acc_y.data.valid) { + printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem) + } } } alu.io.out.data.bits.foreach { tensor => - tensor.zipWithIndex.foreach { case(elem, i) => - when (alu.io.out.data.valid) { - printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem) - } + tensor.zipWithIndex.foreach { + case (elem, i) => + when(alu.io.out.data.valid) { + printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem) + } } } } diff --git a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala index 051e011..3f5f387 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala @@ -62,8 +62,10 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { } /** Pipelined DotProduct based on MAC and PipeAdder */ -class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module { - val errorMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" +class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) + extends Module { + val errorMsg = + s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" require(size >= 2 && isPow2(size), errorMsg) val b = aBits + bBits val outBits = b + log2Ceil(size) + 1 @@ -72,12 +74,15 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module val b = Input(Vec(size, SInt(bBits.W))) val y = Output(SInt(outBits.W)) }) - val s = Seq.tabulate(log2Ceil(size + 1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers + val s = Seq.tabulate(log2Ceil(size + 1))(i => + pow(2, log2Ceil(size) - i).toInt) // # of total layers val p = log2Ceil(size / 2) + 1 // # of adder layers val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs - val a = Seq.tabulate(p)(i => - Seq.fill(s(i + 1))(Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1)))) - ) // # adders within each layer + val a = Seq.tabulate(p)( + i => + Seq.fill(s(i + 1))(Module(new PipeAdder( + aBits = (b + i + 1), + bBits = (b + i + 1))))) // # adders within each layer // Vector MACs for (i <- 0 until s(0)) { @@ -111,7 +116,7 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { val inpBits = p(CoreKey).inpBits val wgtBits = p(CoreKey).wgtBits val outBits = p(CoreKey).outBits - val io = IO(new Bundle{ + val io = IO(new Bundle { val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr val inp = new TensorMasterData(tensorType = "inp") val wgt = new TensorMasterData(tensorType = "wgt") @@ -119,8 +124,10 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { val acc_o = new TensorClientData(tensorType = "acc") val out = new TensorClientData(tensorType = "out") }) - val dot = Seq.fill(size)(Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size))) - val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1))) + val dot = Seq.fill(size)( + Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size))) + val acc = Seq.fill(size)( + Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1))) val add = Seq.fill(size)(Wire(SInt(accBits.W))) val vld = Wire(Vec(size, Bool())) @@ -149,7 +156,8 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { * Also, the TensorGemm uses the reset field in the Gemm instruction to * clear or zero-out the acc-scratchpad locations based on the micro-ops. */ -class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module { +class TensorGemm(debug: Boolean = false)(implicit p: Parameters) + extends Module { val io = IO(new Bundle { val start = Input(Bool()) val done = Output(Bool()) @@ -160,7 +168,8 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module val acc = new TensorMaster(tensorType = "acc") val out = new TensorMaster(tensorType = "out") }) - val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6) + val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = + Enum(6) val state = RegInit(sIdle) val mvc = Module(new MatrixVectorMultiplication) val dec = io.inst.asTypeOf(new GemmDecode) @@ -181,99 +190,103 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module val inflight = Reg(UInt(pBits.W)) val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits)) val done = inflight === 0.U & - ((state === sExe & - cnt_o === dec.lp_0 - 1.U & - cnt_i === dec.lp_1 - 1.U & - uop_idx === uop_end - 1.U & - inflight === 0.U) | - state === sWait) - - switch (state) { - is (sIdle) { - when (io.start) { + ((state === sExe & + cnt_o === dec.lp_0 - 1.U & + cnt_i === dec.lp_1 - 1.U & + uop_idx === uop_end - 1.U & + inflight === 0.U) | + state === sWait) + + switch(state) { + is(sIdle) { + when(io.start) { state := sReadUop } } - is (sReadUop) { + is(sReadUop) { state := sComputeIdx } - is (sComputeIdx) { + is(sComputeIdx) { state := sReadTensor } - is (sReadTensor) { + is(sReadTensor) { state := sExe } - is (sExe) { - when ((cnt_o === dec.lp_0 - 1.U) && - (cnt_i === dec.lp_1 - 1.U) && - (uop_idx === uop_end - 1.U)) { - when (inflight =/= 0.U) { + is(sExe) { + when( + (cnt_o === dec.lp_0 - 1.U) && + (cnt_i === dec.lp_1 - 1.U) && + (uop_idx === uop_end - 1.U)) { + when(inflight =/= 0.U) { state := sWait - } .otherwise { + }.otherwise { state := sIdle } - } .otherwise { + }.otherwise { state := sReadUop } } - is (sWait) { - when (inflight === 0.U) { + is(sWait) { + when(inflight === 0.U) { state := sIdle } } } - when (state === sIdle) { + when(state === sIdle) { inflight := 0.U - } .elsewhen (!dec.reset) { - when (state === sReadTensor) { // issue a tensor + }.elsewhen(!dec.reset) { + when(state === sReadTensor) { // issue a tensor inflight := inflight + 1.U - } .elsewhen (mvc.io.acc_o.data.valid) { // commit a tensor + }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor inflight := inflight - 1.U } } - when (state === sIdle || - (state === sExe && - uop_idx === uop_end - 1.U)) { + when( + state === sIdle || + (state === sExe && + uop_idx === uop_end - 1.U)) { uop_idx := dec.uop_begin - } .elsewhen (state === sExe) { + }.elsewhen(state === sExe) { uop_idx := uop_idx + 1.U } - when (state === sIdle) { + when(state === sIdle) { cnt_o := 0.U acc_o := 0.U inp_o := 0.U wgt_o := 0.U - } .elsewhen (state === sExe && - uop_idx === uop_end - 1.U && - cnt_i === dec.lp_1 - 1.U) { + }.elsewhen( + state === sExe && + uop_idx === uop_end - 1.U && + cnt_i === dec.lp_1 - 1.U) { cnt_o := cnt_o + 1.U acc_o := acc_o + dec.acc_0 inp_o := inp_o + dec.inp_0 wgt_o := wgt_o + dec.wgt_0 } - when (state === sIdle) { + when(state === sIdle) { cnt_i := 0.U acc_i := 0.U inp_i := 0.U wgt_i := 0.U - } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { - cnt_i := 0.U - acc_i := acc_o - inp_i := inp_o - wgt_i := wgt_o - } .elsewhen (state === sExe && - uop_idx === uop_end - 1.U) { - cnt_i := cnt_i + 1.U - acc_i := acc_i + dec.acc_1 - inp_i := inp_i + dec.inp_1 - wgt_i := wgt_i + dec.wgt_1 - } + }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) { + cnt_i := 0.U + acc_i := acc_o + inp_i := inp_o + wgt_i := wgt_o + } + .elsewhen(state === sExe && + uop_idx === uop_end - 1.U) { + cnt_i := cnt_i + 1.U + acc_i := acc_i + dec.acc_1 + inp_i := inp_i + dec.inp_1 + wgt_i := wgt_i + dec.wgt_1 + } - when (state === sComputeIdx && io.uop.data.valid) { + when(state === sComputeIdx && io.uop.data.valid) { uop_acc := io.uop.data.bits.u0 + acc_i uop_inp := io.uop.data.bits.u1 + inp_i uop_wgt := io.uop.data.bits.u2 + wgt_i @@ -307,7 +320,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module mvc.io.acc_i.data <> io.acc.rd.data // acc_o - io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid) + io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, + true.B, + wrpipe.io.deq.valid) io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits) io.acc.wr.bits.data <> mvc.io.acc_o.data.bits @@ -320,47 +335,55 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module io.done := done if (debug) { - when (state === sReadUop && ~dec.reset) { + when(state === sReadUop && ~dec.reset) { printf("[TensorGemm] [uop] idx:%x\n", uop_idx) } - when (state === sReadTensor && ~dec.reset) { - printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt) + when(state === sReadTensor && ~dec.reset) { + printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", + uop_acc, + uop_inp, + uop_wgt) } - io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) => - when (io.inp.rd.data.valid && ~dec.reset) { - printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt) - } + io.inp.rd.data.bits.zipWithIndex.foreach { + case (r, i) => + when(io.inp.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt) + } } - io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) => - when (io.wgt.rd.data.valid && ~dec.reset) { - printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt) - } + io.wgt.rd.data.bits.zipWithIndex.foreach { + case (r, i) => + when(io.wgt.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt) + } } io.acc.rd.data.bits.foreach { tensor => - tensor.zipWithIndex.foreach { case(elem, i) => - when (io.acc.rd.data.valid && ~dec.reset) { - printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem) - } + tensor.zipWithIndex.foreach { + case (elem, i) => + when(io.acc.rd.data.valid && ~dec.reset) { + printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem) + } } } mvc.io.acc_o.data.bits.foreach { tensor => - tensor.zipWithIndex.foreach { case(elem, i) => - when (mvc.io.acc_o.data.valid && ~dec.reset) { - printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem) - } + tensor.zipWithIndex.foreach { + case (elem, i) => + when(mvc.io.acc_o.data.valid && ~dec.reset) { + printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem) + } } } mvc.io.out.data.bits.foreach { tensor => - tensor.zipWithIndex.foreach { case(elem, i) => - when (mvc.io.out.data.valid && ~dec.reset) { - printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem) - } + tensor.zipWithIndex.foreach { + case (elem, i) => + when(mvc.io.out.data.valid && ~dec.reset) { + printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem) + } } } } diff --git a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala index 8f1956f..ca6803c 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorLoad.scala @@ -32,8 +32,9 @@ import vta.shell._ * managed by TensorPadCtrl. The TensorDataCtrl is in charge of * handling the way tensors are stored on the scratchpads. */ -class TensorLoad(tensorType: String = "none", debug: Boolean = false) - (implicit p: Parameters) extends Module { +class TensorLoad(tensorType: String = "none", debug: Boolean = false)( + implicit p: Parameters) + extends Module { val tp = new TensorParams(tensorType) val mp = p(ShellKey).memParams val io = IO(new Bundle { @@ -48,7 +49,8 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) val strideFactor = tp.tensorLength * tp.tensorWidth val dec = io.inst.asTypeOf(new MemDecode) - val dataCtrl = Module(new TensorDataCtrl(tensorType, sizeFactor, strideFactor)) + val dataCtrl = Module( + new TensorDataCtrl(tensorType, sizeFactor, strideFactor)) val dataCtrlDone = RegInit(false.B) val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor)) val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor)) @@ -58,81 +60,85 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) val tag = Reg(UInt(log2Ceil(tp.numMemBlock).W)) val set = Reg(UInt(log2Ceil(tp.tensorLength).W)) - val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7) + val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = + Enum(7) val state = RegInit(sIdle) // control - switch (state) { - is (sIdle) { - when (io.start) { - when (dec.ypad_0 =/= 0.U) { + switch(state) { + is(sIdle) { + when(io.start) { + when(dec.ypad_0 =/= 0.U) { state := sYPad0 - } .elsewhen (dec.xpad_0 =/= 0.U) { - state := sXPad0 - } .otherwise { - state := sReadCmd - } + }.elsewhen(dec.xpad_0 =/= 0.U) { + state := sXPad0 + } + .otherwise { + state := sReadCmd + } } } - is (sYPad0) { - when (yPadCtrl0.io.done) { - when (dec.xpad_0 =/= 0.U) { + is(sYPad0) { + when(yPadCtrl0.io.done) { + when(dec.xpad_0 =/= 0.U) { state := sXPad0 - } .otherwise { + }.otherwise { state := sReadCmd } } } - is (sXPad0) { - when (xPadCtrl0.io.done) { + is(sXPad0) { + when(xPadCtrl0.io.done) { state := sReadCmd } } - is (sReadCmd) { - when (io.vme_rd.cmd.ready) { + is(sReadCmd) { + when(io.vme_rd.cmd.ready) { state := sReadData } } - is (sReadData) { - when (io.vme_rd.data.valid) { - when (dataCtrl.io.done) { - when (dec.xpad_1 =/= 0.U) { + is(sReadData) { + when(io.vme_rd.data.valid) { + when(dataCtrl.io.done) { + when(dec.xpad_1 =/= 0.U) { state := sXPad1 - } .elsewhen (dec.ypad_1 =/= 0.U) { - state := sYPad1 - } .otherwise { - state := sIdle - } - } .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) { - when (dec.xpad_1 =/= 0.U) { + }.elsewhen(dec.ypad_1 =/= 0.U) { + state := sYPad1 + } + .otherwise { + state := sIdle + } + }.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) { + when(dec.xpad_1 =/= 0.U) { state := sXPad1 - } .elsewhen (dec.xpad_0 =/= 0.U) { - state := sXPad0 - } .otherwise { + }.elsewhen(dec.xpad_0 =/= 0.U) { + state := sXPad0 + } + .otherwise { state := sReadCmd - } + } } } } - is (sXPad1) { - when (xPadCtrl1.io.done) { - when (dataCtrlDone) { - when (dec.ypad_1 =/= 0.U) { + is(sXPad1) { + when(xPadCtrl1.io.done) { + when(dataCtrlDone) { + when(dec.ypad_1 =/= 0.U) { state := sYPad1 - } .otherwise { + }.otherwise { state := sIdle } - } .otherwise { - when (dec.xpad_0 =/= 0.U) { + }.otherwise { + when(dec.xpad_0 =/= 0.U) { state := sXPad0 - } .otherwise { + }.otherwise { state := sReadCmd } } } } - is (sYPad1) { - when (yPadCtrl1.io.done && dataCtrlDone) { + is(sYPad1) { + when(yPadCtrl1.io.done && dataCtrlDone) { state := sIdle } } @@ -146,9 +152,9 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) dataCtrl.io.xupdate := io.vme_rd.data.fire() dataCtrl.io.yupdate := io.vme_rd.data.fire() - when (state === sIdle) { + when(state === sIdle) { dataCtrlDone := false.B - } .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) { + }.elsewhen(io.vme_rd.data.fire() && dataCtrl.io.done) { dataCtrlDone := true.B } @@ -156,18 +162,19 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start yPadCtrl1.io.start := dec.ypad_1 =/= 0.U & - ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) | - (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone)) + ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) | + (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone)) xPadCtrl0.io.start := dec.xpad_0 =/= 0.U & - ((state === sIdle & io.start) | - (state === sYPad0 & yPadCtrl0.io.done) | - (io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) | - (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone)) + ((state === sIdle & io.start) | + (state === sYPad0 & yPadCtrl0.io.done) | + (io.vme_rd.data + .fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) | + (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone)) xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() & - ((dataCtrl.io.done) | - (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U)) + ((dataCtrl.io.done) | + (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U)) yPadCtrl0.io.inst := io.inst yPadCtrl1.io.inst := io.inst @@ -183,39 +190,49 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) // write-to-sram val isZeroPad = state === sYPad0 | - state === sXPad0 | - state === sXPad1 | - state === sYPad1 + state === sXPad0 | + state === sXPad1 | + state === sYPad1 - when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) { + when(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) { tag := 0.U - } .elsewhen (io.vme_rd.data.fire() || isZeroPad) { + }.elsewhen(io.vme_rd.data.fire() || isZeroPad) { tag := tag + 1.U } - when (state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) { + when( + state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) { set := 0.U - } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) { + }.elsewhen( + (io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) { set := set + 1.U } val waddr_cur = Reg(UInt(tp.memAddrBits.W)) val waddr_nxt = Reg(UInt(tp.memAddrBits.W)) - when (state === sIdle) { + when(state === sIdle) { waddr_cur := dec.sram_offset waddr_nxt := dec.sram_offset - } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) { - waddr_cur := waddr_cur + 1.U - } .elsewhen (dataCtrl.io.stride) { - waddr_cur := waddr_nxt + dec.xsize - waddr_nxt := waddr_nxt + dec.xsize - } + }.elsewhen((io.vme_rd.data + .fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) { + waddr_cur := waddr_cur + 1.U + } + .elsewhen(dataCtrl.io.stride) { + waddr_cur := waddr_nxt + dec.xsize + waddr_nxt := waddr_nxt + dec.xsize + } - val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) } + val tensorFile = Seq.fill(tp.tensorLength) { + SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) + } val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) } - val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) } + val wdata = Seq.fill(tp.tensorLength) { + Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) + } val no_mask = Wire(Vec(tp.numMemBlock, Bool())) - no_mask.foreach { m => m := true.B } + no_mask.foreach { m => + m := true.B + } for (i <- 0 until tp.tensorLength) { for (j <- 0 until tp.numMemBlock) { @@ -223,11 +240,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits) } val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i)) - val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U) + val muxWen = + Mux(state === sIdle, + io.tensor.wr.valid, + (io.vme_rd.data.fire() | isZeroPad) & set === i.U) val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur) val muxWdata = Mux(state === sIdle, tdata, wdata(i)) val muxWmask = Mux(state === sIdle, no_mask, wmask(i)) - when (muxWen) { + when(muxWen) { tensorFile(i).write(muxWaddr, muxWdata, muxWmask) } } @@ -236,13 +256,16 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) val rvalid = RegNext(io.tensor.rd.idx.valid) io.tensor.rd.data.valid := rvalid - val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid)) - rdata.zipWithIndex.foreach { case(r, i) => - io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i)) + val rdata = + tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid)) + rdata.zipWithIndex.foreach { + case (r, i) => + io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i)) } // done - val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U + val done_no_pad = io.vme_rd.data + .fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done io.done := done_no_pad | done_x_pad | done_y_pad @@ -250,28 +273,34 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false) // debug if (debug) { if (tensorType == "inp") { - when (io.vme_rd.cmd.fire()) { - printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + when(io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", + dataCtrl.io.addr, + dataCtrl.io.len) } - when (state === sYPad0) { + when(state === sYPad0) { printf("[TensorLoad] [inp] sYPad0\n") } - when (state === sYPad1) { + when(state === sYPad1) { printf("[TensorLoad] [inp] sYPad1\n") } - when (state === sXPad0) { + when(state === sXPad0) { printf("[TensorLoad] [inp] sXPad0\n") } - when (state === sXPad1) { + when(state === sXPad1) { printf("[TensorLoad] [inp] sXPad1\n") } } else if (tensorType == "wgt") { - when (io.vme_rd.cmd.fire()) { - printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + when(io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", + dataCtrl.io.addr, + dataCtrl.io.len) } } else if (tensorType == "acc") { - when (io.vme_rd.cmd.fire()) { - printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len) + when(io.vme_rd.cmd.fire()) { + printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", + dataCtrl.io.addr, + dataCtrl.io.len) } } } diff --git a/vta/hardware/chisel/src/main/scala/core/TensorStore.scala b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala index 20ff6f4..083a70c 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorStore.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorStore.scala @@ -28,8 +28,9 @@ import vta.shell._ * * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM). */ -class TensorStore(tensorType: String = "none", debug: Boolean = false) - (implicit p: Parameters) extends Module { +class TensorStore(tensorType: String = "none", debug: Boolean = false)( + implicit p: Parameters) + extends Module { val tp = new TensorParams(tensorType) val mp = p(ShellKey).memParams val io = IO(new Bundle { @@ -53,9 +54,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) val xrem = Reg(chiselTypeOf(dec.xsize)) - val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U + val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock)) - 1.U val xmax = (1 << mp.lenBits).U - val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U val ycnt = Reg(chiselTypeOf(dec.ysize)) val ysize = dec.ysize val tag = Reg(UInt(8.W)) @@ -65,132 +66,147 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) val state = RegInit(sIdle) // control - switch (state) { - is (sIdle) { - when (io.start) { + switch(state) { + is(sIdle) { + when(io.start) { state := sWriteCmd - when (xsize < xmax) { + when(xsize < xmax) { xlen := xsize xrem := 0.U - } .otherwise { + }.otherwise { xlen := xmax - 1.U xrem := xsize - xmax } } } - is (sWriteCmd) { - when (io.vme_wr.cmd.ready) { + is(sWriteCmd) { + when(io.vme_wr.cmd.ready) { state := sWriteData } } - is (sWriteData) { - when (io.vme_wr.data.ready) { - when (xcnt === xlen) { + is(sWriteData) { + when(io.vme_wr.data.ready) { + when(xcnt === xlen) { state := sWriteAck - } .elsewhen (tag === (numMemBlock - 1).U) { + }.elsewhen(tag === (numMemBlock - 1).U) { state := sReadMem } } } - is (sReadMem) { + is(sReadMem) { state := sWriteData } - is (sWriteAck) { - when (io.vme_wr.ack) { - when (xrem === 0.U) { - when (ycnt === ysize - 1.U) { + is(sWriteAck) { + when(io.vme_wr.ack) { + when(xrem === 0.U) { + when(ycnt === ysize - 1.U) { state := sIdle - } .otherwise { + }.otherwise { state := sWriteCmd - when (xsize < xmax) { + when(xsize < xmax) { xlen := xsize xrem := 0.U - } .otherwise { + }.otherwise { xlen := xmax - 1.U xrem := xsize - xmax } } - } .elsewhen (xrem < xmax) { - state := sWriteCmd - xlen := xrem - xrem := 0.U - } .otherwise { - state := sWriteCmd - xlen := xmax - 1.U - xrem := xrem - xmax - } + }.elsewhen(xrem < xmax) { + state := sWriteCmd + xlen := xrem + xrem := 0.U + } + .otherwise { + state := sWriteCmd + xlen := xmax - 1.U + xrem := xrem - xmax + } } } } // write-to-sram - val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) } + val tensorFile = Seq.fill(tensorLength) { + SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) + } val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W))) val no_mask = Wire(Vec(numMemBlock, Bool())) wdata_t := DontCare - no_mask.foreach { m => m := true.B } + no_mask.foreach { m => + m := true.B + } for (i <- 0 until tensorLength) { val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t) - when (io.tensor.wr.valid) { + when(io.tensor.wr.valid) { tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask) } } // read-from-sram val stride = state === sWriteAck & - io.vme_wr.ack & - xcnt === xlen + 1.U & - xrem === 0.U & - ycnt =/= ysize - 1.U + io.vme_wr.ack & + xcnt === xlen + 1.U & + xrem === 0.U & + ycnt =/= ysize - 1.U - when (state === sIdle) { + when(state === sIdle) { ycnt := 0.U - } .elsewhen (stride) { + }.elsewhen(stride) { ycnt := ycnt + 1.U } - when (state === sWriteCmd || tag === (numMemBlock - 1).U) { + when(state === sWriteCmd || tag === (numMemBlock - 1).U) { tag := 0.U - } .elsewhen (io.vme_wr.data.fire()) { + }.elsewhen(io.vme_wr.data.fire()) { tag := tag + 1.U } - when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) { + when( + state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) { set := 0.U - } .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) { + }.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) { set := set + 1.U } val raddr_cur = Reg(UInt(tp.memAddrBits.W)) val raddr_nxt = Reg(UInt(tp.memAddrBits.W)) - when (state === sIdle) { + when(state === sIdle) { raddr_cur := dec.sram_offset raddr_nxt := dec.sram_offset - } .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) { - raddr_cur := raddr_cur + 1.U - } .elsewhen (stride) { - raddr_cur := raddr_nxt + dec.xsize - raddr_nxt := raddr_nxt + dec.xsize - } + }.elsewhen(io.vme_wr.data + .fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) { + raddr_cur := raddr_cur + 1.U + } + .elsewhen(stride) { + raddr_cur := raddr_nxt + dec.xsize + raddr_nxt := raddr_nxt + dec.xsize + } - val tread = Seq.tabulate(tensorLength) { i => i.U -> - tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) } + val tread = Seq.tabulate(tensorLength) { i => + i.U -> + tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) + } val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread) // write-to-dram val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8 - when (state === sIdle) { - waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) - waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) - } .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) { - waddr_cur := waddr_cur + xmax_bytes - } .elsewhen (stride) { - waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) - waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) - } + when(state === sIdle) { + waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil( + elemBytes))) + waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil( + elemBytes))) + }.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) { + waddr_cur := waddr_cur + xmax_bytes + } + .elsewhen(stride) { + waddr_cur := waddr_nxt + (dec.xstride << log2Ceil( + tensorLength * tensorWidth)) + waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil( + tensorLength * tensorWidth)) + } io.vme_wr.cmd.valid := state === sWriteCmd io.vme_wr.cmd.bits.addr := waddr_cur @@ -199,9 +215,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) io.vme_wr.data.valid := state === sWriteData io.vme_wr.data.bits := mdata(tag) - when (state === sWriteCmd) { + when(state === sWriteCmd) { xcnt := 0.U - } .elsewhen (io.vme_wr.data.fire()) { + }.elsewhen(io.vme_wr.data.fire()) { xcnt := xcnt + 1.U } @@ -213,13 +229,19 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) // debug if (debug) { - when (io.vme_wr.cmd.fire()) { - printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem) + when(io.vme_wr.cmd.fire()) { + printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", + ysize, + ycnt, + raddr_cur, + waddr_cur, + xlen, + xrem) } - when (io.vme_wr.data.fire()) { + when(io.vme_wr.data.fire()) { printf("[TensorStore] data:%x\n", io.vme_wr.data.bits) } - when (io.vme_wr.ack) { + when(io.vme_wr.ack) { printf("[TensorStore] ack\n") } } diff --git a/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala b/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala index bb03846..1f00554 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorUtil.scala @@ -30,11 +30,14 @@ import vta.shell._ * weights (wgt), biases (acc), and outputs (out). This is used to avoid * doing the same boring calculations over and over again. */ -class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle { - val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n" +class TensorParams(tensorType: String = "none")(implicit p: Parameters) + extends Bundle { + val errorMsg = + s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n" - require (tensorType == "inp" || tensorType == "wgt" - || tensorType == "acc" || tensorType == "out", errorMsg) + require(tensorType == "inp" || tensorType == "wgt" + || tensorType == "acc" || tensorType == "out", + errorMsg) val (tensorLength, tensorWidth, tensorElemBits) = if (tensorType == "inp") @@ -69,25 +72,30 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends * biases (acc), and outputs (out). * */ -class TensorMaster(tensorType: String = "none") - (implicit p: Parameters) extends TensorParams(tensorType) { - val rd = new Bundle { - val idx = ValidIO(UInt(memAddrBits.W)) - val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) - } - val wr = ValidIO(new Bundle { - val idx = UInt(memAddrBits.W) - val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) - }) - def tieoffRead() { - rd.idx.valid := false.B - rd.idx.bits := 0.U - } - def tieoffWrite() { - wr.valid := false.B - wr.bits.idx := 0.U - wr.bits.data.foreach { b => b.foreach { c => c := 0.U } } +class TensorMaster(tensorType: String = "none")(implicit p: Parameters) + extends TensorParams(tensorType) { + val rd = new Bundle { + val idx = ValidIO(UInt(memAddrBits.W)) + val data = Flipped( + ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) + } + val wr = ValidIO(new Bundle { + val idx = UInt(memAddrBits.W) + val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) + }) + def tieoffRead() { + rd.idx.valid := false.B + rd.idx.bits := 0.U + } + def tieoffWrite() { + wr.valid := false.B + wr.bits.idx := 0.U + wr.bits.data.foreach { b => + b.foreach { c => + c := 0.U + } } + } override def cloneType = new TensorMaster(tensorType).asInstanceOf[this.type] } @@ -98,20 +106,25 @@ class TensorMaster(tensorType: String = "none") * The TensorLoad unit uses this interface for receiving read and write requests from * the TensorGemm unit. */ -class TensorClient(tensorType: String = "none") - (implicit p: Parameters) extends TensorParams(tensorType) { - val rd = new Bundle { - val idx = Flipped(ValidIO(UInt(memAddrBits.W))) - val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) - } - val wr = Flipped(ValidIO(new Bundle { - val idx = UInt(memAddrBits.W) - val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) - })) - def tieoffRead() { - rd.data.valid := false.B - rd.data.bits.foreach { b => b.foreach { c => c := 0.U } } +class TensorClient(tensorType: String = "none")(implicit p: Parameters) + extends TensorParams(tensorType) { + val rd = new Bundle { + val idx = Flipped(ValidIO(UInt(memAddrBits.W))) + val data = ValidIO( + Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) + } + val wr = Flipped(ValidIO(new Bundle { + val idx = UInt(memAddrBits.W) + val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))) + })) + def tieoffRead() { + rd.data.valid := false.B + rd.data.bits.foreach { b => + b.foreach { c => + c := 0.U + } } + } override def cloneType = new TensorClient(tensorType).asInstanceOf[this.type] } @@ -122,9 +135,10 @@ class TensorClient(tensorType: String = "none") * is based on the TensorMaster interface, which means this is an input. This interface * is used on datapath only module such MatrixVectorCore or AluVector. */ -class TensorMasterData(tensorType: String = "none") - (implicit p: Parameters) extends TensorParams(tensorType) { - val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) +class TensorMasterData(tensorType: String = "none")(implicit p: Parameters) + extends TensorParams(tensorType) { + val data = Flipped( + ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) override def cloneType = new TensorMasterData(tensorType).asInstanceOf[this.type] } @@ -135,18 +149,22 @@ class TensorMasterData(tensorType: String = "none") * is based on the TensorClient interface, which means this is an output. This interface * is used on datapath only module such MatrixVectorCore or AluVector. */ -class TensorClientData(tensorType: String = "none") - (implicit p: Parameters) extends TensorParams(tensorType) { - val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) +class TensorClientData(tensorType: String = "none")(implicit p: Parameters) + extends TensorParams(tensorType) { + val data = ValidIO( + Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) override def cloneType = new TensorClientData(tensorType).asInstanceOf[this.type] } /** TensorPadCtrl. Zero-padding controller for TensorLoad. */ -class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module { - val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n" - require (padType == "YPad0" || padType == "YPad1" - || padType == "XPad0" || padType == "XPad1", errorMsg) +class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) + extends Module { + val errorMsg = + s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n" + require(padType == "YPad0" || padType == "YPad1" + || padType == "XPad0" || padType == "XPad1", + errorMsg) val io = IO(new Bundle { val start = Input(Bool()) @@ -180,33 +198,33 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul val sIdle :: sActive :: Nil = Enum(2) val state = RegInit(sIdle) - switch (state) { - is (sIdle) { - when (io.start) { + switch(state) { + is(sIdle) { + when(io.start) { state := sActive } } - is (sActive) { - when (ycnt === ymax && xcnt === xmax) { + is(sActive) { + when(ycnt === ymax && xcnt === xmax) { state := sIdle } } } - when (state === sIdle) { + when(state === sIdle) { xmax := xval ymax := yval } - when (state === sIdle || xcnt === xmax) { + when(state === sIdle || xcnt === xmax) { xcnt := 0.U - } .elsewhen (state === sActive) { + }.elsewhen(state === sActive) { xcnt := xcnt + 1.U } - when (state === sIdle || ymax === 0.U) { + when(state === sIdle || ymax === 0.U) { ycnt := 0.U - } .elsewhen (state === sActive && xcnt === xmax) { + }.elsewhen(state === sActive && xcnt === xmax) { ycnt := ycnt + 1.U } @@ -214,7 +232,10 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul } /** TensorDataCtrl. Data controller for TensorLoad. */ -class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module { +class TensorDataCtrl(tensorType: String = "none", + sizeFactor: Int = 1, + strideFactor: Int = 1)(implicit p: Parameters) + extends Module { val mp = p(ShellKey).memParams val io = IO(new Bundle { val start = Input(Bool()) @@ -238,7 +259,7 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac val len = Reg(UInt(mp.lenBits.W)) - val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U + val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U val xcnt = Reg(UInt(mp.lenBits.W)) val xrem = Reg(chiselTypeOf(dec.xsize)) val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U @@ -246,38 +267,38 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac val ycnt = Reg(chiselTypeOf(dec.ysize)) val stride = xcnt === len & - xrem === 0.U & - ycnt =/= dec.ysize - 1.U + xrem === 0.U & + ycnt =/= dec.ysize - 1.U val split = xcnt === len & xrem =/= 0.U - when (io.start || (io.xupdate && stride)) { - when (xsize < xmax) { + when(io.start || (io.xupdate && stride)) { + when(xsize < xmax) { len := xsize xrem := 0.U - } .otherwise { + }.otherwise { len := xmax - 1.U xrem := xsize - xmax } - } .elsewhen (io.xupdate && split) { - when (xrem < xmax) { + }.elsewhen(io.xupdate && split) { + when(xrem < xmax) { len := xrem xrem := 0.U - } .otherwise { + }.otherwise { len := xmax - 1.U xrem := xrem - xmax } } - when (io.xinit) { + when(io.xinit) { xcnt := 0.U - } .elsewhen (io.xupdate) { + }.elsewhen(io.xupdate) { xcnt := xcnt + 1.U } - when (io.start) { + when(io.start) { ycnt := 0.U - } .elsewhen (io.yupdate && stride) { + }.elsewhen(io.yupdate && stride) { ycnt := ycnt + 1.U } @@ -291,13 +312,13 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).accBits) / 8 } - when (io.start) { + when(io.start) { caddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) baddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) - } .elsewhen (io.yupdate) { - when (split) { + }.elsewhen(io.yupdate) { + when(split) { caddr := caddr + xmax_bytes - } .elsewhen (stride) { + }.elsewhen(stride) { caddr := baddr + (dec.xstride << log2Ceil(strideFactor)) baddr := baddr + (dec.xstride << log2Ceil(strideFactor)) } @@ -309,6 +330,6 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac io.addr := caddr io.len := len io.done := xcnt === len & - xrem === 0.U & - ycnt === dec.ysize - 1.U + xrem === 0.U & + ycnt === dec.ysize - 1.U } diff --git a/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala b/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala index 115bcbc..3318251 100644 --- a/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala +++ b/vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala @@ -78,55 +78,56 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource { * * Convert Host DPI to AXI for VTAShell */ - -class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { +class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) + extends Module { val io = IO(new Bundle { val dpi = new VTAHostDPIClient val axi = new AXILiteMaster(p(ShellKey).hostParams) }) val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value))) - val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) + val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = + Enum(6) val state = RegInit(sIdle) - switch (state) { - is (sIdle) { - when (io.dpi.req.valid) { - when (io.dpi.req.opcode) { + switch(state) { + is(sIdle) { + when(io.dpi.req.valid) { + when(io.dpi.req.opcode) { state := sWriteAddress - } .otherwise { + }.otherwise { state := sReadAddress } } } - is (sReadAddress) { - when (io.axi.ar.ready) { + is(sReadAddress) { + when(io.axi.ar.ready) { state := sReadData } } - is (sReadData) { - when (io.axi.r.valid) { + is(sReadData) { + when(io.axi.r.valid) { state := sIdle } } - is (sWriteAddress) { - when (io.axi.aw.ready) { + is(sWriteAddress) { + when(io.axi.aw.ready) { state := sWriteData } } - is (sWriteData) { - when (io.axi.w.ready) { + is(sWriteData) { + when(io.axi.w.ready) { state := sWriteResponse } } - is (sWriteResponse) { - when (io.axi.b.valid) { + is(sWriteResponse) { + when(io.axi.b.valid) { state := sIdle } } } - when (state === sIdle && io.dpi.req.valid) { + when(state === sIdle && io.dpi.req.valid) { addr := io.dpi.req.addr data := io.dpi.req.value } @@ -147,9 +148,17 @@ class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mo io.dpi.resp.bits := io.axi.r.bits.data if (debug) { - when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) } - when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) } - when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) } - when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) } + when(state === sWriteAddress && io.axi.aw.ready) { + printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) + } + when(state === sReadAddress && io.axi.ar.ready) { + printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) + } + when(io.axi.r.fire()) { + printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) + } + when(io.axi.w.fire()) { + printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) + } } } diff --git a/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala b/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala index 5e2fa74..f46b778 100644 --- a/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala +++ b/vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala @@ -75,7 +75,8 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource { setResource("/verilog/VTAMemDPI.v") } -class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { +class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) + extends Module { val io = IO(new Bundle { val dpi = new VTAMemDPIMaster val axi = new AXIClient(p(ShellKey).memParams) @@ -83,56 +84,57 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod val opcode = RegInit(false.B) val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len))) val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) - val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) + val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = + Enum(6) val state = RegInit(sIdle) - switch (state) { - is (sIdle) { - when (io.axi.ar.valid) { + switch(state) { + is(sIdle) { + when(io.axi.ar.valid) { state := sReadAddress - } .elsewhen (io.axi.aw.valid) { + }.elsewhen(io.axi.aw.valid) { state := sWriteAddress } } - is (sReadAddress) { - when (io.axi.ar.valid) { + is(sReadAddress) { + when(io.axi.ar.valid) { state := sReadData } } - is (sReadData) { - when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) { + is(sReadData) { + when(io.axi.r.ready && io.dpi.rd.valid && len === 0.U) { state := sIdle } } - is (sWriteAddress) { - when (io.axi.aw.valid) { + is(sWriteAddress) { + when(io.axi.aw.valid) { state := sWriteData } } - is (sWriteData) { - when (io.axi.w.valid && io.axi.w.bits.last) { + is(sWriteData) { + when(io.axi.w.valid && io.axi.w.bits.last) { state := sWriteResponse } } - is (sWriteResponse) { - when (io.axi.b.ready) { + is(sWriteResponse) { + when(io.axi.b.ready) { state := sIdle } } } - when (state === sIdle) { - when (io.axi.ar.valid) { + when(state === sIdle) { + when(io.axi.ar.valid) { opcode := false.B len := io.axi.ar.bits.len addr := io.axi.ar.bits.addr - } .elsewhen (io.axi.aw.valid) { + }.elsewhen(io.axi.aw.valid) { opcode := true.B len := io.axi.aw.bits.len addr := io.axi.aw.bits.addr } - } .elsewhen (state === sReadData) { - when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) { + }.elsewhen(state === sReadData) { + when(io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) { len := len - 1.U } } @@ -163,9 +165,21 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod io.axi.b.bits.id := 0.U if (debug) { - when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) } - when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) } - when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) } - when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) } + when(state === sReadAddress && io.axi.ar.valid) { + printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) + } + when(state === sWriteAddress && io.axi.aw.valid) { + printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) + } + when(io.axi.r.fire()) { + printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", + io.axi.r.bits.last, + io.axi.r.bits.data) + } + when(io.axi.w.fire()) { + printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", + io.axi.w.bits.last, + io.axi.w.bits.data) + } } } diff --git a/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala b/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala index bf34b6a..8fd0fa8 100644 --- a/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala +++ b/vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala @@ -24,18 +24,17 @@ import chisel3.util._ import vta.util.genericbundle._ case class AXIParams( - coherent: Boolean = false, - idBits: Int = 1, - addrBits: Int = 32, - dataBits: Int = 64, - lenBits: Int = 8, - userBits: Int = 1 -) -{ - require (addrBits > 0) - require (dataBits >= 8 && dataBits % 2 == 0) + coherent: Boolean = false, + idBits: Int = 1, + addrBits: Int = 32, + dataBits: Int = 64, + lenBits: Int = 8, + userBits: Int = 1 +) { + require(addrBits > 0) + require(dataBits >= 8 && dataBits % 2 == 0) - val strbBits = dataBits/8 + val strbBits = dataBits / 8 val sizeBits = 3 val burstBits = 2 val lockBits = 2 @@ -44,7 +43,7 @@ case class AXIParams( val qosBits = 4 val regionBits = 4 val respBits = 2 - val sizeConst = log2Ceil(dataBits/8) + val sizeConst = log2Ceil(dataBits / 8) val idConst = 0 val userConst = if (coherent) 1 else 0 val burstConst = 1 @@ -56,7 +55,7 @@ case class AXIParams( } abstract class AXIBase(params: AXIParams) - extends GenericParameterizedBundle(params) + extends GenericParameterizedBundle(params) // AXILite diff --git a/vta/hardware/chisel/src/main/scala/shell/Configs.scala b/vta/hardware/chisel/src/main/scala/shell/Configs.scala index fd9309e..3c271f5 100644 --- a/vta/hardware/chisel/src/main/scala/shell/Configs.scala +++ b/vta/hardware/chisel/src/main/scala/shell/Configs.scala @@ -25,52 +25,59 @@ import vta.util.config._ import vta.interface.axi._ /** PynqConfig. Shell configuration for Pynq */ -class PynqConfig extends Config((site, here, up) => { - case ShellKey => ShellParams( - hostParams = AXIParams( - coherent = false, - addrBits = 16, - dataBits = 32, - lenBits = 8, - userBits = 1), - memParams = AXIParams( - coherent = true, - addrBits = 32, - dataBits = 64, - lenBits = 8, - userBits = 1), - vcrParams = VCRParams(), - vmeParams = VMEParams()) -}) +class PynqConfig + extends Config((site, here, up) => { + case ShellKey => + ShellParams( + hostParams = AXIParams(coherent = false, + addrBits = 16, + dataBits = 32, + lenBits = 8, + userBits = 1), + memParams = AXIParams(coherent = true, + addrBits = 32, + dataBits = 64, + lenBits = 8, + userBits = 1), + vcrParams = VCRParams(), + vmeParams = VMEParams() + ) + }) /** F1Config. Shell configuration for F1 */ -class F1Config extends Config((site, here, up) => { - case ShellKey => ShellParams( - hostParams = AXIParams( - coherent = false, - addrBits = 16, - dataBits = 32, - lenBits = 8, - userBits = 1), - memParams = AXIParams( - coherent = false, - addrBits = 64, - dataBits = 64, - lenBits = 8, - userBits = 1), - vcrParams = VCRParams(), - vmeParams = VMEParams()) -}) +class F1Config + extends Config((site, here, up) => { + case ShellKey => + ShellParams( + hostParams = AXIParams(coherent = false, + addrBits = 16, + dataBits = 32, + lenBits = 8, + userBits = 1), + memParams = AXIParams(coherent = false, + addrBits = 64, + dataBits = 64, + lenBits = 8, + userBits = 1), + vcrParams = VCRParams(), + vmeParams = VMEParams() + ) + }) /** De10Config. Shell configuration for De10 */ -class De10Config extends Config((site, here, up) => { - case ShellKey => ShellParams( - hostParams = AXIParams( - addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4), - memParams = AXIParams( - addrBits = 32, dataBits = 64, userBits = 5, - lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4 - coherent = true), - vcrParams = VCRParams(), - vmeParams = VMEParams()) -}) +class De10Config + extends Config((site, here, up) => { + case ShellKey => + ShellParams( + hostParams = + AXIParams(addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4), + memParams = AXIParams( + addrBits = 32, + dataBits = 64, + userBits = 5, + lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4 + coherent = true), + vcrParams = VCRParams(), + vmeParams = VMEParams() + ) + }) diff --git a/vta/hardware/chisel/src/main/scala/shell/IntelShell.scala b/vta/hardware/chisel/src/main/scala/shell/IntelShell.scala index 817b786..6eb2222 100644 --- a/vta/hardware/chisel/src/main/scala/shell/IntelShell.scala +++ b/vta/hardware/chisel/src/main/scala/shell/IntelShell.scala @@ -30,7 +30,7 @@ import vta.core._ * system that can be used for simulation or real hardware. */ class IntelShell(implicit p: Parameters) extends Module { - val io = IO(new Bundle{ + val io = IO(new Bundle { val host = new AXIClient(p(ShellKey).hostParams) val mem = new AXIMaster(p(ShellKey).memParams) }) diff --git a/vta/hardware/chisel/src/main/scala/shell/SimShell.scala b/vta/hardware/chisel/src/main/scala/shell/SimShell.scala index f3d74ef..30b84d6 100644 --- a/vta/hardware/chisel/src/main/scala/shell/SimShell.scala +++ b/vta/hardware/chisel/src/main/scala/shell/SimShell.scala @@ -76,6 +76,7 @@ class VTASim(implicit p: Parameters) extends MultiIOModule { sim.io.clock := clock sim_wait := sim.io.dpi_wait } + /** SimShell. * * The simulation shell instantiate the sim, host and memory DPI modules that diff --git a/vta/hardware/chisel/src/main/scala/shell/VCR.scala b/vta/hardware/chisel/src/main/scala/shell/VCR.scala index efff6a4..517f581 100644 --- a/vta/hardware/chisel/src/main/scala/shell/VCR.scala +++ b/vta/hardware/chisel/src/main/scala/shell/VCR.scala @@ -29,8 +29,7 @@ import vta.interface.axi._ * * These parameters are used on VCR interfaces and modules. */ -case class VCRParams() -{ +case class VCRParams() { val nCtrl = 1 val nECnt = 1 val nVals = 1 @@ -40,7 +39,7 @@ case class VCRParams() /** VCRBase. Parametrize base class. */ abstract class VCRBase(implicit p: Parameters) - extends GenericParameterizedBundle(p) + extends GenericParameterizedBundle(p) /** VCRMaster. * @@ -80,7 +79,7 @@ class VCRClient(implicit p: Parameters) extends VCRBase { * registers that could be used as event counters by the Core unit. */ class VCR(implicit p: Parameters) extends Module { - val io = IO(new Bundle{ + val io = IO(new Bundle { val host = new AXILiteClient(p(ShellKey).hostParams) val vcr = new VCRMaster }) @@ -101,50 +100,49 @@ class VCR(implicit p: Parameters) extends Module { val rdata = RegInit(0.U(vp.regBits.W)) // registers - val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2*vp.nPtrs + val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2 * vp.nPtrs val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + nPtrs val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W))) val addr = Seq.tabulate(nTotal)(_ * 4) - val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } + val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } val eo = vp.nCtrl val vo = eo + vp.nECnt val po = vo + vp.nVals - switch (wstate) { - is (sWriteAddress) { - when (io.host.aw.valid) { + switch(wstate) { + is(sWriteAddress) { + when(io.host.aw.valid) { wstate := sWriteData } } - is (sWriteData) { - when (io.host.w.valid) { + is(sWriteData) { + when(io.host.w.valid) { wstate := sWriteResponse } } - is (sWriteResponse) { - when (io.host.b.ready) { + is(sWriteResponse) { + when(io.host.b.ready) { wstate := sWriteAddress } } } - when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr } + when(io.host.aw.fire()) { waddr := io.host.aw.bits.addr } io.host.aw.ready := wstate === sWriteAddress io.host.w.ready := wstate === sWriteData io.host.b.valid := wstate === sWriteResponse io.host.b.bits.resp := 0.U - - switch (rstate) { - is (sReadAddress) { - when (io.host.ar.valid) { + switch(rstate) { + is(sReadAddress) { + when(io.host.ar.valid) { rstate := sReadData } } - is (sReadData) { - when (io.host.r.ready) { + is(sReadData) { + when(io.host.r.ready) { rstate := sReadAddress } } @@ -155,27 +153,27 @@ class VCR(implicit p: Parameters) extends Module { io.host.r.bits.data := rdata io.host.r.bits.resp := 0.U - when (io.vcr.finish) { + when(io.vcr.finish) { reg(0) := "b_10".U - } .elsewhen (io.host.w.fire() && addr(0).U === waddr) { + }.elsewhen(io.host.w.fire() && addr(0).U === waddr) { reg(0) := wdata } for (i <- 0 until vp.nECnt) { - when (io.vcr.ecnt(i).valid) { + when(io.vcr.ecnt(i).valid) { reg(eo + i) := io.vcr.ecnt(i).bits - } .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) { + }.elsewhen(io.host.w.fire() && addr(eo + i).U === waddr) { reg(eo + i) := wdata } } for (i <- 0 until (vp.nVals + nPtrs)) { - when (io.host.w.fire() && addr(vo + i).U === waddr) { + when(io.host.w.fire() && addr(vo + i).U === waddr) { reg(vo + i) := wdata } } - when (io.host.ar.fire()) { + when(io.host.ar.fire()) { rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map) } @@ -185,13 +183,13 @@ class VCR(implicit p: Parameters) extends Module { io.vcr.vals(i) := reg(vo + i) } - if (mp.addrBits == 32) { // 32-bit pointers + if (mp.addrBits == 32) { // 32-bit pointers for (i <- 0 until nPtrs) { io.vcr.ptrs(i) := reg(po + i) } - } else { // 64-bits pointers - for (i <- 0 until (nPtrs/2)) { - io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i)) + } else { // 64-bits pointers + for (i <- 0 until (nPtrs / 2)) { + io.vcr.ptrs(i) := Cat(reg(po + 2 * i + 1), reg(po + 2 * i)) } } } diff --git a/vta/hardware/chisel/src/main/scala/shell/VME.scala b/vta/hardware/chisel/src/main/scala/shell/VME.scala index db46295..949929a 100644 --- a/vta/hardware/chisel/src/main/scala/shell/VME.scala +++ b/vta/hardware/chisel/src/main/scala/shell/VME.scala @@ -32,13 +32,16 @@ import vta.interface.axi._ case class VMEParams() { val nReadClients: Int = 5 val nWriteClients: Int = 1 - require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n") - require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n") + require(nReadClients > 0, + s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n") + require( + nWriteClients == 1, + s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n") } /** VMEBase. Parametrize base class. */ abstract class VMEBase(implicit p: Parameters) - extends GenericParameterizedBundle(p) + extends GenericParameterizedBundle(p) /** VMECmd. * @@ -149,19 +152,19 @@ class VME(implicit p: Parameters) extends Module { val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3) val rstate = RegInit(sReadIdle) - switch (rstate) { - is (sReadIdle) { - when (rd_arb.io.out.valid) { + switch(rstate) { + is(sReadIdle) { + when(rd_arb.io.out.valid) { rstate := sReadAddr } } - is (sReadAddr) { - when (io.mem.ar.ready) { + is(sReadAddr) { + when(io.mem.ar.ready) { rstate := sReadData } } - is (sReadData) { - when (io.mem.r.fire() && io.mem.r.bits.last) { + is(sReadData) { + when(io.mem.r.fire() && io.mem.r.bits.last) { rstate := sReadIdle } } @@ -173,30 +176,34 @@ class VME(implicit p: Parameters) extends Module { val lenBits = p(ShellKey).memParams.lenBits val wr_cnt = RegInit(0.U(lenBits.W)) - when (wstate === sWriteIdle) { + when(wstate === sWriteIdle) { wr_cnt := 0.U - } .elsewhen (io.mem.w.fire()) { + }.elsewhen(io.mem.w.fire()) { wr_cnt := wr_cnt + 1.U } - switch (wstate) { - is (sWriteIdle) { - when (io.vme.wr(0).cmd.valid) { + switch(wstate) { + is(sWriteIdle) { + when(io.vme.wr(0).cmd.valid) { wstate := sWriteAddr } } - is (sWriteAddr) { - when (io.mem.aw.ready) { + is(sWriteAddr) { + when(io.mem.aw.ready) { wstate := sWriteData } } - is (sWriteData) { - when (io.vme.wr(0).data.valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) { + is(sWriteData) { + when( + io.vme + .wr(0) + .data + .valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) { wstate := sWriteResp } } - is (sWriteResp) { - when (io.mem.b.valid) { + is(sWriteResp) { + when(io.mem.b.valid) { wstate := sWriteIdle } } @@ -209,12 +216,12 @@ class VME(implicit p: Parameters) extends Module { val rd_addr = RegInit(0.U(addrBits.W)) val wr_addr = RegInit(0.U(addrBits.W)) - when (rd_arb.io.out.fire()) { + when(rd_arb.io.out.fire()) { rd_len := rd_arb.io.out.bits.len rd_addr := rd_arb.io.out.bits.addr } - when (io.vme.wr(0).cmd.fire()) { + when(io.vme.wr(0).cmd.fire()) { wr_len := io.vme.wr(0).cmd.bits.len wr_addr := io.vme.wr(0).cmd.bits.addr } @@ -230,7 +237,7 @@ class VME(implicit p: Parameters) extends Module { io.vme.wr(0).cmd.ready := wstate === sWriteIdle io.vme.wr(0).ack := io.mem.b.fire() - io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready + io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready // mem io.mem.aw.valid := wstate === sWriteAddr diff --git a/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala b/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala index c809311..782aeae 100644 --- a/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala +++ b/vta/hardware/chisel/src/main/scala/shell/VTAShell.scala @@ -26,10 +26,10 @@ import vta.core._ /** Shell parameters. */ case class ShellParams( - hostParams: AXIParams, - memParams: AXIParams, - vcrParams: VCRParams, - vmeParams: VMEParams + hostParams: AXIParams, + memParams: AXIParams, + vcrParams: VCRParams, + vmeParams: VMEParams ) case object ShellKey extends Field[ShellParams] @@ -40,7 +40,7 @@ case object ShellKey extends Field[ShellParams] * system that can be used for simulation or real hardware. */ class VTAShell(implicit p: Parameters) extends Module { - val io = IO(new Bundle{ + val io = IO(new Bundle { val host = new AXILiteClient(p(ShellKey).hostParams) val mem = new AXIMaster(p(ShellKey).memParams) }) diff --git a/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala b/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala index db72137..ec7bffb 100644 --- a/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala +++ b/vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala @@ -20,7 +20,7 @@ package vta.shell import chisel3._ -import chisel3.experimental.{RawModule, withClockAndReset} +import chisel3.experimental.{withClockAndReset, RawModule} import vta.util.config._ import vta.interface.axi._ @@ -39,7 +39,9 @@ class XilinxShell(implicit p: Parameters) extends RawModule { val m_axi_gmem = IO(new XilinxAXIMaster(mp)) val s_axi_control = IO(new XilinxAXILiteClient(hp)) - val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) } + val shell = withClockAndReset(clock = ap_clk, reset = ~ap_rst_n) { + Module(new VTAShell) + } // memory m_axi_gmem.AWVALID := shell.io.mem.aw.valid diff --git a/vta/hardware/chisel/src/main/scala/util/Config.scala b/vta/hardware/chisel/src/main/scala/util/Config.scala index 6699507..41104c4 100644 --- a/vta/hardware/chisel/src/main/scala/util/Config.scala +++ b/vta/hardware/chisel/src/main/scala/util/Config.scala @@ -21,8 +21,7 @@ package vta.util.config // taken from https://github.com/vta.roject/rocket-chip -abstract class Field[T] private (val default: Option[T]) -{ +abstract class Field[T] private (val default: Option[T]) { def this() = this(None) def this(default: T) = this(Some(default)) } @@ -31,42 +30,50 @@ abstract class View { final def apply[T](pname: Field[T]): T = apply(pname, this) final def apply[T](pname: Field[T], site: View): T = { val out = find(pname, site) - require (out.isDefined, s"Key ${pname} is not defined in Parameters") + require(out.isDefined, s"Key ${pname} is not defined in Parameters") out.get } final def lift[T](pname: Field[T]): Option[T] = lift(pname, this) - final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T]) + final def lift[T](pname: Field[T], site: View): Option[T] = + find(pname, site).map(_.asInstanceOf[T]) protected[config] def find[T](pname: Field[T], site: View): Option[T] } abstract class Parameters extends View { - final def ++ (x: Parameters): Parameters = + final def ++(x: Parameters): Parameters = new ChainParameters(this, x) - final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = + final def alter( + f: (View, View, View) => PartialFunction[Any, Any]): Parameters = Parameters(f) ++ this - final def alterPartial(f: PartialFunction[Any,Any]): Parameters = - Parameters((_,_,_) => f) ++ this + final def alterPartial(f: PartialFunction[Any, Any]): Parameters = + Parameters((_, _, _) => f) ++ this - final def alterMap(m: Map[Any,Any]): Parameters = + final def alterMap(m: Map[Any, Any]): Parameters = new MapParameters(m) ++ this - protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T] - protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname) + protected[config] def chain[T](site: View, + tail: View, + pname: Field[T]): Option[T] + protected[config] def find[T](pname: Field[T], site: View) = + chain(site, new TerminalView, pname) } object Parameters { def empty: Parameters = new EmptyParameters - def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f) + def apply(f: (View, View, View) => PartialFunction[Any, Any]): Parameters = + new PartialParameters(f) } class Config(p: Parameters) extends Parameters { - def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f)) + def this(f: (View, View, View) => PartialFunction[Any, Any]) = + this(Parameters(f)) - protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname) + protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = + p.chain(site, tail, pname) override def toString = this.getClass.getSimpleName def toInstance = this } @@ -82,17 +89,21 @@ private class ChainView(head: Parameters, tail: View) extends View { } private class ChainParameters(x: Parameters, y: Parameters) extends Parameters { - def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname) + def chain[T](site: View, tail: View, pname: Field[T]) = + x.chain(site, new ChainView(y, tail), pname) } private class EmptyParameters extends Parameters { def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site) } -private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters { +private class PartialParameters( + f: (View, View, View) => PartialFunction[Any, Any]) + extends Parameters { protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = { val g = f(site, this, tail) - if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site) + if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) + else tail.find(pname, site) } } diff --git a/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala b/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala index db19635..db8f5d2 100644 --- a/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala +++ b/vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala @@ -23,18 +23,22 @@ package vta.util.genericbundle import chisel3._ -abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle -{ +abstract class GenericParameterizedBundle[+T <: Object](val params: T) + extends Bundle { override def cloneType = { try { - this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type] + this.getClass.getConstructors.head + .newInstance(params) + .asInstanceOf[this.type] } catch { case e: java.lang.IllegalArgumentException => - throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " + - this.getClass + ", probably because " + this.getClass + - "() takes more than one argument. Consider overriding " + - "cloneType() on " + this.getClass, e) + throw new Exception( + "Unable to use GenericParameterizedBundle.cloneType on " + + this.getClass + ", probably because " + this.getClass + + "() takes more than one argument. Consider overriding " + + "cloneType() on " + this.getClass, + e + ) } } } - diff --git a/vta/hardware/chisel/src/main/scala/vta/Configs.scala b/vta/hardware/chisel/src/main/scala/vta/Configs.scala index 78c7316..f137ab6 100644 --- a/vta/hardware/chisel/src/main/scala/vta/Configs.scala +++ b/vta/hardware/chisel/src/main/scala/vta/Configs.scala @@ -31,7 +31,6 @@ import vta.test._ * These configurations are built in a mix/match form based on core * and shell configurations. */ - class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig) class DefaultF1Config extends Config(new CoreConfig ++ new F1Config) class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config) -- 2.7.4