from enum import Enum
-from .util import _internal_assert
+from .util import _internal_assert, _apply_indices
from . import calls
from . import util
from .preprocessor import determine_variable_usage
}
- def __init__(self, args, usage, symbols, func_name=None):
+ def __init__(self, args, usage, symbols, closure_vars, func_name=None):
"""
Parameters
----------
usage: A dict of variables used in last in this function
Provided by last lower pass, which collects this information
+ symbols : list of str
+ The symbol list of the global context of the function.
+
+ closure_vars: dict
+ A dict of external name reference captured by this function.
+
Returns
-------
func_name: str
if isinstance(v, types.FunctionType):
self.add_symbol(k, Symbol.Callable, v)
+ self.closure_vars = closure_vars
+
self.binds = {} # Thread binds
self.device = 0 # Is it generating device
def visit_Name(self, node):
name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']:
- return _api.convert(eval(name)) #pylint: disable=eval-used
+ return _api.convert(ast.literal_eval(name))
+
+ if name in self.closure_vars:
+ return _api.convert(self.closure_vars[name])
+
ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
buf = self.visit(node.value)
return getattr(buf, node.attr)
-
def visit_Subscript(self, node):
args = self.visit(node.slice)
if isinstance(node.value, ast.Name):
+ if node.value.id in self.closure_vars:
+ args = ast.literal_eval(str(args))
+ return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
buf = self.visit(node.value)
if isinstance(buf, Array):
return _make.AssertStmt(test, mesg, util.make_nop())
-def parse_python(src, symbols, args):
+def parse_python(src, args, symbols, closure_vars):
"""The helper function of calling the AST visitor
Parameters
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
- symbols : str
- The symbol list of the global context of the function.
-
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
+ symbols : list of str
+ The symbol list of the global context of the function.
+
+ closure_vars: dict
+ A dict of external name reference captured by this function.
+
Returns
-------
root : Stmt
"""
root = ast.parse(src) if isinstance(src, str) else src
_internal_assert(root, ast.AST)
- var_usage = determine_variable_usage(root, args, symbols)
- parser = HybridParser(args, var_usage, symbols)
+ var_usage = determine_variable_usage(root, args, symbols, closure_vars)
+ parser = HybridParser(args, var_usage, symbols, closure_vars)
parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!')
return parser
-def source_to_op(src, symbols, args):
+def source_to_op(src, args, symbols, closure_vars):
"""Another level of wrapper
Parameters
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
- symbols : str
- The symbol list of the global context of the function.
-
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
+ symbols : list of str
+ The symbol list of the global context of the function.
+
+ closure_vars: dict
+ A dict of external name reference captured by this function.
+
Returns
-------
res : list of output tensors
The result of output tensors of the formed OpNode.
"""
- parser = parse_python(src, symbols, args)
+ parser = parse_python(src, args, symbols, closure_vars)
input_tensors = []
for i in args:
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
- def __init__(self, args, symbols):
+ def __init__(self, args, symbols, closure_vars):
self.status = {}
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False
self.symbols = symbols
-
+ self.closure_vars = closure_vars
def visit_FunctionDef(self, node):
self.scope_level.append(node)
"Iter var cannot be overwritten")
if node.id not in self.status.keys():
+ # It is a captured value in closure
+ if node.id in self.closure_vars:
+ try:
+ ast.literal_eval(str(self.closure_vars[node.id]))
+ except ValueError:
+ raise ValueError("Only support capturing constant values in closure")
+ return
+
_internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id)
if self.aug_assign_:
self.status[node.id] = (decl, loop, usage)
-def determine_variable_usage(root, args, symbols):
+def determine_variable_usage(root, args, symbols, closure_vars):
"""The helper function for calling the dedicated visitor."""
- visitor = PyVariableUsage(args, symbols)
+ visitor = PyVariableUsage(args, symbols, closure_vars)
visitor.visit(root)
return visitor.status