[VTA][Chisel] add ISA BitPat generation (#3891)
authorLuis Vega <vegaluisjose@users.noreply.github.com>
Wed, 4 Sep 2019 17:36:21 +0000 (10:36 -0700)
committerThierry Moreau <moreau@uw.edu>
Wed, 4 Sep 2019 17:36:21 +0000 (10:36 -0700)
vta/hardware/chisel/src/main/scala/core/ISA.scala

index f08b23b..edc1823 100644 (file)
@@ -21,6 +21,7 @@ package vta.core
 
 import chisel3._
 import chisel3.util._
+import scala.collection.mutable.HashMap
 
 /** ISAConstants.
   *
@@ -70,45 +71,78 @@ trait ISAConstants {
 
 /** ISA.
   *
-  * This is the VTA ISA, here we specify the cares and dont-cares that makes
-  * decoding easier. Since instructions are quite long 128-bit, we could generate
-  * these based on ISAConstants.
+  * This is the VTA task ISA
   *
-  * FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler
   * TODO: Add VXOR to clear accumulator
+  * TODO: Use ISA object for decoding as well
+  * TODO: Eventually deprecate ISAConstants
   */
 object ISA {
-  def LUOP =
-    BitPat(
-      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
-  def LWGT =
-    BitPat(
-      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
-  def LINP =
-    BitPat(
-      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
-  def LACC =
-    BitPat(
-      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
-  def SOUT =
-    BitPat(
-      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
-  def GEMM =
-    BitPat(
-      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
-  def VMIN =
-    BitPat(
-      "b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def VMAX =
-    BitPat(
-      "b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def VADD =
-    BitPat(
-      "b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def VSHX =
-    BitPat(
-      "b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
-  def FNSH =
-    BitPat(
-      "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
+  private val xLen = 128
+  private val depBits = 4
+
+  private val idBits: HashMap[String, Int] =
+    HashMap(("task", 3), ("mem", 2), ("alu", 2))
+
+  private val taskId: HashMap[String, String] =
+    HashMap(("load", "000"),
+            ("store", "001"),
+            ("gemm", "010"),
+            ("finish", "011"),
+            ("alu", "100"))
+
+  private val memId: HashMap[String, String] =
+    HashMap(("uop", "00"), ("wgt", "01"), ("inp", "10"), ("acc", "11"))
+
+  private val aluId: HashMap[String, String] =
+    HashMap(("minpool", "00"),
+            ("maxpool", "01"),
+            ("add", "10"),
+            ("shift", "11"))
+
+  private def dontCare(bits: Int): String = "?" * bits
+
+  private def instPat(bin: String): BitPat = BitPat("b" + bin)
+
+  private def load(id: String): BitPat = {
+    val rem = xLen - idBits("mem") - depBits - idBits("task")
+    val inst = dontCare(rem) + memId(id) + dontCare(depBits) + taskId("load")
+    instPat(inst)
+  }
+
+  private def store: BitPat = {
+    val rem = xLen - idBits("task")
+    val inst = dontCare(rem) + taskId("store")
+    instPat(inst)
+  }
+
+  private def gemm: BitPat = {
+    val rem = xLen - idBits("task")
+    val inst = dontCare(rem) + taskId("gemm")
+    instPat(inst)
+  }
+
+  private def alu(id: String): BitPat = {
+    // TODO: move alu id next to task id
+    val inst = dontCare(18) + aluId(id) + dontCare(105) + taskId("alu")
+    instPat(inst)
+  }
+
+  private def finish: BitPat = {
+    val rem = xLen - idBits("task")
+    val inst = dontCare(rem) + taskId("finish")
+    instPat(inst)
+  }
+
+  def LUOP = load("uop")
+  def LWGT = load("wgt")
+  def LINP = load("inp")
+  def LACC = load("acc")
+  def SOUT = store
+  def GEMM = gemm
+  def VMIN = alu("minpool")
+  def VMAX = alu("maxpool")
+  def VADD = alu("add")
+  def VSHX = alu("shift")
+  def FNSH = finish
 }