From: Jon Soifer Date: Sun, 15 Sep 2019 20:03:19 +0000 (-0700) Subject: [Relay][TensorFlow] Add support for SquaredDifference (#3930) X-Git-Tag: upstream/0.7.0~1910 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0482623e9c20518fac8e6ceb34a6552f674fed9b;p=platform%2Fupstream%2Ftvm.git [Relay][TensorFlow] Add support for SquaredDifference (#3930) * Add support for SquaredDifference and StopGradient; minor fix in BatchMatMul * Remove stopgradient change * Resolve PR comment * Dummy change to retrigger CI * dummy change to retrigger CI --- diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7114235..91d3da3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -469,7 +469,9 @@ def _batch_matmul(): # reshape result back to n-dimensional if len(orig_shape_x) > 3: - final_shape = attr['_output_shapes'][0] + final_shape = list(orig_shape_x) + final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2] + final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1] ret = _op.reshape(ret, newshape=final_shape) return ret @@ -1227,6 +1229,12 @@ def _one_hot(): extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr) return _impl +def _squared_difference(): + def _impl(inputs, attr, params): + difference = _op.subtract(inputs[0], inputs[1]) + return _op.multiply(difference, difference) + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1334,6 +1342,7 @@ _convert_map = { 'SplitV' : _split(True), 'Sqrt' : AttrCvt('sqrt'), 'Square' : _square(), + 'SquaredDifference' : _squared_difference(), 'Squeeze' : _squeeze(), 'StridedSlice' : _stridedSlice(), 'Sub' : _elemwise('subtract'), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 43a1809..29dba54 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1852,6 +1852,16 @@ def test_forward_erf(): tf.math.erf(in1) compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0') +def test_forward_squared_difference(): + ishape = (1, 3, 10, 14) + inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32) + inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32) + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1") + in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2") + out = tf.math.squared_difference(in1, in2) + compare_tf_with_tvm([inp_array_a, inp_array_b], [in1.name, in2.name], out.name) + def _test_forward_reverse_v2(in_shape, axis, dtype): np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype) tf.reset_default_graph() @@ -2253,6 +2263,7 @@ if __name__ == '__main__': test_forward_bias_add() test_forward_zeros_like() test_forward_erf() + test_forward_squared_difference() # Reductions test_forward_argminmax()