From 87e18a4411d11ff08a18f40d1c44ccf6a32b7db1 Mon Sep 17 00:00:00 2001 From: Benjamin Tu Date: Thu, 25 Jul 2019 18:47:04 -0700 Subject: [PATCH] [VTA] [Chisel] support for different inp/wgt bits, rewrote DotProduct for clarity (#3605) * support for different inp/wgt bits, rewrote dot for clarity * [VTA] [Chisel] support for different inp/wgt bits, rewrote DotProduct for clarity * [VTA] [Chisel] support for different inp/wgt bits, rewrote DotProduct for clarity * change back to sim * fix index * fix index * fix indent * fix indent * fix indent * fix trailing spaces * fix trailing spaces * change to more descriptive name * matric->matrix * fix spacing * fix spacing & added generic name for dot * better parameter flow * spacing * spacing * spacing * update requirement (tested) for dot, spacing * function call convention * small edit --- .../chisel/src/main/scala/core/TensorGemm.scala | 99 +++++++++++----------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala index a910864..bfa79dd 100644 --- a/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala +++ b/vta/hardware/chisel/src/main/scala/core/TensorGemm.scala @@ -26,88 +26,91 @@ import vta.util.config._ import scala.math.pow /** Pipelined multiply and accumulate */ -class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module { - require (cBits >= dataBits * 2) - require (outBits >= dataBits * 2) +class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module { + val outBits = Math.max(aBits + bBits, cBits) + 1 val io = IO(new Bundle { - val a = Input(SInt(dataBits.W)) - val b = Input(SInt(dataBits.W)) + val a = Input(SInt(aBits.W)) + val b = Input(SInt(bBits.W)) val c = Input(SInt(cBits.W)) val y = Output(SInt(outBits.W)) }) - val mult = Wire(SInt(cBits.W)) + val mult = Wire(SInt((aBits + bBits).W)) val add = Wire(SInt(outBits.W)) val rA = RegNext(io.a) val rB = RegNext(io.b) val rC = RegNext(io.c) + mult := rA * rB - add := rC + mult + add := rC +& mult + io.y := add } /** Pipelined adder */ -class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module { - require (outBits >= dataBits) +class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { + val outBits = Math.max(aBits, bBits) + 1 val io = IO(new Bundle { - val a = Input(SInt(dataBits.W)) - val b = Input(SInt(dataBits.W)) + val a = Input(SInt(aBits.W)) + val b = Input(SInt(bBits.W)) val y = Output(SInt(outBits.W)) }) val add = Wire(SInt(outBits.W)) val rA = RegNext(io.a) val rB = RegNext(io.b) - add := rA + rB + add := rA +& rB io.y := add } -/** Pipelined DotProduct based on MAC and Adder */ -class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module { - val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" - require(size >= 4 && isPow2(size), errMsg) - val b = dataBits * 2 +/** Pipelined DotProduct based on MAC and PipeAdder */ +class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module { + val errorMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" + require(size >= 2 && isPow2(size), errorMsg) + val b = aBits + bBits val outBits = b + log2Ceil(size) + 1 val io = IO(new Bundle { - val a = Input(Vec(size, SInt(dataBits.W))) - val b = Input(Vec(size, SInt(dataBits.W))) + val a = Input(Vec(size, SInt(aBits.W))) + val b = Input(Vec(size, SInt(bBits.W))) val y = Output(SInt(outBits.W)) }) - val p = log2Ceil(size/2) - val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt) - val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i))) - val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i))) - val m = Seq.tabulate(2)(i => - Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1))) - ) + val s = Seq.tabulate(log2Ceil(size + 1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers + val p = log2Ceil(size / 2) + 1 // # of adder layers + val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs val a = Seq.tabulate(p)(i => - Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3))) - ) + Seq.fill(s(i + 1))(Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1)))) + ) // # adders within each layer - for (i <- 0 until log2Ceil(size)) { - for (j <- 0 until s(i)) { + // Vector MACs + for (i <- 0 until s(0)) { + m(i).io.a := io.a(i) + m(i).io.b := io.b(i) + m(i).io.c := 0.S + } + + // PipeAdder Reduction + for (i <- 0 until p) { + for (j <- 0 until s(i + 1)) { if (i == 0) { - m(i)(j).io.a := io.a(j) - m(i)(j).io.b := io.b(j) - m(i)(j).io.c := 0.S - m(i + 1)(j).io.a := da(j) - m(i + 1)(j).io.b := db(j) - m(i + 1)(j).io.c := m(i)(j).io.y - } else if (i == 1) { - a(i - 1)(j).io.a := m(i)(2*j).io.y - a(i - 1)(j).io.b := m(i)(2*j + 1).io.y + // First layer of PipeAdders + a(i)(j).io.a := m(2 * j).io.y + a(i)(j).io.b := m(2 * j + 1).io.y } else { - a(i - 1)(j).io.a := a(i - 2)(2*j).io.y - a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y + a(i)(j).io.a := a(i - 1)(2 * j).io.y + a(i)(j).io.b := a(i - 1)(2 * j + 1).io.y } } } - io.y := a(p-1)(0).io.y + + // last adder + io.y := a(p - 1)(0).io.y } -/** Perform matric-vector-multiplication based on DotProduct */ -class MatrixVectorCore(implicit p: Parameters) extends Module { +/** Perform matrix-vector-multiplication based on DotProduct */ +class MatrixVectorMultiplication(implicit p: Parameters) extends Module { val accBits = p(CoreKey).accBits val size = p(CoreKey).blockOut - val dataBits = p(CoreKey).inpBits + val inpBits = p(CoreKey).inpBits + val wgtBits = p(CoreKey).wgtBits + val outBits = p(CoreKey).outBits val io = IO(new Bundle{ val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr val inp = new TensorMasterData(tensorType = "inp") @@ -116,7 +119,7 @@ class MatrixVectorCore(implicit p: Parameters) extends Module { val acc_o = new TensorClientData(tensorType = "acc") val out = new TensorClientData(tensorType = "out") }) - val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size))) + val dot = Seq.fill(size)(Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size))) val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1))) val add = Seq.fill(size)(Wire(SInt(accBits.W))) val vld = Wire(Vec(size, Bool())) @@ -139,7 +142,7 @@ class MatrixVectorCore(implicit p: Parameters) extends Module { /** TensorGemm. * - * This unit instantiate the MatrixVectorCore and go over the + * This unit instantiate the MatrixVectorMultiplication and go over the * micro-ops (uops) which are used to read inputs, weights and biases, * and writes results back to the acc and out scratchpads. * @@ -159,7 +162,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module }) val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6) val state = RegInit(sIdle) - val mvc = Module(new MatrixVectorCore) + val mvc = Module(new MatrixVectorMultiplication) val dec = io.inst.asTypeOf(new GemmDecode) val uop_idx = Reg(chiselTypeOf(dec.uop_end)) val uop_end = dec.uop_end -- 2.7.4