# type: (bool) -> Dict[str, str]
return {
'Tensor': 'const Tensor &' if const else 'Tensor &',
+ 'BoolTensor': 'const Tensor &' if const else 'Tensor &',
+ 'IndexTensor': 'const Tensor &' if const else 'Tensor &',
'Type': 'const Type &' if const else 'Type &',
'TensorOptions': 'const TensorOptions &' if const else 'TensorOptions &',
'TensorList': 'TensorList',
`Tensor!` - shorthand for Tensor(fresh\_identifier!)
`Tensor(a! -> a|b)` - Tensor is in set `a`, written to, and after the write is in set `a` AND `b`.
For more details on when and why this needs to happen, please see the section on annotations.
+- Tensors of specific types. At the moment, valid type names are:
+ - `IntegerTensor` (a.k.a. `LongTensor`)
+ - `BoolTensor` (a.k.a. `ByteTensor`)
+ - `IndexTensor` (a.k.a. `IntTensor`)
+ These type names were inherited from TH, and may be renamed soon, so
+ don't commit them to memory.
- `Tensor[]`. A `Tensor[]` argument translates into a C++ argument of type `ArrayRef<Tensor>`
(a.k.a. `TensorList`)
- `int[]`. `int[]` accepts an optional length specifier, e.g., `int[2]`, which
base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
view_info = VIEW_FUNCTIONS.get(base_name, None)
+ # These exclude things like BoolTensor, int64_t, and Scalar
def is_differentiable(arg):
if 'TensorOptions' in arg['type']:
return False
if 'Tensor' not in arg['type']:
return False
if arg['dynamic_type'] in {'IndexTensor', 'BoolTensor'}:
+ # TODO: Enable this after native_functions.yaml schema unification.
# These are necessary for legacy code and should be
# used by legacy code only!
- assert declaration['mode'] == 'TH' or declaration['mode'] == 'NN', \
- "IndexTensor and BoolTensor are restricted to legacy TH/THNN functions only."
+ # assert name.startswith('_th_'), \
+ # "IndexTensor and BoolTensor are restricted to legacy _th_ functions only.
return False
return True