From 59988f81bda9b0fd3db5cf61992a3b2ec8f4f147 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Wed, 15 Sep 2021 13:43:12 -0700 Subject: [PATCH] Add embedding shape analysis (#64323) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64323 Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D30738145 Pulled By: eellison fbshipit-source-id: be12408330d671bc65cf645aa2c20fafd954e6a9 --- test/jit/test_symbolic_shape_analysis.py | 38 ++++++++++++++++++++++ torch/csrc/jit/passes/symbolic_shape_analysis.cpp | 12 ++++--- torch/csrc/jit/runtime/symbolic_shape_registry.cpp | 22 ++++++++++--- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 4ee1294..385e1f6 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -4,6 +4,8 @@ import operator from torch.testing import FileCheck +from torch.testing._internal.common_utils import make_tensor + from textwrap import dedent if __name__ == '__main__': @@ -217,3 +219,39 @@ class TestSymbolicShapeAnalysis(JitTestCase): execWrapper(funcs_str, globals(), scope) cu = torch.jit.CompilationUnit(funcs_str) self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False) + + def test_shape_embedding_bag(self): + # TODO: merge into opinfos, having difficulties there + with torch.no_grad(): + def make_arg(shape, low=None, high=None): + return make_tensor(shape, device='cpu', dtype=torch.int64, + low=low, high=high, requires_grad=False) + + nn_inps = ( + (make_arg((40,), 0, 9), torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0)), + (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)), + (make_arg(()), torch.nn.Embedding(0, 0, sparse=True)), + (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)), + (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)), + (make_arg((2,), 0, 1), torch.nn.Embedding.from_pretrained(torch.arange(6.).view(2, 3), max_norm=2., + norm_type=.5, scale_grad_by_freq=False, sparse=True)), + ) + + for inp, module in nn_inps: + kwargs = { + "weight": module.weight.detach(), + "padding_idx": module.padding_idx, + "max_norm": module.max_norm, + "norm_type": module.norm_type, + "scale_grad_by_freq": module.scale_grad_by_freq, + "sparse": module.sparse, + } + + out_size = torch.nn.functional.embedding(inp, **kwargs).size() + + def foo(x): + return torch.nn.functional.embedding(inp, **kwargs) + + fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False) + + self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False) diff --git a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp index d00a717..dcee6ce 100644 --- a/torch/csrc/jit/passes/symbolic_shape_analysis.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_analysis.cpp @@ -210,10 +210,14 @@ struct SymbolicShapeAnalyzer { continue; } // TODO: remove, all constant tensors should have typed sizes - if (toIValue(node_->input(i)) && !symbolic_shape_analysis_test_mode) { - replaceWithIValue( - graph_->inputs().at(i), - constant_as(node_->input(i))->sizes()); + if (toIValue(node_->input(i))) { + auto size = constant_as(node_->input(i))->sizes(); + if (!symbolic_shape_analysis_test_mode) { + replaceWithIValue(graph_->inputs().at(i), size); + } else { + node_symbolic_input_indices_.emplace_back( + i, c10::SymbolicShape(size)); + } continue; } diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 3751eb6..0509abd 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -55,19 +55,22 @@ const std::string shape_compute_functions = out.append(elem) return out - def unary_five_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any, inp4: Any): + def unary(self: List[int]): + return _copy(self) + + def unary_one_unused_input(self: List[int], inp0: Any): return _copy(self) def unary_two_unused_inputs(self: List[int], inp0: Any, inp1: Any): return _copy(self) - def unary_one_unused_input(self: List[int], inp0: Any): + def unary_three_unused_input(self: List[int], inp0: Any, inp1: Any, inp2: Any): return _copy(self) def unary_four_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any): return _copy(self) - def unary(self: List[int]): + def unary_five_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any, inp4: Any): return _copy(self) def expand(self: List[int], sizes: List[int]): @@ -267,6 +270,14 @@ const std::string shape_compute_functions = result_size.append(self[i]) return result_size + def embedding(weight: List[int], indices: List[int], padding_idx:int = -1, scale_grad_by_freq:bool=False, sparse: bool=False): + assert len(weight) == 2 + if len(indices) == 1: + return index_select(weight, 0, indices) + size = _copy(indices) + size.append(weight[1]) + return size + def max_int(): return 9223372036854775807 @@ -534,7 +545,7 @@ const std::string shape_compute_functions = return linear(input, weight, bias) )" #endif -; + ; // mapping function schema to shape compute graphs allows multiple functions to // share the same shape compute graph, which is memory efficient and also will @@ -586,6 +597,9 @@ static const OperatorMap& get_schema_to_function_graph() { {"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, " "float eps=1e-05, bool cudnn_enable=True) -> Tensor", "unary_five_unused_inputs"}, {"aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "unary_two_unused_inputs"}, + {"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", "unary_three_unused_input"}, + {"aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "unary_three_unused_input"}, + {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "embedding"}, {"aten::mm(Tensor self, Tensor mat2) -> Tensor", "mm"}, {"aten::dot(Tensor self, Tensor tensor) -> Tensor", "dot"}, {"aten::mv(Tensor self, Tensor vec) -> Tensor", "mv"}, -- 2.7.4