From: eellison Date: Wed, 27 Mar 2019 21:29:45 +0000 (-0700) Subject: Rename isTensor api -> isCompleteTensor (#18437) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~593 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e4f1681c82fe6a8c11400bec80b520aba9ceb495;p=platform%2Fupstream%2Fpytorch.git Rename isTensor api -> isCompleteTensor (#18437) Summary: Is Tensor has been brought up as misleading a couple times, rename it isCompleteTensor for clarity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18437 Differential Revision: D14605223 Pulled By: eellison fbshipit-source-id: 189f67f12cbecd76516a04e67d8145c260c79036 --- diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index afcadb6..3190203 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -167,7 +167,7 @@ struct Value { bool requires_grad() const { return type()->requires_grad(); } - bool isTensor() const { + bool isCompleteTensor() const { return type()->kind() == TypeKind::CompleteTensorType; } TORCH_API bool mustBeNone() const; diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 83f062a..ecb687d 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include @@ -77,7 +77,9 @@ const std::vector& getBroadcastPositions(Node* node) { // Determine whether `from` can broadcast to `to`, and if so at which // position. `from` must be a suffix of `to`, except that any // occurences of 1 in `from` are treated as wildcards. -c10::optional fusibleExpandTo(at::IntArrayRef from, at::IntArrayRef to) { +c10::optional fusibleExpandTo( + at::IntArrayRef from, + at::IntArrayRef to) { if (from.size() > to.size()) { return c10::nullopt; } @@ -120,7 +122,8 @@ void fuseBroadcast(Block* b) { // always have this information (because expands are only ever traced, // not generated from symbolic), but if for some reason we don't // have it, we need to skip. - if (!unexpanded_input->isTensor() || !n->output()->isTensor()) + if (!unexpanded_input->isCompleteTensor() || + !n->output()->isCompleteTensor()) continue; // Not all broadcasts are supported by ONNX broadcast. @@ -372,7 +375,9 @@ void fixDefaultRNNState(Graph* graph, Node* n, int input_index) { Node* gather_indices = graph->create(onnx::Constant, 1); gather_indices->insertBefore(n); - gather_indices->t_(attr::value, autograd::make_variable(at::scalar_to_tensor(at::Scalar(1)))); + gather_indices->t_( + attr::value, + autograd::make_variable(at::scalar_to_tensor(at::Scalar(1)))); Node* batch_size = graph->create(onnx::Gather, 1); batch_size->insertBefore(n); @@ -563,16 +568,18 @@ static void eraseListConstruct(Block* block) { } } -static void fuseSplitListUnpack(Block *b) { - for(auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { - for (auto *child_block : it->blocks()) { +static void fuseSplitListUnpack(Block* b) { + for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { + for (auto* child_block : it->blocks()) { fuseSplitListUnpack(child_block); } - if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == onnx::Split) { + if (it->kind() == prim::ListUnpack && + it->input()->node()->kind() == onnx::Split) { auto origSplitNode = it->input()->node(); - Node * splitNode = b->owningGraph()->create(onnx::Split, it->outputs().size()); - for (size_t i=0; ioutputs().size(); ++i) { + Node* splitNode = + b->owningGraph()->create(onnx::Split, it->outputs().size()); + for (size_t i = 0; i < splitNode->outputs().size(); ++i) { splitNode->outputs()[i]->copyMetadata(it->outputs()[i]); } splitNode->copyAttributes(*origSplitNode); @@ -588,7 +595,7 @@ static void fuseSplitListUnpack(Block *b) { } void removeMaxPoolUnusedOutput(Block* b) { - for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { + for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { auto n = *it; for (auto* child_block : n->blocks()) { removeMaxPoolUnusedOutput(child_block); diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 8131d6b..95c06fe 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -384,7 +384,7 @@ void initPythonIRBindings(PyObject* module_) { return node; }) .VS(copyMetadata) - .VS(isTensor) + .VS(isCompleteTensor) .VS(requires_grad) .def("toIValue", [](Value& n) { return toIValue(&n); }) .def("type", [](Value& v) { return v.type(); }); @@ -422,16 +422,8 @@ void initPythonIRBindings(PyObject* module_) { [](Block& b) { return py::make_iterator(b.outputs().begin(), b.outputs().end()); }) - .def( - "returnNode", - [](Block& b) { - return b.return_node(); - }) - .def( - "paramNode", - [](Block& b) { - return b.param_node(); - }); + .def("returnNode", [](Block& b) { return b.return_node(); }) + .def("paramNode", [](Block& b) { return b.param_node(); }); #define NS(name) def(#name, &Node ::name) py::class_>(m, "Node") diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 157b895..fbb8d97 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -475,7 +475,7 @@ def view(g, self, size): if _is_value(size): shape = size else: - if self.isTensor(): + if self.isCompleteTensor(): self_sizes = self.type().sizes() if self_sizes and len(size) == 2 and self_sizes[0] == size[0]: return g.op("Flatten", self, axis_i=1) @@ -1061,10 +1061,10 @@ def index_put(g, self, indices_list_value, values, accumulate): def type_as(g, self, other): - if self.isTensor() and other.isTensor() and self.type().scalarType() == other.type().scalarType(): + if self.isCompleteTensor() and other.isCompleteTensor() and self.type().scalarType() == other.type().scalarType(): return self - if other.isTensor(): + if other.isCompleteTensor(): other_type_name = other.type().scalarType() return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name]) else: @@ -1390,7 +1390,7 @@ def repeat(g, self, repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) const_repeats = _maybe_get_const(repeats, 'is') - if self.isTensor() and not _is_value(const_repeats): + if self.isCompleteTensor() and not _is_value(const_repeats): sizes = self.type().sizes() diff_dims = len(const_repeats) - len(sizes) if diff_dims > 0: