[VTA] Bug fix for padded load with large inputs (#4293)
authorLiangfu Chen <liangfu.chen@icloud.com>
Fri, 15 Nov 2019 17:59:04 +0000 (01:59 +0800)
committerThierry Moreau <moreau@uw.edu>
Fri, 15 Nov 2019 17:59:04 +0000 (09:59 -0800)
* bug fix for padded load with large inputs

* Update TensorLoad.scala

* Update test_vta_insn.py

vta/hardware/chisel/src/main/scala/core/TensorLoad.scala
vta/tests/python/unittest/test_vta_insn.py

index ca6803c..d184cd2 100644 (file)
@@ -103,20 +103,21 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
           when(dec.xpad_1 =/= 0.U) {
             state := sXPad1
           }.elsewhen(dec.ypad_1 =/= 0.U) {
-              state := sYPad1
-            }
-            .otherwise {
-              state := sIdle
-            }
-        }.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) {
+            state := sYPad1
+          }
+          .otherwise {
+            state := sIdle
+          }
+        }.elsewhen(dataCtrl.io.stride) {
           when(dec.xpad_1 =/= 0.U) {
             state := sXPad1
           }.elsewhen(dec.xpad_0 =/= 0.U) {
-              state := sXPad0
-            }
-            .otherwise {
-              state := sReadCmd
-            }
+            state := sXPad0
+          }.otherwise {
+            state := sReadCmd
+          }
+        }.elsewhen(dataCtrl.io.split) {
+          state := sReadCmd
         }
       }
     }
@@ -168,13 +169,11 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
   xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
     ((state === sIdle & io.start) |
       (state === sYPad0 & yPadCtrl0.io.done) |
-      (io.vme_rd.data
-        .fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
+      (io.vme_rd.data.fire() & ~dataCtrlDone & dataCtrl.io.stride & dec.xpad_1 === 0.U) |
       (state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
 
   xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
-    ((dataCtrl.io.done) |
-      (~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
+    ((dataCtrl.io.done) | (~dataCtrl.io.done & dataCtrl.io.stride & dec.xpad_1 =/= 0.U))
 
   yPadCtrl0.io.inst := io.inst
   yPadCtrl1.io.inst := io.inst
index 574273f..ef3c45c 100644 (file)
@@ -24,6 +24,7 @@ import vta
 import vta.testing
 from vta.testing import simulator
 
+np.random.seed(0xdeadb)
 
 def test_save_load_out():
     """Test save/store output command"""
@@ -88,68 +89,73 @@ def test_save_load_out():
 def test_padded_load():
     """Test padded load."""
     def _run(env, remote):
-        # declare
-        n = 3
-        m = 5
-        pad_before = [2, 1, 0, 0]
-        pad_after = [1, 2, 0, 0]
-        x = tvm.placeholder(
-            (n, m, env.BATCH, env.BLOCK_OUT),
-            name="x",
-            dtype=env.acc_dtype)
-        x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
-        # insert no-op that won't be optimized away
-        y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
+        def check_padded_load(pad_before, pad_after, test_name=None):
+            # declare
+            n = 3
+            m = 5
+            x = tvm.placeholder(
+                (n, m, env.BATCH, env.BLOCK_OUT),
+                name="x",
+                dtype=env.acc_dtype)
+            x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
+            # insert no-op that won't be optimized away
+            y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
+                                 m + pad_before[1] + pad_after[1],
+                                 env.BATCH,
+                                 env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
+            y = tvm.compute((n + pad_before[0] + pad_after[0],
                              m + pad_before[1] + pad_after[1],
                              env.BATCH,
-                             env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
-        y = tvm.compute((n + pad_before[0] + pad_after[0],
-                         m + pad_before[1] + pad_after[1],
-                         env.BATCH,
-                         env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
-        # schedule
-        s = tvm.create_schedule(y.op)
-        s[x_buf].set_scope(env.acc_scope)
-        s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
-        s[y_buf].set_scope(env.acc_scope)
-        s[y_buf].pragma(y_buf.op.axis[0], env.alu)
-        s[y].pragma(y.op.axis[0], env.dma_copy)
-        # build
-        with vta.build_config():
-            mod = vta.build(s, [x, y], "ext_dev", env.target_host)
+                             env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
+            # schedule
+            s = tvm.create_schedule(y.op)
+            s[x_buf].set_scope(env.acc_scope)
+            s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
+            s[y_buf].set_scope(env.acc_scope)
+            s[y_buf].pragma(y_buf.op.axis[0], env.alu)
+            s[y].pragma(y.op.axis[0], env.dma_copy)
+            # build
+            with vta.build_config():
+                mod = vta.build(s, [x, y], "ext_dev", env.target_host)
 
-        if not remote:
-            return
-        temp = util.tempdir()
-        mod.save(temp.relpath("padded_load.o"))
-        remote.upload(temp.relpath("padded_load.o"))
-        f = remote.load_module("padded_load.o")
-        # verify
-        ctx = remote.ext_dev(0)
-        x_np = np.random.randint(-10, 10, size=(
-            n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
-        y_np = np.zeros((n + pad_before[0] + pad_after[0],
-                         m + pad_before[1] + pad_after[1],
-                         env.BATCH,
-                         env.BLOCK_OUT)).astype(y.dtype)
-        y_np[pad_before[0]:pad_before[0] + n,
-             pad_before[1]:pad_before[1] + m,
-             :] = x_np
-        x_nd = tvm.nd.array(x_np, ctx)
-        y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
+            if not remote:
+                return
+            temp = util.tempdir()
+            mod.save(temp.relpath("padded_load.o"))
+            remote.upload(temp.relpath("padded_load.o"))
+            f = remote.load_module("padded_load.o")
+            # verify
+            ctx = remote.ext_dev(0)
+            x_np = np.random.randint(0, 10, size=(
+                n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
+            y_np = np.zeros((n + pad_before[0] + pad_after[0],
+                             m + pad_before[1] + pad_after[1],
+                             env.BATCH,
+                             env.BLOCK_OUT)).astype(y.dtype)
+            y_np[pad_before[0]:pad_before[0] + n,
+                 pad_before[1]:pad_before[1] + m,
+                 :] = x_np
+            x_nd = tvm.nd.array(x_np, ctx)
+            y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
 
-        if env.TARGET in ["sim", "tsim"]:
-            simulator.clear_stats()
+            if env.TARGET in ["sim", "tsim"]:
+                simulator.clear_stats()
 
-        f(x_nd, y_nd)
+            f(x_nd, y_nd)
 
-        np.testing.assert_equal(y_np, y_nd.asnumpy())
+            np.testing.assert_equal(y_np, y_nd.asnumpy())
 
-        if env.TARGET in ["sim", "tsim"]:
-            sim_stats = simulator.stats()
-            print("Padded load execution statistics:")
-            for k, v in sim_stats.items():
-                print("\t{:<16}: {:>16}".format(k, v))
+            if env.TARGET in ["sim", "tsim"]:
+                sim_stats = simulator.stats()
+                print("Padded {} load execution statistics:".format(test_name))
+                for k, v in sim_stats.items():
+                    print("\t{:<16}: {:>16}".format(k, v))
+
+        check_padded_load([2, 0, 0, 0], [0, 0, 0, 0], test_name="Y0")
+        check_padded_load([0, 2, 0, 0], [0, 0, 0, 0], test_name="Y1")
+        check_padded_load([0, 0, 0, 0], [2, 0, 0, 0], test_name="X0")
+        check_padded_load([0, 0, 0, 0], [0, 2, 0, 0], test_name="X1")
+        check_padded_load([1, 1, 0, 0], [1, 1, 0, 0], test_name="all")
 
     vta.testing.run(_run)