axis = _get_num_param(params, inputs.pop(2))
else:
axis = 0
+ if int(attr.get('batch_dims', 0)) != 0:
+ raise tvm.error.OpAttributeUnImplemented(
+ 'Attribute batch_dims is not supported')
new_input = inputs[0:2]
return AttrCvt(op_name="take",
extras={'axis': tvm.const(axis, 'int32')},
ignores=['Tindices', 'Tparams', 'validate_indices',
- 'Taxis', '_class'])(new_input, attr)
+ 'Taxis', '_class', 'batch_dims'])(new_input, attr)
return _impl
def _gather_nd():
'Square' : _square(),
'SquaredDifference' : _squared_difference(),
'Squeeze' : _squeeze(),
+ 'StopGradient' : _identity(),
'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'),
'Sum' : _sum(),