Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / reshape.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "reshape_inst.h"
19 #include "primitive_type_base.h"
20 #include "memory_impl.h"
21 #include "error_handler.h"
22 #include "json_object.h"
23
24 namespace cldnn
25 {
26
27 primitive_type_id reshape_type_id()
28 {
29     static primitive_type_base<reshape> instance;
30     return &instance;
31 }
32
33 layout reshape_inst::calc_output_layout(reshape_node const& node)
34 {
35     assert((bool)node.get_primitive()->output_data_type == false
36            && "Output data type forcing is not supported for reshape_node!");
37     auto input_layout = node.input().get_non_padded_output_layout();
38     auto sizes = node.get_primitive()->output_shape.sizes();
39     auto input_sizes = input_layout.size.sizes();
40     size_t need_recalc = 0;
41     uint32_t shape_count = 1;
42
43     for (size_t i = 0; i < sizes.size(); i++) {
44         if (sizes[i] == -1) {
45             if (need_recalc) {
46                 CLDNN_ERROR_MESSAGE(node.id(), "Only one dimension of the new shape can be -1");
47             }
48             need_recalc = i;
49             continue;
50         }
51         if (sizes[i] == 0) {
52             sizes[i] = input_sizes[i];
53         }
54         shape_count *= sizes[i];
55     }
56     if (need_recalc)
57         sizes[need_recalc] = (int)input_layout.size.count() / shape_count;
58
59     input_layout.size = tensor(sizes);
60     return input_layout;
61 }
62
63 std::string reshape_inst::to_string(reshape_node const& node)
64 {
65     auto desc      = node.get_primitive();
66     auto node_info = node.desc_to_json();
67     auto& input    = node.input();
68
69     std::stringstream primitive_description;
70
71     json_composite reshape_info;
72     reshape_info.add("input id", input.id());
73     reshape_info.add("output shape", desc->output_shape);
74
75     node_info->add("reshape info", reshape_info);
76     node_info->dump(primitive_description);
77
78     return primitive_description.str();
79 }
80
81 reshape_inst::typed_primitive_inst(network_impl& network, reshape_node const& node)
82     : parent(network, node, false)
83 {
84     auto input_layout = node.input().get_output_layout();
85     auto output_layout = node.get_output_layout();
86     CLDNN_ERROR_DATA_TYPES_MISMATCH(node.id(), "Input layout data typr", input_layout.data_type, "output layout data type", output_layout.data_type, "");
87     CLDNN_ERROR_NOT_EQUAL(node.id(), "Output layout count", output_layout.count(), "input layout count", input_layout.count(), "Output layout of reshape primitive changes size of input buffer");
88
89     //if reshape operated in-place, postpone creation of the output until network run,
90     //then create new memory object as the reinterpreted output of the previous primitive
91     if (!node.is_in_place())
92         _output = allocate_output();
93     else
94         reuse_input();
95 }
96
97 void reshape_inst::on_execute()
98 {
99     if (!node.is_in_place())
100         return;
101
102     if (_output && _network.get_engine().is_the_same_buffer(output_memory(), input_memory()))
103         return;
104
105     reuse_input();
106 }
107
108 void reshape_inst::reuse_input()
109 {
110     build_deps(); //reshape need deps
111     _output = _network.get_engine().reinterpret_buffer(input_memory(), node.get_output_layout());
112 }
113
114 }