[VTA] Enable streamlined GEMM execution (#4392)
authorLiangfu Chen <liangfu.chen@harman.com>
Wed, 27 Nov 2019 17:04:19 +0000 (01:04 +0800)
committerThierry Moreau <moreau@uw.edu>
Wed, 27 Nov 2019 17:04:18 +0000 (09:04 -0800)
* disable pipelined adder and enable streamlined gemm execution

* pipeline first layer of adder

* explain difference between pipeadder and adder

* add comment for explaining the hard-coded latency

vta/hardware/chisel/src/main/scala/core/TensorGemm.scala

index 3f5f387..7328c42 100644 (file)
@@ -46,7 +46,10 @@ class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module {
   io.y := add
 }
 
-/** Pipelined adder */
+/** PipeAdder
+  *
+  * This unit loads input bits into register and performs addition in the next cycle
+  */
 class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
   val outBits = Math.max(aBits, bBits) + 1
   val io = IO(new Bundle {
@@ -61,6 +64,27 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
   io.y := add
 }
 
+/** Adder
+  *
+  * This unit wires input bits to an adder directly.
+  * The output comes out of combinational logic without waiting for another cycle.
+  */
+class Adder(aBits: Int = 8, bBits: Int = 8) extends Module {
+  val outBits = Math.max(aBits, bBits) + 1
+  val io = IO(new Bundle {
+    val a = Input(SInt(aBits.W))
+    val b = Input(SInt(bBits.W))
+    val y = Output(SInt(outBits.W))
+  })
+  val add = Wire(SInt(outBits.W))
+  val rA = Wire(SInt(aBits.W))
+  val rB = Wire(SInt(bBits.W))
+  rA := io.a
+  rB := io.b
+  add := rA +& rB
+  io.y := add
+}
+
 /** Pipelined DotProduct based on MAC and PipeAdder */
 class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
     extends Module {
@@ -80,9 +104,11 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
   val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
   val a = Seq.tabulate(p)(
     i =>
-      Seq.fill(s(i + 1))(Module(new PipeAdder(
-        aBits = (b + i + 1),
-        bBits = (b + i + 1))))) // # adders within each layer
+      Seq.fill(s(i + 1))(
+        if (i == 0)
+          Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1)))
+        else
+          Module(new Adder(aBits = (b + i + 1), bBits = (b + i + 1))))) // # adders within each layer
 
   // Vector MACs
   for (i <- 0 until s(0)) {
@@ -126,8 +152,9 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
   })
   val dot = Seq.fill(size)(
     Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
-  val acc = Seq.fill(size)(
-    Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
+  // Latency is defined as two in the following, because there is one cycle in the MAC module,
+  // and another cycle in the pipelined adders as the first layer of the accumulator
+  val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = 2)))
   val add = Seq.fill(size)(Wire(SInt(accBits.W)))
   val vld = Wire(Vec(size, Bool()))
 
@@ -188,7 +215,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
   val wgt_i = Reg(chiselTypeOf(dec.uop_end))
   val pBits = log2Ceil(p(CoreKey).blockOut) + 1
   val inflight = Reg(UInt(pBits.W))
-  val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
+  // Latency is defined as two in the following, because there is one cycle in the MAC module,
+  // and another cycle in the pipelined adders as the first layer of the accumulator
+  val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2))
   val done = inflight === 0.U &
     ((state === sExe &
       cnt_o === dec.lp_0 - 1.U &
@@ -236,11 +265,14 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
   when(state === sIdle) {
     inflight := 0.U
   }.elsewhen(!dec.reset) {
-    when(state === sReadTensor) { // issue a tensor
-      inflight := inflight + 1.U
-    }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
-      inflight := inflight - 1.U
-    }
+    when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit
+      inflight := inflight
+    }.elsewhen(state === sReadTensor) { // issue a tensor
+        inflight := inflight + 1.U
+      }
+      .elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
+        inflight := inflight - 1.U
+      }
   }
 
   when(
@@ -278,8 +310,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
       inp_i := inp_o
       wgt_i := wgt_o
     }
-    .elsewhen(state === sExe &&
-      uop_idx === uop_end - 1.U) {
+    .elsewhen(state === sExe && uop_idx === uop_end - 1.U) {
       cnt_i := cnt_i + 1.U
       acc_i := acc_i + dec.acc_1
       inp_i := inp_i + dec.inp_1