Summary:
For #6593 and #9515
This completes the support for optional<ScalarType> in native, JIT and autograd.
Note: Mostly following the existing implementation for optional<Scalar> that was added in https://github.com/pytorch/pytorch/pull/12582.
This PR introduces a way to make functions accept an optional dtype and it will unblock #9515 by allowing the `dtype` param for type promotion interface:
```
func: name(inputs, *, ScalarType? dtype=None, Casting casting=same_kind)
```
An alternative approach could have been using `ScalarType::Undefined` for the same purpose but without optional, though it would have been a bit hacky.
```
func: name(inputs, *, ScalarType dtype=Undefined, Casting casting=same_kind)
```
Here's an example use of this in action: https://github.com/pytorch/pytorch/pull/15133/commits/
971f69eac69101955ed90078b44dab975d37a4f7
There are already a bunch of native functions that were getting optional `dtype` through function overloading. https://github.com/pytorch/pytorch/pull/15133 is the attempt to migrate all of those. I will send those changes separately after this since some functions (e.g. sum) need quite a bit of change in the codebase. See the commits over there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15154
Differential Revision:
D13457760
Pulled By: tugrulates
fbshipit-source-id:
706134f0bd578683edd416b96329b49a1ba8ab48
'const Type &': 'scalartype',
'const THPLayout &': 'layout',
'const Device &': 'device',
- 'optional<ScalarType>': 'scalartypeOptional',
'int64_t': 'toInt64',
'bool': 'toBool',
'double': 'toDouble',
'Layout': 'Layout',
'Device': 'Device',
'ScalarType': 'ScalarType',
+ 'ScalarType?': 'ScalarType?',
'int64_t': 'int',
'double': 'float',
'bool': 'bool',
'Scalar': '{}.toScalar()',
'Scalar?': '{}.toOptional<Scalar>()',
'ScalarType': '{}.toScalarType()',
+ 'ScalarType?': '{}.toOptional<ScalarType>()',
'Tensor': '{}.toTensor()',
'TensorList': '{}.toTensorList()->elements()',
'bool': '{}.toBool()',
using autograd::Variable;
using autograd::variable_list;
using at::Scalar;
+using at::ScalarType;
using at::Tensor;
using at::TensorOptions;
using at::DeviceGuard;
else if(n->kind().is_prim()){
switch(n->kind()){
case prim::Constant:
+ case prim::None:
visitConstant(n, block, res_block);
break;
case prim::NumToTensor:
void addInputs(Node *n, const char * name, at::ScalarType value) {
detail::genericAddInput(n, static_cast<int64_t>(value));
}
+void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value) {
+ if(value) {
+ detail::genericAddInput(n, static_cast<int64_t>(*value));
+ } else {
+ Graph * g = n->owningGraph();
+ Value* none =
+ g->insertNode(g->createNone(IntType::get()))
+ ->output();
+ n->addInput(none);
+ }
+}
void addInputs(Node *n, const char * name, at::TensorList value) {
Graph *g = n->owningGraph();
TORCH_API void addInputs(Node *n, const char * name, at::Device value);
TORCH_API void addInputs(Node *n, const char * name, at::Layout value);
TORCH_API void addInputs(Node *n, const char * name, at::ScalarType value);
+TORCH_API void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value);
TORCH_API void addInputs(Node *n, const char * name, at::Generator * value);
template<size_t N>
{"Storage", ParameterType::STORAGE},
{"PyObject*", ParameterType::PYOBJECT},
{"ScalarType", ParameterType::SCALARTYPE},
- {"optional<ScalarType>", ParameterType::SCALARTYPE},
{"Layout", ParameterType::LAYOUT},
{"Device", ParameterType::DEVICE},
{"std::string", ParameterType::STRING},