inputs[1].setType(inputs[1].type().with_sizes(size_2))
torch._C._jit_pass_propagate_shapes_on_graph(t.graph)
self.assertEqual(next(t.graph.outputs()).type().symbolic_sizes(), [4, 4, 8])
+
+ def test_size_and_sizes(self):
+ @torch.jit.script
+ def foo(x, y):
+ return x.view(y.size(0), 8, y.size(-1))
+
+ @torch.jit.script
+ def foo2(x, y):
+ return x.view(y.size())
+
+ for graph in [foo.graph, foo2.graph]:
+ inputs = list(graph.inputs())
+ sym1 = torch._C._new_symbolic_shape_symbol()
+
+ inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
+ torch._C._jit_pass_propagate_shapes_on_graph(graph)
+ self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1])
}
// we can't optimize a tensor without fixed rank
if (symbolic_shapes.rank()) {
- node_input_tensor_indices.push_back(i);
+ node_symbolic_input_indices.emplace_back(i, symbolic_shapes);
}
} else if (
type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<TensorType>()) {
TORCH_INTERNAL_ASSERT(false); // not handled yet
- } else {
- if (auto ival = toIValue(node_->input(i))) {
- replaceWithIValue(graph_->inputs().at(i), *ival);
+ } else if (auto ival = toIValue(node_->input(i))) {
+ replaceWithIValue(graph_->inputs().at(i), *ival);
+ } else if (
+ type->cast<ListType>() &&
+ type->cast<ListType>()->getElementType()->cast<IntType>()) {
+ if (node_->input(i)->node()->kind() == prim::ListConstruct &&
+ node_->input(i)->uses().size() == 1) {
+ // it is a very common in graphs to see patterns like:
+ // z = x.view(y.size())
+ // or:
+ // z = x.view(1, 10, y.size(0), y.size(1))
+ // We want to propagate symbolic dimensions and concrete sizes
+ // from y to z. To do this we try to associate symbolic dimensions
+ // or concrete sizes with the integer list inputs that have a
+ // constructor taken from constants or y.size() or y.size(0)
+ auto list_construct = node_->input(i)->node();
+ std::vector<c10::optional<int64_t>> shape;
+ for (Value* v : list_construct->inputs()) {
+ if (auto constant = constant_as<int64_t>(v)) {
+ shape.emplace_back(*constant);
+ } else if (
+ v->node()->kind() == aten::size &&
+ v->node()->input(1)->node()->kind() == prim::Constant) {
+ auto tt = v->node()->input(0)->type()->expect<TensorType>();
+ auto ss = tt->symbolic_sizes();
+ if (!ss.rank()) {
+ shape.emplace_back(c10::nullopt);
+ continue;
+ }
+ auto norm_index = normIndex(
+ *constant_as<int64_t>(v->node()->input(1)), *ss.rank());
+ if (!norm_index) {
+ shape.emplace_back(c10::nullopt);
+ continue;
+ }
+ shape.emplace_back(ss[*norm_index].value());
+ } else {
+ shape.emplace_back(c10::nullopt);
+ }
+ }
+ node_symbolic_input_indices.emplace_back(
+ i, c10::SymbolicShape(shape));
+ } else if (
+ node_->input(i)->node()->kind() == aten::size &&
+ node_->input(i)->uses().size() == 1) {
+ auto ten_inp = node_->input(i)->node()->input();
+ auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes();
+ node_symbolic_input_indices.emplace_back(i, ss);
}
}
}
std::unordered_map<int64_t, std::vector<Value*>> symbolic_shape_map;
- for (auto tensor_index : node_input_tensor_indices) {
- auto tensor_value = node_->inputs().at(tensor_index);
- auto tensor_shape =
- tensor_value->type()->expect<TensorType>()->symbolic_sizes();
+ for (const auto& index_symbolic_shape : node_symbolic_input_indices) {
+ auto index = index_symbolic_shape.first;
+ auto tensor_shape = index_symbolic_shape.second;
TORCH_INTERNAL_ASSERT(tensor_shape.rank().has_value());
- for (const auto& use : graph_->inputs().at(tensor_index)->uses()) {
+ for (const auto& use : graph_->inputs().at(index)->uses()) {
// TODO: either decompose composite ops like slice or add handling here
switch (use.user->kind()) {
case aten::len: {
// and non-fixed dimensions we cannot reason about at all
// TODO: might be cleaner to store as a pair of index -> symbolic shape
// but there were weird lifetime issues
- std::vector<int64_t> node_input_tensor_indices;
+ std::vector<std::pair<int64_t, c10::SymbolicShape>>
+ node_symbolic_input_indices;
std::shared_ptr<Graph> graph_;
Node* node_;
};