From be8fa6ac17fc82cc57390b30b247d6f932822efd Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Tue, 6 Aug 2019 14:05:06 -0700 Subject: [PATCH] [relay][frontend] clean up tf frontend (#3710) * clean up tf frontend * fix get_relay_op --- python/tvm/relay/frontend/common.py | 30 ++++- python/tvm/relay/frontend/mxnet.py | 10 +- python/tvm/relay/frontend/tensorflow.py | 204 +++----------------------------- 3 files changed, 47 insertions(+), 197 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index c5057f3..3dab51c 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -17,6 +17,8 @@ """Common utilities""" from __future__ import absolute_import as _abs import logging + +import tvm from topi.util import get_const_tuple from .. import expr as _expr from .. import module as _module @@ -224,6 +226,7 @@ class StrAttrsDict(object): raise AttributeError("Required attribute {} not found.".format(key)) return default + def get_relay_op(op_name): """Get the callable function from Relay based on operator name. Parameters @@ -246,9 +249,10 @@ def get_relay_op(op_name): if op is not None: break if not op: - raise RuntimeError("Unable to map op_name {} to relay".format(op_name)) + raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name)) return op + class ExprTable(object): """Table storing Relay expressions by names.""" def __init__(self): @@ -298,21 +302,27 @@ class AttrCvt(object): If set as str, returned operator name is the str. If set as callable, returned operator is the str returned by calling: `op_name = func(attr)` + transforms : dict of `new_name, or (new_name, default_value, transform function)` If only a new_name is provided, it's like renaming the attribute name. If default_value if provided, then the attribute is considered as optional. If transform function is provided, the original attribute value is handled by transform function. + excludes : list A list of excluded attributes that should `NOT` appear. Raise NotImplementedError if occurred. + disables : list A list of attributes that is disabled in relay. Log warnings. + ignores : list A list of attributes that is ignored in relay. Debug level logging. + extras : dict A series of additional attributes should be added anyway to the returned attribute dict. + custom_check : callable A custom function takes attribute, and return True/False. Raise RuntimeError if not bool(True) returned. @@ -329,6 +339,14 @@ class AttrCvt(object): self._custom_check = custom_check def __call__(self, inputs, attrs, *args): + self._ignores.append('_output_shapes') + self._ignores.append('_input_shapes') + self._ignores.append('T') + self._ignores.append('use_cudnn_on_gpu') + self._ignores.append('_node_name') + self._ignores.append('is_training') + self._ignores.append('_target_layout') + # apply custom check if self._custom_check: func, msg = self._custom_check @@ -348,7 +366,8 @@ class AttrCvt(object): new_attrs = {} for k in attrs.keys(): if k in self._excludes: - raise NotImplementedError("Attribute {} not supported yet.".format(k)) + raise NotImplementedError('Attribute %s in operator %s is not' + + ' supported.', k, op_name) elif k in self._disables: logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) elif k in self._ignores: @@ -401,6 +420,7 @@ class AttrCvt(object): raise AttributeError("Required attribute {} not found.".format(key)) return attr[key] + def get_name(node): name = '' if hasattr(node, "name_hint"): @@ -410,17 +430,19 @@ def get_name(node): def infer_type(node): """A method to infer the type of an intermediate node in the relay graph.""" - mod = _module.Module.from_expr(node) + mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node) mod = _transform.InferType()(mod) entry = mod["main"] return entry if isinstance(node, _expr.Function) else entry.body + def infer_shape(inputs): """A method to get the output shape of an intermediate node in the graph.""" out_type = infer_type(inputs) out_shapes = get_const_tuple(out_type.checked_type.shape) return out_shapes + def infer_channels(inputs, transpose=False): """A hack for getting 'channels' or 'units' since caffe2 does not provide these attributes. We check the shape of weights provided to get the number. @@ -430,12 +452,14 @@ def infer_channels(inputs, transpose=False): channels = out_shapes[0][0] if not transpose else out_shapes[0][1] return channels + def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): return _expr.var(name_hint, type_annotation, shape, dtype) + class Renamer(object): """A simply renamer for operators. diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 6949a6f..8d13d6c 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -20,13 +20,14 @@ from __future__ import absolute_import as _abs import json import tvm -from .. import analysis, transform +from .. import analysis from .. import expr as _expr from .. import op as _op from .. import module as _module from ... import nd as _nd from .common import StrAttrsDict +from .common import infer_type as _infer_type from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast from .nnvm_common import _clip, _transpose, _upsampling @@ -41,13 +42,6 @@ _activation_map = { "relu" : _op.nn.relu } -def _infer_type(node): - """A method to infer the type of an intermediate node in the relay graph.""" - mod = _module.Module.from_expr(node) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(node, _expr.Function) else entry.body - def _mx_fully_connected(inputs, attrs): import mxnet as mx units = attrs.get_int("num_hidden") diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 12fa8ed..756022b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -19,20 +19,21 @@ from __future__ import absolute_import as _abs from __future__ import print_function -import logging import warnings from collections import defaultdict # Numpy support import numpy as np import tvm -from topi.util import get_const_tuple from .. import analysis -from .. import transform as _transform from .. import expr as _expr from .. import op as _op from ..expr_functor import ExprMutator from .. import module as _module +from .common import AttrCvt, get_relay_op +from .common import infer_type as _infer_type +from .common import infer_shape as _infer_shape +from .common import infer_channels as _infer_channels __all__ = ['from_tensorflow'] @@ -50,140 +51,6 @@ def _infer_value(input_val, params): m.run() return m.get_output(0) -def _get_relay_op(op_name): - ops = [_op, _op.nn, _op.image, _op.vision] - for operator in ops: - try: - op = getattr(operator, op_name) - return op - except AttributeError: - continue - - raise tvm.error.OpNotImplemented( - 'Operator {} is not supported for frontend TensorFlow.'.format(op_name)) - -class AttrCvt(object): - """Common attribute converter. An AttrConverter instance is a callable: - ``` - attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) - new_op_name, new_attr = attr_converter(attrs) - ``` - - Parameters - ---------- - op_name : str or callable - If set as str, returned operator name is the str. - If set as callable, returned operator is the str returned by calling: - `op_name = func(attr)` - transforms : dict of `new_name, or (new_name, default_value, transform function)` - If only a new_name is provided, it's like renaming the attribute name. - If default_value if provided, then the attribute is considered as optional. - If transform function is provided, the original attribute value is handled - by transform function. - excludes : list - A list of excluded attributes that should `NOT` appear. - Raise NotImplementedError if occurred. - disables : list - A list of attributes that is disabled in relay. Log warnings. - ignores : list - A list of attributes that is ignored in relay. Debug level logging. - extras : dict - A series of additional attributes should be added anyway to the returned - attribute dict. - custom_check : callable - A custom function takes attribute, and return True/False. - Raise RuntimeError if not bool(True) returned. - """ - - def __init__(self, op_name, transforms=None, - excludes=None, disables=None, ignores=None, - extras=None, custom_check=None): - self._op_name = op_name - self._transforms = transforms if transforms else {} - self._excludes = excludes if excludes else [] - self._disables = disables if disables else [] - self._ignores = ignores if ignores else [] - self._extras = extras if extras else {} - self._custom_check = custom_check - - def __call__(self, inputs, attrs, *args): - self._ignores.append('_output_shapes') - self._ignores.append('_input_shapes') - self._ignores.append('T') - self._ignores.append('use_cudnn_on_gpu') - self._ignores.append('_node_name') - self._ignores.append('is_training') - self._ignores.append('_target_layout') - - # apply custom check - if self._custom_check: - func, msg = self._custom_check - if not func(attrs): - raise RuntimeError("Check failed: {}".format(msg)) - # get new op_name - if isinstance(self._op_name, str): - op_name = self._op_name - else: - assert callable(self._op_name), "op_name can either be string or callable" - op_name = self._op_name(attrs) - # convert attributes - new_attrs = {} - for k in attrs.keys(): - if k in self._excludes: - raise tvm.error.OpAttributeUnImplemented( - 'Attribute {} in operator {} is not supported.'.format(k, op_name)) - elif k in self._disables: - logging.warning("Attribute %s is disabled in relay.%s", k, op_name) - elif k in self._ignores: - logging.debug("Attribute %s is ignored in relay.%s", k, op_name) - elif k in self._transforms: - new_name, defaults, transform = self._parse_default(self._transforms[k]) - if defaults is None: - new_attr = self._required_attr(attrs, k) - else: - new_attr = attrs.get(k, None) - if new_attr is None: - new_attrs[new_name] = defaults - else: - new_attrs[new_name] = transform(new_attr) - else: - # copy - new_attrs[k] = attrs[k] - # add extras - new_attrs.update(self._extras) - return _get_relay_op(op_name)(*inputs, **new_attrs) - - def _parse_default(self, target): - """Helper function to parse default values.""" - if not isinstance(target, (list, tuple)): - k, v, t = target, None, lambda x: x - elif len(target) == 1: - k, v, t = target[0], None, lambda x: x - elif len(target) == 2: - k, v, t = target[0], target[1], lambda x: x - elif len(target) > 2: - k, v, t = target[0], target[1], target[2] - else: - k = None # should raise - if not isinstance(k, str): - msg = "{} is not a valid target, (name, default) expected.".format(target) - raise ValueError(msg) - return k, v, t - - def _parse_bool(self, value): - """Helper function to parse default boolean values.""" - if isinstance(value, str): - return value.strip().lower() in ['true', '1', 't', 'y', 'yes'] - return bool(value) - - def _required_attr(self, attr, key): - """Wrapper for getting required attributes.""" - assert isinstance(attr, dict) - if key not in attr: - raise tvm.error.OpAttributeRequired( - 'Attribute {} not found in operator {}'.format(key, self._op_name)) - return attr[key] - def _get_pad_pair(input1d, kernel1d, stride1d): if input1d % stride1d == 0: pad = max(kernel1d - stride1d, 0) @@ -195,12 +62,6 @@ def _get_pad_pair(input1d, kernel1d, stride1d): return [pad_before, pad_after] -def _get_name_hint(node): - name = '' - if hasattr(node, "name_hint"): - name = node.name_hint - return name - def _math_name_picker(surfix): def _impl(attr): return 'broadcast_' + surfix @@ -222,30 +83,6 @@ def _dimension_constraint(): return False return _dim_check, "Only 2d kernel supported." -def _infer_channels(node, params, transpose=False): - """A hack for getting 'channels' or 'units' since tensorflow don't provide - these attributes. We check the shape of weights provided to get the number. - """ - out_shape = _infer_shape(node, params) - channels = out_shape[0] if not transpose else out_shape[1] - return channels - -def _infer_out_shapes(inputs, params): - """A method to get the output shape of intermediate nodes in the relay graph.""" - return [_infer_shape(inputs, params)] - -def _infer_type(node): - """A method to infer the type of an intermediate node in the relay graph.""" - mod = _module.Module.from_expr(node) - mod = _transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(node, _expr.Function) else entry.body - -def _infer_shape(node, params=None): - """A method to get the output shape of an intermediate node in the relay graph.""" - out_type = _infer_type(node) - return get_const_tuple(out_type.checked_type.shape) - def _get_param(params, input_node): return params.pop(input_node.name_hint).asnumpy() @@ -280,7 +117,7 @@ def _argx(func, func_name): def _elemwise(name): def _impl(inputs, attr, params): assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) - return _get_relay_op(name)(*inputs) + return get_relay_op(name)(*inputs) return _impl def _pooling(name): @@ -300,7 +137,7 @@ def _pooling(name): else: msg = 'Value {} of attribute "data_format" of operator Pooling ' \ 'is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(attrs['data_format'])) + raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": tmp_shape = attr['_input_shapes'][inputs[0]] @@ -539,7 +376,7 @@ def _crop_and_resize(): res_crop = _op.strided_slice(inputs[0], begin=begin, end=size) # 2) Resize - res_resize = _get_relay_op('resize')(res_crop, **attrs) + res_resize = get_relay_op('resize')(res_crop, **attrs) out = _op.concatenate([out, res_resize], axis=0) if out else res_resize return out return _impl @@ -598,7 +435,7 @@ def _check_numerics(): def _matmul(): def _impl(inputs, attr, params): - channels = _infer_channels(inputs[1], params, not attr['transpose_b']) + channels = _infer_channels(inputs[1], not attr['transpose_b']) if attr['transpose_a']: inputs[0] = _op.transpose(inputs[0], axes=(1, 0)) if not attr['transpose_b']: @@ -615,15 +452,10 @@ def _batch_matmul(): adj_y = attr['adj_y'] input_x = _op.transpose(inputs[0], axes=[0, 2, 1]) if adj_x else inputs[0] input_y = _op.transpose(inputs[1], axes=[0, 2, 1]) if not adj_y else inputs[1] - ret = _get_relay_op('batch_matmul')(input_x, input_y) + ret = get_relay_op('batch_matmul')(input_x, input_y) return ret return _impl -def _undef(): - def _impl(inputs, attr, params): - return _sym.__undef__() - return _impl - def _identity(): def _impl(inputs, attr, params): return inputs[0] @@ -985,7 +817,7 @@ def _stridedSlice(): if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_shape(out, params) + out_shape = _infer_shape(out) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -1178,8 +1010,8 @@ def _softplus(): exp_out = AttrCvt('exp')(inputs, attr) inputs.append(tvm.relay.const(1, attr['T'].name)) rh = tvm.relay.const(1, attr['T'].name) - add_out = _get_relay_op('add')(exp_out, rh) - return _get_relay_op('log')(add_out) + add_out = get_relay_op('add')(exp_out, rh) + return get_relay_op('log')(add_out) return _impl def _topk(): @@ -1200,7 +1032,7 @@ def _floordiv(): def _impl(inputs, attr, params): assert len(inputs) == 2 div = AttrCvt('divide')(inputs, attr) - return _get_relay_op('floor')(div) + return get_relay_op('floor')(div) return _impl def _logical(name): @@ -1234,7 +1066,7 @@ def _space_to_batch_nd(): axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params) + permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded) # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # producing an output tensor of shape: # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., @@ -1277,7 +1109,7 @@ def _batch_to_space_nd(): # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # input_shape[M+1], ..., input_shape[N-1]] - reshaped_permuted_shape = _infer_shape(reshaped_permuted, params) + reshaped_permuted_shape = _infer_shape(reshaped_permuted) cropped = reshaped_permuted for axis in range(1, M+1): crop = crops[axis - 1] @@ -1305,8 +1137,8 @@ def _log1p(): # op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p def _impl(inputs, attr, params): one = tvm.relay.const(1, attr['T'].name) - add_out = _get_relay_op('add')(inputs[0], one) - return _get_relay_op('log')(add_out) + add_out = get_relay_op('add')(inputs[0], one) + return get_relay_op('log')(add_out) return _impl # compatible operators that do NOT require any conversion. @@ -2399,7 +2231,7 @@ class GraphProto(object): convert_map = convert_map if convert_map else _convert_map convert_map_rnn = _convert_map_rnn if op_name in identity_list: - sym = _get_relay_op(op_name)(*inputs, **attrs) + sym = get_relay_op(op_name)(*inputs, **attrs) elif op_name in convert_map: sym = convert_map[op_name](inputs, attrs, self._params) elif op_name in convert_map_rnn: -- 2.7.4