[Relay] Add list update to prelude (#2866)
authorWei Chen <ipondering.weic@gmail.com>
Sat, 23 Mar 2019 01:21:31 +0000 (18:21 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Sat, 23 Mar 2019 01:21:31 +0000 (18:21 -0700)
python/tvm/relay/prelude.py
tests/python/relay/test_adt.py

index 26f00c5c5e6d25ae969b575026777cf9aa7c6b2b..fb8c58bf431e89515b9e2b53f6d5f19f214695a3 100644 (file)
@@ -62,6 +62,25 @@ class Prelude:
         s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
         self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
 
+    def define_list_update(self):
+        """Defines a function to update the nth element of a list and return the updated list.
+
+        update(l, i, v) : list[a] -> nat -> a -> list[a]
+        """
+        self.update = GlobalVar("update")
+        a = TypeVar("a")
+        l = Var("l", self.l(a))
+        n = Var("n", self.nat())
+        v = Var("v", a)
+
+        y = Var("y")
+
+        z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
+        s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
+                        self.cons(self.hd(l), self.update(self.tl(l), y, v)))
+
+        self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])
+
     def define_list_map(self):
         """Defines a function for mapping a function over a list's
         elements. That is, map(f, l) returns a new list where
@@ -470,6 +489,7 @@ class Prelude:
         self.define_nat_add()
         self.define_list_length()
         self.define_list_nth()
+        self.define_list_update()
         self.define_list_sum()
 
         self.define_tree_adt()
index e176194fede600ad50e5a322a168499ed3c1e8f2..e9e2915f28a895abf0b53e7fd9ce9d5685c822b1 100644 (file)
@@ -26,6 +26,7 @@ l = p.l
 hd = p.hd
 tl = p.tl
 nth = p.nth
+update = p.update
 length = p.length
 map = p.map
 foldl = p.foldl
@@ -148,6 +149,23 @@ def test_nth():
 
     assert got == expected
 
+def test_update():
+    expected = list(range(10))
+    l = nil()
+    # create zero initialized list
+    for i in range(len(expected)):
+        l = cons(build_nat(0), l)
+
+    # set value
+    for i, v in enumerate(expected):
+        l = update(l, build_nat(i), build_nat(v))
+
+    got = []
+    for i in range(len(expected)):
+        got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
+
+    assert got == expected
+
 def test_length():
     a = relay.TypeVar("a")
     assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])