2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "reshape_inst.h"
19 #include "primitive_type_base.h"
20 #include "memory_impl.h"
21 #include "error_handler.h"
22 #include "json_object.h"
27 primitive_type_id reshape_type_id()
29 static primitive_type_base<reshape> instance;
33 layout reshape_inst::calc_output_layout(reshape_node const& node)
35 assert((bool)node.get_primitive()->output_data_type == false
36 && "Output data type forcing is not supported for reshape_node!");
37 auto input_layout = node.input().get_non_padded_output_layout();
38 auto sizes = node.get_primitive()->output_shape.sizes();
39 auto input_sizes = input_layout.size.sizes();
40 size_t need_recalc = 0;
41 uint32_t shape_count = 1;
43 for (size_t i = 0; i < sizes.size(); i++) {
46 CLDNN_ERROR_MESSAGE(node.id(), "Only one dimension of the new shape can be -1");
52 sizes[i] = input_sizes[i];
54 shape_count *= sizes[i];
57 sizes[need_recalc] = (int)input_layout.size.count() / shape_count;
59 input_layout.size = tensor(sizes);
63 std::string reshape_inst::to_string(reshape_node const& node)
65 auto desc = node.get_primitive();
66 auto node_info = node.desc_to_json();
67 auto& input = node.input();
69 std::stringstream primitive_description;
71 json_composite reshape_info;
72 reshape_info.add("input id", input.id());
73 reshape_info.add("output shape", desc->output_shape);
75 node_info->add("reshape info", reshape_info);
76 node_info->dump(primitive_description);
78 return primitive_description.str();
81 reshape_inst::typed_primitive_inst(network_impl& network, reshape_node const& node)
82 : parent(network, node, false)
84 auto input_layout = node.input().get_output_layout();
85 auto output_layout = node.get_output_layout();
86 CLDNN_ERROR_DATA_TYPES_MISMATCH(node.id(), "Input layout data typr", input_layout.data_type, "output layout data type", output_layout.data_type, "");
87 CLDNN_ERROR_NOT_EQUAL(node.id(), "Output layout count", output_layout.count(), "input layout count", input_layout.count(), "Output layout of reshape primitive changes size of input buffer");
89 //if reshape operated in-place, postpone creation of the output until network run,
90 //then create new memory object as the reinterpreted output of the previous primitive
91 if (!node.is_in_place())
92 _output = allocate_output();
97 void reshape_inst::on_execute()
99 if (!node.is_in_place())
102 if (_output && _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
108 void reshape_inst::reuse_input()
110 build_deps(); //reshape need deps
111 _output = _network.get_engine().reinterpret_buffer(input_memory(), node.get_output_layout());