Add embedding shape analysis (#64323)
authorElias Ellison <eellison@devfair044.h1.fair>
Wed, 15 Sep 2021 20:43:12 +0000 (13:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 20:45:48 +0000 (13:45 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64323

Test Plan: Imported from OSS

Reviewed By: driazati

Differential Revision: D30738145

Pulled By: eellison

fbshipit-source-id: be12408330d671bc65cf645aa2c20fafd954e6a9

test/jit/test_symbolic_shape_analysis.py
torch/csrc/jit/passes/symbolic_shape_analysis.cpp
torch/csrc/jit/runtime/symbolic_shape_registry.cpp

index 4ee1294..385e1f6 100644 (file)
@@ -4,6 +4,8 @@ import operator
 
 
 from torch.testing import FileCheck
+from torch.testing._internal.common_utils import make_tensor
+
 from textwrap import dedent
 
 if __name__ == '__main__':
@@ -217,3 +219,39 @@ class TestSymbolicShapeAnalysis(JitTestCase):
             execWrapper(funcs_str, globals(), scope)
             cu = torch.jit.CompilationUnit(funcs_str)
             self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False)
+
+    def test_shape_embedding_bag(self):
+        # TODO: merge into opinfos, having difficulties there
+        with torch.no_grad():
+            def make_arg(shape, low=None, high=None):
+                return make_tensor(shape, device='cpu', dtype=torch.int64,
+                                   low=low, high=high, requires_grad=False)
+
+            nn_inps = (
+                (make_arg((40,), 0, 9), torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0)),
+                (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)),
+                (make_arg(()), torch.nn.Embedding(0, 0, sparse=True)),
+                (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)),
+                (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)),
+                (make_arg((2,), 0, 1), torch.nn.Embedding.from_pretrained(torch.arange(6.).view(2, 3), max_norm=2.,
+                                                                          norm_type=.5, scale_grad_by_freq=False, sparse=True)),
+            )
+
+            for inp, module in nn_inps:
+                kwargs = {
+                    "weight": module.weight.detach(),
+                    "padding_idx": module.padding_idx,
+                    "max_norm": module.max_norm,
+                    "norm_type": module.norm_type,
+                    "scale_grad_by_freq": module.scale_grad_by_freq,
+                    "sparse": module.sparse,
+                }
+
+                out_size = torch.nn.functional.embedding(inp, **kwargs).size()
+
+                def foo(x):
+                    return torch.nn.functional.embedding(inp, **kwargs)
+
+                fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False)
+
+                self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False)
index d00a717..dcee6ce 100644 (file)
@@ -210,10 +210,14 @@ struct SymbolicShapeAnalyzer {
           continue;
         }
         // TODO: remove, all constant tensors should have typed sizes
-        if (toIValue(node_->input(i)) && !symbolic_shape_analysis_test_mode) {
-          replaceWithIValue(
-              graph_->inputs().at(i),
-              constant_as<at::Tensor>(node_->input(i))->sizes());
+        if (toIValue(node_->input(i))) {
+          auto size = constant_as<at::Tensor>(node_->input(i))->sizes();
+          if (!symbolic_shape_analysis_test_mode) {
+            replaceWithIValue(graph_->inputs().at(i), size);
+          } else {
+            node_symbolic_input_indices_.emplace_back(
+                i, c10::SymbolicShape(size));
+          }
           continue;
         }
 
index 3751eb6..0509abd 100644 (file)
@@ -55,19 +55,22 @@ const std::string shape_compute_functions =
             out.append(elem)
           return out
 
-        def unary_five_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any, inp4: Any):
+        def unary(self: List[int]):
+          return _copy(self)
+
+        def unary_one_unused_input(self: List[int], inp0: Any):
           return _copy(self)
 
         def unary_two_unused_inputs(self: List[int], inp0: Any, inp1: Any):
           return _copy(self)
 
-        def unary_one_unused_input(self: List[int], inp0: Any):
+        def unary_three_unused_input(self: List[int], inp0: Any, inp1: Any, inp2: Any):
           return _copy(self)
 
         def unary_four_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any):
           return _copy(self)
 
-        def unary(self: List[int]):
+        def unary_five_unused_inputs(self: List[int], inp0: Any, inp1: Any, inp2: Any, inp3: Any, inp4: Any):
           return _copy(self)
 
         def expand(self: List[int], sizes: List[int]):
@@ -267,6 +270,14 @@ const std::string shape_compute_functions =
               result_size.append(self[i])
           return result_size
 
+        def embedding(weight: List[int], indices: List[int], padding_idx:int = -1, scale_grad_by_freq:bool=False, sparse: bool=False):
+          assert len(weight) == 2
+          if len(indices) == 1:
+            return index_select(weight, 0, indices)
+          size = _copy(indices)
+          size.append(weight[1])
+          return size
+
         def max_int():
           return 9223372036854775807
 
@@ -534,7 +545,7 @@ const std::string shape_compute_functions =
           return linear(input, weight, bias)
     )"
 #endif
-;
+    ;
 
 // mapping function schema to shape compute graphs allows multiple functions to
 // share the same shape compute graph, which is memory efficient and also will
@@ -586,6 +597,9 @@ static const OperatorMap<std::string>& get_schema_to_function_graph() {
       {"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, "
        "float eps=1e-05, bool cudnn_enable=True) -> Tensor", "unary_five_unused_inputs"},
       {"aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "unary_two_unused_inputs"},
+      {"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", "unary_three_unused_input"},
+      {"aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "unary_three_unused_input"},
+      {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "embedding"},
       {"aten::mm(Tensor self, Tensor mat2) -> Tensor", "mm"},
       {"aten::dot(Tensor self, Tensor tensor) -> Tensor", "dot"},
       {"aten::mv(Tensor self, Tensor vec) -> Tensor", "mv"},