from torch.testing import FileCheck
+from torch.testing._internal.common_utils import make_tensor
+
from textwrap import dedent
if __name__ == '__main__':
execWrapper(funcs_str, globals(), scope)
cu = torch.jit.CompilationUnit(funcs_str)
self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False)
+
+ def test_shape_embedding_bag(self):
+ # TODO: merge into opinfos, having difficulties there
+ with torch.no_grad():
+ def make_arg(shape, low=None, high=None):
+ return make_tensor(shape, device='cpu', dtype=torch.int64,
+ low=low, high=high, requires_grad=False)
+
+ nn_inps = (
+ (make_arg((40,), 0, 9), torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0)),
+ (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)),
+ (make_arg(()), torch.nn.Embedding(0, 0, sparse=True)),
+ (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)),
+ (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)),
+ (make_arg((2,), 0, 1), torch.nn.Embedding.from_pretrained(torch.arange(6.).view(2, 3), max_norm=2.,
+ norm_type=.5, scale_grad_by_freq=False, sparse=True)),
+ )
+
+ for inp, module in nn_inps:
+ kwargs = {
+ "weight": module.weight.detach(),
+ "padding_idx": module.padding_idx,
+ "max_norm": module.max_norm,
+ "norm_type": module.norm_type,
+ "scale_grad_by_freq": module.scale_grad_by_freq,
+ "sparse": module.sparse,
+ }
+
+ out_size = torch.nn.functional.embedding(inp, **kwargs).size()
+
+ def foo(x):
+ return torch.nn.functional.embedding(inp, **kwargs)
+
+ fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False)
+
+ self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False)
continue;
}
// TODO: remove, all constant tensors should have typed sizes
- if (toIValue(node_->input(i)) && !symbolic_shape_analysis_test_mode) {
- replaceWithIValue(
- graph_->inputs().at(i),
- constant_as<at::Tensor>(node_->input(i))->sizes());
+ if (toIValue(node_->input(i))) {
+ auto size = constant_as<at::Tensor>(node_->input(i))->sizes();
+ if (!symbolic_shape_analysis_test_mode) {
+ replaceWithIValue(graph_->inputs().at(i), size);
+ } else {
+ node_symbolic_input_indices_.emplace_back(
+ i, c10::SymbolicShape(size));
+ }
continue;
}
out.append(elem)
return out
- def unary_five_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any, inp4: Any):
+ def unary(self: List[int]):
+ return _copy(self)
+
+ def unary_one_unused_input(self: List[int], inp0: Any):
return _copy(self)
def unary_two_unused_inputs(self: List[int], inp0: Any, inp1: Any):
return _copy(self)
- def unary_one_unused_input(self: List[int], inp0: Any):
+ def unary_three_unused_input(self: List[int], inp0: Any, inp1: Any, inp2: Any):
return _copy(self)
def unary_four_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any):
return _copy(self)
- def unary(self: List[int]):
+ def unary_five_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any, inp4: Any):
return _copy(self)
def expand(self: List[int], sizes: List[int]):
result_size.append(self[i])
return result_size
+ def embedding(weight: List[int], indices: List[int], padding_idx:int = -1, scale_grad_by_freq:bool=False, sparse: bool=False):
+ assert len(weight) == 2
+ if len(indices) == 1:
+ return index_select(weight, 0, indices)
+ size = _copy(indices)
+ size.append(weight[1])
+ return size
+
def max_int():
return 9223372036854775807
return linear(input, weight, bias)
)"
#endif
-;
+ ;
// mapping function schema to shape compute graphs allows multiple functions to
// share the same shape compute graph, which is memory efficient and also will
{"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
"float eps=1e-05, bool cudnn_enable=True) -> Tensor", "unary_five_unused_inputs"},
{"aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "unary_two_unused_inputs"},
+ {"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", "unary_three_unused_input"},
+ {"aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "unary_three_unused_input"},
+ {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "embedding"},
{"aten::mm(Tensor self, Tensor mat2) -> Tensor", "mm"},
{"aten::dot(Tensor self, Tensor tensor) -> Tensor", "dot"},
{"aten::mv(Tensor self, Tensor vec) -> Tensor", "mv"},