[Relay][Prelude] Remove Peano nats from the prelude (#3045)
authorSteven S. Lyubomirsky <sslyu@cs.washington.edu>
Wed, 22 May 2019 20:57:53 +0000 (13:57 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Wed, 22 May 2019 20:57:53 +0000 (13:57 -0700)
python/tvm/relay/prelude.py
python/tvm/relay/testing/__init__.py
python/tvm/relay/testing/nat.py [new file with mode: 0644]
tests/python/relay/test_adt.py
tests/python/relay/test_ir_well_formed.py
tests/python/relay/test_pass_alpha_equal.py
tests/python/relay/test_pass_gradient.py
tests/python/relay/test_pass_to_a_normal_form.py

index ff823c3..92647e5 100644 (file)
@@ -17,7 +17,8 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
 """Adds certain standard global functions and ADT definitions to the module."""
 from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type
-from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem
+from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
+from .op.tensor import add, subtract, equal
 from .adt import Constructor, TypeData, Clause, Match
 from .adt import PatternConstructor, PatternVar, PatternWildcard
 
@@ -34,6 +35,7 @@ class Prelude:
         self.cons = Constructor("cons", [a, self.l(a)], self.l)
         self.mod[self.l] = TypeData(self.l, [a], [self.nil, self.cons])
 
+
     def define_list_hd(self):
         """Defines a function to get the head of a list. Assume the list has at least one
         element.
@@ -48,6 +50,7 @@ class Prelude:
         cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), y)
         self.mod[self.hd] = Function([x], Match(x, [cons_case]), a, [a])
 
+
     def define_list_tl(self):
         """Defines a function to get the tail of a list.
 
@@ -61,39 +64,44 @@ class Prelude:
         cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z)
         self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a])
 
+
     def define_list_nth(self):
         """Defines a function to get the nth element of a list.
 
-        nth(l) : list[a] -> a
+        nth(l) : list[a] -> Tensor[(), int32] -> a
         """
         self.nth = GlobalVar("nth")
         a = TypeVar("a")
         x = Var("x", self.l(a))
-        n = Var("n", self.nat())
+        n = Var("n", scalar_type('int32'))
+
+        body = If(equal(n, const(0)),
+                  self.hd(x),
+                  self.nth(self.tl(x), subtract(n, const(1))))
+
+        self.mod[self.nth] = Function([x, n], body, a, [a])
 
-        y = Var("y")
-        z_case = Clause(PatternConstructor(self.z), self.hd(x))
-        s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y))
-        self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a])
 
     def define_list_update(self):
         """Defines a function to update the nth element of a list and return the updated list.
 
-        update(l, i, v) : list[a] -> nat -> a -> list[a]
+        update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a]
         """
         self.update = GlobalVar("update")
         a = TypeVar("a")
         l = Var("l", self.l(a))
-        n = Var("n", self.nat())
+        n = Var("n", scalar_type('int32'))
         v = Var("v", a)
 
-        y = Var("y")
+        body = If(equal(n, const(0)),
+                  self.cons(v, self.tl(l)),
+                  self.cons(self.hd(l),
+                            self.update(self.tl(l),
+                                        subtract(n, const(1)),
+                                        v)))
 
-        z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l)))
-        s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
-                        self.cons(self.hd(l), self.update(self.tl(l), y, v)))
+        self.mod[self.update] = Function([l, n, v], body, self.l(a), [a])
 
-        self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a])
 
     def define_list_map(self):
         """Defines a function for mapping a function over a list's
@@ -114,6 +122,7 @@ class Prelude:
                            self.cons(f(y), self.map(f, z)))
         self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b])
 
+
     def define_list_foldl(self):
         """Defines a left-way fold over a list.
 
@@ -136,6 +145,7 @@ class Prelude:
         self.mod[self.foldl] = Function([f, av, bv],
                                         Match(bv, [nil_case, cons_case]), a, [a, b])
 
+
     def define_list_foldr(self):
         """Defines a right-way fold over a list.
 
@@ -158,6 +168,7 @@ class Prelude:
         self.mod[self.foldr] = Function([f, bv, av],
                                         Match(av, [nil_case, cons_case]), b, [a, b])
 
+
     def define_list_foldr1(self):
         """Defines a right-way fold over a nonempty list.
 
@@ -196,6 +207,7 @@ class Prelude:
                                          self.foldr(updater, l2, l1),
                                          self.l(a), [a])
 
+
     def define_list_filter(self):
         """Defines a function that filters a list.
 
@@ -214,6 +226,7 @@ class Prelude:
                            If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t)))
         self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a])
 
+
     def define_list_zip(self):
         """Defines a function that combines two lists into a list of tuples of their elements.
 
@@ -238,6 +251,7 @@ class Prelude:
         self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
                                       self.l(TupleType([a, b])), [a, b])
 
+
     def define_list_rev(self):
         """Defines a function that reverses a list.
 
@@ -253,6 +267,7 @@ class Prelude:
                                       self.foldl(updater, self.nil(), l),
                                       self.l(a), [a])
 
+
     def define_list_map_accumr(self):
         """Defines an accumulative map, which is a fold that simulataneously updates
         an accumulator value and a list of results.
@@ -282,6 +297,7 @@ class Prelude:
                                              TupleType([a, self.l(c)]),
                                              [a, b, c])
 
+
     def define_list_map_accuml(self):
         """Defines an accumulative map, which is a fold that simulataneously updates
         an accumulator value and a list of results.
@@ -321,6 +337,7 @@ class Prelude:
         self.none = Constructor("none", [], self.optional)
         self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none])
 
+
     def define_list_unfoldr(self):
         """Defines a function that builds up a list starting from a seed value.
 
@@ -343,6 +360,7 @@ class Prelude:
         self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]),
                                           self.l(b), [a, b])
 
+
     def define_list_unfoldl(self):
         """Defines a function that builds up a list starting from a seed value.
 
@@ -362,52 +380,29 @@ class Prelude:
                                           self.rev(self.unfoldr(f, s)),
                                           self.l(b), [a, b])
 
-    def define_nat_adt(self):
-        """Defines a Peano (unary) natural number ADT.
-        Zero is represented by z(). s(n) adds 1 to a nat n."""
-        self.nat = GlobalTypeVar("nat")
-        self.z = Constructor("z", [], self.nat)
-        self.s = Constructor("s", [self.nat()], self.nat)
-        self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s])
-
-    def define_nat_double(self):
-        """Defines a function that doubles a nat."""
-        self.double = GlobalVar("double")
-        x = Var("x", self.nat())
-        y = Var("y")
-        z_case = Clause(PatternConstructor(self.z), self.z())
-        s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
-                        self.s(self.s(self.double(y))))
-        self.mod[self.double] = Function([x], Match(x, [z_case, s_case]))
-
-    def define_nat_add(self):
-        """Defines a function that adds two nats."""
-        self.add = GlobalVar("add")
-        x = Var("x", self.nat())
-        y = Var("y", self.nat())
-        a = Var("a")
-        z_case = Clause(PatternConstructor(self.z), y)
-        s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]),
-                        self.s(self.add(a, y)))
-        self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case]))
 
     def define_list_sum(self):
-        """Defines a function that computes the sum of a list of nats."""
+        """Defines a function that computes the sum of a list of integer scalars."""
         self.sum = GlobalVar("sum")
-        a = Var("a", self.l(self.nat()))
-        self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a))
+        a = Var("a", self.l(scalar_type('int32')))
+        x = Var('x')
+        y = Var('y')
+        addf = Function([x, y], add(x, y))
+        self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a))
+
 
     def define_list_length(self):
-        """Defines a function that returns the length of a list as a nat"""
+        """Defines a function that returns the length of a list"""
         self.length = GlobalVar("length")
         a = TypeVar("a")
         x = Var("x", self.l(a))
         y = Var("y")
-        nil_case = Clause(PatternConstructor(self.nil), self.z())
+        nil_case = Clause(PatternConstructor(self.nil), const(0))
         cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]),
-                           self.s(self.length(y)))
+                           add(const(1), self.length(y)))
         self.mod[self.length] = Function([x],
-                                         Match(x, [nil_case, cons_case]), None, [a])
+                                         Match(x, [nil_case, cons_case]), scalar_type('int32'), [a])
+
 
     def define_tree_adt(self):
         """Defines a tree ADT. A tree can contain any type.
@@ -420,6 +415,7 @@ class Prelude:
         self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree)
         self.mod[self.tree] = TypeData(self.tree, [a], [self.rose])
 
+
     def define_tree_map(self):
         """Defines a function that maps over a tree. The function
         is applied to each subtree's contents.
@@ -439,23 +435,24 @@ class Prelude:
         self.mod[self.tmap] = Function([f, t],
                                        Match(t, [rose_case]), self.tree(b), [a, b])
 
+
     def define_tree_size(self):
-        """Defines a function that computes the size of a tree as a nat.
+        """Defines a function that computes the size of a tree.
 
-        Signature: fn<a>(t : tree[a]) -> nat
+        Signature: fn<a>(t : tree[a]) -> Tensor[(), int32]
         """
         self.size = GlobalVar("size")
         a = TypeVar("a")
         t = Var("t", self.tree(a))
-        x = Var("x", self.tree(a))
         z = Var("z")
         rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]),
-                           self.s(self.sum(self.map(Function([x], self.size(x)), z))))
+                           add(const(1), self.sum(self.map(self.size, z))))
         self.mod[self.size] = Function([t],
-                                       Match(t, [rose_case]), self.nat(), [a])
+                                       Match(t, [rose_case]), scalar_type('int32'), [a])
+
 
     def define_id(self):
-        """Defines a function that return it's argument.
+        """Defines a function that return its argument.
 
         Signature: fn<a>(x : a) -> a
         """
@@ -466,7 +463,7 @@ class Prelude:
 
 
     def define_compose(self):
-        """Defines a function that compose two function.
+        """Defines a function that composes two function.
 
         Signature: fn<a, b, c>(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c
         """
@@ -484,24 +481,26 @@ class Prelude:
 
 
     def define_iterate(self):
-        """Define a function that take a number n, a function f,
-        and return a closure that apply f n time on it's argument.
+        """Defines a function that take a number n and a function f;
+        returns a closure that takes an argument and applies f
+        n times to its argument.
 
-        Signature: fn<a>(n : nat, f : fn(a) -> a) -> fn(a) -> a
+        Signature: fn<a>(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a
         """
         self.iterate = GlobalVar("iterate")
         a = TypeVar("a")
         f = Var("f", FuncType([a], a))
-        x = Var("x", self.nat())
-        y = Var("y", self.nat())
-        z_case = Clause(PatternConstructor(self.z), self.id)
-        s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]),
-                        self.compose(f, self.iterate(f, y)))
+        x = Var("x", scalar_type('int32'))
+        body = If(equal(x, const(0)),
+                  self.id,
+                  self.compose(f,
+                               self.iterate(f, subtract(x, const(1)))))
         self.mod[self.iterate] = Function([f, x],
-                                          Match(x, [z_case, s_case]),
+                                          body,
                                           FuncType([a], a),
                                           [a])
 
+
     def __init__(self, mod):
         self.mod = mod
         self.define_list_adt()
@@ -522,9 +521,6 @@ class Prelude:
         self.define_list_unfoldr()
         self.define_list_unfoldl()
 
-        self.define_nat_adt()
-        self.define_nat_double()
-        self.define_nat_add()
         self.define_list_length()
         self.define_list_nth()
         self.define_list_update()
index b4a8394..192afe1 100644 (file)
@@ -30,3 +30,4 @@ from . import densenet
 
 from .config import ctx_list
 from .init import create_workload
+from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py
new file mode 100644 (file)
index 0000000..4c0c87c
--- /dev/null
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Defines a unary natural number (Peano natural number) abstract
+data type for Relay and provides some utility functions for it.
+Nats are useful for testing purposes, as they make it easy to write
+test cases for recursion and pattern matching."""
+
+from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar
+from tvm.relay.backend.interpreter import ConstructorValue
+from tvm.relay.expr import Var, Function, GlobalVar
+from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
+
+def define_nat_adt(prelude):
+    """Defines a Peano (unary) natural number ADT.
+    Zero is represented by z(). s(n) adds 1 to a nat n.
+    Adds the fields nat, z, and s to the preluide, representing
+    (respectively) the nat ADT and the z and s constructors.
+    """
+    prelude.nat = GlobalTypeVar("nat")
+    prelude.z = Constructor("z", [], prelude.nat)
+    prelude.s = Constructor("s", [prelude.nat()], prelude.nat)
+    prelude.mod[prelude.nat] = TypeData(prelude.nat, [], [prelude.z, prelude.s])
+
+
+def define_nat_double(prelude):
+    """Defines a function that doubles a nat. Adds a field called
+    'double' to the prelude, giving the GlobalVar pointing to
+    the function.
+    """
+    prelude.double = GlobalVar("double")
+    x = Var("x", prelude.nat())
+    y = Var("y")
+    z_case = Clause(PatternConstructor(prelude.z), prelude.z())
+    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+                    prelude.s(prelude.s(prelude.double(y))))
+    prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case]))
+
+
+def define_nat_add(prelude):
+    """Defines a function that adds two nats and adds a field to the
+    prelude 'add' giving the GlobalVar pointing to that function.
+    """
+    prelude.add = GlobalVar("add")
+    x = Var("x", prelude.nat())
+    y = Var("y", prelude.nat())
+    a = Var("a")
+    z_case = Clause(PatternConstructor(prelude.z), y)
+    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]),
+                    prelude.s(prelude.add(a, y)))
+    prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case]))
+
+
+# versions of prelude functions that use nats instead of scalars
+
+def define_nat_nth(prelude):
+    """Defines a function to get the nth eleemnt of a list using
+    a nat to index into the list.
+
+    nat_nth(l, n): fun<a>(list[a], nat) -> a
+    """
+    prelude.nat_nth = GlobalVar("nat_nth")
+    a = TypeVar("a")
+    x = Var("x", prelude.l(a))
+    n = Var("n", prelude.nat())
+    y = Var("y")
+
+    z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x))
+    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+                    prelude.nat_nth(prelude.tl(x), y))
+
+    prelude.mod[prelude.nat_nth] = Function([x, n],
+                                            Match(n, [z_case, s_case]),
+                                            a, [a])
+
+
+def define_nat_update(prelude):
+    """Defines a function to update the nth element of a list and return the updated list.
+
+    nat_update(l, i, v) : fun<a>(list[a], nat, a) -> list[a]
+    """
+    prelude.nat_update = GlobalVar("nat_update")
+    a = TypeVar("a")
+    # pylint: disable=invalid-name
+    l = Var("l", prelude.l(a))
+    n = Var("n", prelude.nat())
+    v = Var("v", a)
+    y = Var("y")
+
+    z_case = Clause(PatternConstructor(prelude.z),
+                    prelude.cons(v, prelude.tl(l)))
+    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+                    prelude.cons(
+                        prelude.hd(l),
+                        prelude.nat_update(prelude.tl(l), y, v)))
+
+    prelude.mod[prelude.nat_update] = Function([l, n, v],
+                                               Match(n, [z_case, s_case]),
+                                               prelude.l(a), [a])
+
+
+def define_nat_iterate(prelude):
+    """Defines a function that takes a number n and a function f;
+    returns a closure that takes an argument and applies f
+    n times to its argument.
+
+    Signature: fn<a>(fn(a) -> a, nat) -> fn(a) -> a
+    """
+    prelude.nat_iterate = GlobalVar("nat_iterate")
+    a = TypeVar("a")
+    f = Var("f", FuncType([a], a))
+    x = Var("x", prelude.nat())
+    y = Var("y", prelude.nat())
+
+    z_case = Clause(PatternConstructor(prelude.z), prelude.id)
+    s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]),
+                    prelude.compose(f, prelude.nat_iterate(f, y)))
+
+    prelude.mod[prelude.nat_iterate] = Function([f, x],
+                                                Match(x, [z_case, s_case]),
+                                                FuncType([a], a),
+                                                [a])
+
+
+def add_nat_definitions(prelude):
+    """Given a Relay prelude, adds a Peano nat ADT, as well as functions
+    for adding nats and doubling nats. It also adds versions of
+    update, nth, and iterate that take nats instead of scalars (the
+    names are prefixed with 'nat_')."""
+    define_nat_adt(prelude)
+    define_nat_double(prelude)
+    define_nat_add(prelude)
+    define_nat_nth(prelude)
+    define_nat_update(prelude)
+    define_nat_iterate(prelude)
+
+
+# helper functions for working with nats
+
+
+def count(n):
+    """Takes a ConstructorValue corresponding to a nat ADT
+    and converts it into a Python integer. This is an example of
+    using an ADT value in Python.
+    """
+    assert isinstance(n, ConstructorValue)
+    if n.constructor.name_hint == 'z':
+        return 0
+    assert n.constructor.name_hint == 's'
+    return 1 + count(n.fields[0])
+
+
+def make_nat_value(prelude, n):
+    """The inverse of count(): Given a non-negative Python integer,
+    constructs a ConstructorValue representing that value as a nat.
+    """
+    if n == 0:
+        return ConstructorValue(prelude.z, [], [])
+    return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], [])
+
+
+def make_nat_expr(prelude, n):
+    """Given a non-negative Python integer, constructs a Python
+    expression representing that integer's value as a nat.
+    """
+    assert n >= 0
+    ret = prelude.z()
+    while n > 0:
+        ret = prelude.s(ret)
+        n = n - 1
+    return ret
index 58ab0c4..77f4ab1 100644 (file)
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np
 import tvm
 from tvm import relay
 from tvm.relay.ir_pass import infer_type
 from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
 from tvm.relay import testing, create_executor
 from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr
 
 mod = relay.Module()
 p = Prelude(mod)
+add_nat_definitions(p)
+
 ctx = tvm.context("llvm", 0)
 intrp = create_executor(mod=mod, ctx=ctx, target="llvm")
 
@@ -67,15 +71,6 @@ size = p.size
 compose = p.compose
 iterate = p.iterate
 
-# this is an example of using the adt value in python side
-def count(n):
-    assert isinstance(n, ConstructorValue)
-    if n.constructor.name_hint == 's':
-        return 1 + count(n.fields[0])
-    else:
-        assert n.constructor.name_hint == 'z'
-        return 0
-
 # this is an example of creating the adt value in python side
 def make_nat(n):
     if n != 0:
@@ -83,7 +78,7 @@ def make_nat(n):
     else:
         return ConstructorValue(z, [], [])
 
-def build_nat(n):
+def make_nat_expr(n):
     assert n >= 0
     ret = z()
     while n > 0:
@@ -115,8 +110,14 @@ def tree_to_dict(t):
         ret['children'].append(l)
     return ret
 
+
+# turns a scalar-valued relay tensor value into a python number
+def get_scalar(tv):
+    return tv.asnumpy().item()
+
+
 def test_nat_value():
-    assert count(make_nat(10)) == 10
+    assert count(make_nat_value(p, 10)) == 10
     assert count(intrp.evaluate(s(s(z())))) == 2
 
 
@@ -145,7 +146,7 @@ def test_hd_tl():
     expected = list(range(10))
     l = nil()
     for i in reversed(expected):
-        l = cons(build_nat(i), l)
+        l = cons(make_nat_expr(i), l)
 
     got = []
     for i in range(len(expected)):
@@ -158,36 +159,35 @@ def test_nth():
     expected = list(range(10))
     l = nil()
     for i in reversed(expected):
-        l = cons(build_nat(i), l)
+        l = cons(relay.const(i), l)
 
-    got = []
     for i in range(len(expected)):
-        got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
+        item = intrp.evaluate(nth(l, relay.const(i)))
+        assert get_scalar(item) == i
 
-    assert got == expected
 
 def test_update():
     expected = list(range(10))
     l = nil()
     # create zero initialized list
     for i in range(len(expected)):
-        l = cons(build_nat(0), l)
+        l = cons(make_nat_expr(0), l)
 
     # set value
     for i, v in enumerate(expected):
-        l = update(l, build_nat(i), build_nat(v))
+        l = update(l, relay.const(i), make_nat_expr(v))
 
     got = []
     for i in range(len(expected)):
-        got.append(count(intrp.evaluate(nth(l, build_nat(i)))))
+        got.append(count(intrp.evaluate(nth(l, relay.const(i)))))
 
     assert got == expected
 
 def test_length():
     a = relay.TypeVar("a")
-    assert mod[length].checked_type == relay.FuncType([l(a)], nat(), [a])
+    assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a])
     res = intrp.evaluate(length(cons(z(), cons(z(), cons(z(), nil())))))
-    assert count(res) == 3
+    assert get_scalar(res) == 3
 
 
 def test_map():
@@ -216,9 +216,9 @@ def test_foldl():
     y = relay.Var("y")
     rev_dup = relay.Function([y, x], cons(x, cons(x, y)))
     res = intrp.evaluate(foldl(rev_dup, nil(),
-                               cons(build_nat(1),
-                                    cons(build_nat(2),
-                                         cons(build_nat(3), nil())))))
+                               cons(make_nat_expr(1),
+                                    cons(make_nat_expr(2),
+                                         cons(make_nat_expr(3), nil())))))
     reversed = to_list(res)
     assert len(reversed) == 6
     assert count(reversed[0]) == 3 and count(reversed[1]) == 3
@@ -237,9 +237,9 @@ def test_foldr():
     y = relay.Var("y")
     identity = relay.Function([x, y], cons(x, y))
     res = intrp.evaluate(foldr(identity, nil(),
-                               cons(build_nat(1),
-                                    cons(build_nat(2),
-                                         cons(build_nat(3), nil())))))
+                               cons(make_nat_expr(1),
+                                    cons(make_nat_expr(2),
+                                         cons(make_nat_expr(3), nil())))))
     same = to_list(res)
     assert len(same) == 3
     assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3
@@ -255,25 +255,25 @@ def test_foldr1():
     y = relay.Var("y")
     f = relay.Function([x, y], add(x, y))
     res = intrp.evaluate(foldr1(f,
-                                cons(build_nat(1),
-                                    cons(build_nat(2),
-                                         cons(build_nat(3), nil())))))
+                                cons(make_nat_expr(1),
+                                    cons(make_nat_expr(2),
+                                         cons(make_nat_expr(3), nil())))))
 
     assert count(res) == 6
 
 
 def test_sum():
-    assert mod[sum].checked_type == relay.FuncType([l(nat())], nat())
-    res = intrp.evaluate(sum(cons(build_nat(1), cons(build_nat(2), nil()))))
-    assert count(res) == 3
+    assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
+    res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil()))))
+    assert get_scalar(res) == 3
 
 
 def test_concat():
     a = relay.TypeVar("a")
     assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a])
 
-    l1 = cons(build_nat(1), cons(build_nat(2), nil()))
-    l2 = cons(build_nat(3), cons(build_nat(4), nil()))
+    l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), nil()))
+    l2 = cons(make_nat_expr(3), cons(make_nat_expr(4), nil()))
     res = intrp.evaluate(concat(l1, l2))
 
     catted = to_list(res)
@@ -305,12 +305,12 @@ def test_filter():
         ]))
     res = intrp.evaluate(
         filter(greater_than_one,
-               cons(build_nat(1),
-                    cons(build_nat(1),
-                         cons(build_nat(3),
-                              cons(build_nat(1),
-                                   cons(build_nat(5),
-                                        cons(build_nat(1),
+               cons(make_nat_expr(1),
+                    cons(make_nat_expr(1),
+                         cons(make_nat_expr(3),
+                              cons(make_nat_expr(1),
+                                   cons(make_nat_expr(5),
+                                        cons(make_nat_expr(1),
                                              nil()))))))))
     filtered = to_list(res)
     assert len(filtered) == 2
@@ -325,7 +325,7 @@ def test_zip():
                                    l(relay.TupleType([a, b])), [a, b])
     assert mod[zip].checked_type == expected_type
 
-    l1 = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
+    l1 = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
     l2 = cons(nil(),
               cons(cons(nil(), nil()),
                    cons(cons(nil(), cons(nil(), nil())),
@@ -342,7 +342,7 @@ def test_zip():
     assert len(to_list(zipped[2][1])) == 2
 
     # test truncation
-    l3 = cons(build_nat(4), cons(build_nat(5), nil()))
+    l3 = cons(make_nat_expr(4), cons(make_nat_expr(5), nil()))
     shorter_res = intrp.evaluate(zip(l3, l2))
     truncated = to_list(shorter_res)
     assert len(truncated) == 2
@@ -363,9 +363,9 @@ def test_rev():
     a = relay.TypeVar("a")
     assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a])
 
-    res = intrp.evaluate(rev(cons(build_nat(1),
-                                  cons(build_nat(2),
-                                       cons(build_nat(3), nil())))))
+    res = intrp.evaluate(rev(cons(make_nat_expr(1),
+                                  cons(make_nat_expr(2),
+                                       cons(make_nat_expr(3), nil())))))
     reversed = to_list(res)
 
     assert len(reversed) == 3
@@ -392,7 +392,7 @@ def test_unfoldr():
             relay.Clause(relay.PatternConstructor(z, []), none())
         ]))
 
-    res = intrp.evaluate(unfoldr(count_down, build_nat(3)))
+    res = intrp.evaluate(unfoldr(count_down, make_nat_expr(3)))
     unfolded = to_list(res)
 
     assert len(unfolded) == 3
@@ -419,7 +419,7 @@ def test_unfoldl():
             relay.Clause(relay.PatternConstructor(z, []), none())
         ]))
 
-    res = intrp.evaluate(unfoldl(count_down, build_nat(3)))
+    res = intrp.evaluate(unfoldl(count_down, make_nat_expr(3)))
     unfolded = to_list(res)
 
     assert len(unfolded) == 3
@@ -444,7 +444,7 @@ def test_map_accumr():
                                      relay.Tuple([add(x, acc),
                                                   add(x, acc)]))
 
-    vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
+    vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
     res = intrp.evaluate(map_accumr(add_acc_to_each, z(), vals))
 
     sum = count(res[0])
@@ -472,7 +472,7 @@ def test_map_accuml():
     add_to_acc = relay.Function([acc, x],
                                 relay.Tuple([add(x, acc), x]))
 
-    vals = cons(build_nat(1), cons(build_nat(2), cons(build_nat(3), nil())))
+    vals = cons(make_nat_expr(1), cons(make_nat_expr(2), cons(make_nat_expr(3), nil())))
     res = intrp.evaluate(map_accuml(add_to_acc, z(), vals))
 
     sum = count(res[0])
@@ -497,8 +497,8 @@ def test_optional_matching():
         ]))
 
     res = intrp.evaluate(foldr(condense, nil(), cons(
-        some(build_nat(3)),
-        cons(none(), cons(some(build_nat(1)), nil())))))
+        some(make_nat_expr(3)),
+        cons(none(), cons(some(make_nat_expr(1)), nil())))))
 
     reduced = to_list(res)
     assert len(reduced) == 2
@@ -532,7 +532,7 @@ def test_tmap():
 def test_size():
     a = relay.TypeVar("a")
     lhs = mod[size].checked_type
-    rhs = relay.FuncType([tree(a)], nat(), [a])
+    rhs = relay.FuncType([tree(a)], relay.scalar_type('int32'), [a])
     assert lhs == rhs
 
     root = rose(z(), cons(rose(z(), nil()),
@@ -540,7 +540,7 @@ def test_size():
                                        nil())))
     t = rose(z(), cons(root, cons(root, cons(root, nil()))))
     res = intrp.evaluate(size(t))
-    assert count(res) == 10
+    assert get_scalar(res) == 10
 
 
 def test_wildcard_match_solo():
@@ -601,10 +601,10 @@ def test_nested_matches():
                          inner_match)
         ]), l(a), [a])
 
-    first_list = cons(build_nat(1), cons(build_nat(2),
-                                         cons(build_nat(3), nil())))
-    second_list = cons(build_nat(4), cons(build_nat(5),
-                                          cons(build_nat(6), nil())))
+    first_list = cons(make_nat_expr(1), cons(make_nat_expr(2),
+                                         cons(make_nat_expr(3), nil())))
+    second_list = cons(make_nat_expr(4), cons(make_nat_expr(5),
+                                          cons(make_nat_expr(6), nil())))
     final_list = cons(first_list, cons(second_list, nil()))
 
     res = intrp.evaluate(flatten(final_list))
@@ -660,6 +660,7 @@ def test_nested_pattern_match():
 
     assert count(res) == 2
 
+
 def test_compose():
     n = relay.Var('n')
     inc = relay.Function([n], s(n))
@@ -667,11 +668,13 @@ def test_compose():
     res = intrp.evaluate(relay.Call(compose(inc, double), [s(s(z()))]))
     assert count(res) == 5
 
+
 def test_iterate():
-    expr = relay.Call(iterate(double, build_nat(2)), [build_nat(3)])
+    expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)])
     res = intrp.evaluate(relay.Function([], expr)())
     assert count(res) == 12
 
+
 if __name__ == "__main__":
     test_nat_constructor()
     test_double()
index e69f839..3cf73ae 100644 (file)
@@ -53,10 +53,12 @@ def test_adt():
     mod = relay.Module()
     p = Prelude(mod)
     x = relay.Var("x")
-    s_case = relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]), x)
+    some_case = relay.Clause(relay.PatternConstructor(p.some,
+                                                      [relay.PatternVar(x)]),
+                             x)
     default_case = relay.Clause(relay.PatternVar(x), x)
-    m0 = relay.Match(p.z(), [default_case])
-    m1 = relay.Match(p.z(), [s_case, default_case])
+    m0 = relay.Match(p.none(), [default_case])
+    m1 = relay.Match(p.none(), [some_case, default_case])
     assert well_formed(m0)
     assert not well_formed(m1)
 
index f00dc85..478b433 100644 (file)
@@ -521,7 +521,7 @@ def test_match_alpha_equal():
                                                              relay.PatternVar(a)]),
                                    p.cons(z, a))
 
-    data = p.cons(p.z(), p.cons(p.z(), p.nil()))
+    data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))
 
     match = relay.Match(data, [nil_case, cons_case])
     equivalent = relay.Match(data, [nil_case, equivalent_cons])
@@ -547,8 +547,8 @@ def test_match_alpha_equal():
         relay.Clause(relay.PatternWildcard(), p.nil())
     ])
     wrong_constructors = relay.Match(data, [
-        relay.Clause(relay.PatternConstructor(p.z), p.nil()),
-        relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]),
+        relay.Clause(relay.PatternConstructor(p.none), p.nil()),
+        relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]),
                      p.cons(x, p.nil()))
     ])
 
index f5968a4..d99bee5 100644 (file)
@@ -19,6 +19,7 @@ from tvm import relay
 from tvm.relay.ir_pass import free_vars, free_type_vars, gradient
 from tvm.relay import create_executor
 from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, make_nat_expr
 
 import numpy as np
 
@@ -174,13 +175,14 @@ def test_tuple():
 def test_pow():
     mod = relay.Module()
     p = Prelude(mod)
+    add_nat_definitions(p)
     shape = (10, 10)
     dtype = 'float32'
     t = relay.TensorType(shape, dtype)
     x = relay.var("x", t)
     double = relay.Function([x], x + x)
     i = relay.var("i", t)
-    func = relay.Function([i], relay.Call(p.iterate(double, p.s(p.s(p.s(p.z())))), [i]))
+    func = relay.Function([i], p.nat_iterate(double, make_nat_expr(p, 3))(i))
     back_func = relay.ir_pass.infer_type(gradient(func, mod=mod), mod=mod)
     assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
     i_nd = rand(dtype, *shape)
index 2e95dbe..f395580 100644 (file)
@@ -21,6 +21,7 @@ from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type
 from tvm.relay import op, create_executor
 from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue
 from tvm.relay.prelude import Prelude
+from tvm.relay.testing import add_nat_definitions, count
 
 
 def check_eval(expr, expected_result, mod=None, rtol=1e-07):
@@ -130,19 +131,10 @@ def test_ref():
     check_eval(to_a_normal_form(body), 3)
 
 
-# this is an example of using the adt value in python side
-def count(n):
-    assert isinstance(n, ConstructorValue)
-    if n.constructor.name_hint == 's':
-        return 1 + count(n.fields[0])
-    else:
-        assert n.constructor.name_hint == 'z'
-        return 0
-
-
-def test_add():
+def test_nat_add():
     mod = relay.Module()
     p = Prelude(mod)
+    add_nat_definitions(p)
     nat = p.nat
     add = p.add
     s = p.s
@@ -183,4 +175,5 @@ if __name__ == '__main__':
     test_ref()
     test_add()
     test_let()
+    test_nat_add()
     test_function()