[VTA] [Chisel] support for different inp/wgt bits, rewrote DotProduct for clarity...
authorBenjamin Tu <tu.benjamin1115@gmail.com>
Fri, 26 Jul 2019 01:47:04 +0000 (18:47 -0700)
committerThierry Moreau <moreau@uw.edu>
Fri, 26 Jul 2019 01:47:04 +0000 (18:47 -0700)
* support for different inp/wgt bits, rewrote dot for clarity

* [VTA] [Chisel] support for different inp/wgt bits, rewrote DotProduct for clarity

* [VTA] [Chisel] support for different inp/wgt bits, rewrote DotProduct for clarity

* change back to sim

* fix index

* fix index

* fix indent

* fix indent

* fix indent

* fix trailing spaces

* fix trailing spaces

* change to more descriptive name

* matric->matrix

* fix spacing

* fix spacing & added generic name for dot

* better parameter flow

* spacing

* spacing

* spacing

* update requirement (tested) for dot, spacing

* function call convention

* small edit

vta/hardware/chisel/src/main/scala/core/TensorGemm.scala

index a910864..bfa79dd 100644 (file)
@@ -26,88 +26,91 @@ import vta.util.config._
 import scala.math.pow
 
 /** Pipelined multiply and accumulate */
-class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module {
-  require (cBits >= dataBits * 2)
-  require (outBits >= dataBits * 2)
+class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module {
+  val outBits = Math.max(aBits + bBits, cBits) + 1
   val io = IO(new Bundle {
-    val a = Input(SInt(dataBits.W))
-    val b = Input(SInt(dataBits.W))
+    val a = Input(SInt(aBits.W))
+    val b = Input(SInt(bBits.W))
     val c = Input(SInt(cBits.W))
     val y = Output(SInt(outBits.W))
   })
-  val mult = Wire(SInt(cBits.W))
+  val mult = Wire(SInt((aBits + bBits).W))
   val add = Wire(SInt(outBits.W))
   val rA = RegNext(io.a)
   val rB = RegNext(io.b)
   val rC = RegNext(io.c)
+  
   mult := rA * rB
-  add := rC + mult
+  add := rC +& mult
+  
   io.y := add
 }
 
 /** Pipelined adder */
-class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module {
-  require (outBits >= dataBits)
+class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
+  val outBits = Math.max(aBits, bBits) + 1
   val io = IO(new Bundle {
-    val a = Input(SInt(dataBits.W))
-    val b = Input(SInt(dataBits.W))
+    val a = Input(SInt(aBits.W))
+    val b = Input(SInt(bBits.W))
     val y = Output(SInt(outBits.W))
   })
   val add = Wire(SInt(outBits.W))
   val rA = RegNext(io.a)
   val rB = RegNext(io.b)
-  add := rA + rB
+  add := rA +& rB
   io.y := add
 }
 
-/** Pipelined DotProduct based on MAC and Adder */
-class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module {
-  val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
-  require(size >= 4 && isPow2(size), errMsg)
-  val b = dataBits * 2
+/** Pipelined DotProduct based on MAC and PipeAdder */
+class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module {
+  val errorMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
+  require(size >= 2 && isPow2(size), errorMsg)
+  val b = aBits + bBits
   val outBits = b + log2Ceil(size) + 1
   val io = IO(new Bundle {
-    val a = Input(Vec(size, SInt(dataBits.W)))
-    val b = Input(Vec(size, SInt(dataBits.W)))
+    val a = Input(Vec(size, SInt(aBits.W)))
+    val b = Input(Vec(size, SInt(bBits.W)))
     val y = Output(SInt(outBits.W))
   })
-  val p = log2Ceil(size/2)
-  val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt)
-  val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i)))
-  val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i)))
-  val m = Seq.tabulate(2)(i =>
-    Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1)))
-  )
+  val s = Seq.tabulate(log2Ceil(size + 1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers
+  val p = log2Ceil(size / 2) + 1 // # of adder layers
+  val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
   val a = Seq.tabulate(p)(i =>
-    Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3)))
-  )
+    Seq.fill(s(i + 1))(Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1))))
+  ) // # adders within each layer
 
-  for (i <- 0 until log2Ceil(size)) {
-    for (j <- 0 until s(i)) {
+  // Vector MACs
+  for (i <- 0 until s(0)) {
+    m(i).io.a := io.a(i)
+    m(i).io.b := io.b(i)
+    m(i).io.c := 0.S
+  }
+
+  // PipeAdder Reduction
+  for (i <- 0 until p) {
+    for (j <- 0 until s(i + 1)) {
       if (i == 0) {
-        m(i)(j).io.a := io.a(j)
-        m(i)(j).io.b := io.b(j)
-        m(i)(j).io.c := 0.S
-        m(i + 1)(j).io.a := da(j)
-        m(i + 1)(j).io.b := db(j)
-        m(i + 1)(j).io.c := m(i)(j).io.y
-      } else if (i == 1) {
-        a(i - 1)(j).io.a := m(i)(2*j).io.y
-        a(i - 1)(j).io.b := m(i)(2*j + 1).io.y
+        // First layer of PipeAdders
+        a(i)(j).io.a := m(2 * j).io.y
+        a(i)(j).io.b := m(2 * j + 1).io.y
       } else {
-        a(i - 1)(j).io.a := a(i - 2)(2*j).io.y
-        a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y
+        a(i)(j).io.a := a(i - 1)(2 * j).io.y
+        a(i)(j).io.b := a(i - 1)(2 * j + 1).io.y
       }
     }
   }
-  io.y := a(p-1)(0).io.y
+
+  // last adder
+  io.y := a(p - 1)(0).io.y
 }
 
-/** Perform matric-vector-multiplication based on DotProduct */
-class MatrixVectorCore(implicit p: Parameters) extends Module {
+/** Perform matrix-vector-multiplication based on DotProduct */
+class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
   val accBits = p(CoreKey).accBits
   val size = p(CoreKey).blockOut
-  val dataBits = p(CoreKey).inpBits
+  val inpBits = p(CoreKey).inpBits
+  val wgtBits = p(CoreKey).wgtBits
+  val outBits = p(CoreKey).outBits
   val io = IO(new Bundle{
     val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
     val inp = new TensorMasterData(tensorType = "inp")
@@ -116,7 +119,7 @@ class MatrixVectorCore(implicit p: Parameters) extends Module {
     val acc_o = new TensorClientData(tensorType = "acc")
     val out = new TensorClientData(tensorType = "out")
   })
-  val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size)))
+  val dot = Seq.fill(size)(Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
   val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
   val add = Seq.fill(size)(Wire(SInt(accBits.W)))
   val vld = Wire(Vec(size, Bool()))
@@ -139,7 +142,7 @@ class MatrixVectorCore(implicit p: Parameters) extends Module {
 
 /** TensorGemm.
   *
-  * This unit instantiate the MatrixVectorCore and go over the
+  * This unit instantiate the MatrixVectorMultiplication and go over the
   * micro-ops (uops) which are used to read inputs, weights and biases,
   * and writes results back to the acc and out scratchpads.
   *
@@ -159,7 +162,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
   })
   val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
   val state = RegInit(sIdle)
-  val mvc = Module(new MatrixVectorCore)
+  val mvc = Module(new MatrixVectorMultiplication)
   val dec = io.inst.asTypeOf(new GemmDecode)
   val uop_idx = Reg(chiselTypeOf(dec.uop_end))
   val uop_end = dec.uop_end