Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / reshape.cpp
index 1825375..0cc6870 100644 (file)
@@ -32,8 +32,31 @@ primitive_type_id reshape_type_id()
 
 layout reshape_inst::calc_output_layout(reshape_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for reshape_node!");
     auto input_layout = node.input().get_non_padded_output_layout();
-    input_layout.size = node.get_primitive()->output_shape;
+    auto sizes = node.get_primitive()->output_shape.sizes();
+    auto input_sizes = input_layout.size.sizes();
+    size_t need_recalc = 0;
+    uint32_t shape_count = 1;
+
+    for (size_t i = 0; i < sizes.size(); i++) {
+        if (sizes[i] == -1) {
+            if (need_recalc) {
+                CLDNN_ERROR_MESSAGE(node.id(), "Only one dimension of the new shape can be -1");
+            }
+            need_recalc = i;
+            continue;
+        }
+        if (sizes[i] == 0) {
+            sizes[i] = input_sizes[i];
+        }
+        shape_count *= sizes[i];
+    }
+    if (need_recalc)
+        sizes[need_recalc] = (int)input_layout.size.count() / shape_count;
+
+    input_layout.size = tensor(sizes);
     return input_layout;
 }
 
@@ -61,7 +84,7 @@ reshape_inst::typed_primitive_inst(network_impl& network, reshape_node const& no
     auto input_layout = node.input().get_output_layout();
     auto output_layout = node.get_output_layout();
     CLDNN_ERROR_DATA_TYPES_MISMATCH(node.id(), "Input layout data typr", input_layout.data_type, "output layout data type", output_layout.data_type, "");
-    CLDNN_ERROR_NOT_EQUAL(node.id(), "Output layout count", output_layout.count(), "input layout count", input_layout.count(), "Output layout of reshape pirmitive changes size of input buffer");
+    CLDNN_ERROR_NOT_EQUAL(node.id(), "Output layout count", output_layout.count(), "input layout count", input_layout.count(), "Output layout of reshape primitive changes size of input buffer");
 
     //if reshape operated in-place, postpone creation of the output until network run,
     //then create new memory object as the reinterpreted output of the previous primitive
@@ -88,4 +111,4 @@ void reshape_inst::reuse_input()
     _output = _network.get_engine().reinterpret_buffer(input_memory(), node.get_output_layout());
 }
 
-}
\ No newline at end of file
+}