[VTA][Chisel] add scalafmt and format existing scala codebase (#3880)
authorLuis Vega <vegaluisjose@users.noreply.github.com>
Wed, 4 Sep 2019 05:19:01 +0000 (22:19 -0700)
committerThierry Moreau <moreau@uw.edu>
Wed, 4 Sep 2019 05:19:01 +0000 (22:19 -0700)
* [VTA][Chisel] add scalafmt and format existing scala codebase

* change column width to 100

* add scalafmt conf file as a valid file type

* add asf header to scalafmt conf file and rerun formatter

39 files changed:
tests/lint/check_file_type.py
vta/apps/tsim_example/hardware/chisel/.scalafmt.conf [new file with mode: 0644]
vta/apps/tsim_example/hardware/chisel/Makefile
vta/apps/tsim_example/hardware/chisel/project/plugins.sbt
vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Accel.scala
vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/Compute.scala
vta/apps/tsim_example/hardware/chisel/src/main/scala/accel/RegFile.scala
vta/hardware/chisel/.scalafmt.conf [new file with mode: 0644]
vta/hardware/chisel/Makefile
vta/hardware/chisel/project/plugins.sbt
vta/hardware/chisel/src/main/scala/core/Compute.scala
vta/hardware/chisel/src/main/scala/core/Configs.scala
vta/hardware/chisel/src/main/scala/core/Core.scala
vta/hardware/chisel/src/main/scala/core/Decode.scala
vta/hardware/chisel/src/main/scala/core/EventCounters.scala
vta/hardware/chisel/src/main/scala/core/Fetch.scala
vta/hardware/chisel/src/main/scala/core/ISA.scala
vta/hardware/chisel/src/main/scala/core/Load.scala
vta/hardware/chisel/src/main/scala/core/LoadUop.scala
vta/hardware/chisel/src/main/scala/core/Semaphore.scala
vta/hardware/chisel/src/main/scala/core/Store.scala
vta/hardware/chisel/src/main/scala/core/TensorAlu.scala
vta/hardware/chisel/src/main/scala/core/TensorGemm.scala
vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
vta/hardware/chisel/src/main/scala/core/TensorStore.scala
vta/hardware/chisel/src/main/scala/core/TensorUtil.scala
vta/hardware/chisel/src/main/scala/dpi/VTAHostDPI.scala
vta/hardware/chisel/src/main/scala/dpi/VTAMemDPI.scala
vta/hardware/chisel/src/main/scala/interface/axi/AXI.scala
vta/hardware/chisel/src/main/scala/shell/Configs.scala
vta/hardware/chisel/src/main/scala/shell/IntelShell.scala
vta/hardware/chisel/src/main/scala/shell/SimShell.scala
vta/hardware/chisel/src/main/scala/shell/VCR.scala
vta/hardware/chisel/src/main/scala/shell/VME.scala
vta/hardware/chisel/src/main/scala/shell/VTAShell.scala
vta/hardware/chisel/src/main/scala/shell/XilinxShell.scala
vta/hardware/chisel/src/main/scala/util/Config.scala
vta/hardware/chisel/src/main/scala/util/GenericParameterizedBundle.scala
vta/hardware/chisel/src/main/scala/vta/Configs.scala

index c6691bb..e5f2dc7 100644 (file)
@@ -87,6 +87,7 @@ ALLOW_FILE_NAME = {
     ".clang-format",
     ".gitmodules",
     "CODEOWNERS",
+    ".scalafmt.conf",
    }
 
 # List of specific files allowed in relpath to <proj_root>
diff --git a/vta/apps/tsim_example/hardware/chisel/.scalafmt.conf b/vta/apps/tsim_example/hardware/chisel/.scalafmt.conf
new file mode 100644 (file)
index 0000000..9172d5e
--- /dev/null
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+maxColumn = 100
+rewrite.rules = [SortModifiers, SortImports]
index 4f555ba..0f97945 100644 (file)
@@ -91,7 +91,10 @@ else
   lib_path = $(build_dir)/$(LIBNAME).so
 endif
 
-default: lib
+default: lint lib
+
+lint:
+       sbt scalafmt
 
 lib: $(lib_path)
 $(lib_path): $(verilator_build_dir)/V$(TOP).cpp
index d654a7f..b90c729 100644 (file)
@@ -41,7 +41,7 @@ case class AccelConfig() {
   val nVals = 2
   val nPtrs = 2
   val regBits = 32
-  val ptrBits = 2*regBits
+  val ptrBits = 2 * regBits
 }
 
 class Accel extends Module {
index f24cbdd..7ad965c 100644 (file)
@@ -54,27 +54,27 @@ class Compute(implicit config: AccelConfig) extends Module {
   val raddr = Reg(UInt(config.ptrBits.W))
   val waddr = Reg(UInt(config.ptrBits.W))
 
-  switch (state) {
-    is (sIdle) {
-      when (io.launch) {
+  switch(state) {
+    is(sIdle) {
+      when(io.launch) {
         state := sReadReq
       }
     }
-    is (sReadReq) {
+    is(sReadReq) {
       state := sReadData
     }
-    is (sReadData) {
-      when (io.mem.rd.valid) {
+    is(sReadData) {
+      when(io.mem.rd.valid) {
         state := sWriteReq
       }
     }
-    is (sWriteReq) {
+    is(sWriteReq) {
       state := sWriteData
     }
-    is (sWriteData) {
-      when (cnt === (length - 1.U)) {
+    is(sWriteData) {
+      when(cnt === (length - 1.U)) {
         state := sIdle
-      } .otherwise {
+      }.otherwise {
         state := sReadReq
       }
     }
@@ -83,9 +83,9 @@ class Compute(implicit config: AccelConfig) extends Module {
   val last = state === sWriteData && cnt === (length - 1.U)
 
   // cycle counter
-  when (state === sIdle) {
+  when(state === sIdle) {
     cycles := 0.U
-  } .otherwise {
+  }.otherwise {
     cycles := cycles + 1.U
   }
 
@@ -93,10 +93,10 @@ class Compute(implicit config: AccelConfig) extends Module {
   io.ecnt(0).bits := cycles
 
   // calculate next address
-  when (state === sIdle) {
+  when(state === sIdle) {
     raddr := io.ptrs(0)
     waddr := io.ptrs(1)
-  } .elsewhen (state === sWriteData) { // increment by 8-bytes
+  }.elsewhen(state === sWriteData) { // increment by 8-bytes
     raddr := raddr + 8.U
     waddr := waddr + 8.U
   }
@@ -108,7 +108,7 @@ class Compute(implicit config: AccelConfig) extends Module {
   io.mem.req.addr := Mux(state === sReadReq, raddr, waddr)
 
   // read
-  when (state === sReadData && io.mem.rd.valid) {
+  when(state === sReadData && io.mem.rd.valid) {
     reg := io.mem.rd.bits + const
   }
   io.mem.rd.ready := state === sReadData
@@ -118,9 +118,9 @@ class Compute(implicit config: AccelConfig) extends Module {
   io.mem.wr.bits := reg
 
   // count read/write
-  when (state === sIdle) {
+  when(state === sIdle) {
     cnt := 0.U
-  } .elsewhen (state === sWriteData) {
+  }.elsewhen(state === sWriteData) {
     cnt := cnt + 1.U
   }
 
index 92a9833..1982f18 100644 (file)
@@ -59,52 +59,54 @@ class RegFile(implicit config: AccelConfig) extends Module {
   val sIdle :: sRead :: Nil = Enum(2)
   val state = RegInit(sIdle)
 
-  switch (state) {
-    is (sIdle) {
-      when (io.host.req.valid && !io.host.req.opcode) {
+  switch(state) {
+    is(sIdle) {
+      when(io.host.req.valid && !io.host.req.opcode) {
         state := sRead
       }
     }
-    is (sRead) {
+    is(sRead) {
       state := sIdle
     }
   }
 
   io.host.req.deq := state === sIdle & io.host.req.valid
 
-  val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs)
-  val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
+  val nTotal = config.nCtrl + config.nECnt + config.nVals + (2 * config.nPtrs)
+  val reg =
+    Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
   val addr = Seq.tabulate(nTotal)(_ * 4)
-  val reg_map = (addr zip reg)  map { case (a, r) => a.U -> r }
+  val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
   val eo = config.nCtrl
   val vo = eo + config.nECnt
   val po = vo + config.nVals
 
-  when (io.finish) {
+  when(io.finish) {
     reg(0) := "b_10".U
-  } .elsewhen (state === sIdle && io.host.req.valid &&
-        io.host.req.opcode && addr(0).U === io.host.req.addr) {
+  }.elsewhen(state === sIdle && io.host.req.valid &&
+    io.host.req.opcode && addr(0).U === io.host.req.addr) {
     reg(0) := io.host.req.value
   }
 
   for (i <- 0 until config.nECnt) {
-    when (io.ecnt(i).valid) {
+    when(io.ecnt(i).valid) {
       reg(eo + i) := io.ecnt(i).bits
-    } .elsewhen (state === sIdle && io.host.req.valid &&
-          io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
+    }.elsewhen(state === sIdle && io.host.req.valid &&
+      io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
       reg(eo + i) := io.host.req.value
     }
   }
 
-  for (i <- 0 until (config.nVals + (2*config.nPtrs))) {
-    when (state === sIdle && io.host.req.valid &&
-          io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
+  for (i <- 0 until (config.nVals + (2 * config.nPtrs))) {
+    when(
+      state === sIdle && io.host.req.valid &&
+        io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
       reg(vo + i) := io.host.req.value
     }
   }
 
   val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
-  when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
+  when(state === sIdle && io.host.req.valid && !io.host.req.opcode) {
     rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
   }
 
@@ -118,6 +120,6 @@ class RegFile(implicit config: AccelConfig) extends Module {
   }
 
   for (i <- 0 until config.nPtrs) {
-    io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
+    io.ptrs(i) := Cat(reg(po + (2 * i) + 1), reg(po + (2 * i)))
   }
 }
diff --git a/vta/hardware/chisel/.scalafmt.conf b/vta/hardware/chisel/.scalafmt.conf
new file mode 100644 (file)
index 0000000..9172d5e
--- /dev/null
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+maxColumn = 100
+rewrite.rules = [SortModifiers, SortImports]
index 6cd2802..7c88915 100644 (file)
@@ -102,7 +102,10 @@ else
   lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
 endif
 
-default: lib
+default: lint lib
+
+lint:
+       sbt scalafmt
 
 lib: $(lib_path)
 
index 79ffb22..e14e694 100644 (file)
@@ -18,3 +18,4 @@
  */
 
 logLevel := Level.Warn
+addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1")
index 01fa9d6..7751bf7 100644 (file)
@@ -49,7 +49,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val sIdle :: sSync :: sExe :: Nil = Enum(3)
   val state = RegInit(sIdle)
 
-  val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
+  val s = Seq.tabulate(2)(_ =>
+    Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
 
   val loadUop = Module(new LoadUop)
   val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
@@ -62,18 +63,20 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val dec = Module(new ComputeDecode)
   dec.io.inst := inst_q.io.deq.bits
 
-  val inst_type = Cat(dec.io.isFinish,
-                      dec.io.isAlu,
-                      dec.io.isGemm,
-                      dec.io.isLoadAcc,
-                      dec.io.isLoadUop).asUInt
+  val inst_type =
+    Cat(dec.io.isFinish,
+        dec.io.isAlu,
+        dec.io.isGemm,
+        dec.io.isLoadAcc,
+        dec.io.isLoadUop).asUInt
 
   val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
   val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
   val start = snext & sprev
   val done =
-    MuxLookup(inst_type,
-               false.B, // default
+    MuxLookup(
+      inst_type,
+      false.B, // default
       Array(
         "h_01".U -> loadUop.io.done,
         "h_02".U -> tensorAcc.io.done,
@@ -84,21 +87,21 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
     )
 
   // control
-  switch (state) {
-    is (sIdle) {
-      when (start) {
-        when (dec.io.isSync) {
+  switch(state) {
+    is(sIdle) {
+      when(start) {
+        when(dec.io.isSync) {
           state := sSync
-        } .elsewhen (inst_type.orR) {
+        }.elsewhen(inst_type.orR) {
           state := sExe
         }
       }
     }
-    is (sSync) {
+    is(sSync) {
       state := sIdle
     }
-    is (sExe) {
-      when (done) {
+    is(sExe) {
+      when(done) {
         state := sIdle
       }
     }
@@ -109,22 +112,28 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
 
   // uop
-  loadUop.io.start :=  state === sIdle & start & dec.io.isLoadUop
+  loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
   loadUop.io.inst := inst_q.io.deq.bits
   loadUop.io.baddr := io.uop_baddr
   io.vme_rd(0) <> loadUop.io.vme_rd
-  loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
+  loadUop.io.uop.idx <> Mux(dec.io.isGemm,
+                            tensorGemm.io.uop.idx,
+                            tensorAlu.io.uop.idx)
 
   // acc
   tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
   tensorAcc.io.inst := inst_q.io.deq.bits
   tensorAcc.io.baddr := io.acc_baddr
-  tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
-  tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
+  tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm,
+                                    tensorGemm.io.acc.rd.idx,
+                                    tensorAlu.io.acc.rd.idx)
+  tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm,
+                                tensorGemm.io.acc.wr,
+                                tensorAlu.io.acc.wr)
   io.vme_rd(1) <> tensorAcc.io.vme_rd
 
   // gemm
-  tensorGemm.io.start :=  state === sIdle & start & dec.io.isGemm
+  tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
   tensorGemm.io.inst := inst_q.io.deq.bits
   tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
   tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
@@ -136,7 +145,7 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
 
   // alu
-  tensorAlu.io.start :=  state === sIdle & start & dec.io.isAlu
+  tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
   tensorAlu.io.inst := inst_q.io.deq.bits
   tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
   tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
@@ -146,7 +155,9 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
 
   // out
-  io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
+  io.out.rd.idx <> Mux(dec.io.isGemm,
+                       tensorGemm.io.out.rd.idx,
+                       tensorAlu.io.out.rd.idx)
   io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
 
   // semaphore
@@ -163,38 +174,45 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
   // debug
   if (debug) {
     // start
-    when (state === sIdle && start) {
-      when (dec.io.isSync) {
+    when(state === sIdle && start) {
+      when(dec.io.isSync) {
         printf("[Compute] start sync\n")
-      } .elsewhen (dec.io.isLoadUop) {
-        printf("[Compute] start load uop\n")
-      } .elsewhen (dec.io.isLoadAcc) {
-        printf("[Compute] start load acc\n")
-      } .elsewhen (dec.io.isGemm) {
-        printf("[Compute] start gemm\n")
-      } .elsewhen (dec.io.isAlu) {
-        printf("[Compute] start alu\n")
-      } .elsewhen (dec.io.isFinish) {
-        printf("[Compute] start finish\n")
-      }
+      }.elsewhen(dec.io.isLoadUop) {
+          printf("[Compute] start load uop\n")
+        }
+        .elsewhen(dec.io.isLoadAcc) {
+          printf("[Compute] start load acc\n")
+        }
+        .elsewhen(dec.io.isGemm) {
+          printf("[Compute] start gemm\n")
+        }
+        .elsewhen(dec.io.isAlu) {
+          printf("[Compute] start alu\n")
+        }
+        .elsewhen(dec.io.isFinish) {
+          printf("[Compute] start finish\n")
+        }
     }
     // done
-    when (state === sSync) {
+    when(state === sSync) {
       printf("[Compute] done sync\n")
     }
-    when (state === sExe) {
-      when (done) {
-        when (dec.io.isLoadUop) {
+    when(state === sExe) {
+      when(done) {
+        when(dec.io.isLoadUop) {
           printf("[Compute] done load uop\n")
-        } .elsewhen (dec.io.isLoadAcc) {
-          printf("[Compute] done load acc\n")
-        } .elsewhen (dec.io.isGemm) {
-          printf("[Compute] done gemm\n")
-        } .elsewhen (dec.io.isAlu) {
-          printf("[Compute] done alu\n")
-        } .elsewhen (dec.io.isFinish) {
-          printf("[Compute] done finish\n")
-        }
+        }.elsewhen(dec.io.isLoadAcc) {
+            printf("[Compute] done load acc\n")
+          }
+          .elsewhen(dec.io.isGemm) {
+            printf("[Compute] done gemm\n")
+          }
+          .elsewhen(dec.io.isAlu) {
+            printf("[Compute] done alu\n")
+          }
+          .elsewhen(dec.io.isFinish) {
+            printf("[Compute] done finish\n")
+          }
       }
     }
   }
index b4e764b..de7012b 100644 (file)
@@ -27,20 +27,23 @@ import vta.util.config._
   * be eventually filled out with class configurations that can be
   * mixed/matched with Shell configurations for different backends.
   */
-class CoreConfig extends Config((site, here, up) => {
-  case CoreKey => CoreParams(
-    batch = 1,
-    blockOut = 16,
-    blockIn = 16,
-    inpBits = 8,
-    wgtBits = 8,
-    uopBits = 32,
-    accBits = 32,
-    outBits = 8,
-    uopMemDepth = 2048,
-    inpMemDepth = 2048,
-    wgtMemDepth = 1024,
-    accMemDepth = 2048,
-    outMemDepth = 2048,
-    instQueueEntries = 512)
-})
+class CoreConfig
+    extends Config((site, here, up) => {
+      case CoreKey =>
+        CoreParams(
+          batch = 1,
+          blockOut = 16,
+          blockIn = 16,
+          inpBits = 8,
+          wgtBits = 8,
+          uopBits = 32,
+          accBits = 32,
+          outBits = 8,
+          uopMemDepth = 2048,
+          inpMemDepth = 2048,
+          wgtMemDepth = 1024,
+          accMemDepth = 2048,
+          outMemDepth = 2048,
+          instQueueEntries = 512
+        )
+    })
index e63a112..a7228ee 100644 (file)
@@ -24,24 +24,24 @@ import vta.util.config._
 import vta.shell._
 
 /** Core parameters */
-case class CoreParams (
-  batch: Int = 1,
-  blockOut: Int = 16,
-  blockIn: Int = 16,
-  inpBits: Int = 8,
-  wgtBits: Int = 8,
-  uopBits: Int = 32,
-  accBits: Int = 32,
-  outBits: Int = 8,
-  uopMemDepth: Int = 512,
-  inpMemDepth: Int = 512,
-  wgtMemDepth: Int = 512,
-  accMemDepth: Int = 512,
-  outMemDepth: Int = 512,
-  instQueueEntries: Int = 32
-)
-{
-  require (uopBits % 8 == 0, s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
+case class CoreParams(
+    batch: Int = 1,
+    blockOut: Int = 16,
+    blockIn: Int = 16,
+    inpBits: Int = 8,
+    wgtBits: Int = 8,
+    uopBits: Int = 32,
+    accBits: Int = 32,
+    outBits: Int = 8,
+    uopMemDepth: Int = 512,
+    inpMemDepth: Int = 512,
+    wgtMemDepth: Int = 512,
+    accMemDepth: Int = 512,
+    outMemDepth: Int = 512,
+    instQueueEntries: Int = 32
+) {
+  require(uopBits % 8 == 0,
+          s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
 }
 
 case object CoreKey extends Field[CoreParams]
index f5bf340..a49ddce 100644 (file)
@@ -133,8 +133,9 @@ class FetchDecode extends Module {
     val isStore = Output(Bool())
   })
   val csignals =
-    ListLookup(io.inst,
-        List(N, OP_X),
+    ListLookup(
+      io.inst,
+      List(N, OP_X),
       Array(
         LUOP -> List(Y, OP_G),
         LWGT -> List(Y, OP_L),
index 5a5b095..8990aef 100644 (file)
@@ -38,17 +38,18 @@ import vta.shell._
   * If one would like to add an event counter, then the value of nECnt must be
   * changed in VCRParams together with the corresponding counting logic here.
   */
-class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class EventCounters(debug: Boolean = false)(implicit p: Parameters)
+    extends Module {
   val vp = p(ShellKey).vcrParams
-  val io = IO(new Bundle{
+  val io = IO(new Bundle {
     val launch = Input(Bool())
     val finish = Input(Bool())
     val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
   })
   val cycle_cnt = RegInit(0.U(vp.regBits.W))
-  when (io.launch && !io.finish) {
+  when(io.launch && !io.finish) {
     cycle_cnt := cycle_cnt + 1.U
-  } .otherwise {
+  }.otherwise {
     cycle_cnt := 0.U
   }
   io.ecnt(0).valid := io.finish
index c7a6d50..9baf1cc 100644 (file)
@@ -67,69 +67,70 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val xrem = Reg(chiselTypeOf(io.ins_count))
   val xsize = (io.ins_count << 1.U) - 1.U
   val xmax = (1 << mp.lenBits).U
-  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+  val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
 
   val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
   val state = RegInit(sIdle)
 
   // control
-  switch (state) {
-    is (sIdle) {
-      when (pulse) {
+  switch(state) {
+    is(sIdle) {
+      when(pulse) {
         state := sReadCmd
-        when (xsize < xmax) {
+        when(xsize < xmax) {
           rlen := xsize
           ilen := xsize >> 1.U
           xrem := 0.U
-        } .otherwise {
+        }.otherwise {
           rlen := xmax - 1.U
           ilen := (xmax >> 1.U) - 1.U
           xrem := xsize - xmax
         }
       }
     }
-    is (sReadCmd) {
-      when (io.vme_rd.cmd.ready) {
+    is(sReadCmd) {
+      when(io.vme_rd.cmd.ready) {
         state := sReadLSB
       }
     }
-    is (sReadLSB) {
-      when (io.vme_rd.data.valid) {
+    is(sReadLSB) {
+      when(io.vme_rd.data.valid) {
         state := sReadMSB
       }
     }
-    is (sReadMSB) {
-      when (io.vme_rd.data.valid) {
-        when (inst_q.io.count === ilen) {
+    is(sReadMSB) {
+      when(io.vme_rd.data.valid) {
+        when(inst_q.io.count === ilen) {
           state := sDrain
-        } .otherwise {
+        }.otherwise {
           state := sReadLSB
         }
       }
     }
-    is (sDrain) {
-      when (inst_q.io.count === 0.U) {
-        when (xrem === 0.U) {
+    is(sDrain) {
+      when(inst_q.io.count === 0.U) {
+        when(xrem === 0.U) {
           state := sIdle
-        } .elsewhen (xrem < xmax) {
-          state := sReadCmd
-          rlen := xrem
-          ilen := xrem >> 1.U
-          xrem := 0.U
-        } .otherwise {
-          state := sReadCmd
-          rlen := xmax - 1.U
-          ilen := (xmax >> 1.U) - 1.U
-          xrem := xrem - xmax
-        }
+        }.elsewhen(xrem < xmax) {
+            state := sReadCmd
+            rlen := xrem
+            ilen := xrem >> 1.U
+            xrem := 0.U
+          }
+          .otherwise {
+            state := sReadCmd
+            rlen := xmax - 1.U
+            ilen := (xmax >> 1.U) - 1.U
+            xrem := xrem - xmax
+          }
       }
     }
   }
 
   // read instructions from dram
-  when (state === sIdle) {
+  when(state === sIdle) {
     raddr := io.ins_baddr
-  } .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
+  }.elsewhen(state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
     raddr := raddr + xmax_bytes
   }
 
@@ -143,7 +144,7 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val msb = io.vme_rd.data.bits
   val inst = Cat(msb, lsb)
 
-  when (state === sReadLSB) { lsb := io.vme_rd.data.bits }
+  when(state === sReadLSB) { lsb := io.vme_rd.data.bits }
 
   inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
   inst_q.io.enq.bits := inst
@@ -164,32 +165,30 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
   val deq_ready =
     MuxLookup(deq_sel,
-               false.B, // default
-      Array(
-        "h_01".U -> io.inst.ld.ready,
-        "h_02".U -> io.inst.st.ready,
-        "h_04".U -> io.inst.co.ready
-      )
-    )
+              false.B, // default
+              Array(
+                "h_01".U -> io.inst.ld.ready,
+                "h_02".U -> io.inst.st.ready,
+                "h_04".U -> io.inst.co.ready
+              ))
 
   // dequeue instruction
   inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
 
-
   // debug
   if (debug) {
-    when (state === sIdle && pulse) {
+    when(state === sIdle && pulse) {
       printf("[Fetch] Launch\n")
     }
     // instruction
-    when (inst_q.io.deq.fire()) {
-      when (dec.io.isLoad) {
+    when(inst_q.io.deq.fire()) {
+      when(dec.io.isLoad) {
         printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
       }
-      when (dec.io.isCompute) {
+      when(dec.io.isCompute) {
         printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
       }
-      when (dec.io.isStore) {
+      when(dec.io.isStore) {
         printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
       }
     }
index c3bf609..f08b23b 100644 (file)
@@ -26,47 +26,46 @@ import chisel3.util._
   *
   * These constants are used for decoding (parsing) fields on instructions.
   */
-trait ISAConstants
-{
-   val INST_BITS = 128
+trait ISAConstants {
+  val INST_BITS = 128
 
-   val OP_BITS = 3
+  val OP_BITS = 3
 
-   val M_DEP_BITS = 4
-   val M_ID_BITS = 2
-   val M_SRAM_OFFSET_BITS = 16
-   val M_DRAM_OFFSET_BITS = 32
-   val M_SIZE_BITS = 16
-   val M_STRIDE_BITS = 16
-   val M_PAD_BITS = 4
+  val M_DEP_BITS = 4
+  val M_ID_BITS = 2
+  val M_SRAM_OFFSET_BITS = 16
+  val M_DRAM_OFFSET_BITS = 32
+  val M_SIZE_BITS = 16
+  val M_STRIDE_BITS = 16
+  val M_PAD_BITS = 4
 
-   val C_UOP_BGN_BITS = 13
-   val C_UOP_END_BITS = 14
-   val C_ITER_BITS = 14
-   val C_AIDX_BITS = 11
-   val C_IIDX_BITS = 11
-   val C_WIDX_BITS = 10
-   val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
-   val C_ALU_OP_BITS = 3
-   val C_ALU_IMM_BITS = 16
+  val C_UOP_BGN_BITS = 13
+  val C_UOP_END_BITS = 14
+  val C_ITER_BITS = 14
+  val C_AIDX_BITS = 11
+  val C_IIDX_BITS = 11
+  val C_WIDX_BITS = 10
+  val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
+  val C_ALU_OP_BITS = 3
+  val C_ALU_IMM_BITS = 16
 
-   val Y = true.B
-   val N = false.B
+  val Y = true.B
+  val N = false.B
 
-   val OP_L = 0.asUInt(OP_BITS.W)
-   val OP_S = 1.asUInt(OP_BITS.W)
-   val OP_G = 2.asUInt(OP_BITS.W)
-   val OP_F = 3.asUInt(OP_BITS.W)
-   val OP_A = 4.asUInt(OP_BITS.W)
-   val OP_X = 5.asUInt(OP_BITS.W)
+  val OP_L = 0.asUInt(OP_BITS.W)
+  val OP_S = 1.asUInt(OP_BITS.W)
+  val OP_G = 2.asUInt(OP_BITS.W)
+  val OP_F = 3.asUInt(OP_BITS.W)
+  val OP_A = 4.asUInt(OP_BITS.W)
+  val OP_X = 5.asUInt(OP_BITS.W)
 
-   val ALU_OP_NUM = 5
-   val ALU_OP = Enum(ALU_OP_NUM)
+  val ALU_OP_NUM = 5
+  val ALU_OP = Enum(ALU_OP_NUM)
 
-   val M_ID_U = 0.asUInt(M_ID_BITS.W)
-   val M_ID_W = 1.asUInt(M_ID_BITS.W)
-   val M_ID_I = 2.asUInt(M_ID_BITS.W)
-   val M_ID_A = 3.asUInt(M_ID_BITS.W)
+  val M_ID_U = 0.asUInt(M_ID_BITS.W)
+  val M_ID_W = 1.asUInt(M_ID_BITS.W)
+  val M_ID_I = 2.asUInt(M_ID_BITS.W)
+  val M_ID_A = 3.asUInt(M_ID_BITS.W)
 }
 
 /** ISA.
@@ -79,15 +78,37 @@ trait ISAConstants
   * TODO: Add VXOR to clear accumulator
   */
 object ISA {
-  def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
-  def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
-  def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
-  def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
-  def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
-  def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
-  def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
+  def LUOP =
+    BitPat(
+      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
+  def LWGT =
+    BitPat(
+      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
+  def LINP =
+    BitPat(
+      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
+  def LACC =
+    BitPat(
+      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
+  def SOUT =
+    BitPat(
+      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
+  def GEMM =
+    BitPat(
+      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
+  def VMIN =
+    BitPat(
+      "b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def VMAX =
+    BitPat(
+      "b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def VADD =
+    BitPat(
+      "b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def VSHX =
+    BitPat(
+      "b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
+  def FNSH =
+    BitPat(
+      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
 }
index bbc6600..7c79498 100644 (file)
@@ -54,27 +54,28 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
 
   val tensorType = Seq("inp", "wgt")
   val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
-  val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
+  val tensorLoad =
+    Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
 
   val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
   val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
 
   // control
-  switch (state) {
-    is (sIdle) {
-      when (start) {
-        when (dec.io.isSync) {
+  switch(state) {
+    is(sIdle) {
+      when(start) {
+        when(dec.io.isSync) {
           state := sSync
-        } .elsewhen (dec.io.isInput || dec.io.isWeight) {
+        }.elsewhen(dec.io.isInput || dec.io.isWeight) {
           state := sExe
         }
       }
     }
-    is (sSync) {
+    is(sSync) {
       state := sIdle
     }
-    is (sExe) {
-      when (done) {
+    is(sExe) {
+      when(done) {
         state := sIdle
       }
     }
@@ -105,24 +106,25 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
   // debug
   if (debug) {
     // start
-    when (state === sIdle && start) {
-      when (dec.io.isSync) {
+    when(state === sIdle && start) {
+      when(dec.io.isSync) {
         printf("[Load] start sync\n")
-      } .elsewhen (dec.io.isInput) {
-        printf("[Load] start input\n")
-      } .elsewhen (dec.io.isWeight) {
-        printf("[Load] start weight\n")
-      }
+      }.elsewhen(dec.io.isInput) {
+          printf("[Load] start input\n")
+        }
+        .elsewhen(dec.io.isWeight) {
+          printf("[Load] start weight\n")
+        }
     }
     // done
-    when (state === sSync) {
+    when(state === sSync) {
       printf("[Load] done sync\n")
     }
-    when (state === sExe) {
-      when (done) {
-        when (dec.io.isInput) {
+    when(state === sExe) {
+      when(done) {
+        when(dec.io.isInput) {
           printf("[Load] done input\n")
-        } .elsewhen (dec.io.isWeight) {
+        }.elsewhen(dec.io.isWeight) {
           printf("[Load] done weight\n")
         }
       }
index bbf8cf1..fcde836 100644 (file)
@@ -77,9 +77,9 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
   val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
   val xrem = Reg(chiselTypeOf(dec.xsize))
-  val xsize =  (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
+  val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
   val xmax = (1 << mp.lenBits).U
-  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+  val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
 
   val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
   val sizeIsEven = (dec.xsize % 2.U) === 0.U
@@ -88,38 +88,39 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val state = RegInit(sIdle)
 
   // control
-  switch (state) {
-    is (sIdle) {
-      when (io.start) {
+  switch(state) {
+    is(sIdle) {
+      when(io.start) {
         state := sReadCmd
-        when (xsize < xmax) {
+        when(xsize < xmax) {
           xlen := xsize
           xrem := 0.U
-        } .otherwise {
+        }.otherwise {
           xlen := xmax - 1.U
           xrem := xsize - xmax
         }
       }
     }
-    is (sReadCmd) {
-      when (io.vme_rd.cmd.ready) {
+    is(sReadCmd) {
+      when(io.vme_rd.cmd.ready) {
         state := sReadData
       }
     }
-    is (sReadData) {
-      when (io.vme_rd.data.valid) {
+    is(sReadData) {
+      when(io.vme_rd.data.valid) {
         when(xcnt === xlen) {
-          when (xrem === 0.U) {
+          when(xrem === 0.U) {
             state := sIdle
-          } .elsewhen (xrem < xmax) {
-            state := sReadCmd
-            xlen := xrem
-            xrem := 0.U
-          } .otherwise {
-            state := sReadCmd
-            xlen := xmax - 1.U
-            xrem := xrem - xmax
-          }
+          }.elsewhen(xrem < xmax) {
+              state := sReadCmd
+              xlen := xrem
+              xrem := 0.U
+            }
+            .otherwise {
+              state := sReadCmd
+              xlen := xmax - 1.U
+              xrem := xrem - xmax
+            }
         }
       }
     }
@@ -127,13 +128,14 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
 
   // read-from-dram
   val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
-  when (state === sIdle) {
-    when (offsetIsEven) {
+  when(state === sIdle) {
+    when(offsetIsEven) {
       raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))
-    } .otherwise {
-      raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))) - uopBytes.U
+    }.otherwise {
+      raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
+        uopBytes)))) - uopBytes.U
     }
-  } .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
+  }.elsewhen(state === sReadData && xcnt === xlen && xrem =/= 0.U) {
     raddr := raddr + xmax_bytes
   }
 
@@ -143,16 +145,16 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
 
   io.vme_rd.data.ready := state === sReadData
 
-  when (state =/= sReadData) {
+  when(state =/= sReadData) {
     xcnt := 0.U
-  } .elsewhen (io.vme_rd.data.fire()) {
+  }.elsewhen(io.vme_rd.data.fire()) {
     xcnt := xcnt + 1.U
   }
 
   val waddr = Reg(UInt(log2Ceil(uopDepth).W))
-  when (state === sIdle) {
+  when(state === sIdle) {
     waddr := dec.sram_offset >> log2Ceil(numUop)
-  } .elsewhen (io.vme_rd.data.fire()) {
+  }.elsewhen(io.vme_rd.data.fire()) {
     waddr := waddr + 1.U
   }
 
@@ -160,36 +162,37 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
   val wmask = Reg(Vec(numUop, Bool()))
 
-  when (offsetIsEven) {
-    when (sizeIsEven) {
+  when(offsetIsEven) {
+    when(sizeIsEven) {
       wmask := "b_11".U.asTypeOf(wmask)
-    } .elsewhen (io.vme_rd.cmd.fire()) {
-      when (dec.xsize === 1.U) {
-        wmask := "b_01".U.asTypeOf(wmask)
-      } .otherwise {
-        wmask := "b_11".U.asTypeOf(wmask)
+    }.elsewhen(io.vme_rd.cmd.fire()) {
+        when(dec.xsize === 1.U) {
+          wmask := "b_01".U.asTypeOf(wmask)
+        }.otherwise {
+          wmask := "b_11".U.asTypeOf(wmask)
+        }
       }
-    } .elsewhen (io.vme_rd.data.fire()) {
-      when (xcnt === xlen - 1.U) {
-        wmask := "b_01".U.asTypeOf(wmask)
-      } .otherwise {
-        wmask := "b_11".U.asTypeOf(wmask)
+      .elsewhen(io.vme_rd.data.fire()) {
+        when(xcnt === xlen - 1.U) {
+          wmask := "b_01".U.asTypeOf(wmask)
+        }.otherwise {
+          wmask := "b_11".U.asTypeOf(wmask)
+        }
       }
-    }
-  } .otherwise {
-    when (io.vme_rd.cmd.fire()) {
+  }.otherwise {
+    when(io.vme_rd.cmd.fire()) {
       wmask := "b_10".U.asTypeOf(wmask)
-    } .elsewhen (io.vme_rd.data.fire()) {
-      when (sizeIsEven && xcnt === xlen - 1.U) {
+    }.elsewhen(io.vme_rd.data.fire()) {
+      when(sizeIsEven && xcnt === xlen - 1.U) {
         wmask := "b_01".U.asTypeOf(wmask)
-      } .otherwise {
+      }.otherwise {
         wmask := "b_11".U.asTypeOf(wmask)
       }
     }
   }
 
   wdata := io.vme_rd.data.bits.asTypeOf(wdata)
-  when (io.vme_rd.data.fire()) {
+  when(io.vme_rd.data.fire()) {
     mem.write(waddr, wdata, wmask)
   }
 
@@ -209,7 +212,7 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
 
   // debug
   if (debug) {
-    when (io.vme_rd.cmd.fire()) {
+    when(io.vme_rd.cmd.fire()) {
       printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
     }
   }
index 06df51e..f268e79 100644 (file)
@@ -29,14 +29,17 @@ import chisel3.util._
   * depending on the push and pop fields on instructions to prevent RAW and WAR
   * hazards.
   */
-class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
+class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1)
+    extends Module {
   val io = IO(new Bundle {
     val spost = Input(Bool())
     val swait = Input(Bool())
     val sready = Output(Bool())
   })
   val cnt = RegInit(counterInitValue.U(counterBits.W))
-  when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U }
-  when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
+  when(io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) {
+    cnt := cnt + 1.U
+  }
+  when(!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
   io.sready := cnt =/= 0.U
 }
index 71d9208..04bc7f5 100644 (file)
@@ -55,21 +55,21 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val done = tensorStore.io.done
 
   // control
-  switch (state) {
-    is (sIdle) {
-      when (start) {
-        when (dec.io.isSync) {
+  switch(state) {
+    is(sIdle) {
+      when(start) {
+        when(dec.io.isSync) {
           state := sSync
-        } .elsewhen (dec.io.isStore) {
+        }.elsewhen(dec.io.isStore) {
           state := sExe
         }
       }
     }
-    is (sSync) {
+    is(sSync) {
       state := sIdle
     }
-    is (sExe) {
-      when (done) {
+    is(sExe) {
+      when(done) {
         state := sIdle
       }
     }
@@ -94,19 +94,19 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
   // debug
   if (debug) {
     // start
-    when (state === sIdle && start) {
-      when (dec.io.isSync) {
+    when(state === sIdle && start) {
+      when(dec.io.isSync) {
         printf("[Store] start sync\n")
-      } .elsewhen (dec.io.isStore) {
+      }.elsewhen(dec.io.isStore) {
         printf("[Store] start\n")
       }
     }
     // done
-    when (state === sSync) {
+    when(state === sSync) {
       printf("[Store] done sync\n")
     }
-    when (state === sExe) {
-      when (done) {
+    when(state === sExe) {
+      when(done) {
         printf("[Store] done\n")
       }
     }
index fbb0578..b438641 100644 (file)
@@ -116,7 +116,8 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
     val acc = new TensorMaster(tensorType = "acc")
     val out = new TensorMaster(tensorType = "out")
   })
-  val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6)
+  val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil =
+    Enum(6)
   val state = RegInit(sIdle)
   val alu = Module(new AluVector)
   val dec = io.inst.asTypeOf(new AluDecode)
@@ -132,81 +133,86 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
   val src_i = Reg(chiselTypeOf(dec.uop_end))
   val done =
     state === sExe &
-    alu.io.out.data.valid &
-    (cnt_o === dec.lp_0 - 1.U) &
-    (cnt_i === dec.lp_1 - 1.U) &
-    (uop_idx === uop_end - 1.U)
-
-  switch (state) {
-    is (sIdle) {
-      when (io.start) {
+      alu.io.out.data.valid &
+      (cnt_o === dec.lp_0 - 1.U) &
+      (cnt_i === dec.lp_1 - 1.U) &
+      (uop_idx === uop_end - 1.U)
+
+  switch(state) {
+    is(sIdle) {
+      when(io.start) {
         state := sReadUop
       }
     }
-    is (sReadUop) {
+    is(sReadUop) {
       state := sComputeIdx
     }
-    is (sComputeIdx) {
+    is(sComputeIdx) {
       state := sReadTensorA
     }
-    is (sReadTensorA) {
+    is(sReadTensorA) {
       state := sReadTensorB
     }
-    is (sReadTensorB) {
+    is(sReadTensorB) {
       state := sExe
     }
-    is (sExe) {
-      when (alu.io.out.data.valid) {
-        when ((cnt_o === dec.lp_0 - 1.U) &&
-              (cnt_i === dec.lp_1 - 1.U) &&
-              (uop_idx === uop_end - 1.U)) {
+    is(sExe) {
+      when(alu.io.out.data.valid) {
+        when(
+          (cnt_o === dec.lp_0 - 1.U) &&
+            (cnt_i === dec.lp_1 - 1.U) &&
+            (uop_idx === uop_end - 1.U)) {
           state := sIdle
-        } .otherwise {
+        }.otherwise {
           state := sReadUop
         }
       }
     }
   }
 
-  when (state === sIdle ||
-         (state === sExe &&
-          alu.io.out.data.valid &&
-          uop_idx === uop_end - 1.U)) {
+  when(
+    state === sIdle ||
+      (state === sExe &&
+        alu.io.out.data.valid &&
+        uop_idx === uop_end - 1.U)) {
     uop_idx := dec.uop_begin
-  } .elsewhen (state === sExe && alu.io.out.data.valid) {
+  }.elsewhen(state === sExe && alu.io.out.data.valid) {
     uop_idx := uop_idx + 1.U
   }
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     cnt_o := 0.U
     dst_o := 0.U
     src_o := 0.U
-  } .elsewhen (state === sExe &&
-               alu.io.out.data.valid &&
-               uop_idx === uop_end - 1.U &&
-               cnt_i === dec.lp_1 - 1.U) {
+  }.elsewhen(
+    state === sExe &&
+      alu.io.out.data.valid &&
+      uop_idx === uop_end - 1.U &&
+      cnt_i === dec.lp_1 - 1.U) {
     cnt_o := cnt_o + 1.U
     dst_o := dst_o + dec.dst_0
     src_o := src_o + dec.src_0
   }
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     cnt_i := 0.U
     dst_i := 0.U
     src_i := 0.U
-  } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
-    cnt_i := 0.U
-    dst_i := dst_o
-    src_i := src_o
-  } .elsewhen (state === sExe &&
-               alu.io.out.data.valid &&
-               uop_idx === uop_end - 1.U) {
-    cnt_i := cnt_i + 1.U
-    dst_i := dst_i + dec.dst_1
-    src_i := src_i + dec.src_1
-  }
+  }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
+      cnt_i := 0.U
+      dst_i := dst_o
+      src_i := src_o
+    }
+    .elsewhen(
+      state === sExe &&
+        alu.io.out.data.valid &&
+        uop_idx === uop_end - 1.U) {
+      cnt_i := cnt_i + 1.U
+      dst_i := dst_i + dec.dst_1
+      src_i := src_i + dec.src_1
+    }
 
-  when (state === sComputeIdx && io.uop.data.valid) {
+  when(state === sComputeIdx && io.uop.data.valid) {
     uop_dst := io.uop.data.bits.u0 + dst_i
     uop_src := io.uop.data.bits.u1 + src_i
   }
@@ -222,17 +228,25 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
   // imm
   val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
   tensorImm.data.valid := state === sReadTensorB
-  tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } }
+  tensorImm.data.bits.foreach { b =>
+    b.foreach { c =>
+      c := dec.alu_imm
+    }
+  }
 
   // alu
   val isSHR = dec.alu_op === ALU_OP(3)
-  val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1)
+  val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
   val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
   alu.io.opcode := fixme_alu_op
   alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
   alu.io.acc_a.data.bits <> io.acc.rd.data.bits
-  alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe)
-  alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits)
+  alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
+                                 tensorImm.data.valid,
+                                 io.acc.rd.data.valid & state === sExe)
+  alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
+                                tensorImm.data.bits,
+                                io.acc.rd.data.bits)
 
   // acc_o
   io.acc.wr.valid := alu.io.acc_y.data.valid
@@ -249,47 +263,51 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
 
   if (debug) {
 
-    when (state === sReadUop) {
+    when(state === sReadUop) {
       printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
     }
 
-    when (state === sReadTensorA) {
+    when(state === sReadTensorA) {
       printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
     }
 
-    when (state === sIdle && io.start) {
+    when(state === sIdle && io.start) {
       printf(p"[TensorAlu] decode:$dec\n")
     }
 
     alu.io.acc_a.data.bits.foreach { tensor =>
-      tensor.zipWithIndex.foreach { case(elem, i) =>
-        when (alu.io.acc_a.data.valid) {
-          printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
-        }
+      tensor.zipWithIndex.foreach {
+        case (elem, i) =>
+          when(alu.io.acc_a.data.valid) {
+            printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
+          }
       }
     }
 
     alu.io.acc_b.data.bits.foreach { tensor =>
-      tensor.zipWithIndex.foreach { case(elem, i) =>
-        when (alu.io.acc_b.data.valid) {
-          printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
-        }
+      tensor.zipWithIndex.foreach {
+        case (elem, i) =>
+          when(alu.io.acc_b.data.valid) {
+            printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
+          }
       }
     }
 
     alu.io.acc_y.data.bits.foreach { tensor =>
-      tensor.zipWithIndex.foreach { case(elem, i) =>
-        when (alu.io.acc_y.data.valid) {
-          printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
-        }
+      tensor.zipWithIndex.foreach {
+        case (elem, i) =>
+          when(alu.io.acc_y.data.valid) {
+            printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
+          }
       }
     }
 
     alu.io.out.data.bits.foreach { tensor =>
-      tensor.zipWithIndex.foreach { case(elem, i) =>
-        when (alu.io.out.data.valid) {
-          printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
-        }
+      tensor.zipWithIndex.foreach {
+        case (elem, i) =>
+          when(alu.io.out.data.valid) {
+            printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
+          }
       }
     }
   }
index 051e011..3f5f387 100644 (file)
@@ -62,8 +62,10 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
 }
 
 /** Pipelined DotProduct based on MAC and PipeAdder */
-class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module {
-  val errorMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
+class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
+    extends Module {
+  val errorMsg =
+    s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
   require(size >= 2 && isPow2(size), errorMsg)
   val b = aBits + bBits
   val outBits = b + log2Ceil(size) + 1
@@ -72,12 +74,15 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module
     val b = Input(Vec(size, SInt(bBits.W)))
     val y = Output(SInt(outBits.W))
   })
-  val s = Seq.tabulate(log2Ceil(size + 1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers
+  val s = Seq.tabulate(log2Ceil(size + 1))(i =>
+    pow(2, log2Ceil(size) - i).toInt) // # of total layers
   val p = log2Ceil(size / 2) + 1 // # of adder layers
   val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
-  val a = Seq.tabulate(p)(i =>
-    Seq.fill(s(i + 1))(Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1))))
-  ) // # adders within each layer
+  val a = Seq.tabulate(p)(
+    i =>
+      Seq.fill(s(i + 1))(Module(new PipeAdder(
+        aBits = (b + i + 1),
+        bBits = (b + i + 1))))) // # adders within each layer
 
   // Vector MACs
   for (i <- 0 until s(0)) {
@@ -111,7 +116,7 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
   val inpBits = p(CoreKey).inpBits
   val wgtBits = p(CoreKey).wgtBits
   val outBits = p(CoreKey).outBits
-  val io = IO(new Bundle{
+  val io = IO(new Bundle {
     val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
     val inp = new TensorMasterData(tensorType = "inp")
     val wgt = new TensorMasterData(tensorType = "wgt")
@@ -119,8 +124,10 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
     val acc_o = new TensorClientData(tensorType = "acc")
     val out = new TensorClientData(tensorType = "out")
   })
-  val dot = Seq.fill(size)(Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
-  val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
+  val dot = Seq.fill(size)(
+    Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
+  val acc = Seq.fill(size)(
+    Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
   val add = Seq.fill(size)(Wire(SInt(accBits.W)))
   val vld = Wire(Vec(size, Bool()))
 
@@ -149,7 +156,8 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
   * Also, the TensorGemm uses the reset field in the Gemm instruction to
   * clear or zero-out the acc-scratchpad locations based on the micro-ops.
   */
-class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
+    extends Module {
   val io = IO(new Bundle {
     val start = Input(Bool())
     val done = Output(Bool())
@@ -160,7 +168,8 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
     val acc = new TensorMaster(tensorType = "acc")
     val out = new TensorMaster(tensorType = "out")
   })
-  val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
+  val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
+    Enum(6)
   val state = RegInit(sIdle)
   val mvc = Module(new MatrixVectorMultiplication)
   val dec = io.inst.asTypeOf(new GemmDecode)
@@ -181,99 +190,103 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
   val inflight = Reg(UInt(pBits.W))
   val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
   val done = inflight === 0.U &
-             ((state === sExe &
-              cnt_o === dec.lp_0 - 1.U &
-              cnt_i === dec.lp_1 - 1.U &
-              uop_idx === uop_end - 1.U &
-              inflight === 0.U) |
-             state === sWait)
-
-  switch (state) {
-    is (sIdle) {
-      when (io.start) {
+    ((state === sExe &
+      cnt_o === dec.lp_0 - 1.U &
+      cnt_i === dec.lp_1 - 1.U &
+      uop_idx === uop_end - 1.U &
+      inflight === 0.U) |
+      state === sWait)
+
+  switch(state) {
+    is(sIdle) {
+      when(io.start) {
         state := sReadUop
       }
     }
-    is (sReadUop) {
+    is(sReadUop) {
       state := sComputeIdx
     }
-    is (sComputeIdx) {
+    is(sComputeIdx) {
       state := sReadTensor
     }
-    is (sReadTensor) {
+    is(sReadTensor) {
       state := sExe
     }
-    is (sExe) {
-      when ((cnt_o === dec.lp_0 - 1.U) &&
-            (cnt_i === dec.lp_1 - 1.U) &&
-            (uop_idx === uop_end - 1.U)) {
-        when (inflight =/= 0.U) {
+    is(sExe) {
+      when(
+        (cnt_o === dec.lp_0 - 1.U) &&
+          (cnt_i === dec.lp_1 - 1.U) &&
+          (uop_idx === uop_end - 1.U)) {
+        when(inflight =/= 0.U) {
           state := sWait
-        } .otherwise {
+        }.otherwise {
           state := sIdle
         }
-      } .otherwise {
+      }.otherwise {
         state := sReadUop
       }
     }
-    is (sWait) {
-      when (inflight === 0.U) {
+    is(sWait) {
+      when(inflight === 0.U) {
         state := sIdle
       }
     }
   }
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     inflight := 0.U
-  } .elsewhen (!dec.reset) {
-    when (state === sReadTensor) { // issue a tensor
+  }.elsewhen(!dec.reset) {
+    when(state === sReadTensor) { // issue a tensor
       inflight := inflight + 1.U
-    } .elsewhen (mvc.io.acc_o.data.valid) { // commit a tensor
+    }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
       inflight := inflight - 1.U
     }
   }
 
-  when (state === sIdle ||
-         (state === sExe &&
-          uop_idx === uop_end - 1.U)) {
+  when(
+    state === sIdle ||
+      (state === sExe &&
+        uop_idx === uop_end - 1.U)) {
     uop_idx := dec.uop_begin
-  } .elsewhen (state === sExe) {
+  }.elsewhen(state === sExe) {
     uop_idx := uop_idx + 1.U
   }
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     cnt_o := 0.U
     acc_o := 0.U
     inp_o := 0.U
     wgt_o := 0.U
-  } .elsewhen (state === sExe &&
-               uop_idx === uop_end - 1.U &&
-               cnt_i === dec.lp_1 - 1.U) {
+  }.elsewhen(
+    state === sExe &&
+      uop_idx === uop_end - 1.U &&
+      cnt_i === dec.lp_1 - 1.U) {
     cnt_o := cnt_o + 1.U
     acc_o := acc_o + dec.acc_0
     inp_o := inp_o + dec.inp_0
     wgt_o := wgt_o + dec.wgt_0
   }
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     cnt_i := 0.U
     acc_i := 0.U
     inp_i := 0.U
     wgt_i := 0.U
-  } .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
-    cnt_i := 0.U
-    acc_i := acc_o
-    inp_i := inp_o
-    wgt_i := wgt_o
-  } .elsewhen (state === sExe &&
-               uop_idx === uop_end - 1.U) {
-    cnt_i := cnt_i + 1.U
-    acc_i := acc_i + dec.acc_1
-    inp_i := inp_i + dec.inp_1
-    wgt_i := wgt_i + dec.wgt_1
-  }
+  }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
+      cnt_i := 0.U
+      acc_i := acc_o
+      inp_i := inp_o
+      wgt_i := wgt_o
+    }
+    .elsewhen(state === sExe &&
+      uop_idx === uop_end - 1.U) {
+      cnt_i := cnt_i + 1.U
+      acc_i := acc_i + dec.acc_1
+      inp_i := inp_i + dec.inp_1
+      wgt_i := wgt_i + dec.wgt_1
+    }
 
-  when (state === sComputeIdx && io.uop.data.valid) {
+  when(state === sComputeIdx && io.uop.data.valid) {
     uop_acc := io.uop.data.bits.u0 + acc_i
     uop_inp := io.uop.data.bits.u1 + inp_i
     uop_wgt := io.uop.data.bits.u2 + wgt_i
@@ -307,7 +320,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
   mvc.io.acc_i.data <> io.acc.rd.data
 
   // acc_o
-  io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid)
+  io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset,
+                                                   true.B,
+                                                   wrpipe.io.deq.valid)
   io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
   io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
 
@@ -320,47 +335,55 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
   io.done := done
 
   if (debug) {
-    when (state === sReadUop && ~dec.reset) {
+    when(state === sReadUop && ~dec.reset) {
       printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
     }
 
-    when (state === sReadTensor && ~dec.reset) {
-      printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
+    when(state === sReadTensor && ~dec.reset) {
+      printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n",
+             uop_acc,
+             uop_inp,
+             uop_wgt)
     }
 
-    io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
-      when (io.inp.rd.data.valid && ~dec.reset) {
-        printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
-      }
+    io.inp.rd.data.bits.zipWithIndex.foreach {
+      case (r, i) =>
+        when(io.inp.rd.data.valid && ~dec.reset) {
+          printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
+        }
     }
 
-    io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
-      when (io.wgt.rd.data.valid && ~dec.reset) {
-        printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
-      }
+    io.wgt.rd.data.bits.zipWithIndex.foreach {
+      case (r, i) =>
+        when(io.wgt.rd.data.valid && ~dec.reset) {
+          printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
+        }
     }
 
     io.acc.rd.data.bits.foreach { tensor =>
-      tensor.zipWithIndex.foreach { case(elem, i) =>
-        when (io.acc.rd.data.valid && ~dec.reset) {
-          printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
-        }
+      tensor.zipWithIndex.foreach {
+        case (elem, i) =>
+          when(io.acc.rd.data.valid && ~dec.reset) {
+            printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
+          }
       }
     }
 
     mvc.io.acc_o.data.bits.foreach { tensor =>
-      tensor.zipWithIndex.foreach { case(elem, i) =>
-        when (mvc.io.acc_o.data.valid && ~dec.reset) {
-          printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
-        }
+      tensor.zipWithIndex.foreach {
+        case (elem, i) =>
+          when(mvc.io.acc_o.data.valid && ~dec.reset) {
+            printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
+          }
       }
     }
 
     mvc.io.out.data.bits.foreach { tensor =>
-      tensor.zipWithIndex.foreach { case(elem, i) =>
-        when (mvc.io.out.data.valid && ~dec.reset) {
-          printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
-        }
+      tensor.zipWithIndex.foreach {
+        case (elem, i) =>
+          when(mvc.io.out.data.valid && ~dec.reset) {
+            printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
+          }
       }
     }
   }
index 8f1956f..ca6803c 100644 (file)
@@ -32,8 +32,9 @@ import vta.shell._
   * managed by TensorPadCtrl. The TensorDataCtrl is in charge of
   * handling the way tensors are stored on the scratchpads.
   */
-class TensorLoad(tensorType: String = "none", debug: Boolean = false)
-  (implicit p: Parameters) extends Module {
+class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
+    implicit p: Parameters)
+    extends Module {
   val tp = new TensorParams(tensorType)
   val mp = p(ShellKey).memParams
   val io = IO(new Bundle {
@@ -48,7 +49,8 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
   val strideFactor = tp.tensorLength * tp.tensorWidth
 
   val dec = io.inst.asTypeOf(new MemDecode)
-  val dataCtrl = Module(new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
+  val dataCtrl = Module(
+    new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
   val dataCtrlDone = RegInit(false.B)
   val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
   val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
@@ -58,81 +60,85 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
   val tag = Reg(UInt(log2Ceil(tp.numMemBlock).W))
   val set = Reg(UInt(log2Ceil(tp.tensorLength).W))
 
-  val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7)
+  val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil =
+    Enum(7)
   val state = RegInit(sIdle)
 
   // control
-  switch (state) {
-    is (sIdle) {
-      when (io.start) {
-        when (dec.ypad_0 =/= 0.U) {
+  switch(state) {
+    is(sIdle) {
+      when(io.start) {
+        when(dec.ypad_0 =/= 0.U) {
           state := sYPad0
-        } .elsewhen (dec.xpad_0 =/= 0.U) {
-          state := sXPad0
-        } .otherwise {
-          state := sReadCmd
-        }
+        }.elsewhen(dec.xpad_0 =/= 0.U) {
+            state := sXPad0
+          }
+          .otherwise {
+            state := sReadCmd
+          }
       }
     }
-    is (sYPad0) {
-      when (yPadCtrl0.io.done) {
-        when (dec.xpad_0 =/= 0.U) {
+    is(sYPad0) {
+      when(yPadCtrl0.io.done) {
+        when(dec.xpad_0 =/= 0.U) {
           state := sXPad0
-        } .otherwise {
+        }.otherwise {
           state := sReadCmd
         }
       }
     }
-    is (sXPad0) {
-      when (xPadCtrl0.io.done) {
+    is(sXPad0) {
+      when(xPadCtrl0.io.done) {
         state := sReadCmd
       }
     }
-    is (sReadCmd) {
-      when (io.vme_rd.cmd.ready) {
+    is(sReadCmd) {
+      when(io.vme_rd.cmd.ready) {
         state := sReadData
       }
     }
-    is (sReadData) {
-      when (io.vme_rd.data.valid) {
-        when (dataCtrl.io.done) {
-          when (dec.xpad_1 =/= 0.U) {
+    is(sReadData) {
+      when(io.vme_rd.data.valid) {
+        when(dataCtrl.io.done) {
+          when(dec.xpad_1 =/= 0.U) {
             state := sXPad1
-          } .elsewhen (dec.ypad_1 =/= 0.U) {
-            state := sYPad1
-          } .otherwise  {
-            state := sIdle
-          }
-        } .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) {
-          when (dec.xpad_1 =/= 0.U) {
+          }.elsewhen(dec.ypad_1 =/= 0.U) {
+              state := sYPad1
+            }
+            .otherwise {
+              state := sIdle
+            }
+        }.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) {
+          when(dec.xpad_1 =/= 0.U) {
             state := sXPad1
-          } .elsewhen (dec.xpad_0 =/= 0.U) {
-            state := sXPad0
-          } .otherwise {
+          }.elsewhen(dec.xpad_0 =/= 0.U) {
+              state := sXPad0
+            }
+            .otherwise {
               state := sReadCmd
-          }
+            }
         }
       }
     }
-    is (sXPad1) {
-      when (xPadCtrl1.io.done) {
-        when (dataCtrlDone) {
-          when (dec.ypad_1 =/= 0.U) {
+    is(sXPad1) {
+      when(xPadCtrl1.io.done) {
+        when(dataCtrlDone) {
+          when(dec.ypad_1 =/= 0.U) {
             state := sYPad1
-          } .otherwise {
+          }.otherwise {
             state := sIdle
           }
-        } .otherwise {
-          when (dec.xpad_0 =/= 0.U) {
+        }.otherwise {
+          when(dec.xpad_0 =/= 0.U) {
             state := sXPad0
-          } .otherwise {
+          }.otherwise {
             state := sReadCmd
           }
         }
       }
     }
-    is (sYPad1) {
-      when (yPadCtrl1.io.done && dataCtrlDone) {
+    is(sYPad1) {
+      when(yPadCtrl1.io.done && dataCtrlDone) {
         state := sIdle
       }
     }
@@ -146,9 +152,9 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
   dataCtrl.io.xupdate := io.vme_rd.data.fire()
   dataCtrl.io.yupdate := io.vme_rd.data.fire()
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     dataCtrlDone := false.B
-  } .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) {
+  }.elsewhen(io.vme_rd.data.fire() && dataCtrl.io.done) {
     dataCtrlDone := true.B
   }
 
@@ -156,18 +162,19 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
   yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
 
   yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
-                          ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
-                           (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
+    ((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
+      (state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
 
   xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
-                          ((state === sIdle & io.start) |
-                          (state === sYPad0 & yPadCtrl0.io.done) |
-                          (io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
-                          (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
+    ((state === sIdle & io.start) |
+      (state === sYPad0 & yPadCtrl0.io.done) |
+      (io.vme_rd.data
+        .fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
+      (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
 
   xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
-                          ((dataCtrl.io.done) |
-                          (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
+    ((dataCtrl.io.done) |
+      (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
 
   yPadCtrl0.io.inst := io.inst
   yPadCtrl1.io.inst := io.inst
@@ -183,39 +190,49 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
 
   // write-to-sram
   val isZeroPad = state === sYPad0 |
-                  state === sXPad0 |
-                  state === sXPad1 |
-                  state === sYPad1
+    state === sXPad0 |
+    state === sXPad1 |
+    state === sYPad1
 
-  when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
+  when(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
     tag := 0.U
-  } .elsewhen (io.vme_rd.data.fire() || isZeroPad) {
+  }.elsewhen(io.vme_rd.data.fire() || isZeroPad) {
     tag := tag + 1.U
   }
 
-  when (state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
+  when(
+    state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
     set := 0.U
-  } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
+  }.elsewhen(
+    (io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
     set := set + 1.U
   }
 
   val waddr_cur = Reg(UInt(tp.memAddrBits.W))
   val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
-  when (state === sIdle) {
+  when(state === sIdle) {
     waddr_cur := dec.sram_offset
     waddr_nxt := dec.sram_offset
-  } .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
-    waddr_cur := waddr_cur + 1.U
-  } .elsewhen (dataCtrl.io.stride) {
-    waddr_cur := waddr_nxt + dec.xsize
-    waddr_nxt := waddr_nxt + dec.xsize
-  }
+  }.elsewhen((io.vme_rd.data
+      .fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
+      waddr_cur := waddr_cur + 1.U
+    }
+    .elsewhen(dataCtrl.io.stride) {
+      waddr_cur := waddr_nxt + dec.xsize
+      waddr_nxt := waddr_nxt + dec.xsize
+    }
 
-  val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+  val tensorFile = Seq.fill(tp.tensorLength) {
+    SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
+  }
   val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
-  val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
+  val wdata = Seq.fill(tp.tensorLength) {
+    Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
+  }
   val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
-  no_mask.foreach { m => m := true.B }
+  no_mask.foreach { m =>
+    m := true.B
+  }
 
   for (i <- 0 until tp.tensorLength) {
     for (j <- 0 until tp.numMemBlock) {
@@ -223,11 +240,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
       wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
     }
     val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
-    val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
+    val muxWen =
+      Mux(state === sIdle,
+          io.tensor.wr.valid,
+          (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
     val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
     val muxWdata = Mux(state === sIdle, tdata, wdata(i))
     val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
-    when (muxWen) {
+    when(muxWen) {
       tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
     }
   }
@@ -236,13 +256,16 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
   val rvalid = RegNext(io.tensor.rd.idx.valid)
   io.tensor.rd.data.valid := rvalid
 
-  val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
-  rdata.zipWithIndex.foreach { case(r, i) =>
-    io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
+  val rdata =
+    tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
+  rdata.zipWithIndex.foreach {
+    case (r, i) =>
+      io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
   }
 
   // done
-  val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
+  val done_no_pad = io.vme_rd.data
+    .fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
   val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
   val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
   io.done := done_no_pad | done_x_pad | done_y_pad
@@ -250,28 +273,34 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
   // debug
   if (debug) {
     if (tensorType == "inp") {
-      when (io.vme_rd.cmd.fire()) {
-        printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+      when(io.vme_rd.cmd.fire()) {
+        printf("[TensorLoad] [inp] cmd addr:%x len:%x\n",
+               dataCtrl.io.addr,
+               dataCtrl.io.len)
       }
-      when (state === sYPad0) {
+      when(state === sYPad0) {
         printf("[TensorLoad] [inp] sYPad0\n")
       }
-      when (state === sYPad1) {
+      when(state === sYPad1) {
         printf("[TensorLoad] [inp] sYPad1\n")
       }
-      when (state === sXPad0) {
+      when(state === sXPad0) {
         printf("[TensorLoad] [inp] sXPad0\n")
       }
-      when (state === sXPad1) {
+      when(state === sXPad1) {
         printf("[TensorLoad] [inp] sXPad1\n")
       }
     } else if (tensorType == "wgt") {
-      when (io.vme_rd.cmd.fire()) {
-        printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+      when(io.vme_rd.cmd.fire()) {
+        printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n",
+               dataCtrl.io.addr,
+               dataCtrl.io.len)
       }
     } else if (tensorType == "acc") {
-      when (io.vme_rd.cmd.fire()) {
-        printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
+      when(io.vme_rd.cmd.fire()) {
+        printf("[TensorLoad] [acc] cmd addr:%x len:%x\n",
+               dataCtrl.io.addr,
+               dataCtrl.io.len)
       }
     }
   }
index 20ff6f4..083a70c 100644 (file)
@@ -28,8 +28,9 @@ import vta.shell._
   *
   * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
   */
-class TensorStore(tensorType: String = "none", debug: Boolean = false)
-  (implicit p: Parameters) extends Module {
+class TensorStore(tensorType: String = "none", debug: Boolean = false)(
+    implicit p: Parameters)
+    extends Module {
   val tp = new TensorParams(tensorType)
   val mp = p(ShellKey).memParams
   val io = IO(new Bundle {
@@ -53,9 +54,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
   val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
   val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
   val xrem = Reg(chiselTypeOf(dec.xsize))
-  val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U
+  val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock)) - 1.U
   val xmax = (1 << mp.lenBits).U
-  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+  val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
   val ycnt = Reg(chiselTypeOf(dec.ysize))
   val ysize = dec.ysize
   val tag = Reg(UInt(8.W))
@@ -65,132 +66,147 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
   val state = RegInit(sIdle)
 
   // control
-  switch (state) {
-    is (sIdle) {
-      when (io.start) {
+  switch(state) {
+    is(sIdle) {
+      when(io.start) {
         state := sWriteCmd
-        when (xsize < xmax) {
+        when(xsize < xmax) {
           xlen := xsize
           xrem := 0.U
-        } .otherwise {
+        }.otherwise {
           xlen := xmax - 1.U
           xrem := xsize - xmax
         }
       }
     }
-    is (sWriteCmd) {
-      when (io.vme_wr.cmd.ready) {
+    is(sWriteCmd) {
+      when(io.vme_wr.cmd.ready) {
         state := sWriteData
       }
     }
-    is (sWriteData) {
-      when (io.vme_wr.data.ready) {
-        when (xcnt === xlen) {
+    is(sWriteData) {
+      when(io.vme_wr.data.ready) {
+        when(xcnt === xlen) {
           state := sWriteAck
-        } .elsewhen (tag === (numMemBlock - 1).U) {
+        }.elsewhen(tag === (numMemBlock - 1).U) {
           state := sReadMem
         }
       }
     }
-    is (sReadMem) {
+    is(sReadMem) {
       state := sWriteData
     }
-    is (sWriteAck) {
-      when (io.vme_wr.ack) {
-        when (xrem === 0.U) {
-          when (ycnt === ysize - 1.U) {
+    is(sWriteAck) {
+      when(io.vme_wr.ack) {
+        when(xrem === 0.U) {
+          when(ycnt === ysize - 1.U) {
             state := sIdle
-          } .otherwise {
+          }.otherwise {
             state := sWriteCmd
-            when (xsize < xmax) {
+            when(xsize < xmax) {
               xlen := xsize
               xrem := 0.U
-            } .otherwise {
+            }.otherwise {
               xlen := xmax - 1.U
               xrem := xsize - xmax
             }
           }
-        } .elsewhen (xrem < xmax) {
-          state := sWriteCmd
-          xlen := xrem
-          xrem := 0.U
-        } .otherwise {
-          state := sWriteCmd
-          xlen := xmax - 1.U
-          xrem := xrem - xmax
-        }
+        }.elsewhen(xrem < xmax) {
+            state := sWriteCmd
+            xlen := xrem
+            xrem := 0.U
+          }
+          .otherwise {
+            state := sWriteCmd
+            xlen := xmax - 1.U
+            xrem := xrem - xmax
+          }
       }
     }
   }
 
   // write-to-sram
-  val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) }
+  val tensorFile = Seq.fill(tensorLength) {
+    SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W)))
+  }
   val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
   val no_mask = Wire(Vec(numMemBlock, Bool()))
 
   wdata_t := DontCare
-  no_mask.foreach { m => m := true.B }
+  no_mask.foreach { m =>
+    m := true.B
+  }
 
   for (i <- 0 until tensorLength) {
     val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
-    when (io.tensor.wr.valid) {
+    when(io.tensor.wr.valid) {
       tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
     }
   }
 
   // read-from-sram
   val stride = state === sWriteAck &
-              io.vme_wr.ack &
-              xcnt === xlen + 1.U &
-              xrem === 0.U &
-              ycnt =/= ysize - 1.U
+    io.vme_wr.ack &
+    xcnt === xlen + 1.U &
+    xrem === 0.U &
+    ycnt =/= ysize - 1.U
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     ycnt := 0.U
-  } .elsewhen (stride) {
+  }.elsewhen(stride) {
     ycnt := ycnt + 1.U
   }
 
-  when (state === sWriteCmd || tag === (numMemBlock - 1).U) {
+  when(state === sWriteCmd || tag === (numMemBlock - 1).U) {
     tag := 0.U
-  } .elsewhen (io.vme_wr.data.fire()) {
+  }.elsewhen(io.vme_wr.data.fire()) {
     tag := tag + 1.U
   }
 
-  when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
+  when(
+    state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
     set := 0.U
-  } .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
+  }.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
     set := set + 1.U
   }
 
   val raddr_cur = Reg(UInt(tp.memAddrBits.W))
   val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
-  when (state === sIdle) {
+  when(state === sIdle) {
     raddr_cur := dec.sram_offset
     raddr_nxt := dec.sram_offset
-  } .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
-    raddr_cur := raddr_cur + 1.U
-  } .elsewhen (stride) {
-    raddr_cur := raddr_nxt + dec.xsize
-    raddr_nxt := raddr_nxt + dec.xsize
-  }
+  }.elsewhen(io.vme_wr.data
+      .fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
+      raddr_cur := raddr_cur + 1.U
+    }
+    .elsewhen(stride) {
+      raddr_cur := raddr_nxt + dec.xsize
+      raddr_nxt := raddr_nxt + dec.xsize
+    }
 
-  val tread = Seq.tabulate(tensorLength) { i => i.U ->
-    tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) }
+  val tread = Seq.tabulate(tensorLength) { i =>
+    i.U ->
+      tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem)
+  }
   val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
 
   // write-to-dram
   val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
   val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8
-  when (state === sIdle) {
-    waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
-    waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
-  } .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
-    waddr_cur := waddr_cur + xmax_bytes
-  } .elsewhen (stride) {
-    waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
-    waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
-  }
+  when(state === sIdle) {
+    waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
+      elemBytes)))
+    waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
+      elemBytes)))
+  }.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
+      waddr_cur := waddr_cur + xmax_bytes
+    }
+    .elsewhen(stride) {
+      waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(
+        tensorLength * tensorWidth))
+      waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(
+        tensorLength * tensorWidth))
+    }
 
   io.vme_wr.cmd.valid := state === sWriteCmd
   io.vme_wr.cmd.bits.addr := waddr_cur
@@ -199,9 +215,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
   io.vme_wr.data.valid := state === sWriteData
   io.vme_wr.data.bits := mdata(tag)
 
-  when (state === sWriteCmd) {
+  when(state === sWriteCmd) {
     xcnt := 0.U
-  } .elsewhen (io.vme_wr.data.fire()) {
+  }.elsewhen(io.vme_wr.data.fire()) {
     xcnt := xcnt + 1.U
   }
 
@@ -213,13 +229,19 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
 
   // debug
   if (debug) {
-    when (io.vme_wr.cmd.fire()) {
-      printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
+    when(io.vme_wr.cmd.fire()) {
+      printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n",
+             ysize,
+             ycnt,
+             raddr_cur,
+             waddr_cur,
+             xlen,
+             xrem)
     }
-    when (io.vme_wr.data.fire()) {
+    when(io.vme_wr.data.fire()) {
       printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
     }
-    when (io.vme_wr.ack) {
+    when(io.vme_wr.ack) {
       printf("[TensorStore] ack\n")
     }
   }
index bb03846..1f00554 100644 (file)
@@ -30,11 +30,14 @@ import vta.shell._
   * weights (wgt), biases (acc), and outputs (out). This is used to avoid
   * doing the same boring calculations over and over again.
   */
-class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
-  val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
+class TensorParams(tensorType: String = "none")(implicit p: Parameters)
+    extends Bundle {
+  val errorMsg =
+    s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
 
-  require (tensorType == "inp" || tensorType == "wgt"
-    || tensorType == "acc" || tensorType == "out", errorMsg)
+  require(tensorType == "inp" || tensorType == "wgt"
+            || tensorType == "acc" || tensorType == "out",
+          errorMsg)
 
   val (tensorLength, tensorWidth, tensorElemBits) =
     if (tensorType == "inp")
@@ -69,25 +72,30 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
   * biases (acc), and outputs (out).
   *
   */
-class TensorMaster(tensorType: String = "none")
-  (implicit p: Parameters) extends TensorParams(tensorType) {
-    val rd = new Bundle {
-      val idx = ValidIO(UInt(memAddrBits.W))
-      val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
-    }
-    val wr = ValidIO(new Bundle {
-      val idx = UInt(memAddrBits.W)
-      val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
-    })
-    def tieoffRead() {
-      rd.idx.valid := false.B
-      rd.idx.bits := 0.U
-    }
-    def tieoffWrite() {
-      wr.valid := false.B
-      wr.bits.idx := 0.U
-      wr.bits.data.foreach { b => b.foreach { c => c := 0.U } }
+class TensorMaster(tensorType: String = "none")(implicit p: Parameters)
+    extends TensorParams(tensorType) {
+  val rd = new Bundle {
+    val idx = ValidIO(UInt(memAddrBits.W))
+    val data = Flipped(
+      ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+  }
+  val wr = ValidIO(new Bundle {
+    val idx = UInt(memAddrBits.W)
+    val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+  })
+  def tieoffRead() {
+    rd.idx.valid := false.B
+    rd.idx.bits := 0.U
+  }
+  def tieoffWrite() {
+    wr.valid := false.B
+    wr.bits.idx := 0.U
+    wr.bits.data.foreach { b =>
+      b.foreach { c =>
+        c := 0.U
+      }
     }
+  }
   override def cloneType =
     new TensorMaster(tensorType).asInstanceOf[this.type]
 }
@@ -98,20 +106,25 @@ class TensorMaster(tensorType: String = "none")
   * The TensorLoad unit uses this interface for receiving read and write requests from
   * the TensorGemm unit.
   */
-class TensorClient(tensorType: String = "none")
-  (implicit p: Parameters) extends TensorParams(tensorType) {
-    val rd = new Bundle {
-      val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
-      val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
-    }
-    val wr = Flipped(ValidIO(new Bundle {
-      val idx = UInt(memAddrBits.W)
-      val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
-    }))
-    def tieoffRead() {
-      rd.data.valid := false.B
-      rd.data.bits.foreach { b => b.foreach { c => c := 0.U } }
+class TensorClient(tensorType: String = "none")(implicit p: Parameters)
+    extends TensorParams(tensorType) {
+  val rd = new Bundle {
+    val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
+    val data = ValidIO(
+      Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+  }
+  val wr = Flipped(ValidIO(new Bundle {
+    val idx = UInt(memAddrBits.W)
+    val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
+  }))
+  def tieoffRead() {
+    rd.data.valid := false.B
+    rd.data.bits.foreach { b =>
+      b.foreach { c =>
+        c := 0.U
+      }
     }
+  }
   override def cloneType =
     new TensorClient(tensorType).asInstanceOf[this.type]
 }
@@ -122,9 +135,10 @@ class TensorClient(tensorType: String = "none")
   * is based on the TensorMaster interface, which means this is an input. This interface
   * is used on datapath only module such MatrixVectorCore or AluVector.
   */
-class TensorMasterData(tensorType: String = "none")
-  (implicit p: Parameters) extends TensorParams(tensorType) {
-  val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
+class TensorMasterData(tensorType: String = "none")(implicit p: Parameters)
+    extends TensorParams(tensorType) {
+  val data = Flipped(
+    ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
   override def cloneType =
     new TensorMasterData(tensorType).asInstanceOf[this.type]
 }
@@ -135,18 +149,22 @@ class TensorMasterData(tensorType: String = "none")
   * is based on the TensorClient interface, which means this is an output. This interface
   * is used on datapath only module such MatrixVectorCore or AluVector.
   */
-class TensorClientData(tensorType: String = "none")
-  (implicit p: Parameters) extends TensorParams(tensorType) {
-  val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
+class TensorClientData(tensorType: String = "none")(implicit p: Parameters)
+    extends TensorParams(tensorType) {
+  val data = ValidIO(
+    Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
   override def cloneType =
     new TensorClientData(tensorType).asInstanceOf[this.type]
 }
 
 /** TensorPadCtrl. Zero-padding controller for TensorLoad. */
-class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
-  val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
-  require (padType == "YPad0" || padType == "YPad1"
-    || padType == "XPad0" || padType == "XPad1", errorMsg)
+class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1)
+    extends Module {
+  val errorMsg =
+    s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
+  require(padType == "YPad0" || padType == "YPad1"
+            || padType == "XPad0" || padType == "XPad1",
+          errorMsg)
 
   val io = IO(new Bundle {
     val start = Input(Bool())
@@ -180,33 +198,33 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul
   val sIdle :: sActive :: Nil = Enum(2)
   val state = RegInit(sIdle)
 
-  switch (state) {
-    is (sIdle) {
-      when (io.start) {
+  switch(state) {
+    is(sIdle) {
+      when(io.start) {
         state := sActive
       }
     }
-    is (sActive) {
-      when (ycnt === ymax && xcnt === xmax) {
+    is(sActive) {
+      when(ycnt === ymax && xcnt === xmax) {
         state := sIdle
       }
     }
   }
 
-  when (state === sIdle) {
+  when(state === sIdle) {
     xmax := xval
     ymax := yval
   }
 
-  when (state === sIdle || xcnt === xmax) {
+  when(state === sIdle || xcnt === xmax) {
     xcnt := 0.U
-  } .elsewhen (state === sActive) {
+  }.elsewhen(state === sActive) {
     xcnt := xcnt + 1.U
   }
 
-  when (state === sIdle || ymax === 0.U) {
+  when(state === sIdle || ymax === 0.U) {
     ycnt := 0.U
-  } .elsewhen (state === sActive && xcnt === xmax) {
+  }.elsewhen(state === sActive && xcnt === xmax) {
     ycnt := ycnt + 1.U
   }
 
@@ -214,7 +232,10 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul
 }
 
 /** TensorDataCtrl. Data controller for TensorLoad. */
-class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
+class TensorDataCtrl(tensorType: String = "none",
+                     sizeFactor: Int = 1,
+                     strideFactor: Int = 1)(implicit p: Parameters)
+    extends Module {
   val mp = p(ShellKey).memParams
   val io = IO(new Bundle {
     val start = Input(Bool())
@@ -238,7 +259,7 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
 
   val len = Reg(UInt(mp.lenBits.W))
 
-  val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
+  val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
   val xcnt = Reg(UInt(mp.lenBits.W))
   val xrem = Reg(chiselTypeOf(dec.xsize))
   val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
@@ -246,38 +267,38 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
   val ycnt = Reg(chiselTypeOf(dec.ysize))
 
   val stride = xcnt === len &
-               xrem === 0.U &
-               ycnt =/= dec.ysize - 1.U
+    xrem === 0.U &
+    ycnt =/= dec.ysize - 1.U
 
   val split = xcnt === len & xrem =/= 0.U
 
-  when (io.start || (io.xupdate && stride)) {
-    when (xsize < xmax) {
+  when(io.start || (io.xupdate && stride)) {
+    when(xsize < xmax) {
       len := xsize
       xrem := 0.U
-    } .otherwise {
+    }.otherwise {
       len := xmax - 1.U
       xrem := xsize - xmax
     }
-  } .elsewhen (io.xupdate && split) {
-    when (xrem < xmax) {
+  }.elsewhen(io.xupdate && split) {
+    when(xrem < xmax) {
       len := xrem
       xrem := 0.U
-    } .otherwise {
+    }.otherwise {
       len := xmax - 1.U
       xrem := xrem - xmax
     }
   }
 
-  when (io.xinit) {
+  when(io.xinit) {
     xcnt := 0.U
-  } .elsewhen (io.xupdate) {
+  }.elsewhen(io.xupdate) {
     xcnt := xcnt + 1.U
   }
 
-  when (io.start) {
+  when(io.start) {
     ycnt := 0.U
-  } .elsewhen (io.yupdate && stride) {
+  }.elsewhen(io.yupdate && stride) {
     ycnt := ycnt + 1.U
   }
 
@@ -291,13 +312,13 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
       (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).accBits) / 8
     }
 
-  when (io.start) {
+  when(io.start) {
     caddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
     baddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
-  } .elsewhen (io.yupdate) {
-    when (split) {
+  }.elsewhen(io.yupdate) {
+    when(split) {
       caddr := caddr + xmax_bytes
-    } .elsewhen (stride) {
+    }.elsewhen(stride) {
       caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
       baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
     }
@@ -309,6 +330,6 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
   io.addr := caddr
   io.len := len
   io.done := xcnt === len &
-             xrem === 0.U &
-             ycnt === dec.ysize - 1.U
+    xrem === 0.U &
+    ycnt === dec.ysize - 1.U
 }
index 115bcbc..3318251 100644 (file)
@@ -78,55 +78,56 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
   *
   * Convert Host DPI to AXI for VTAShell
   */
-
-class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
+    extends Module {
   val io = IO(new Bundle {
     val dpi = new VTAHostDPIClient
     val axi = new AXILiteMaster(p(ShellKey).hostParams)
   })
   val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
   val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
-  val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+  val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
+    Enum(6)
   val state = RegInit(sIdle)
 
-  switch (state) {
-    is (sIdle) {
-      when (io.dpi.req.valid) {
-        when (io.dpi.req.opcode) {
+  switch(state) {
+    is(sIdle) {
+      when(io.dpi.req.valid) {
+        when(io.dpi.req.opcode) {
           state := sWriteAddress
-        } .otherwise {
+        }.otherwise {
           state := sReadAddress
         }
       }
     }
-    is (sReadAddress) {
-      when (io.axi.ar.ready) {
+    is(sReadAddress) {
+      when(io.axi.ar.ready) {
         state := sReadData
       }
     }
-    is (sReadData) {
-      when (io.axi.r.valid) {
+    is(sReadData) {
+      when(io.axi.r.valid) {
         state := sIdle
       }
     }
-    is (sWriteAddress) {
-      when (io.axi.aw.ready) {
+    is(sWriteAddress) {
+      when(io.axi.aw.ready) {
         state := sWriteData
       }
     }
-    is (sWriteData) {
-      when (io.axi.w.ready) {
+    is(sWriteData) {
+      when(io.axi.w.ready) {
         state := sWriteResponse
       }
     }
-    is (sWriteResponse) {
-      when (io.axi.b.valid) {
+    is(sWriteResponse) {
+      when(io.axi.b.valid) {
         state := sIdle
       }
     }
   }
 
-  when (state === sIdle && io.dpi.req.valid) {
+  when(state === sIdle && io.dpi.req.valid) {
     addr := io.dpi.req.addr
     data := io.dpi.req.value
   }
@@ -147,9 +148,17 @@ class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mo
   io.dpi.resp.bits := io.axi.r.bits.data
 
   if (debug) {
-    when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) }
-    when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) }
-    when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) }
-    when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) }
+    when(state === sWriteAddress && io.axi.aw.ready) {
+      printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr)
+    }
+    when(state === sReadAddress && io.axi.ar.ready) {
+      printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr)
+    }
+    when(io.axi.r.fire()) {
+      printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data)
+    }
+    when(io.axi.w.fire()) {
+      printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data)
+    }
   }
 }
index 5e2fa74..f46b778 100644 (file)
@@ -75,7 +75,8 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
   setResource("/verilog/VTAMemDPI.v")
 }
 
-class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
+class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
+    extends Module {
   val io = IO(new Bundle {
     val dpi = new VTAMemDPIMaster
     val axi = new AXIClient(p(ShellKey).memParams)
@@ -83,56 +84,57 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod
   val opcode = RegInit(false.B)
   val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
   val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
-  val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
+  val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
+    Enum(6)
   val state = RegInit(sIdle)
 
-  switch (state) {
-    is (sIdle) {
-      when (io.axi.ar.valid) {
+  switch(state) {
+    is(sIdle) {
+      when(io.axi.ar.valid) {
         state := sReadAddress
-      } .elsewhen (io.axi.aw.valid) {
+      }.elsewhen(io.axi.aw.valid) {
         state := sWriteAddress
       }
     }
-    is (sReadAddress) {
-      when (io.axi.ar.valid) {
+    is(sReadAddress) {
+      when(io.axi.ar.valid) {
         state := sReadData
       }
     }
-    is (sReadData) {
-      when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
+    is(sReadData) {
+      when(io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
         state := sIdle
       }
     }
-    is (sWriteAddress) {
-      when (io.axi.aw.valid) {
+    is(sWriteAddress) {
+      when(io.axi.aw.valid) {
         state := sWriteData
       }
     }
-    is (sWriteData) {
-      when (io.axi.w.valid && io.axi.w.bits.last) {
+    is(sWriteData) {
+      when(io.axi.w.valid && io.axi.w.bits.last) {
         state := sWriteResponse
       }
     }
-    is (sWriteResponse) {
-      when (io.axi.b.ready) {
+    is(sWriteResponse) {
+      when(io.axi.b.ready) {
         state := sIdle
       }
     }
   }
 
-  when (state === sIdle) {
-    when (io.axi.ar.valid) {
+  when(state === sIdle) {
+    when(io.axi.ar.valid) {
       opcode := false.B
       len := io.axi.ar.bits.len
       addr := io.axi.ar.bits.addr
-    } .elsewhen (io.axi.aw.valid) {
+    }.elsewhen(io.axi.aw.valid) {
       opcode := true.B
       len := io.axi.aw.bits.len
       addr := io.axi.aw.bits.addr
     }
-  } .elsewhen (state === sReadData) {
-    when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
+  }.elsewhen(state === sReadData) {
+    when(io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
       len := len - 1.U
     }
   }
@@ -163,9 +165,21 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod
   io.axi.b.bits.id := 0.U
 
   if (debug) {
-    when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) }
-    when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) }
-    when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) }
-    when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) }
+    when(state === sReadAddress && io.axi.ar.valid) {
+      printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len)
+    }
+    when(state === sWriteAddress && io.axi.aw.valid) {
+      printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len)
+    }
+    when(io.axi.r.fire()) {
+      printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n",
+             io.axi.r.bits.last,
+             io.axi.r.bits.data)
+    }
+    when(io.axi.w.fire()) {
+      printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n",
+             io.axi.w.bits.last,
+             io.axi.w.bits.data)
+    }
   }
 }
index bf34b6a..8fd0fa8 100644 (file)
@@ -24,18 +24,17 @@ import chisel3.util._
 import vta.util.genericbundle._
 
 case class AXIParams(
-  coherent: Boolean = false,
-  idBits: Int = 1,
-  addrBits: Int = 32,
-  dataBits: Int = 64,
-  lenBits: Int = 8,
-  userBits: Int = 1
-)
-{
-  require (addrBits > 0)
-  require (dataBits >= 8 && dataBits % 2 == 0)
+    coherent: Boolean = false,
+    idBits: Int = 1,
+    addrBits: Int = 32,
+    dataBits: Int = 64,
+    lenBits: Int = 8,
+    userBits: Int = 1
+) {
+  require(addrBits > 0)
+  require(dataBits >= 8 && dataBits % 2 == 0)
 
-  val strbBits = dataBits/8
+  val strbBits = dataBits / 8
   val sizeBits = 3
   val burstBits = 2
   val lockBits = 2
@@ -44,7 +43,7 @@ case class AXIParams(
   val qosBits = 4
   val regionBits = 4
   val respBits = 2
-  val sizeConst = log2Ceil(dataBits/8)
+  val sizeConst = log2Ceil(dataBits / 8)
   val idConst = 0
   val userConst = if (coherent) 1 else 0
   val burstConst = 1
@@ -56,7 +55,7 @@ case class AXIParams(
 }
 
 abstract class AXIBase(params: AXIParams)
-  extends GenericParameterizedBundle(params)
+    extends GenericParameterizedBundle(params)
 
 // AXILite
 
index fd9309e..3c271f5 100644 (file)
@@ -25,52 +25,59 @@ import vta.util.config._
 import vta.interface.axi._
 
 /** PynqConfig. Shell configuration for Pynq */
-class PynqConfig extends Config((site, here, up) => {
-  case ShellKey => ShellParams(
-    hostParams = AXIParams(
-      coherent = false,
-      addrBits = 16,
-      dataBits = 32,
-      lenBits = 8,
-      userBits = 1),
-    memParams = AXIParams(
-      coherent = true,
-      addrBits = 32,
-      dataBits = 64,
-      lenBits = 8,
-      userBits = 1),
-    vcrParams = VCRParams(),
-    vmeParams = VMEParams())
-})
+class PynqConfig
+    extends Config((site, here, up) => {
+      case ShellKey =>
+        ShellParams(
+          hostParams = AXIParams(coherent = false,
+                                 addrBits = 16,
+                                 dataBits = 32,
+                                 lenBits = 8,
+                                 userBits = 1),
+          memParams = AXIParams(coherent = true,
+                                addrBits = 32,
+                                dataBits = 64,
+                                lenBits = 8,
+                                userBits = 1),
+          vcrParams = VCRParams(),
+          vmeParams = VMEParams()
+        )
+    })
 
 /** F1Config. Shell configuration for F1 */
-class F1Config extends Config((site, here, up) => {
-  case ShellKey => ShellParams(
-    hostParams = AXIParams(
-      coherent = false,
-      addrBits = 16,
-      dataBits = 32,
-      lenBits = 8,
-      userBits = 1),
-    memParams = AXIParams(
-      coherent = false,
-      addrBits = 64,
-      dataBits = 64,
-      lenBits = 8,
-      userBits = 1),
-    vcrParams = VCRParams(),
-    vmeParams = VMEParams())
-})
+class F1Config
+    extends Config((site, here, up) => {
+      case ShellKey =>
+        ShellParams(
+          hostParams = AXIParams(coherent = false,
+                                 addrBits = 16,
+                                 dataBits = 32,
+                                 lenBits = 8,
+                                 userBits = 1),
+          memParams = AXIParams(coherent = false,
+                                addrBits = 64,
+                                dataBits = 64,
+                                lenBits = 8,
+                                userBits = 1),
+          vcrParams = VCRParams(),
+          vmeParams = VMEParams()
+        )
+    })
 
 /** De10Config. Shell configuration for De10 */
-class De10Config extends Config((site, here, up) => {
-  case ShellKey => ShellParams(
-    hostParams = AXIParams(
-      addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
-    memParams = AXIParams(
-      addrBits = 32, dataBits = 64, userBits = 5,
-      lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
-      coherent = true),
-    vcrParams = VCRParams(),
-    vmeParams = VMEParams())
-})
+class De10Config
+    extends Config((site, here, up) => {
+      case ShellKey =>
+        ShellParams(
+          hostParams =
+            AXIParams(addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
+          memParams = AXIParams(
+            addrBits = 32,
+            dataBits = 64,
+            userBits = 5,
+            lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
+            coherent = true),
+          vcrParams = VCRParams(),
+          vmeParams = VMEParams()
+        )
+    })
index 817b786..6eb2222 100644 (file)
@@ -30,7 +30,7 @@ import vta.core._
   * system that can be used for simulation or real hardware.
   */
 class IntelShell(implicit p: Parameters) extends Module {
-  val io = IO(new Bundle{
+  val io = IO(new Bundle {
     val host = new AXIClient(p(ShellKey).hostParams)
     val mem = new AXIMaster(p(ShellKey).memParams)
   })
index f3d74ef..30b84d6 100644 (file)
@@ -76,6 +76,7 @@ class VTASim(implicit p: Parameters) extends MultiIOModule {
   sim.io.clock := clock
   sim_wait := sim.io.dpi_wait
 }
+
 /** SimShell.
   *
   * The simulation shell instantiate the sim, host and memory DPI modules that
index efff6a4..517f581 100644 (file)
@@ -29,8 +29,7 @@ import vta.interface.axi._
   *
   * These parameters are used on VCR interfaces and modules.
   */
-case class VCRParams()
-{
+case class VCRParams() {
   val nCtrl = 1
   val nECnt = 1
   val nVals = 1
@@ -40,7 +39,7 @@ case class VCRParams()
 
 /** VCRBase. Parametrize base class. */
 abstract class VCRBase(implicit p: Parameters)
-  extends GenericParameterizedBundle(p)
+    extends GenericParameterizedBundle(p)
 
 /** VCRMaster.
   *
@@ -80,7 +79,7 @@ class VCRClient(implicit p: Parameters) extends VCRBase {
   * registers that could be used as event counters by the Core unit.
   */
 class VCR(implicit p: Parameters) extends Module {
-  val io = IO(new Bundle{
+  val io = IO(new Bundle {
     val host = new AXILiteClient(p(ShellKey).hostParams)
     val vcr = new VCRMaster
   })
@@ -101,50 +100,49 @@ class VCR(implicit p: Parameters) extends Module {
   val rdata = RegInit(0.U(vp.regBits.W))
 
   // registers
-  val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2*vp.nPtrs
+  val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2 * vp.nPtrs
   val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + nPtrs
 
   val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W)))
   val addr = Seq.tabulate(nTotal)(_ * 4)
-  val reg_map = (addr zip reg)  map { case (a, r) => a.U -> r }
+  val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
   val eo = vp.nCtrl
   val vo = eo + vp.nECnt
   val po = vo + vp.nVals
 
-  switch (wstate) {
-    is (sWriteAddress) {
-      when (io.host.aw.valid) {
+  switch(wstate) {
+    is(sWriteAddress) {
+      when(io.host.aw.valid) {
         wstate := sWriteData
       }
     }
-    is (sWriteData) {
-      when (io.host.w.valid) {
+    is(sWriteData) {
+      when(io.host.w.valid) {
         wstate := sWriteResponse
       }
     }
-    is (sWriteResponse) {
-      when (io.host.b.ready) {
+    is(sWriteResponse) {
+      when(io.host.b.ready) {
         wstate := sWriteAddress
       }
     }
   }
 
-  when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
+  when(io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
 
   io.host.aw.ready := wstate === sWriteAddress
   io.host.w.ready := wstate === sWriteData
   io.host.b.valid := wstate === sWriteResponse
   io.host.b.bits.resp := 0.U
 
-
-  switch (rstate) {
-    is (sReadAddress) {
-      when (io.host.ar.valid) {
+  switch(rstate) {
+    is(sReadAddress) {
+      when(io.host.ar.valid) {
         rstate := sReadData
       }
     }
-    is (sReadData) {
-      when (io.host.r.ready) {
+    is(sReadData) {
+      when(io.host.r.ready) {
         rstate := sReadAddress
       }
     }
@@ -155,27 +153,27 @@ class VCR(implicit p: Parameters) extends Module {
   io.host.r.bits.data := rdata
   io.host.r.bits.resp := 0.U
 
-  when (io.vcr.finish) {
+  when(io.vcr.finish) {
     reg(0) := "b_10".U
-  } .elsewhen (io.host.w.fire() && addr(0).U === waddr) {
+  }.elsewhen(io.host.w.fire() && addr(0).U === waddr) {
     reg(0) := wdata
   }
 
   for (i <- 0 until vp.nECnt) {
-    when (io.vcr.ecnt(i).valid) {
+    when(io.vcr.ecnt(i).valid) {
       reg(eo + i) := io.vcr.ecnt(i).bits
-    } .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) {
+    }.elsewhen(io.host.w.fire() && addr(eo + i).U === waddr) {
       reg(eo + i) := wdata
     }
   }
 
   for (i <- 0 until (vp.nVals + nPtrs)) {
-    when (io.host.w.fire() && addr(vo + i).U === waddr) {
+    when(io.host.w.fire() && addr(vo + i).U === waddr) {
       reg(vo + i) := wdata
     }
   }
 
-  when (io.host.ar.fire()) {
+  when(io.host.ar.fire()) {
     rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map)
   }
 
@@ -185,13 +183,13 @@ class VCR(implicit p: Parameters) extends Module {
     io.vcr.vals(i) := reg(vo + i)
   }
 
-  if (mp.addrBits == 32) {  // 32-bit pointers
+  if (mp.addrBits == 32) { // 32-bit pointers
     for (i <- 0 until nPtrs) {
       io.vcr.ptrs(i) := reg(po + i)
     }
-  } else {  // 64-bits pointers
-    for (i <- 0 until (nPtrs/2)) {
-      io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
+  } else { // 64-bits pointers
+    for (i <- 0 until (nPtrs / 2)) {
+      io.vcr.ptrs(i) := Cat(reg(po + 2 * i + 1), reg(po + 2 * i))
     }
   }
 }
index db46295..949929a 100644 (file)
@@ -32,13 +32,16 @@ import vta.interface.axi._
 case class VMEParams() {
   val nReadClients: Int = 5
   val nWriteClients: Int = 1
-  require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
-  require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
+  require(nReadClients > 0,
+          s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
+  require(
+    nWriteClients == 1,
+    s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
 }
 
 /** VMEBase. Parametrize base class. */
 abstract class VMEBase(implicit p: Parameters)
-  extends GenericParameterizedBundle(p)
+    extends GenericParameterizedBundle(p)
 
 /** VMECmd.
   *
@@ -149,19 +152,19 @@ class VME(implicit p: Parameters) extends Module {
   val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
   val rstate = RegInit(sReadIdle)
 
-  switch (rstate) {
-    is (sReadIdle) {
-      when (rd_arb.io.out.valid) {
+  switch(rstate) {
+    is(sReadIdle) {
+      when(rd_arb.io.out.valid) {
         rstate := sReadAddr
       }
     }
-    is (sReadAddr) {
-      when (io.mem.ar.ready) {
+    is(sReadAddr) {
+      when(io.mem.ar.ready) {
         rstate := sReadData
       }
     }
-    is (sReadData) {
-      when (io.mem.r.fire() && io.mem.r.bits.last) {
+    is(sReadData) {
+      when(io.mem.r.fire() && io.mem.r.bits.last) {
         rstate := sReadIdle
       }
     }
@@ -173,30 +176,34 @@ class VME(implicit p: Parameters) extends Module {
   val lenBits = p(ShellKey).memParams.lenBits
   val wr_cnt = RegInit(0.U(lenBits.W))
 
-  when (wstate === sWriteIdle) {
+  when(wstate === sWriteIdle) {
     wr_cnt := 0.U
-  } .elsewhen (io.mem.w.fire()) {
+  }.elsewhen(io.mem.w.fire()) {
     wr_cnt := wr_cnt + 1.U
   }
 
-  switch (wstate) {
-    is (sWriteIdle) {
-      when (io.vme.wr(0).cmd.valid) {
+  switch(wstate) {
+    is(sWriteIdle) {
+      when(io.vme.wr(0).cmd.valid) {
         wstate := sWriteAddr
       }
     }
-    is (sWriteAddr) {
-      when (io.mem.aw.ready) {
+    is(sWriteAddr) {
+      when(io.mem.aw.ready) {
         wstate := sWriteData
       }
     }
-    is (sWriteData) {
-      when (io.vme.wr(0).data.valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
+    is(sWriteData) {
+      when(
+        io.vme
+          .wr(0)
+          .data
+          .valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
         wstate := sWriteResp
       }
     }
-    is (sWriteResp) {
-      when (io.mem.b.valid) {
+    is(sWriteResp) {
+      when(io.mem.b.valid) {
         wstate := sWriteIdle
       }
     }
@@ -209,12 +216,12 @@ class VME(implicit p: Parameters) extends Module {
   val rd_addr = RegInit(0.U(addrBits.W))
   val wr_addr = RegInit(0.U(addrBits.W))
 
-  when (rd_arb.io.out.fire()) {
+  when(rd_arb.io.out.fire()) {
     rd_len := rd_arb.io.out.bits.len
     rd_addr := rd_arb.io.out.bits.addr
   }
 
-  when (io.vme.wr(0).cmd.fire()) {
+  when(io.vme.wr(0).cmd.fire()) {
     wr_len := io.vme.wr(0).cmd.bits.len
     wr_addr := io.vme.wr(0).cmd.bits.addr
   }
@@ -230,7 +237,7 @@ class VME(implicit p: Parameters) extends Module {
 
   io.vme.wr(0).cmd.ready := wstate === sWriteIdle
   io.vme.wr(0).ack := io.mem.b.fire()
-  io.vme.wr(0).data.ready := wstate === sWriteData &  io.mem.w.ready
+  io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
 
   // mem
   io.mem.aw.valid := wstate === sWriteAddr
index c809311..782aeae 100644 (file)
@@ -26,10 +26,10 @@ import vta.core._
 
 /** Shell parameters. */
 case class ShellParams(
-  hostParams: AXIParams,
-  memParams: AXIParams,
-  vcrParams: VCRParams,
-  vmeParams: VMEParams
+    hostParams: AXIParams,
+    memParams: AXIParams,
+    vcrParams: VCRParams,
+    vmeParams: VMEParams
 )
 
 case object ShellKey extends Field[ShellParams]
@@ -40,7 +40,7 @@ case object ShellKey extends Field[ShellParams]
   * system that can be used for simulation or real hardware.
   */
 class VTAShell(implicit p: Parameters) extends Module {
-  val io = IO(new Bundle{
+  val io = IO(new Bundle {
     val host = new AXILiteClient(p(ShellKey).hostParams)
     val mem = new AXIMaster(p(ShellKey).memParams)
   })
index db72137..ec7bffb 100644 (file)
@@ -20,7 +20,7 @@
 package vta.shell
 
 import chisel3._
-import chisel3.experimental.{RawModule, withClockAndReset}
+import chisel3.experimental.{withClockAndReset, RawModule}
 import vta.util.config._
 import vta.interface.axi._
 
@@ -39,7 +39,9 @@ class XilinxShell(implicit p: Parameters) extends RawModule {
   val m_axi_gmem = IO(new XilinxAXIMaster(mp))
   val s_axi_control = IO(new XilinxAXILiteClient(hp))
 
-  val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) }
+  val shell = withClockAndReset(clock = ap_clk, reset = ~ap_rst_n) {
+    Module(new VTAShell)
+  }
 
   // memory
   m_axi_gmem.AWVALID := shell.io.mem.aw.valid
index 6699507..41104c4 100644 (file)
@@ -21,8 +21,7 @@ package vta.util.config
 
 // taken from https://github.com/vta.roject/rocket-chip
 
-abstract class Field[T] private (val default: Option[T])
-{
+abstract class Field[T] private (val default: Option[T]) {
   def this() = this(None)
   def this(default: T) = this(Some(default))
 }
@@ -31,42 +30,50 @@ abstract class View {
   final def apply[T](pname: Field[T]): T = apply(pname, this)
   final def apply[T](pname: Field[T], site: View): T = {
     val out = find(pname, site)
-    require (out.isDefined, s"Key ${pname} is not defined in Parameters")
+    require(out.isDefined, s"Key ${pname} is not defined in Parameters")
     out.get
   }
 
   final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
-  final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T])
+  final def lift[T](pname: Field[T], site: View): Option[T] =
+    find(pname, site).map(_.asInstanceOf[T])
 
   protected[config] def find[T](pname: Field[T], site: View): Option[T]
 }
 
 abstract class Parameters extends View {
-  final def ++ (x: Parameters): Parameters =
+  final def ++(x: Parameters): Parameters =
     new ChainParameters(this, x)
 
-  final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters =
+  final def alter(
+      f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
     Parameters(f) ++ this
 
-  final def alterPartial(f: PartialFunction[Any,Any]): Parameters =
-    Parameters((_,_,_) => f) ++ this
+  final def alterPartial(f: PartialFunction[Any, Any]): Parameters =
+    Parameters((_, _, _) => f) ++ this
 
-  final def alterMap(m: Map[Any,Any]): Parameters =
+  final def alterMap(m: Map[Any, Any]): Parameters =
     new MapParameters(m) ++ this
 
-  protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T]
-  protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname)
+  protected[config] def chain[T](site: View,
+                                 tail: View,
+                                 pname: Field[T]): Option[T]
+  protected[config] def find[T](pname: Field[T], site: View) =
+    chain(site, new TerminalView, pname)
 }
 
 object Parameters {
   def empty: Parameters = new EmptyParameters
-  def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f)
+  def apply(f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
+    new PartialParameters(f)
 }
 
 class Config(p: Parameters) extends Parameters {
-  def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f))
+  def this(f: (View, View, View) => PartialFunction[Any, Any]) =
+    this(Parameters(f))
 
-  protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname)
+  protected[config] def chain[T](site: View, tail: View, pname: Field[T]) =
+    p.chain(site, tail, pname)
   override def toString = this.getClass.getSimpleName
   def toInstance = this
 }
@@ -82,17 +89,21 @@ private class ChainView(head: Parameters, tail: View) extends View {
 }
 
 private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
-  def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname)
+  def chain[T](site: View, tail: View, pname: Field[T]) =
+    x.chain(site, new ChainView(y, tail), pname)
 }
 
 private class EmptyParameters extends Parameters {
   def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
 }
 
-private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters {
+private class PartialParameters(
+    f: (View, View, View) => PartialFunction[Any, Any])
+    extends Parameters {
   protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
     val g = f(site, this, tail)
-    if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site)
+    if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T])
+    else tail.find(pname, site)
   }
 }
 
index db19635..db8f5d2 100644 (file)
@@ -23,18 +23,22 @@ package vta.util.genericbundle
 
 import chisel3._
 
-abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle
-{
+abstract class GenericParameterizedBundle[+T <: Object](val params: T)
+    extends Bundle {
   override def cloneType = {
     try {
-      this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type]
+      this.getClass.getConstructors.head
+        .newInstance(params)
+        .asInstanceOf[this.type]
     } catch {
       case e: java.lang.IllegalArgumentException =>
-        throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " +
-                       this.getClass + ", probably because " + this.getClass +
-                       "() takes more than one argument.  Consider overriding " +
-                       "cloneType() on " + this.getClass, e)
+        throw new Exception(
+          "Unable to use GenericParameterizedBundle.cloneType on " +
+            this.getClass + ", probably because " + this.getClass +
+            "() takes more than one argument.  Consider overriding " +
+            "cloneType() on " + this.getClass,
+          e
+        )
     }
   }
 }
-
index 78c7316..f137ab6 100644 (file)
@@ -31,7 +31,6 @@ import vta.test._
   * These configurations are built in a mix/match form based on core
   * and shell configurations.
   */
-
 class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
 class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
 class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config)