From 3a1c8c5d29ac4b98183620a1c0ed646a1b17fa37 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Thu, 28 Nov 2019 01:04:19 +0800 Subject: [PATCH] [VTA] Enable streamlined GEMM execution (#4392) * disable pipelined adder and enable streamlined gemm execution * pipeline first layer of adder * explain difference between pipeadder and adder * add comment for explaining the hard-coded latency --- .../chisel/src/main/scala/core/TensorGemm.scala | 59 +++++++++++++++++----- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala index 3f5f387..7328c42 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala @@ -46,7 +46,10 @@ class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module { io.y := add } -/** Pipelined adder */ +/** PipeAdder + * + * This unit loads input bits into register and performs addition in the next cycle + */ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { val outBits = Math.max(aBits, bBits) + 1 val io = IO(new Bundle { @@ -61,6 +64,27 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { io.y := add } +/** Adder + * + * This unit wires input bits to an adder directly. + * The output comes out of combinational logic without waiting for another cycle. + */ +class Adder(aBits: Int = 8, bBits: Int = 8) extends Module { + val outBits = Math.max(aBits, bBits) + 1 + val io = IO(new Bundle { + val a = Input(SInt(aBits.W)) + val b = Input(SInt(bBits.W)) + val y = Output(SInt(outBits.W)) + }) + val add = Wire(SInt(outBits.W)) + val rA = Wire(SInt(aBits.W)) + val rB = Wire(SInt(bBits.W)) + rA := io.a + rB := io.b + add := rA +& rB + io.y := add +} + /** Pipelined DotProduct based on MAC and PipeAdder */ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module { @@ -80,9 +104,11 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs val a = Seq.tabulate(p)( i => - Seq.fill(s(i + 1))(Module(new PipeAdder( - aBits = (b + i + 1), - bBits = (b + i + 1))))) // # adders within each layer + Seq.fill(s(i + 1))( + if (i == 0) + Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1))) + else + Module(new Adder(aBits = (b + i + 1), bBits = (b + i + 1))))) // # adders within each layer // Vector MACs for (i <- 0 until s(0)) { @@ -126,8 +152,9 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { }) val dot = Seq.fill(size)( Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size))) - val acc = Seq.fill(size)( - Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1))) + // Latency is defined as two in the following, because there is one cycle in the MAC module, + // and another cycle in the pipelined adders as the first layer of the accumulator + val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = 2))) val add = Seq.fill(size)(Wire(SInt(accBits.W))) val vld = Wire(Vec(size, Bool())) @@ -188,7 +215,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) val wgt_i = Reg(chiselTypeOf(dec.uop_end)) val pBits = log2Ceil(p(CoreKey).blockOut) + 1 val inflight = Reg(UInt(pBits.W)) - val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits)) + // Latency is defined as two in the following, because there is one cycle in the MAC module, + // and another cycle in the pipelined adders as the first layer of the accumulator + val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2)) val done = inflight === 0.U & ((state === sExe & cnt_o === dec.lp_0 - 1.U & @@ -236,11 +265,14 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) when(state === sIdle) { inflight := 0.U }.elsewhen(!dec.reset) { - when(state === sReadTensor) { // issue a tensor - inflight := inflight + 1.U - }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor - inflight := inflight - 1.U - } + when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit + inflight := inflight + }.elsewhen(state === sReadTensor) { // issue a tensor + inflight := inflight + 1.U + } + .elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor + inflight := inflight - 1.U + } } when( @@ -278,8 +310,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) inp_i := inp_o wgt_i := wgt_o } - .elsewhen(state === sExe && - uop_idx === uop_end - 1.U) { + .elsewhen(state === sExe && uop_idx === uop_end - 1.U) { cnt_i := cnt_i + 1.U acc_i := acc_i + dec.acc_1 inp_i := inp_i + dec.inp_1 -- 2.7.4