From cbec5b94b87455f07918f7f4488c9a82a2d26708 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 25 Jun 2019 01:38:55 -0700 Subject: [PATCH] [Relay] Add ResizeNearestNeighbor and CropAndResize in tf converter (#3393) --- python/tvm/relay/frontend/tensorflow.py | 67 +++++++++++++++++++++++- python/tvm/relay/frontend/tensorflow_parser.py | 8 +-- tests/python/frontend/tensorflow/test_forward.py | 62 ++++++++++++++++++++-- 3 files changed, 127 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c0df8e6..1b55731 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -484,6 +484,54 @@ def _decode_image(): return inputs[0] return _impl +def _crop_and_resize(): + def _impl(inputs, attr, params): + # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] + # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] + try: + boxes = params.pop(inputs[1].name_hint).asnumpy().tolist() + box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist() + crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist() + except (IndexError, KeyError): + boxes = _infer_value(inputs[1], params).asnumpy().tolist() + box_ind = _infer_value(inputs[2], params).asnumpy().tolist() + crop_size = _infer_value(inputs[3], params).asnumpy().tolist() + + data_shape = attr['_input_shapes'][inputs[0]] + data_dim = len(data_shape) + method = attr['method'].decode() + + attrs = {} + attrs['size'] = crop_size + attrs['layout'] = 'NHWC' + if method.lower() == 'nearest': + raise tvm.error.OpAttributeUnimplemented( + 'Attribute method=nearest is not supported') + else: + attrs['align_corners'] = True + attrs['method'] = 'BILINEAR' + + out = None + begin = [0] * data_dim + size = data_shape[:] + for idx in box_ind: + # 1) Crop + # y is mapped to the image coordinate at y * (image_height - 1) + # x is mapped to the image coordinate at x * (image_width - 1) + begin[0] = idx + begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1))) + begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1))) + size[0] = idx + 1 + size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1 + size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1 + res_crop = _op.strided_slice(inputs[0], begin=begin, end=size) + + # 2) Resize + res_resize = _get_relay_op('resize')(res_crop, **attrs) + out = _op.concatenate([out, res_resize], axis=0) if out else res_resize + return out + return _impl + def _cast(): def _impl(inputs, attr, params): return inputs[0].astype(attr['DstT'].name) @@ -514,6 +562,21 @@ def _resize_bilinear(): extras={'method': "BILINEAR"})(inputs, attr) return _impl +def _resize_nearest_neighbor(): + def _impl(inputs, attr, params): + size = attr['_output_shapes'][0][1:3] + if -1 in size: + size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + attr['size'] = size + inputs.pop(1) + # NHWC + attr['layout'] = 'NHWC' + + return AttrCvt(op_name="resize", + ignores=['Tdim'], + extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr) + return _impl + def _check_numerics(): def _impl(inputs, attr, params): # Making a copy node assuming no need to verify @@ -593,7 +656,7 @@ def _slice(): end[i] = data_shape[i] - begin[i] else: end[i] += begin[i] - return _op.strided_slice(inputs[0], begin=begin, end=size) + return _op.strided_slice(inputs[0], begin=begin, end=end) return _impl @@ -1243,6 +1306,7 @@ _convert_map = { 'Concat' : _concat(), 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), + 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), 'DepthwiseConv2dNative' : _conv('depthwise'), 'DepthToSpace' : _depth_to_space(), @@ -1295,6 +1359,7 @@ _convert_map = { 'Reshape' : _reshape(), 'ResizeBilinear' : _resize_bilinear(), 'ResizeBicubic' : _resize_bilinear(), + 'ResizeNearestNeighbor' : _resize_nearest_neighbor(), 'ReverseV2' : _reverse_v2(), 'RightShift' : AttrCvt('right_shift'), 'Round' : AttrCvt('round'), diff --git a/python/tvm/relay/frontend/tensorflow_parser.py b/python/tvm/relay/frontend/tensorflow_parser.py index 9cb7eab..8105ef0 100644 --- a/python/tvm/relay/frontend/tensorflow_parser.py +++ b/python/tvm/relay/frontend/tensorflow_parser.py @@ -18,7 +18,6 @@ from __future__ import absolute_import as _abs from __future__ import print_function import os -from tensorflow.core.framework import graph_pb2 from tvm.contrib import util @@ -35,12 +34,12 @@ class TFParser(object): -------- .. code-block:: python - parser = TfParser(model_dir) - graph = parser.parse() - # graph is related graphdef of the model + parser = TFParser(model_dir) + graphdef = parser.parse() """ def __init__(self, model_dir): + from tensorflow.core.framework import graph_pb2 self._tmp_dir = util.tempdir() self._model_dir = model_dir self._graph = graph_pb2.GraphDef() @@ -96,6 +95,7 @@ class TFParser(object): from tensorflow.python.tools import freeze_graph from tensorflow.python.framework import ops from tensorflow.python.framework import graph_util + from tensorflow.core.framework import graph_pb2 except ImportError: raise ImportError( "InputConfiguration: Unable to import tensorflow which is " diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 4035d02..2fa1c73 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -949,8 +949,8 @@ def test_forward_multi_output(): tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5) ####################################################################### -# Resize Bilinear -# --------------- +# Resize Bilinear, Nearest_Neighbor +# --------------------------------- def _test_resize_bilinear(in_shape, to_shape, align_corners): """ One iteration of resize bilinear """ @@ -980,13 +980,31 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners): compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0') -def test_forward_resize_bilinear(): - """ Resize Bilinear """ + +def _test_resize_nearest_neighbor(in_shape, to_shape): + """ One iteration of resize nearest neighbor """ + + data = np.random.uniform(size=in_shape).astype('float32') + shape_data = np.array(to_shape).astype('int32') + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + shape_data = constant_op.constant( + shape_data, shape=shape_data.shape, dtype=shape_data.dtype) + tf.image.resize_nearest_neighbor(in_data, shape_data, name='resize_nearest_neighbor') + + compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0') + + +def test_forward_resize(): + """ Resize Bilinear, Nearest_Neighbor """ _test_resize_bilinear((4, 16, 32, 32), [50, 50], False) _test_resize_bilinear((6, 32, 64, 64), [20, 20], True) _test_resize_bilinear_from_tensor((4, 16, 32, 32), False) _test_resize_bilinear_from_tensor((6, 32, 50, 50), True) + _test_resize_nearest_neighbor((6, 32, 64, 64), [20, 20]) + ####################################################################### # BroadcastTo @@ -1081,6 +1099,39 @@ def test_forward_crop(): ####################################################################### +# CropAndResize +# ------------- + +def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='bilinear', dtype="float32"): + image = np.random.uniform(0, 10, size=img_shape).astype(dtype) + tf.reset_default_graph() + in_data = tf.placeholder(dtype, image.shape, name="in_data") + tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, crop_size=crop_size, + method=method, name="crop_and_resize") + compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0') + +def test_forward_crop_and_resize(): + """ CropAndResize """ + _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5]) + _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5]) + _test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5]) + _test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4]) + _test_forward_crop_and_resize([1, 106, 106, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3]) + _test_forward_crop_and_resize([10, 11, 11, 3], + [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]], + [0, 1], + [5, 5]) + _test_forward_crop_and_resize([3, 11, 11, 3], + [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8],[0, 0, 1, 1]], + [0, 1, 2], + [3, 3]) + _test_forward_crop_and_resize([3, 11, 11, 3], + [[0, 0, 1, 0.8], [0, 0, 0.9, 0.9], [0, 0, 1, 0.8]], + [2, 1, 0], + [3, 3]) + + +####################################################################### # LSTM # ---- @@ -1979,10 +2030,11 @@ if __name__ == '__main__': test_forward_depthtospace() test_forward_squeeze() test_forward_pack() - test_forward_resize_bilinear() test_forward_broadcast_to() test_forward_fill() test_forward_crop() + test_forward_resize() + test_forward_crop_and_resize() test_forward_pad() test_forward_unpack() test_forward_gather() -- 2.7.4