# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""Adds certain standard global functions and ADT definitions to the module."""
from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
-from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem
+from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
+from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
self.cons = Constructor("cons", [a, self.l(a)], self.l)
self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])
+
def define_list_hd(self):
"""Defines a function to get the head of a list. Assume the list has at least one
element.
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])
+
def define_list_tl(self):
"""Defines a function to get the tail of a list.
cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a])
+
def define_list_nth(self):
"""Defines a function to get the nth element of a list.
- nth(l) : list[a] -> a
+ nth(l) : list[a] -> Tensor[(), int32] -> a
"""
self.nth = GlobalVar("nth")
a = TypeVar("a")
x = Var("x", self.l(a))
- n = Var("n", self.nat())
+ n = Var("n", scalar_type('int32'))
+
+ body = If(equal(n, const(0)),
+ self.hd(x),
+ self.nth(self.tl(x), subtract(n, const(1))))
+
+ self.mod[self.nth] = Function([x, n], body, a, [a])
- y = Var("y")
- z_case = Clause(PatternConstructor(self.z), self.hd(x))
- s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
- self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
def define_list_update(self):
"""Defines a function to update the nth element of a list and return the updated list.
- update(l, i, v) : list[a] -> nat -> a -> list[a]
+ update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a]
"""
self.update = GlobalVar("update")
a = TypeVar("a")
l = Var("l", self.l(a))
- n = Var("n", self.nat())
+ n = Var("n", scalar_type('int32'))
v = Var("v", a)
- y = Var("y")
+ body = If(equal(n, const(0)),
+ self.cons(v, self.tl(l)),
+ self.cons(self.hd(l),
+ self.update(self.tl(l),
+ subtract(n, const(1)),
+ v)))
- z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
- s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
- self.cons(self.hd(l), self.update(self.tl(l), y, v)))
+ self.mod[self.update] = Function([l, n, v], body, self.l(a), [a])
- self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])
def define_list_map(self):
"""Defines a function for mapping a function over a list's
self.cons(f(y), self.map(f, z)))
self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b])
+
def define_list_foldl(self):
"""Defines a left-way fold over a list.
self.mod[self.foldl] = Function([f, av, bv],
Match(bv, [nil_case, cons_case]), a, [a, b])
+
def define_list_foldr(self):
"""Defines a right-way fold over a list.
self.mod[self.foldr] = Function([f, bv, av],
Match(av, [nil_case, cons_case]), b, [a, b])
+
def define_list_foldr1(self):
"""Defines a right-way fold over a nonempty list.
self.foldr(updater, l2, l1),
self.l(a), [a])
+
def define_list_filter(self):
"""Defines a function that filters a list.
If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t)))
self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a])
+
def define_list_zip(self):
"""Defines a function that combines two lists into a list of tuples of their elements.
self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
self.l(TupleType([a, b])), [a, b])
+
def define_list_rev(self):
"""Defines a function that reverses a list.
self.foldl(updater, self.nil(), l),
self.l(a), [a])
+
def define_list_map_accumr(self):
"""Defines an accumulative map, which is a fold that simulataneously updates
an accumulator value and a list of results.
TupleType([a, self.l(c)]),
[a, b, c])
+
def define_list_map_accuml(self):
"""Defines an accumulative map, which is a fold that simulataneously updates
an accumulator value and a list of results.
self.none = Constructor("none", [], self.optional)
self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none])
+
def define_list_unfoldr(self):
"""Defines a function that builds up a list starting from a seed value.
self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]),
self.l(b), [a, b])
+
def define_list_unfoldl(self):
"""Defines a function that builds up a list starting from a seed value.
self.rev(self.unfoldr(f, s)),
self.l(b), [a, b])
- def define_nat_adt(self):
- """Defines a Peano (unary) natural number ADT.
- Zero is represented by z(). s(n) adds 1 to a nat n."""
- self.nat = GlobalTypeVar("nat")
- self.z = Constructor("z", [], self.nat)
- self.s = Constructor("s", [self.nat()], self.nat)
- self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s])
-
- def define_nat_double(self):
- """Defines a function that doubles a nat."""
- self.double = GlobalVar("double")
- x = Var("x", self.nat())
- y = Var("y")
- z_case = Clause(PatternConstructor(self.z), self.z())
- s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
- self.s(self.s(self.double(y))))
- self.mod[self.double] = Function([x], Match(x, [z_case, s_case]))
-
- def define_nat_add(self):
- """Defines a function that adds two nats."""
- self.add = GlobalVar("add")
- x = Var("x", self.nat())
- y = Var("y", self.nat())
- a = Var("a")
- z_case = Clause(PatternConstructor(self.z), y)
- s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]),
- self.s(self.add(a, y)))
- self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case]))
def define_list_sum(self):
- """Defines a function that computes the sum of a list of nats."""
+ """Defines a function that computes the sum of a list of integer scalars."""
self.sum = GlobalVar("sum")
- a = Var("a", self.l(self.nat()))
- self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a))
+ a = Var("a", self.l(scalar_type('int32')))
+ x = Var('x')
+ y = Var('y')
+ addf = Function([x, y], add(x, y))
+ self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a))
+
def define_list_length(self):
- """Defines a function that returns the length of a list as a nat"""
+ """Defines a function that returns the length of a list"""
self.length = GlobalVar("length")
a = TypeVar("a")
x = Var("x", self.l(a))
y = Var("y")
- nil_case = Clause(PatternConstructor(self.nil), self.z())
+ nil_case = Clause(PatternConstructor(self.nil), const(0))
cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]),
- self.s(self.length(y)))
+ add(const(1), self.length(y)))
self.mod[self.length] = Function([x],
- Match(x, [nil_case, cons_case]), None, [a])
+ Match(x, [nil_case, cons_case]), scalar_type('int32'), [a])
+
def define_tree_adt(self):
"""Defines a tree ADT. A tree can contain any type.
self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree)
self.mod[self.tree] = TypeData(self.tree, [a], [self.rose])
+
def define_tree_map(self):
"""Defines a function that maps over a tree. The function
is applied to each subtree's contents.
self.mod[self.tmap] = Function([f, t],
Match(t, [rose_case]), self.tree(b), [a, b])
+
def define_tree_size(self):
- """Defines a function that computes the size of a tree as a nat.
+ """Defines a function that computes the size of a tree.
- Signature: fn<a>(t : tree[a]) -> nat
+ Signature: fn<a>(t : tree[a]) -> Tensor[(), int32]
"""
self.size = GlobalVar("size")
a = TypeVar("a")
t = Var("t", self.tree(a))
- x = Var("x", self.tree(a))
z = Var("z")
rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]),
- self.s(self.sum(self.map(Function([x], self.size(x)), z))))
+ add(const(1), self.sum(self.map(self.size, z))))
self.mod[self.size] = Function([t],
- Match(t, [rose_case]), self.nat(), [a])
+ Match(t, [rose_case]), scalar_type('int32'), [a])
+
def define_id(self):
- """Defines a function that return it's argument.
+ """Defines a function that return its argument.
Signature: fn<a>(x : a) -> a
"""
def define_compose(self):
- """Defines a function that compose two function.
+ """Defines a function that composes two function.
Signature: fn<a, b, c>(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c
"""
def define_iterate(self):
- """Define a function that take a number n, a function f,
- and return a closure that apply f n time on it's argument.
+ """Defines a function that take a number n and a function f;
+ returns a closure that takes an argument and applies f
+ n times to its argument.
- Signature: fn<a>(n : nat, f : fn(a) -> a) -> fn(a) -> a
+ Signature: fn<a>(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a
"""
self.iterate = GlobalVar("iterate")
a = TypeVar("a")
f = Var("f", FuncType([a], a))
- x = Var("x", self.nat())
- y = Var("y", self.nat())
- z_case = Clause(PatternConstructor(self.z), self.id)
- s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
- self.compose(f, self.iterate(f, y)))
+ x = Var("x", scalar_type('int32'))
+ body = If(equal(x, const(0)),
+ self.id,
+ self.compose(f,
+ self.iterate(f, subtract(x, const(1)))))
self.mod[self.iterate] = Function([f, x],
- Match(x, [z_case, s_case]),
+ body,
FuncType([a], a),
[a])
+
def __init__(self, mod):
self.mod = mod
self.define_list_adt()
self.define_list_unfoldr()
self.define_list_unfoldl()
- self.define_nat_adt()
- self.define_nat_double()
- self.define_nat_add()
self.define_list_length()
self.define_list_nth()
self.define_list_update()
from .config import ctx_list
from .init import create_workload
+from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Defines a unary natural number (Peano natural number) abstract
+data type for Relay and provides some utility functions for it.
+Nats are useful for testing purposes, as they make it easy to write
+test cases for recursion and pattern matching."""
+
+from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar
+from tvm.relay.backend.interpreter import ConstructorValue
+from tvm.relay.expr import Var, Function, GlobalVar
+from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
+
+def define_nat_adt(prelude):
+ """Defines a Peano (unary) natural number ADT.
+ Zero is represented by z(). s(n) adds 1 to a nat n.
+ Adds the fields nat, z, and s to the preluide, representing
+ (respectively) the nat ADT and the z and s constructors.
+ """
+ prelude.nat = GlobalTypeVar("nat")
+ prelude.z = Constructor("z", [], prelude.nat)
+ prelude.s = Constructor("s", [prelude.nat()], prelude.nat)
+ prelude.mod[prelude.nat] = TypeData(prelude.nat, [], [prelude.z, prelude.s])
+
+
+def define_nat_double(prelude):
+ """Defines a function that doubles a nat. Adds a field called
+ 'double' to the prelude, giving the GlobalVar pointing to
+ the function.
+ """
+ prelude.double = GlobalVar("double")
+ x = Var("x", prelude.nat())
+ y = Var("y")
+ z_case = Clause(PatternConstructor(prelude.z), prelude.z())
+ s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+ prelude.s(prelude.s(prelude.double(y))))
+ prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case]))
+
+
+def define_nat_add(prelude):
+ """Defines a function that adds two nats and adds a field to the
+ prelude 'add' giving the GlobalVar pointing to that function.
+ """
+ prelude.add = GlobalVar("add")
+ x = Var("x", prelude.nat())
+ y = Var("y", prelude.nat())
+ a = Var("a")
+ z_case = Clause(PatternConstructor(prelude.z), y)
+ s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]),
+ prelude.s(prelude.add(a, y)))
+ prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case]))
+
+
+# versions of prelude functions that use nats instead of scalars
+
+def define_nat_nth(prelude):
+ """Defines a function to get the nth eleemnt of a list using
+ a nat to index into the list.
+
+ nat_nth(l, n): fun<a>(list[a], nat) -> a
+ """
+ prelude.nat_nth = GlobalVar("nat_nth")
+ a = TypeVar("a")
+ x = Var("x", prelude.l(a))
+ n = Var("n", prelude.nat())
+ y = Var("y")
+
+ z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x))
+ s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+ prelude.nat_nth(prelude.tl(x), y))
+
+ prelude.mod[prelude.nat_nth] = Function([x, n],
+ Match(n, [z_case, s_case]),
+ a, [a])
+
+
+def define_nat_update(prelude):
+ """Defines a function to update the nth element of a list and return the updated list.
+
+ nat_update(l, i, v) : fun<a>(list[a], nat, a) -> list[a]
+ """
+ prelude.nat_update = GlobalVar("nat_update")
+ a = TypeVar("a")
+ # pylint: disable=invalid-name
+ l = Var("l", prelude.l(a))
+ n = Var("n", prelude.nat())
+ v = Var("v", a)
+ y = Var("y")
+
+ z_case = Clause(PatternConstructor(prelude.z),
+ prelude.cons(v, prelude.tl(l)))
+ s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+ prelude.cons(
+ prelude.hd(l),
+ prelude.nat_update(prelude.tl(l), y, v)))
+
+ prelude.mod[prelude.nat_update] = Function([l, n, v],
+ Match(n, [z_case, s_case]),
+ prelude.l(a), [a])
+
+
+def define_nat_iterate(prelude):
+ """Defines a function that takes a number n and a function f;
+ returns a closure that takes an argument and applies f
+ n times to its argument.
+
+ Signature: fn<a>(fn(a) -> a, nat) -> fn(a) -> a
+ """
+ prelude.nat_iterate = GlobalVar("nat_iterate")
+ a = TypeVar("a")
+ f = Var("f", FuncType([a], a))
+ x = Var("x", prelude.nat())
+ y = Var("y", prelude.nat())
+
+ z_case = Clause(PatternConstructor(prelude.z), prelude.id)
+ s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+ prelude.compose(f, prelude.nat_iterate(f, y)))
+
+ prelude.mod[prelude.nat_iterate] = Function([f, x],
+ Match(x, [z_case, s_case]),
+ FuncType([a], a),
+ [a])
+
+
+def add_nat_definitions(prelude):
+ """Given a Relay prelude, adds a Peano nat ADT, as well as functions
+ for adding nats and doubling nats. It also adds versions of
+ update, nth, and iterate that take nats instead of scalars (the
+ names are prefixed with 'nat_')."""
+ define_nat_adt(prelude)
+ define_nat_double(prelude)
+ define_nat_add(prelude)
+ define_nat_nth(prelude)
+ define_nat_update(prelude)
+ define_nat_iterate(prelude)
+
+
+# helper functions for working with nats
+
+
+def count(n):
+ """Takes a ConstructorValue corresponding to a nat ADT
+ and converts it into a Python integer. This is an example of
+ using an ADT value in Python.
+ """
+ assert isinstance(n, ConstructorValue)
+ if n.constructor.name_hint == 'z':
+ return 0
+ assert n.constructor.name_hint == 's'
+ return 1 + count(n.fields[0])
+
+
+def make_nat_value(prelude, n):
+ """The inverse of count(): Given a non-negative Python integer,
+ constructs a ConstructorValue representing that value as a nat.
+ """
+ if n == 0:
+ return ConstructorValue(prelude.z, [], [])
+ return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
+
+
+def make_nat_expr(prelude, n):
+ """Given a non-negative Python integer, constructs a Python
+ expression representing that integer's value as a nat.
+ """
+ assert n >= 0
+ ret = prelude.z()
+ while n > 0:
+ ret = prelude.s(ret)
+ n = n - 1
+ return ret
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import numpy as np
import tvm
from tvm import relay
from tvm.relay.ir_pass import infer_type
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay import testing, create_executor
from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
mod = relay.Module()
p = Prelude(mod)
+add_nat_definitions(p)
+
ctx = tvm.context("llvm", 0)
intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
compose = p.compose
iterate = p.iterate
-# this is an example of using the adt value in python side
-def count(n):
- assert isinstance(n, ConstructorValue)
- if n.constructor.name_hint == 's':
- return 1 + count(n.fields[0])
- else:
- assert n.constructor.name_hint == 'z'
- return 0
-
# this is an example of creating the adt value in python side
def make_nat(n):
if n != 0:
else:
return ConstructorValue(z, [], [])
-def build_nat(n):
+def make_nat_expr(n):
assert n >= 0
ret = z()
while n > 0:
ret['children'].append(l)
return ret
+
+# turns a scalar-valued relay tensor value into a python number
+def get_scalar(tv):
+ return tv.asnumpy().item()
+
+
def test_nat_value():
- assert count(make_nat(10)) == 10
+ assert count(make_nat_value(p, 10)) == 10
assert count(intrp.evaluate(s(s(z())))) == 2
expected = list(range(10))
l = nil()
for i in reversed(expected):
- l = cons(build_nat(i), l)
+ l = cons(make_nat_expr(i), l)
got = []
for i in range(len(expected)):
expected = list(range(10))
l = nil()
for i in reversed(expected):
- l = cons(build_nat(i), l)
+ l = cons(relay.const(i), l)
- got = []
for i in range(len(expected)):
- got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
+ item = intrp.evaluate(nth(l, relay.const(i)))
+ assert get_scalar(item) == i
- assert got == expected
def test_update():
expected = list(range(10))
l = nil()
# create zero initialized list
for i in range(len(expected)):
- l = cons(build_nat(0), l)
+ l = cons(make_nat_expr(0), l)
# set value
for i, v in enumerate(expected):
- l = update(l, build_nat(i), build_nat(v))
+ l = update(l, relay.const(i), make_nat_expr(v))
got = []
for i in range(len(expected)):
- got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
+ got.append(count(intrp.evaluate(nth(l, relay.const(i)))))
assert got == expected
def test_length():
a = relay.TypeVar("a")
- assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])
+ assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a])
res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
- assert count(res) == 3
+ assert get_scalar(res) == 3
def test_map():
y = relay.Var("y")
rev_dup = relay.Function([y, x], cons(x, cons(x, y)))
res = intrp.evaluate(foldl(rev_dup, nil(),
- cons(build_nat(1),
- cons(build_nat(2),
- cons(build_nat(3), nil())))))
+ cons(make_nat_expr(1),
+ cons(make_nat_expr(2),
+ cons(make_nat_expr(3), nil())))))
reversed = to_list(res)
assert len(reversed) == 6
assert count(reversed[0]) == 3 and count(reversed[1]) == 3
y = relay.Var("y")
identity = relay.Function([x, y], cons(x, y))
res = intrp.evaluate(foldr(identity, nil(),
- cons(build_nat(1),
- cons(build_nat(2),
- cons(build_nat(3), nil())))))
+ cons(make_nat_expr(1),
+ cons(make_nat_expr(2),
+ cons(make_nat_expr(3), nil())))))
same = to_list(res)
assert len(same) == 3
assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3
y = relay.Var("y")
f = relay.Function([x, y], add(x, y))
res = intrp.evaluate(foldr1(f,
- cons(build_nat(1),
- cons(build_nat(2),
- cons(build_nat(3), nil())))))
+ cons(make_nat_expr(1),
+ cons(make_nat_expr(2),
+ cons(make_nat_expr(3), nil())))))
assert count(res) == 6
def test_sum():
- assert mod[sum].checked_type == relay.FuncType([l(nat())], nat())
- res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil()))))
- assert count(res) == 3
+ assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
+ res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil()))))
+ assert get_scalar(res) == 3
def test_concat():
a = relay.TypeVar("a")
assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a])
- l1 = cons(build_nat(1), cons(build_nat(2), nil()))
- l2 = cons(build_nat(3), cons(build_nat(4), nil()))
+ l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), nil()))
+ l2 = cons(make_nat_expr(3), cons(make_nat_expr(4), nil()))
res = intrp.evaluate(concat(l1, l2))
catted = to_list(res)
]))
res = intrp.evaluate(
filter(greater_than_one,
- cons(build_nat(1),
- cons(build_nat(1),
- cons(build_nat(3),
- cons(build_nat(1),
- cons(build_nat(5),
- cons(build_nat(1),
+ cons(make_nat_expr(1),
+ cons(make_nat_expr(1),
+ cons(make_nat_expr(3),
+ cons(make_nat_expr(1),
+ cons(make_nat_expr(5),
+ cons(make_nat_expr(1),
nil()))))))))
filtered = to_list(res)
assert len(filtered) == 2
l(relay.TupleType([a, b])), [a, b])
assert mod[zip].checked_type == expected_type
- l1 = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
+ l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
l2 = cons(nil(),
cons(cons(nil(), nil()),
cons(cons(nil(), cons(nil(), nil())),
assert len(to_list(zipped[2][1])) == 2
# test truncation
- l3 = cons(build_nat(4), cons(build_nat(5), nil()))
+ l3 = cons(make_nat_expr(4), cons(make_nat_expr(5), nil()))
shorter_res = intrp.evaluate(zip(l3, l2))
truncated = to_list(shorter_res)
assert len(truncated) == 2
a = relay.TypeVar("a")
assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a])
- res = intrp.evaluate(rev(cons(build_nat(1),
- cons(build_nat(2),
- cons(build_nat(3), nil())))))
+ res = intrp.evaluate(rev(cons(make_nat_expr(1),
+ cons(make_nat_expr(2),
+ cons(make_nat_expr(3), nil())))))
reversed = to_list(res)
assert len(reversed) == 3
relay.Clause(relay.PatternConstructor(z, []), none())
]))
- res = intrp.evaluate(unfoldr(count_down, build_nat(3)))
+ res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3)))
unfolded = to_list(res)
assert len(unfolded) == 3
relay.Clause(relay.PatternConstructor(z, []), none())
]))
- res = intrp.evaluate(unfoldl(count_down, build_nat(3)))
+ res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3)))
unfolded = to_list(res)
assert len(unfolded) == 3
relay.Tuple([add(x, acc),
add(x, acc)]))
- vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
+ vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals))
sum = count(res[0])
add_to_acc = relay.Function([acc, x],
relay.Tuple([add(x, acc), x]))
- vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
+ vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
res = intrp.evaluate(map_accuml(add_to_acc, z(), vals))
sum = count(res[0])
]))
res = intrp.evaluate(foldr(condense, nil(), cons(
- some(build_nat(3)),
- cons(none(), cons(some(build_nat(1)), nil())))))
+ some(make_nat_expr(3)),
+ cons(none(), cons(some(make_nat_expr(1)), nil())))))
reduced = to_list(res)
assert len(reduced) == 2
def test_size():
a = relay.TypeVar("a")
lhs = mod[size].checked_type
- rhs = relay.FuncType([tree(a)], nat(), [a])
+ rhs = relay.FuncType([tree(a)], relay.scalar_type('int32'), [a])
assert lhs == rhs
root = rose(z(), cons(rose(z(), nil()),
nil())))
t = rose(z(), cons(root, cons(root, cons(root, nil()))))
res = intrp.evaluate(size(t))
- assert count(res) == 10
+ assert get_scalar(res) == 10
def test_wildcard_match_solo():
inner_match)
]), l(a), [a])
- first_list = cons(build_nat(1), cons(build_nat(2),
- cons(build_nat(3), nil())))
- second_list = cons(build_nat(4), cons(build_nat(5),
- cons(build_nat(6), nil())))
+ first_list = cons(make_nat_expr(1), cons(make_nat_expr(2),
+ cons(make_nat_expr(3), nil())))
+ second_list = cons(make_nat_expr(4), cons(make_nat_expr(5),
+ cons(make_nat_expr(6), nil())))
final_list = cons(first_list, cons(second_list, nil()))
res = intrp.evaluate(flatten(final_list))
assert count(res) == 2
+
def test_compose():
n = relay.Var('n')
inc = relay.Function([n], s(n))
res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))]))
assert count(res) == 5
+
def test_iterate():
- expr = relay.Call(iterate(double, build_nat(2)), [build_nat(3)])
+ expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)])
res = intrp.evaluate(relay.Function([], expr)())
assert count(res) == 12
+
if __name__ == "__main__":
test_nat_constructor()
test_double()
mod = relay.Module()
p = Prelude(mod)
x = relay.Var("x")
- s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x)
+ some_case = relay.Clause(relay.PatternConstructor(p.some,
+ [relay.PatternVar(x)]),
+ x)
default_case = relay.Clause(relay.PatternVar(x), x)
- m0 = relay.Match(p.z(), [default_case])
- m1 = relay.Match(p.z(), [s_case, default_case])
+ m0 = relay.Match(p.none(), [default_case])
+ m1 = relay.Match(p.none(), [some_case, default_case])
assert well_formed(m0)
assert not well_formed(m1)
relay.PatternVar(a)]),
p.cons(z, a))
- data = p.cons(p.z(), p.cons(p.z(), p.nil()))
+ data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))
match = relay.Match(data, [nil_case, cons_case])
equivalent = relay.Match(data, [nil_case, equivalent_cons])
relay.Clause(relay.PatternWildcard(), p.nil())
])
wrong_constructors = relay.Match(data, [
- relay.Clause(relay.PatternConstructor(p.z), p.nil()),
- relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]),
+ relay.Clause(relay.PatternConstructor(p.none), p.nil()),
+ relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]),
p.cons(x, p.nil()))
])
from tvm.relay.ir_pass import free_vars, free_type_vars, gradient
from tvm.relay import create_executor
from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, make_nat_expr
import numpy as np
def test_pow():
mod = relay.Module()
p = Prelude(mod)
+ add_nat_definitions(p)
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
double = relay.Function([x], x + x)
i = relay.var("i", t)
- func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i]))
+ func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
i_nd = rand(dtype, *shape)
from tvm.relay import op, create_executor
from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, count
def check_eval(expr, expected_result, mod=None, rtol=1e-07):
check_eval(to_a_normal_form(body), 3)
-# this is an example of using the adt value in python side
-def count(n):
- assert isinstance(n, ConstructorValue)
- if n.constructor.name_hint == 's':
- return 1 + count(n.fields[0])
- else:
- assert n.constructor.name_hint == 'z'
- return 0
-
-
-def test_add():
+def test_nat_add():
mod = relay.Module()
p = Prelude(mod)
+ add_nat_definitions(p)
nat = p.nat
add = p.add
s = p.s
test_ref()
test_add()
test_let()
+ test_nat_add()
test_function()