[Relay][Frontend][Tensorflow] Fix GatherV2, Add StopGradient (#4238)
authorTrevor Morris <trevoraidanmorris@gmail.com>
Mon, 4 Nov 2019 18:37:41 +0000 (10:37 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Mon, 4 Nov 2019 18:37:41 +0000 (10:37 -0800)
* Add StopGradient. Add batch_dims attr to ignore list for GatherV2

* Trigger CI

python/tvm/relay/frontend/tensorflow.py

index 4bee712..0abcb09 100644 (file)
@@ -872,11 +872,14 @@ def _gather():
             axis = _get_num_param(params, inputs.pop(2))
         else:
             axis = 0
+        if int(attr.get('batch_dims', 0)) != 0:
+            raise tvm.error.OpAttributeUnImplemented(
+                'Attribute batch_dims is not supported')
         new_input = inputs[0:2]
         return AttrCvt(op_name="take",
                        extras={'axis': tvm.const(axis, 'int32')},
                        ignores=['Tindices', 'Tparams', 'validate_indices',
-                                'Taxis', '_class'])(new_input, attr)
+                                'Taxis', '_class', 'batch_dims'])(new_input, attr)
     return _impl
 
 def _gather_nd():
@@ -1472,6 +1475,7 @@ _convert_map = {
     'Square'                            : _square(),
     'SquaredDifference'                 : _squared_difference(),
     'Squeeze'                           : _squeeze(),
+    'StopGradient'                      : _identity(),
     'StridedSlice'                      : _stridedSlice(),
     'Sub'                               : _elemwise('subtract'),
     'Sum'                               : _sum(),