s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
+ def define_list_update(self):
+ """Defines a function to update the nth element of a list and return the updated list.
+
+ update(l, i, v) : list[a] -> nat -> a -> list[a]
+ """
+ self.update = GlobalVar("update")
+ a = TypeVar("a")
+ l = Var("l", self.l(a))
+ n = Var("n", self.nat())
+ v = Var("v", a)
+
+ y = Var("y")
+
+ z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
+ s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
+ self.cons(self.hd(l), self.update(self.tl(l), y, v)))
+
+ self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])
+
def define_list_map(self):
"""Defines a function for mapping a function over a list's
elements. That is, map(f, l) returns a new list where
self.define_nat_add()
self.define_list_length()
self.define_list_nth()
+ self.define_list_update()
self.define_list_sum()
self.define_tree_adt()
hd = p.hd
tl = p.tl
nth = p.nth
+update = p.update
length = p.length
map = p.map
foldl = p.foldl
assert got == expected
+def test_update():
+ expected = list(range(10))
+ l = nil()
+ # create zero initialized list
+ for i in range(len(expected)):
+ l = cons(build_nat(0), l)
+
+ # set value
+ for i, v in enumerate(expected):
+ l = update(l, build_nat(i), build_nat(v))
+
+ got = []
+ for i in range(len(expected)):
+ got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
+
+ assert got == expected
+
def test_length():
a = relay.TypeVar("a")
assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])