a = next(graph.outputs()).type().kind()
self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
+ def test_shape_prop_promotion(self):
+ @torch.jit.script
+ def fn(x, y):
+ return x + y
+
+ x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double)
+ graph = fn._get_method('forward').propagate_shapes((x, y), False)
+ FileCheck().check('Double(*, *) = aten::add').run(graph)
+
+ def test_shape_prop_promote_scalar_arg(self):
+ @torch.jit.script
+ def fn(x):
+ return math.pi + x
+
+ x = torch.zeros(3, 4, dtype=torch.long)
+ graph = fn._get_method('forward').propagate_shapes((x,), False)
+ FileCheck().check('Long(*, *) = aten::add').run(graph)
+
def test_integral_shape_inference(self):
cu = torch.jit.CompilationUnit('''
def test_integral_shape_inference(a):
"aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
"aten::__irshift__(Tensor self, Tensor other) -> Tensor",
+ // Ops with Tensor-Tensor overloads only
+ "aten::atan2(Tensor self, Tensor other) -> Tensor",
+ },
+ [this](Node* node) -> type_vec_t {
+ if (auto maybe_tensor_types =
+ gatherTensorTypes<DimensionedTensorType>(node)) {
+ AT_ASSERT(maybe_tensor_types->size() == 2);
+ auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
+ auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
+ size_t arg_for_type = 0;
+ if (c10::promoteTypes(first_scalar_type, second_scalar_type) != first_scalar_type) {
+ arg_for_type = 1;
+ }
+ return {broadcast(*maybe_tensor_types, arg_for_type)};
+ }
+ return {};
+ }};
+
+ static const register_formula_for fused_accum_binary_ops{
+ {
+ // Non-binary ops
+ "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
+ "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
+ },
+ [this](Node* node) -> type_vec_t {
+ if (auto maybe_tensor_types =
+ gatherTensorTypes<DimensionedTensorType>(node)) {
+ return {broadcast(*maybe_tensor_types, 0)};
+ }
+ return {};
+ }};
+
+ // NB: we always take the scalar type of the Tensor
+ static const register_formula_for broadcasting_tensor_scalar_ops{
+ {
+
// Tensor-Scalar operators
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
"aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
"aten::__ixor__(Tensor self, Scalar other) -> Tensor",
"aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
"aten::__irshift__(Tensor self, Scalar other) -> Tensor",
-
- // Ops with Tensor-Tensor overloads only
- "aten::atan2(Tensor self, Tensor other) -> Tensor",
-
- // Non-binary ops
- "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
- "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types =