[HybridScript] Capture constant external python variables (#3157)
authorLianmin Zheng <lianminzheng@gmail.com>
Fri, 10 May 2019 23:36:54 +0000 (07:36 +0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 10 May 2019 23:36:53 +0000 (16:36 -0700)
python/tvm/hybrid/__init__.py
python/tvm/hybrid/module.py
python/tvm/hybrid/parser.py
python/tvm/hybrid/preprocessor.py
python/tvm/hybrid/util.py
tests/python/unittest/test_hybrid_script.py

index 7aca007..11ecbc8 100644 (file)
@@ -31,6 +31,8 @@ HalideIR.
 
 from __future__ import absolute_import as _abs
 
+import inspect
+
 from .._ffi.base import decorate
 from .._ffi.function import _init_api
 from ..build_module import form_body
@@ -55,7 +57,9 @@ def script(pyfunc):
         from .util import _is_tvm_arg_types
         if _is_tvm_arg_types(args):
             src = _pruned_source(func)
-            return source_to_op(src, func.__globals__, args)
+            closure_vars = inspect.getclosurevars(func).nonlocals
+            closure_vars.update(inspect.getclosurevars(func).globals)
+            return source_to_op(src, args, func.__globals__, closure_vars)
 
         from .runtime import _enter_hybrid_runtime, _restore_runtime
         intersect = _enter_hybrid_runtime(func)
index 297dd0b..13e45a7 100644 (file)
@@ -62,7 +62,7 @@ class HybridModule(object):
 
     def __call__(self, *args):
         if _is_tvm_arg_types(args):
-            return source_to_op(self.root_, globals(), args)
+            return source_to_op(self.root_, args, globals(), {})
         return self.func_(*args)
 
 
index 1c1525e..40ea171 100644 (file)
@@ -25,7 +25,7 @@ import numbers
 
 from enum import Enum
 
-from .util import _internal_assert
+from .util import _internal_assert, _apply_indices
 from . import calls
 from . import util
 from .preprocessor import determine_variable_usage
@@ -112,7 +112,7 @@ class HybridParser(ast.NodeVisitor):
     }
 
 
-    def __init__(self, args, usage, symbols, func_name=None):
+    def __init__(self, args, usage, symbols, closure_vars, func_name=None):
         """
         Parameters
         ----------
@@ -122,6 +122,12 @@ class HybridParser(ast.NodeVisitor):
         usage: A dict of variables used in last in this function
             Provided by last lower pass, which collects this information
 
+        symbols : list of str
+            The symbol list of the global context of the function.
+
+        closure_vars: dict
+            A dict of external name reference captured by this function.
+
         Returns
         -------
         func_name: str
@@ -136,6 +142,8 @@ class HybridParser(ast.NodeVisitor):
             if isinstance(v, types.FunctionType):
                 self.add_symbol(k, Symbol.Callable, v)
 
+        self.closure_vars = closure_vars
+
         self.binds = {} # Thread binds
         self.device = 0 # Is it generating device
 
@@ -236,7 +244,11 @@ class HybridParser(ast.NodeVisitor):
     def visit_Name(self, node):
         name = node.id
         if sys.version_info[0] == 2 and name in ['True', 'False']:
-            return _api.convert(eval(name)) #pylint: disable=eval-used
+            return _api.convert(ast.literal_eval(name))
+
+        if name in self.closure_vars:
+            return _api.convert(self.closure_vars[name])
+
         ty, entry = self.symbols[name]
         _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
         if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
@@ -356,10 +368,12 @@ class HybridParser(ast.NodeVisitor):
         buf = self.visit(node.value)
         return getattr(buf, node.attr)
 
-
     def visit_Subscript(self, node):
         args = self.visit(node.slice)
         if isinstance(node.value, ast.Name):
+            if node.value.id in self.closure_vars:
+                args = ast.literal_eval(str(args))
+                return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
 
             buf = self.visit(node.value)
             if isinstance(buf, Array):
@@ -576,7 +590,7 @@ class HybridParser(ast.NodeVisitor):
         return _make.AssertStmt(test, mesg, util.make_nop())
 
 
-def parse_python(src, symbols, args):
+def parse_python(src, args, symbols, closure_vars):
     """The helper function of calling the AST visitor
 
     Parameters
@@ -585,14 +599,17 @@ def parse_python(src, symbols, args):
         If an ast.node, then directly lower it.
         If a str, then parse it to ast and lower it.
 
-    symbols : str
-        The symbol list of the global context of the function.
-
     args : list of Tensors or Vars
         The argument lists to the function.
         It is NOT encouraged to write a function without arguments.
         It is NOT encouraged to write a function with side effect.
 
+    symbols : list of str
+        The symbol list of the global context of the function.
+
+    closure_vars: dict
+        A dict of external name reference captured by this function.
+
     Returns
     -------
     root : Stmt
@@ -600,14 +617,14 @@ def parse_python(src, symbols, args):
     """
     root = ast.parse(src) if isinstance(src, str) else src
     _internal_assert(root, ast.AST)
-    var_usage = determine_variable_usage(root, args, symbols)
-    parser = HybridParser(args, var_usage, symbols)
+    var_usage = determine_variable_usage(root, args, symbols, closure_vars)
+    parser = HybridParser(args, var_usage, symbols, closure_vars)
     parser.parsed_body = parser.visit(root)
     _internal_assert(parser.returned, 'No valid return found in the function body!')
     return parser
 
 
-def source_to_op(src, symbols, args):
+def source_to_op(src, args, symbols, closure_vars):
     """Another level of wrapper
 
     Parameters
@@ -616,20 +633,23 @@ def source_to_op(src, symbols, args):
         If an ast.node, then directly lower it.
         If a str, then parse it to ast and lower it.
 
-    symbols : str
-        The symbol list of the global context of the function.
-
     args : list of Tensors or Vars
         The argument lists to the function.
         It is NOT encouraged to write a function without arguments.
         It is NOT encouraged to write a function with side effect.
 
+    symbols : list of str
+        The symbol list of the global context of the function.
+
+    closure_vars: dict
+        A dict of external name reference captured by this function.
+
     Returns
     -------
     res : list of output tensors
         The result of output tensors of the formed OpNode.
     """
-    parser = parse_python(src, symbols, args)
+    parser = parse_python(src, args, symbols, closure_vars)
 
     input_tensors = []
     for i in args:
index 117ebd3..1a9de4e 100644 (file)
@@ -26,14 +26,14 @@ class PyVariableUsage(ast.NodeVisitor):
     """The vistor class to determine the declaration, r/w status, and last use of each variable"""
     #pylint: disable=invalid-name
     #pylint: disable=missing-docstring
-    def __init__(self, args, symbols):
+    def __init__(self, args, symbols, closure_vars):
         self.status = {}
         self.scope_level = []
         self._args = {}
         self.args = args
         self.aug_assign_ = False
         self.symbols = symbols
-
+        self.closure_vars = closure_vars
 
     def visit_FunctionDef(self, node):
         self.scope_level.append(node)
@@ -89,6 +89,14 @@ class PyVariableUsage(ast.NodeVisitor):
                          "Iter var cannot be overwritten")
 
         if node.id not in self.status.keys():
+            # It is a captured value in closure
+            if node.id in self.closure_vars:
+                try:
+                    ast.literal_eval(str(self.closure_vars[node.id]))
+                except ValueError:
+                    raise ValueError("Only support capturing constant values in closure")
+                return
+
             _internal_assert(isinstance(node.ctx, ast.Store), \
                              'Undeclared variable %s' % node.id)
             if self.aug_assign_:
@@ -102,8 +110,8 @@ class PyVariableUsage(ast.NodeVisitor):
             self.status[node.id] = (decl, loop, usage)
 
 
-def determine_variable_usage(root, args, symbols):
+def determine_variable_usage(root, args, symbols, closure_vars):
     """The helper function for calling the dedicated visitor."""
-    visitor = PyVariableUsage(args, symbols)
+    visitor = PyVariableUsage(args, symbols, closure_vars)
     visitor.visit(root)
     return visitor.status
index 0dd1fa1..058c5aa 100644 (file)
@@ -101,3 +101,9 @@ def _is_tvm_arg_types(args):
         _internal_assert(isinstance(elem, np_arg_types), \
                          "Expect a numpy type but %s get!" % str(type(elem)))
     return False
+
+def _apply_indices(value, indices):
+    """Apply multidimensional index"""
+    if indices:
+        return _apply_indices(value[indices[0]], indices[1:])
+    return value
index 2542646..805cff8 100644 (file)
@@ -768,6 +768,24 @@ def test_schedule():
 
     # Test loop binds
 
+def test_capture():
+    n = 8
+
+    constant_tuple = (10, n)
+    constant_list = [[1, 2], [3, n]]
+    const_value = 1
+
+    @tvm.hybrid.script
+    def add_something(a):
+        c = output_tensor((constant_tuple[1],), 'int32')
+        for i in range(constant_tuple[1]):
+            c[i] = a[i] + constant_list[1][const_value]
+        return c
+
+    a = tvm.placeholder((n, ), dtype='int32', name='a')
+
+    func, ins, outs = run_and_check(add_something, [a])
+    run_and_check(func, ins, outs=outs)
 
 if __name__ == "__main__":
     test_outer_product()
@@ -786,5 +804,6 @@ if __name__ == "__main__":
     test_bool()
     test_const_range()
     test_schedule()
+    test_capture()
     # TODO:
     # test_inplace()