From 95bfd4a2427b14f400f48aca83dcc125bdd37e62 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 22 May 2019 13:57:53 -0700 Subject: [PATCH] [Relay][Prelude] Remove Peano nats from the prelude (#3045) --- python/tvm/relay/prelude.py | 132 ++++++++-------- python/tvm/relay/testing/__init__.py | 1 + python/tvm/relay/testing/nat.py | 184 +++++++++++++++++++++++ tests/python/relay/test_adt.py | 121 +++++++-------- tests/python/relay/test_ir_well_formed.py | 8 +- tests/python/relay/test_pass_alpha_equal.py | 6 +- tests/python/relay/test_pass_gradient.py | 4 +- tests/python/relay/test_pass_to_a_normal_form.py | 15 +- 8 files changed, 326 insertions(+), 145 deletions(-) create mode 100644 python/tvm/relay/testing/nat.py diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index ff823c3..92647e5 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -17,7 +17,8 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """Adds certain standard global functions and ADT definitions to the module.""" from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type -from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem +from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const +from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard @@ -34,6 +35,7 @@ class Prelude: self.cons = Constructor("cons", [a, self.l(a)], self.l) self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons]) + def define_list_hd(self): """Defines a function to get the head of a list. Assume the list has at least one element. @@ -48,6 +50,7 @@ class Prelude: cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y) self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a]) + def define_list_tl(self): """Defines a function to get the tail of a list. @@ -61,39 +64,44 @@ class Prelude: cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z) self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a]) + def define_list_nth(self): """Defines a function to get the nth element of a list. - nth(l) : list[a] -> a + nth(l) : list[a] -> Tensor[(), int32] -> a """ self.nth = GlobalVar("nth") a = TypeVar("a") x = Var("x", self.l(a)) - n = Var("n", self.nat()) + n = Var("n", scalar_type('int32')) + + body = If(equal(n, const(0)), + self.hd(x), + self.nth(self.tl(x), subtract(n, const(1)))) + + self.mod[self.nth] = Function([x, n], body, a, [a]) - y = Var("y") - z_case = Clause(PatternConstructor(self.z), self.hd(x)) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y)) - self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a]) def define_list_update(self): """Defines a function to update the nth element of a list and return the updated list. - update(l, i, v) : list[a] -> nat -> a -> list[a] + update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a] """ self.update = GlobalVar("update") a = TypeVar("a") l = Var("l", self.l(a)) - n = Var("n", self.nat()) + n = Var("n", scalar_type('int32')) v = Var("v", a) - y = Var("y") + body = If(equal(n, const(0)), + self.cons(v, self.tl(l)), + self.cons(self.hd(l), + self.update(self.tl(l), + subtract(n, const(1)), + v))) - z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l))) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.cons(self.hd(l), self.update(self.tl(l), y, v))) + self.mod[self.update] = Function([l, n, v], body, self.l(a), [a]) - self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a]) def define_list_map(self): """Defines a function for mapping a function over a list's @@ -114,6 +122,7 @@ class Prelude: self.cons(f(y), self.map(f, z))) self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b]) + def define_list_foldl(self): """Defines a left-way fold over a list. @@ -136,6 +145,7 @@ class Prelude: self.mod[self.foldl] = Function([f, av, bv], Match(bv, [nil_case, cons_case]), a, [a, b]) + def define_list_foldr(self): """Defines a right-way fold over a list. @@ -158,6 +168,7 @@ class Prelude: self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), b, [a, b]) + def define_list_foldr1(self): """Defines a right-way fold over a nonempty list. @@ -196,6 +207,7 @@ class Prelude: self.foldr(updater, l2, l1), self.l(a), [a]) + def define_list_filter(self): """Defines a function that filters a list. @@ -214,6 +226,7 @@ class Prelude: If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t))) self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a]) + def define_list_zip(self): """Defines a function that combines two lists into a list of tuples of their elements. @@ -238,6 +251,7 @@ class Prelude: self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]), self.l(TupleType([a, b])), [a, b]) + def define_list_rev(self): """Defines a function that reverses a list. @@ -253,6 +267,7 @@ class Prelude: self.foldl(updater, self.nil(), l), self.l(a), [a]) + def define_list_map_accumr(self): """Defines an accumulative map, which is a fold that simulataneously updates an accumulator value and a list of results. @@ -282,6 +297,7 @@ class Prelude: TupleType([a, self.l(c)]), [a, b, c]) + def define_list_map_accuml(self): """Defines an accumulative map, which is a fold that simulataneously updates an accumulator value and a list of results. @@ -321,6 +337,7 @@ class Prelude: self.none = Constructor("none", [], self.optional) self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none]) + def define_list_unfoldr(self): """Defines a function that builds up a list starting from a seed value. @@ -343,6 +360,7 @@ class Prelude: self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]), self.l(b), [a, b]) + def define_list_unfoldl(self): """Defines a function that builds up a list starting from a seed value. @@ -362,52 +380,29 @@ class Prelude: self.rev(self.unfoldr(f, s)), self.l(b), [a, b]) - def define_nat_adt(self): - """Defines a Peano (unary) natural number ADT. - Zero is represented by z(). s(n) adds 1 to a nat n.""" - self.nat = GlobalTypeVar("nat") - self.z = Constructor("z", [], self.nat) - self.s = Constructor("s", [self.nat()], self.nat) - self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) - - def define_nat_double(self): - """Defines a function that doubles a nat.""" - self.double = GlobalVar("double") - x = Var("x", self.nat()) - y = Var("y") - z_case = Clause(PatternConstructor(self.z), self.z()) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.s(self.s(self.double(y)))) - self.mod[self.double] = Function([x], Match(x, [z_case, s_case])) - - def define_nat_add(self): - """Defines a function that adds two nats.""" - self.add = GlobalVar("add") - x = Var("x", self.nat()) - y = Var("y", self.nat()) - a = Var("a") - z_case = Clause(PatternConstructor(self.z), y) - s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), - self.s(self.add(a, y))) - self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) def define_list_sum(self): - """Defines a function that computes the sum of a list of nats.""" + """Defines a function that computes the sum of a list of integer scalars.""" self.sum = GlobalVar("sum") - a = Var("a", self.l(self.nat())) - self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) + a = Var("a", self.l(scalar_type('int32'))) + x = Var('x') + y = Var('y') + addf = Function([x, y], add(x, y)) + self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a)) + def define_list_length(self): - """Defines a function that returns the length of a list as a nat""" + """Defines a function that returns the length of a list""" self.length = GlobalVar("length") a = TypeVar("a") x = Var("x", self.l(a)) y = Var("y") - nil_case = Clause(PatternConstructor(self.nil), self.z()) + nil_case = Clause(PatternConstructor(self.nil), const(0)) cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), - self.s(self.length(y))) + add(const(1), self.length(y))) self.mod[self.length] = Function([x], - Match(x, [nil_case, cons_case]), None, [a]) + Match(x, [nil_case, cons_case]), scalar_type('int32'), [a]) + def define_tree_adt(self): """Defines a tree ADT. A tree can contain any type. @@ -420,6 +415,7 @@ class Prelude: self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) self.mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + def define_tree_map(self): """Defines a function that maps over a tree. The function is applied to each subtree's contents. @@ -439,23 +435,24 @@ class Prelude: self.mod[self.tmap] = Function([f, t], Match(t, [rose_case]), self.tree(b), [a, b]) + def define_tree_size(self): - """Defines a function that computes the size of a tree as a nat. + """Defines a function that computes the size of a tree. - Signature: fn(t : tree[a]) -> nat + Signature: fn(t : tree[a]) -> Tensor[(), int32] """ self.size = GlobalVar("size") a = TypeVar("a") t = Var("t", self.tree(a)) - x = Var("x", self.tree(a)) z = Var("z") rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), - self.s(self.sum(self.map(Function([x], self.size(x)), z)))) + add(const(1), self.sum(self.map(self.size, z)))) self.mod[self.size] = Function([t], - Match(t, [rose_case]), self.nat(), [a]) + Match(t, [rose_case]), scalar_type('int32'), [a]) + def define_id(self): - """Defines a function that return it's argument. + """Defines a function that return its argument. Signature: fn(x : a) -> a """ @@ -466,7 +463,7 @@ class Prelude: def define_compose(self): - """Defines a function that compose two function. + """Defines a function that composes two function. Signature: fn(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c """ @@ -484,24 +481,26 @@ class Prelude: def define_iterate(self): - """Define a function that take a number n, a function f, - and return a closure that apply f n time on it's argument. + """Defines a function that take a number n and a function f; + returns a closure that takes an argument and applies f + n times to its argument. - Signature: fn(n : nat, f : fn(a) -> a) -> fn(a) -> a + Signature: fn(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a """ self.iterate = GlobalVar("iterate") a = TypeVar("a") f = Var("f", FuncType([a], a)) - x = Var("x", self.nat()) - y = Var("y", self.nat()) - z_case = Clause(PatternConstructor(self.z), self.id) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.compose(f, self.iterate(f, y))) + x = Var("x", scalar_type('int32')) + body = If(equal(x, const(0)), + self.id, + self.compose(f, + self.iterate(f, subtract(x, const(1))))) self.mod[self.iterate] = Function([f, x], - Match(x, [z_case, s_case]), + body, FuncType([a], a), [a]) + def __init__(self, mod): self.mod = mod self.define_list_adt() @@ -522,9 +521,6 @@ class Prelude: self.define_list_unfoldr() self.define_list_unfoldl() - self.define_nat_adt() - self.define_nat_double() - self.define_nat_add() self.define_list_length() self.define_list_nth() self.define_list_update() diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index b4a8394..192afe1 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -30,3 +30,4 @@ from . import densenet from .config import ctx_list from .init import create_workload +from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py new file mode 100644 index 0000000..4c0c87c --- /dev/null +++ b/python/tvm/relay/testing/nat.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Defines a unary natural number (Peano natural number) abstract +data type for Relay and provides some utility functions for it. +Nats are useful for testing purposes, as they make it easy to write +test cases for recursion and pattern matching.""" + +from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar +from tvm.relay.backend.interpreter import ConstructorValue +from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType + +def define_nat_adt(prelude): + """Defines a Peano (unary) natural number ADT. + Zero is represented by z(). s(n) adds 1 to a nat n. + Adds the fields nat, z, and s to the preluide, representing + (respectively) the nat ADT and the z and s constructors. + """ + prelude.nat = GlobalTypeVar("nat") + prelude.z = Constructor("z", [], prelude.nat) + prelude.s = Constructor("s", [prelude.nat()], prelude.nat) + prelude.mod[prelude.nat] = TypeData(prelude.nat, [], [prelude.z, prelude.s]) + + +def define_nat_double(prelude): + """Defines a function that doubles a nat. Adds a field called + 'double' to the prelude, giving the GlobalVar pointing to + the function. + """ + prelude.double = GlobalVar("double") + x = Var("x", prelude.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(prelude.z), prelude.z()) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.s(prelude.s(prelude.double(y)))) + prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case])) + + +def define_nat_add(prelude): + """Defines a function that adds two nats and adds a field to the + prelude 'add' giving the GlobalVar pointing to that function. + """ + prelude.add = GlobalVar("add") + x = Var("x", prelude.nat()) + y = Var("y", prelude.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(prelude.z), y) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]), + prelude.s(prelude.add(a, y))) + prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case])) + + +# versions of prelude functions that use nats instead of scalars + +def define_nat_nth(prelude): + """Defines a function to get the nth eleemnt of a list using + a nat to index into the list. + + nat_nth(l, n): fun(list[a], nat) -> a + """ + prelude.nat_nth = GlobalVar("nat_nth") + a = TypeVar("a") + x = Var("x", prelude.l(a)) + n = Var("n", prelude.nat()) + y = Var("y") + + z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x)) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.nat_nth(prelude.tl(x), y)) + + prelude.mod[prelude.nat_nth] = Function([x, n], + Match(n, [z_case, s_case]), + a, [a]) + + +def define_nat_update(prelude): + """Defines a function to update the nth element of a list and return the updated list. + + nat_update(l, i, v) : fun(list[a], nat, a) -> list[a] + """ + prelude.nat_update = GlobalVar("nat_update") + a = TypeVar("a") + # pylint: disable=invalid-name + l = Var("l", prelude.l(a)) + n = Var("n", prelude.nat()) + v = Var("v", a) + y = Var("y") + + z_case = Clause(PatternConstructor(prelude.z), + prelude.cons(v, prelude.tl(l))) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.cons( + prelude.hd(l), + prelude.nat_update(prelude.tl(l), y, v))) + + prelude.mod[prelude.nat_update] = Function([l, n, v], + Match(n, [z_case, s_case]), + prelude.l(a), [a]) + + +def define_nat_iterate(prelude): + """Defines a function that takes a number n and a function f; + returns a closure that takes an argument and applies f + n times to its argument. + + Signature: fn(fn(a) -> a, nat) -> fn(a) -> a + """ + prelude.nat_iterate = GlobalVar("nat_iterate") + a = TypeVar("a") + f = Var("f", FuncType([a], a)) + x = Var("x", prelude.nat()) + y = Var("y", prelude.nat()) + + z_case = Clause(PatternConstructor(prelude.z), prelude.id) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.compose(f, prelude.nat_iterate(f, y))) + + prelude.mod[prelude.nat_iterate] = Function([f, x], + Match(x, [z_case, s_case]), + FuncType([a], a), + [a]) + + +def add_nat_definitions(prelude): + """Given a Relay prelude, adds a Peano nat ADT, as well as functions + for adding nats and doubling nats. It also adds versions of + update, nth, and iterate that take nats instead of scalars (the + names are prefixed with 'nat_').""" + define_nat_adt(prelude) + define_nat_double(prelude) + define_nat_add(prelude) + define_nat_nth(prelude) + define_nat_update(prelude) + define_nat_iterate(prelude) + + +# helper functions for working with nats + + +def count(n): + """Takes a ConstructorValue corresponding to a nat ADT + and converts it into a Python integer. This is an example of + using an ADT value in Python. + """ + assert isinstance(n, ConstructorValue) + if n.constructor.name_hint == 'z': + return 0 + assert n.constructor.name_hint == 's' + return 1 + count(n.fields[0]) + + +def make_nat_value(prelude, n): + """The inverse of count(): Given a non-negative Python integer, + constructs a ConstructorValue representing that value as a nat. + """ + if n == 0: + return ConstructorValue(prelude.z, [], []) + return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], []) + + +def make_nat_expr(prelude, n): + """Given a non-negative Python integer, constructs a Python + expression representing that integer's value as a nat. + """ + assert n >= 0 + ret = prelude.z() + while n > 0: + ret = prelude.s(ret) + n = n - 1 + return ret diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 58ab0c4..77f4ab1 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -14,15 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np import tvm from tvm import relay from tvm.relay.ir_pass import infer_type from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay import testing, create_executor from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr mod = relay.Module() p = Prelude(mod) +add_nat_definitions(p) + ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") @@ -67,15 +71,6 @@ size = p.size compose = p.compose iterate = p.iterate -# this is an example of using the adt value in python side -def count(n): - assert isinstance(n, ConstructorValue) - if n.constructor.name_hint == 's': - return 1 + count(n.fields[0]) - else: - assert n.constructor.name_hint == 'z' - return 0 - # this is an example of creating the adt value in python side def make_nat(n): if n != 0: @@ -83,7 +78,7 @@ def make_nat(n): else: return ConstructorValue(z, [], []) -def build_nat(n): +def make_nat_expr(n): assert n >= 0 ret = z() while n > 0: @@ -115,8 +110,14 @@ def tree_to_dict(t): ret['children'].append(l) return ret + +# turns a scalar-valued relay tensor value into a python number +def get_scalar(tv): + return tv.asnumpy().item() + + def test_nat_value(): - assert count(make_nat(10)) == 10 + assert count(make_nat_value(p, 10)) == 10 assert count(intrp.evaluate(s(s(z())))) == 2 @@ -145,7 +146,7 @@ def test_hd_tl(): expected = list(range(10)) l = nil() for i in reversed(expected): - l = cons(build_nat(i), l) + l = cons(make_nat_expr(i), l) got = [] for i in range(len(expected)): @@ -158,36 +159,35 @@ def test_nth(): expected = list(range(10)) l = nil() for i in reversed(expected): - l = cons(build_nat(i), l) + l = cons(relay.const(i), l) - got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(nth(l, build_nat(i))))) + item = intrp.evaluate(nth(l, relay.const(i))) + assert get_scalar(item) == i - assert got == expected def test_update(): expected = list(range(10)) l = nil() # create zero initialized list for i in range(len(expected)): - l = cons(build_nat(0), l) + l = cons(make_nat_expr(0), l) # set value for i, v in enumerate(expected): - l = update(l, build_nat(i), build_nat(v)) + l = update(l, relay.const(i), make_nat_expr(v)) got = [] for i in range(len(expected)): - got.append(count(intrp.evaluate(nth(l, build_nat(i))))) + got.append(count(intrp.evaluate(nth(l, relay.const(i))))) assert got == expected def test_length(): a = relay.TypeVar("a") - assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a]) + assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a]) res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil()))))) - assert count(res) == 3 + assert get_scalar(res) == 3 def test_map(): @@ -216,9 +216,9 @@ def test_foldl(): y = relay.Var("y") rev_dup = relay.Function([y, x], cons(x, cons(x, y))) res = intrp.evaluate(foldl(rev_dup, nil(), - cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) reversed = to_list(res) assert len(reversed) == 6 assert count(reversed[0]) == 3 and count(reversed[1]) == 3 @@ -237,9 +237,9 @@ def test_foldr(): y = relay.Var("y") identity = relay.Function([x, y], cons(x, y)) res = intrp.evaluate(foldr(identity, nil(), - cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) same = to_list(res) assert len(same) == 3 assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3 @@ -255,25 +255,25 @@ def test_foldr1(): y = relay.Var("y") f = relay.Function([x, y], add(x, y)) res = intrp.evaluate(foldr1(f, - cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) assert count(res) == 6 def test_sum(): - assert mod[sum].checked_type == relay.FuncType([l(nat())], nat()) - res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil())))) - assert count(res) == 3 + assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32')) + res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil())))) + assert get_scalar(res) == 3 def test_concat(): a = relay.TypeVar("a") assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a]) - l1 = cons(build_nat(1), cons(build_nat(2), nil())) - l2 = cons(build_nat(3), cons(build_nat(4), nil())) + l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), nil())) + l2 = cons(make_nat_expr(3), cons(make_nat_expr(4), nil())) res = intrp.evaluate(concat(l1, l2)) catted = to_list(res) @@ -305,12 +305,12 @@ def test_filter(): ])) res = intrp.evaluate( filter(greater_than_one, - cons(build_nat(1), - cons(build_nat(1), - cons(build_nat(3), - cons(build_nat(1), - cons(build_nat(5), - cons(build_nat(1), + cons(make_nat_expr(1), + cons(make_nat_expr(1), + cons(make_nat_expr(3), + cons(make_nat_expr(1), + cons(make_nat_expr(5), + cons(make_nat_expr(1), nil())))))))) filtered = to_list(res) assert len(filtered) == 2 @@ -325,7 +325,7 @@ def test_zip(): l(relay.TupleType([a, b])), [a, b]) assert mod[zip].checked_type == expected_type - l1 = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) l2 = cons(nil(), cons(cons(nil(), nil()), cons(cons(nil(), cons(nil(), nil())), @@ -342,7 +342,7 @@ def test_zip(): assert len(to_list(zipped[2][1])) == 2 # test truncation - l3 = cons(build_nat(4), cons(build_nat(5), nil())) + l3 = cons(make_nat_expr(4), cons(make_nat_expr(5), nil())) shorter_res = intrp.evaluate(zip(l3, l2)) truncated = to_list(shorter_res) assert len(truncated) == 2 @@ -363,9 +363,9 @@ def test_rev(): a = relay.TypeVar("a") assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a]) - res = intrp.evaluate(rev(cons(build_nat(1), - cons(build_nat(2), - cons(build_nat(3), nil()))))) + res = intrp.evaluate(rev(cons(make_nat_expr(1), + cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))))) reversed = to_list(res) assert len(reversed) == 3 @@ -392,7 +392,7 @@ def test_unfoldr(): relay.Clause(relay.PatternConstructor(z, []), none()) ])) - res = intrp.evaluate(unfoldr(count_down, build_nat(3))) + res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -419,7 +419,7 @@ def test_unfoldl(): relay.Clause(relay.PatternConstructor(z, []), none()) ])) - res = intrp.evaluate(unfoldl(count_down, build_nat(3))) + res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3))) unfolded = to_list(res) assert len(unfolded) == 3 @@ -444,7 +444,7 @@ def test_map_accumr(): relay.Tuple([add(x, acc), add(x, acc)])) - vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals)) sum = count(res[0]) @@ -472,7 +472,7 @@ def test_map_accuml(): add_to_acc = relay.Function([acc, x], relay.Tuple([add(x, acc), x])) - vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil()))) + vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil()))) res = intrp.evaluate(map_accuml(add_to_acc, z(), vals)) sum = count(res[0]) @@ -497,8 +497,8 @@ def test_optional_matching(): ])) res = intrp.evaluate(foldr(condense, nil(), cons( - some(build_nat(3)), - cons(none(), cons(some(build_nat(1)), nil()))))) + some(make_nat_expr(3)), + cons(none(), cons(some(make_nat_expr(1)), nil()))))) reduced = to_list(res) assert len(reduced) == 2 @@ -532,7 +532,7 @@ def test_tmap(): def test_size(): a = relay.TypeVar("a") lhs = mod[size].checked_type - rhs = relay.FuncType([tree(a)], nat(), [a]) + rhs = relay.FuncType([tree(a)], relay.scalar_type('int32'), [a]) assert lhs == rhs root = rose(z(), cons(rose(z(), nil()), @@ -540,7 +540,7 @@ def test_size(): nil()))) t = rose(z(), cons(root, cons(root, cons(root, nil())))) res = intrp.evaluate(size(t)) - assert count(res) == 10 + assert get_scalar(res) == 10 def test_wildcard_match_solo(): @@ -601,10 +601,10 @@ def test_nested_matches(): inner_match) ]), l(a), [a]) - first_list = cons(build_nat(1), cons(build_nat(2), - cons(build_nat(3), nil()))) - second_list = cons(build_nat(4), cons(build_nat(5), - cons(build_nat(6), nil()))) + first_list = cons(make_nat_expr(1), cons(make_nat_expr(2), + cons(make_nat_expr(3), nil()))) + second_list = cons(make_nat_expr(4), cons(make_nat_expr(5), + cons(make_nat_expr(6), nil()))) final_list = cons(first_list, cons(second_list, nil())) res = intrp.evaluate(flatten(final_list)) @@ -660,6 +660,7 @@ def test_nested_pattern_match(): assert count(res) == 2 + def test_compose(): n = relay.Var('n') inc = relay.Function([n], s(n)) @@ -667,11 +668,13 @@ def test_compose(): res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))])) assert count(res) == 5 + def test_iterate(): - expr = relay.Call(iterate(double, build_nat(2)), [build_nat(3)]) + expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)]) res = intrp.evaluate(relay.Function([], expr)()) assert count(res) == 12 + if __name__ == "__main__": test_nat_constructor() test_double() diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index e69f839..3cf73ae 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -53,10 +53,12 @@ def test_adt(): mod = relay.Module() p = Prelude(mod) x = relay.Var("x") - s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x) + some_case = relay.Clause(relay.PatternConstructor(p.some, + [relay.PatternVar(x)]), + x) default_case = relay.Clause(relay.PatternVar(x), x) - m0 = relay.Match(p.z(), [default_case]) - m1 = relay.Match(p.z(), [s_case, default_case]) + m0 = relay.Match(p.none(), [default_case]) + m1 = relay.Match(p.none(), [some_case, default_case]) assert well_formed(m0) assert not well_formed(m1) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index f00dc85..478b433 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -521,7 +521,7 @@ def test_match_alpha_equal(): relay.PatternVar(a)]), p.cons(z, a)) - data = p.cons(p.z(), p.cons(p.z(), p.nil())) + data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil())) match = relay.Match(data, [nil_case, cons_case]) equivalent = relay.Match(data, [nil_case, equivalent_cons]) @@ -547,8 +547,8 @@ def test_match_alpha_equal(): relay.Clause(relay.PatternWildcard(), p.nil()) ]) wrong_constructors = relay.Match(data, [ - relay.Clause(relay.PatternConstructor(p.z), p.nil()), - relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), + relay.Clause(relay.PatternConstructor(p.none), p.nil()), + relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]), p.cons(x, p.nil())) ]) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index f5968a4..d99bee5 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -19,6 +19,7 @@ from tvm import relay from tvm.relay.ir_pass import free_vars, free_type_vars, gradient from tvm.relay import create_executor from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, make_nat_expr import numpy as np @@ -174,13 +175,14 @@ def test_tuple(): def test_pow(): mod = relay.Module() p = Prelude(mod) + add_nat_definitions(p) shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) double = relay.Function([x], x + x) i = relay.var("i", t) - func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i])) + func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i)) back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) i_nd = rand(dtype, *shape) diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 2e95dbe..f395580 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -21,6 +21,7 @@ from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type from tvm.relay import op, create_executor from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions, count def check_eval(expr, expected_result, mod=None, rtol=1e-07): @@ -130,19 +131,10 @@ def test_ref(): check_eval(to_a_normal_form(body), 3) -# this is an example of using the adt value in python side -def count(n): - assert isinstance(n, ConstructorValue) - if n.constructor.name_hint == 's': - return 1 + count(n.fields[0]) - else: - assert n.constructor.name_hint == 'z' - return 0 - - -def test_add(): +def test_nat_add(): mod = relay.Module() p = Prelude(mod) + add_nat_definitions(p) nat = p.nat add = p.add s = p.s @@ -183,4 +175,5 @@ if __name__ == '__main__': test_ref() test_add() test_let() + test_nat_add() test_function() -- 2.7.4