[relay][frontend] clean up tf frontend (#3710)
authorZhi <5145158+zhiics@users.noreply.github.com>
Tue, 6 Aug 2019 21:05:06 +0000 (14:05 -0700)
committerThierry Moreau <moreau@uw.edu>
Tue, 6 Aug 2019 21:05:06 +0000 (14:05 -0700)
* clean up tf frontend

* fix get_relay_op

python/tvm/relay/frontend/common.py
python/tvm/relay/frontend/mxnet.py
python/tvm/relay/frontend/tensorflow.py

index c5057f3..3dab51c 100644 (file)
@@ -17,6 +17,8 @@
 """Common utilities"""
 from __future__ import absolute_import as _abs
 import logging
+
+import tvm
 from topi.util import get_const_tuple
 from .. import expr as _expr
 from .. import module as _module
@@ -224,6 +226,7 @@ class StrAttrsDict(object):
             raise AttributeError("Required attribute {} not found.".format(key))
         return default
 
+
 def get_relay_op(op_name):
     """Get the callable function from Relay based on operator name.
     Parameters
@@ -246,9 +249,10 @@ def get_relay_op(op_name):
             if op is not None:
                 break
     if not op:
-        raise RuntimeError("Unable to map op_name {} to relay".format(op_name))
+        raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
     return op
 
+
 class ExprTable(object):
     """Table storing Relay expressions by names."""
     def __init__(self):
@@ -298,21 +302,27 @@ class AttrCvt(object):
         If set as str, returned operator name is the str.
         If set as callable, returned operator is the str returned by calling:
         `op_name = func(attr)`
+
     transforms : dict of `new_name, or (new_name, default_value, transform function)`
         If only a new_name is provided, it's like renaming the attribute name.
         If default_value if provided, then the attribute is considered as optional.
         If transform function is provided, the original attribute value is handled
         by transform function.
+
     excludes : list
         A list of excluded attributes that should `NOT` appear.
         Raise NotImplementedError if occurred.
+
     disables : list
         A list of attributes that is disabled in relay. Log warnings.
+
     ignores : list
         A list of attributes that is ignored in relay. Debug level logging.
+
     extras : dict
         A series of additional attributes should be added anyway to the returned
         attribute dict.
+
     custom_check : callable
         A custom function takes attribute, and return True/False.
         Raise RuntimeError if not bool(True) returned.
@@ -329,6 +339,14 @@ class AttrCvt(object):
         self._custom_check = custom_check
 
     def __call__(self, inputs, attrs, *args):
+        self._ignores.append('_output_shapes')
+        self._ignores.append('_input_shapes')
+        self._ignores.append('T')
+        self._ignores.append('use_cudnn_on_gpu')
+        self._ignores.append('_node_name')
+        self._ignores.append('is_training')
+        self._ignores.append('_target_layout')
+
         # apply custom check
         if self._custom_check:
             func, msg = self._custom_check
@@ -348,7 +366,8 @@ class AttrCvt(object):
         new_attrs = {}
         for k in attrs.keys():
             if k in self._excludes:
-                raise NotImplementedError("Attribute {} not supported yet.".format(k))
+                raise NotImplementedError('Attribute %s in operator %s is not' +
+                                          ' supported.', k, op_name)
             elif k in self._disables:
                 logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
             elif k in self._ignores:
@@ -401,6 +420,7 @@ class AttrCvt(object):
             raise AttributeError("Required attribute {} not found.".format(key))
         return attr[key]
 
+
 def get_name(node):
     name = ''
     if hasattr(node, "name_hint"):
@@ -410,17 +430,19 @@ def get_name(node):
 
 def infer_type(node):
     """A method to infer the type of an intermediate node in the relay graph."""
-    mod = _module.Module.from_expr(node)
+    mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
     mod = _transform.InferType()(mod)
     entry = mod["main"]
     return entry if isinstance(node, _expr.Function) else entry.body
 
+
 def infer_shape(inputs):
     """A method to get the output shape of an intermediate node in the graph."""
     out_type = infer_type(inputs)
     out_shapes = get_const_tuple(out_type.checked_type.shape)
     return out_shapes
 
+
 def infer_channels(inputs, transpose=False):
     """A hack for getting 'channels' or 'units' since caffe2 does not provide
     these attributes. We check the shape of weights provided to get the number.
@@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False):
     channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
     return channels
 
+
 def new_var(name_hint,
             type_annotation=None,
             shape=None,
             dtype="float32"):
     return _expr.var(name_hint, type_annotation, shape, dtype)
 
+
 class Renamer(object):
     """A simply renamer for operators.
 
index 6949a6f..8d13d6c 100644 (file)
@@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs
 
 import json
 import tvm
-from .. import analysis, transform
+from .. import analysis
 from .. import expr as _expr
 from .. import op as _op
 from .. import module as _module
 from ... import nd as _nd
 
 from .common import StrAttrsDict
+from .common import infer_type as _infer_type
 from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
 from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
 from .nnvm_common import _clip, _transpose, _upsampling
@@ -41,13 +42,6 @@ _activation_map = {
     "relu"   : _op.nn.relu
 }
 
-def _infer_type(node):
-    """A method to infer the type of an intermediate node in the relay graph."""
-    mod = _module.Module.from_expr(node)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(node, _expr.Function) else entry.body
-
 def _mx_fully_connected(inputs, attrs):
     import mxnet as mx
     units = attrs.get_int("num_hidden")
index 12fa8ed..756022b 100644 (file)
 from __future__ import absolute_import as _abs
 from __future__ import print_function
 
-import logging
 import warnings
 from collections import defaultdict
 # Numpy support
 import numpy as np
 
 import tvm
-from topi.util import get_const_tuple
 from .. import analysis
-from .. import transform as _transform
 from .. import expr as _expr
 from .. import op as _op
 from ..expr_functor import ExprMutator
 from .. import module as _module
+from .common import AttrCvt, get_relay_op
+from .common import infer_type as _infer_type
+from .common import infer_shape as _infer_shape
+from .common import infer_channels as _infer_channels
 
 __all__ = ['from_tensorflow']
 
@@ -50,140 +51,6 @@ def _infer_value(input_val, params):
     m.run()
     return m.get_output(0)
 
-def _get_relay_op(op_name):
-    ops = [_op, _op.nn, _op.image, _op.vision]
-    for operator in ops:
-        try:
-            op = getattr(operator, op_name)
-            return op
-        except AttributeError:
-            continue
-
-    raise tvm.error.OpNotImplemented(
-        'Operator {} is not supported for frontend TensorFlow.'.format(op_name))
-
-class AttrCvt(object):
-    """Common attribute converter. An AttrConverter instance is a callable:
-    ```
-    attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
-    new_op_name, new_attr = attr_converter(attrs)
-    ```
-
-    Parameters
-    ----------
-    op_name : str or callable
-        If set as str, returned operator name is the str.
-        If set as callable, returned operator is the str returned by calling:
-        `op_name = func(attr)`
-    transforms : dict of `new_name, or (new_name, default_value, transform function)`
-        If only a new_name is provided, it's like renaming the attribute name.
-        If default_value if provided, then the attribute is considered as optional.
-        If transform function is provided, the original attribute value is handled
-        by transform function.
-    excludes : list
-        A list of excluded attributes that should `NOT` appear.
-        Raise NotImplementedError if occurred.
-    disables : list
-        A list of attributes that is disabled in relay. Log warnings.
-    ignores : list
-        A list of attributes that is ignored in relay. Debug level logging.
-    extras : dict
-        A series of additional attributes should be added anyway to the returned
-        attribute dict.
-    custom_check : callable
-        A custom function takes attribute, and return True/False.
-        Raise RuntimeError if not bool(True) returned.
-    """
-
-    def __init__(self, op_name, transforms=None,
-                 excludes=None, disables=None, ignores=None,
-                 extras=None, custom_check=None):
-        self._op_name = op_name
-        self._transforms = transforms if transforms else {}
-        self._excludes = excludes if excludes else []
-        self._disables = disables if disables else []
-        self._ignores = ignores if ignores else []
-        self._extras = extras if extras else {}
-        self._custom_check = custom_check
-
-    def __call__(self, inputs, attrs, *args):
-        self._ignores.append('_output_shapes')
-        self._ignores.append('_input_shapes')
-        self._ignores.append('T')
-        self._ignores.append('use_cudnn_on_gpu')
-        self._ignores.append('_node_name')
-        self._ignores.append('is_training')
-        self._ignores.append('_target_layout')
-
-        # apply custom check
-        if self._custom_check:
-            func, msg = self._custom_check
-            if not func(attrs):
-                raise RuntimeError("Check failed: {}".format(msg))
-        # get new op_name
-        if isinstance(self._op_name, str):
-            op_name = self._op_name
-        else:
-            assert callable(self._op_name), "op_name can either be string or callable"
-            op_name = self._op_name(attrs)
-        # convert attributes
-        new_attrs = {}
-        for k in attrs.keys():
-            if k in self._excludes:
-                raise tvm.error.OpAttributeUnImplemented(
-                    'Attribute {} in operator {} is not supported.'.format(k, op_name))
-            elif k in self._disables:
-                logging.warning("Attribute %s is disabled in relay.%s", k, op_name)
-            elif k in self._ignores:
-                logging.debug("Attribute %s is ignored in relay.%s", k, op_name)
-            elif k in self._transforms:
-                new_name, defaults, transform = self._parse_default(self._transforms[k])
-                if defaults is None:
-                    new_attr = self._required_attr(attrs, k)
-                else:
-                    new_attr = attrs.get(k, None)
-                if new_attr is None:
-                    new_attrs[new_name] = defaults
-                else:
-                    new_attrs[new_name] = transform(new_attr)
-            else:
-                # copy
-                new_attrs[k] = attrs[k]
-        # add extras
-        new_attrs.update(self._extras)
-        return _get_relay_op(op_name)(*inputs, **new_attrs)
-
-    def _parse_default(self, target):
-        """Helper function to parse default values."""
-        if not isinstance(target, (list, tuple)):
-            k, v, t = target, None, lambda x: x
-        elif len(target) == 1:
-            k, v, t = target[0], None, lambda x: x
-        elif len(target) == 2:
-            k, v, t = target[0], target[1], lambda x: x
-        elif len(target) > 2:
-            k, v, t = target[0], target[1], target[2]
-        else:
-            k = None  # should raise
-        if not isinstance(k, str):
-            msg = "{} is not a valid target, (name, default) expected.".format(target)
-            raise ValueError(msg)
-        return k, v, t
-
-    def _parse_bool(self, value):
-        """Helper function to parse default boolean values."""
-        if isinstance(value, str):
-            return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
-        return bool(value)
-
-    def _required_attr(self, attr, key):
-        """Wrapper for getting required attributes."""
-        assert isinstance(attr, dict)
-        if key not in attr:
-            raise tvm.error.OpAttributeRequired(
-                'Attribute {} not found in operator {}'.format(key, self._op_name))
-        return attr[key]
-
 def _get_pad_pair(input1d, kernel1d, stride1d):
     if input1d % stride1d == 0:
         pad = max(kernel1d - stride1d, 0)
@@ -195,12 +62,6 @@ def _get_pad_pair(input1d, kernel1d, stride1d):
 
     return [pad_before, pad_after]
 
-def _get_name_hint(node):
-    name = ''
-    if hasattr(node, "name_hint"):
-        name = node.name_hint
-    return name
-
 def _math_name_picker(surfix):
     def _impl(attr):
         return 'broadcast_' + surfix
@@ -222,30 +83,6 @@ def _dimension_constraint():
         return False
     return _dim_check, "Only 2d kernel supported."
 
-def _infer_channels(node, params, transpose=False):
-    """A hack for getting 'channels' or 'units' since tensorflow don't provide
-    these attributes. We check the shape of weights provided to get the number.
-    """
-    out_shape = _infer_shape(node, params)
-    channels = out_shape[0] if not transpose else out_shape[1]
-    return channels
-
-def _infer_out_shapes(inputs, params):
-    """A method to get the output shape of intermediate nodes in the relay graph."""
-    return [_infer_shape(inputs, params)]
-
-def _infer_type(node):
-    """A method to infer the type of an intermediate node in the relay graph."""
-    mod = _module.Module.from_expr(node)
-    mod = _transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(node, _expr.Function) else entry.body
-
-def _infer_shape(node, params=None):
-    """A method to get the output shape of an intermediate node in the relay graph."""
-    out_type = _infer_type(node)
-    return get_const_tuple(out_type.checked_type.shape)
-
 def _get_param(params, input_node):
     return params.pop(input_node.name_hint).asnumpy()
 
@@ -280,7 +117,7 @@ def _argx(func, func_name):
 def _elemwise(name):
     def _impl(inputs, attr, params):
         assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
-        return _get_relay_op(name)(*inputs)
+        return get_relay_op(name)(*inputs)
     return _impl
 
 def _pooling(name):
@@ -300,7 +137,7 @@ def _pooling(name):
         else:
             msg = 'Value {} of attribute "data_format" of operator Pooling ' \
                   'is not valid.'
-            raise tvm.error.OpAttributeInvalid(msg.format(attrs['data_format']))
+            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
 
         if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
             tmp_shape = attr['_input_shapes'][inputs[0]]
@@ -539,7 +376,7 @@ def _crop_and_resize():
             res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
 
             # 2) Resize
-            res_resize = _get_relay_op('resize')(res_crop, **attrs)
+            res_resize = get_relay_op('resize')(res_crop, **attrs)
             out = _op.concatenate([out, res_resize], axis=0) if out else res_resize
         return out
     return _impl
@@ -598,7 +435,7 @@ def _check_numerics():
 
 def _matmul():
     def _impl(inputs, attr, params):
-        channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
+        channels = _infer_channels(inputs[1], not attr['transpose_b'])
         if attr['transpose_a']:
             inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
         if not attr['transpose_b']:
@@ -615,15 +452,10 @@ def _batch_matmul():
         adj_y = attr['adj_y']
         input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0]
         input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1]
-        ret = _get_relay_op('batch_matmul')(input_x, input_y)
+        ret = get_relay_op('batch_matmul')(input_x, input_y)
         return ret
     return _impl
 
-def _undef():
-    def _impl(inputs, attr, params):
-        return _sym.__undef__()
-    return _impl
-
 def _identity():
     def _impl(inputs, attr, params):
         return inputs[0]
@@ -985,7 +817,7 @@ def _stridedSlice():
         if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
             begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
         out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
-        out_shape = _infer_shape(out, params)
+        out_shape = _infer_shape(out)
         if not fshape_indices:
             fshape_indices = range(len(out_shape))
 
@@ -1178,8 +1010,8 @@ def _softplus():
         exp_out = AttrCvt('exp')(inputs, attr)
         inputs.append(tvm.relay.const(1, attr['T'].name))
         rh = tvm.relay.const(1, attr['T'].name)
-        add_out = _get_relay_op('add')(exp_out, rh)
-        return _get_relay_op('log')(add_out)
+        add_out = get_relay_op('add')(exp_out, rh)
+        return get_relay_op('log')(add_out)
     return _impl
 
 def _topk():
@@ -1200,7 +1032,7 @@ def _floordiv():
     def _impl(inputs, attr, params):
         assert len(inputs) == 2
         div = AttrCvt('divide')(inputs, attr)
-        return _get_relay_op('floor')(div)
+        return get_relay_op('floor')(div)
     return _impl
 
 def _logical(name):
@@ -1234,7 +1066,7 @@ def _space_to_batch_nd():
         axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
                list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
         permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
-        permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params)
+        permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded)
         # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
         # producing an output tensor of shape:
         # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
@@ -1277,7 +1109,7 @@ def _batch_to_space_nd():
         # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
         #  ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
         #  input_shape[M+1], ..., input_shape[N-1]]
-        reshaped_permuted_shape = _infer_shape(reshaped_permuted, params)
+        reshaped_permuted_shape = _infer_shape(reshaped_permuted)
         cropped = reshaped_permuted
         for axis in range(1, M+1):
             crop = crops[axis - 1]
@@ -1305,8 +1137,8 @@ def _log1p():
     # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
     def _impl(inputs, attr, params):
         one = tvm.relay.const(1, attr['T'].name)
-        add_out = _get_relay_op('add')(inputs[0], one)
-        return _get_relay_op('log')(add_out)
+        add_out = get_relay_op('add')(inputs[0], one)
+        return get_relay_op('log')(add_out)
     return _impl
 
 # compatible operators that do NOT require any conversion.
@@ -2399,7 +2231,7 @@ class GraphProto(object):
         convert_map = convert_map if convert_map else _convert_map
         convert_map_rnn = _convert_map_rnn
         if op_name in identity_list:
-            sym = _get_relay_op(op_name)(*inputs, **attrs)
+            sym = get_relay_op(op_name)(*inputs, **attrs)
         elif op_name in convert_map:
             sym = convert_map[op_name](inputs, attrs, self._params)
         elif op_name in convert_map_rnn: