Rename isTensor api -> isCompleteTensor (#18437)
authoreellison <elias_ellison@brown.edu>
Wed, 27 Mar 2019 21:29:45 +0000 (14:29 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Mar 2019 21:46:06 +0000 (14:46 -0700)
Summary:
Is Tensor has been brought up as misleading a couple times, rename it isCompleteTensor for clarity.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18437

Differential Revision: D14605223

Pulled By: eellison

fbshipit-source-id: 189f67f12cbecd76516a04e67d8145c260c79036

torch/csrc/jit/ir.h
torch/csrc/jit/passes/onnx/peephole.cpp
torch/csrc/jit/python_ir.cpp
torch/onnx/symbolic.py

index afcadb6..3190203 100644 (file)
@@ -167,7 +167,7 @@ struct Value {
   bool requires_grad() const {
     return type()->requires_grad();
   }
-  bool isTensor() const {
+  bool isCompleteTensor() const {
     return type()->kind() == TypeKind::CompleteTensorType;
   }
   TORCH_API bool mustBeNone() const;
index 83f062a..ecb687d 100644 (file)
@@ -1,5 +1,5 @@
-#include <c10/util/Exception.h>
 #include <torch/csrc/jit/passes/onnx/peephole.h>
+#include <c10/util/Exception.h>
 
 #include <c10/util/Optional.h>
 
@@ -77,7 +77,9 @@ const std::vector<size_t>& getBroadcastPositions(Node* node) {
 // Determine whether `from` can broadcast to `to`, and if so at which
 // position. `from` must be a suffix of `to`, except that any
 // occurences of 1 in `from` are treated as wildcards.
-c10::optional<size_t> fusibleExpandTo(at::IntArrayRef from, at::IntArrayRef to) {
+c10::optional<size_t> fusibleExpandTo(
+    at::IntArrayRef from,
+    at::IntArrayRef to) {
   if (from.size() > to.size()) {
     return c10::nullopt;
   }
@@ -120,7 +122,8 @@ void fuseBroadcast(Block* b) {
       // always have this information (because expands are only ever traced,
       // not generated from symbolic), but if for some reason we don't
       // have it, we need to skip.
-      if (!unexpanded_input->isTensor() || !n->output()->isTensor())
+      if (!unexpanded_input->isCompleteTensor() ||
+          !n->output()->isCompleteTensor())
         continue;
 
       // Not all broadcasts are supported by ONNX broadcast.
@@ -372,7 +375,9 @@ void fixDefaultRNNState(Graph* graph, Node* n, int input_index) {
 
   Node* gather_indices = graph->create(onnx::Constant, 1);
   gather_indices->insertBefore(n);
-  gather_indices->t_(attr::value, autograd::make_variable(at::scalar_to_tensor(at::Scalar(1))));
+  gather_indices->t_(
+      attr::value,
+      autograd::make_variable(at::scalar_to_tensor(at::Scalar(1))));
 
   Node* batch_size = graph->create(onnx::Gather, 1);
   batch_size->insertBefore(n);
@@ -563,16 +568,18 @@ static void eraseListConstruct(Block* block) {
   }
 }
 
-static void fuseSplitListUnpack(Block *b) {
-  for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
-    for (auto *child_block : it->blocks()) {
+static void fuseSplitListUnpack(Blockb) {
+  for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
+    for (autochild_block : it->blocks()) {
       fuseSplitListUnpack(child_block);
     }
-    if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == onnx::Split) {
+    if (it->kind() == prim::ListUnpack &&
+        it->input()->node()->kind() == onnx::Split) {
       auto origSplitNode = it->input()->node();
 
-      Node * splitNode = b->owningGraph()->create(onnx::Split, it->outputs().size());
-      for (size_t i=0; i<splitNode->outputs().size(); ++i) {
+      Node* splitNode =
+          b->owningGraph()->create(onnx::Split, it->outputs().size());
+      for (size_t i = 0; i < splitNode->outputs().size(); ++i) {
         splitNode->outputs()[i]->copyMetadata(it->outputs()[i]);
       }
       splitNode->copyAttributes(*origSplitNode);
@@ -588,7 +595,7 @@ static void fuseSplitListUnpack(Block *b) {
 }
 
 void removeMaxPoolUnusedOutput(Block* b) {
-    for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
+  for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
     auto n = *it;
     for (auto* child_block : n->blocks()) {
       removeMaxPoolUnusedOutput(child_block);
index 8131d6b..95c06fe 100644 (file)
@@ -384,7 +384,7 @@ void initPythonIRBindings(PyObject* module_) {
             return node;
           })
       .VS(copyMetadata)
-      .VS(isTensor)
+      .VS(isCompleteTensor)
       .VS(requires_grad)
       .def("toIValue", [](Value& n) { return toIValue(&n); })
       .def("type", [](Value& v) { return v.type(); });
@@ -422,16 +422,8 @@ void initPythonIRBindings(PyObject* module_) {
           [](Block& b) {
             return py::make_iterator(b.outputs().begin(), b.outputs().end());
           })
-      .def(
-          "returnNode",
-          [](Block& b) {
-            return b.return_node();
-          })
-      .def(
-          "paramNode",
-          [](Block& b) {
-            return b.param_node();
-          });
+      .def("returnNode", [](Block& b) { return b.return_node(); })
+      .def("paramNode", [](Block& b) { return b.param_node(); });
 
 #define NS(name) def(#name, &Node ::name)
   py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
index 157b895..fbb8d97 100644 (file)
@@ -475,7 +475,7 @@ def view(g, self, size):
     if _is_value(size):
         shape = size
     else:
-        if self.isTensor():
+        if self.isCompleteTensor():
             self_sizes = self.type().sizes()
             if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
                 return g.op("Flatten", self, axis_i=1)
@@ -1061,10 +1061,10 @@ def index_put(g, self, indices_list_value, values, accumulate):
 
 
 def type_as(g, self, other):
-    if self.isTensor() and other.isTensor() and self.type().scalarType() == other.type().scalarType():
+    if self.isCompleteTensor() and other.isCompleteTensor() and self.type().scalarType() == other.type().scalarType():
         return self
 
-    if other.isTensor():
+    if other.isCompleteTensor():
         other_type_name = other.type().scalarType()
         return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
     else:
@@ -1390,7 +1390,7 @@ def repeat(g, self, repeats):
         repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
     const_repeats = _maybe_get_const(repeats, 'is')
 
-    if self.isTensor() and not _is_value(const_repeats):
+    if self.isCompleteTensor() and not _is_value(const_repeats):
         sizes = self.type().sizes()
         diff_dims = len(const_repeats) - len(sizes)
         if diff_dims > 0: