Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / lstm_elt.cpp
index 718939d..d809f86 100644 (file)
@@ -30,6 +30,8 @@ primitive_type_id lstm_elt_type_id()
 
 layout lstm_elt_inst::calc_output_layout(lstm_elt_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for lstm_elt_node!");
     auto desc = node.get_primitive();
     auto input_layout = node.input().get_output_layout();
 
@@ -38,7 +40,7 @@ layout lstm_elt_inst::calc_output_layout(lstm_elt_node const& node)
     // output{bfyx}   = [b: batch, f: 2,         x: direction, y: hidden_size ] output
     // The output of the lstm_elt node is the concatenation of the intermediate [hidden, cell] tensors.
     // A crop/split node is needed to extract each individual tensors
-    auto result = layout(input_layout.data_type, format::bfyx,
+    auto result = layout(input_layout.data_type, input_layout.format,
                     tensor(input_layout.size.batch[0], 2, input_layout.size.spatial[0] / 4, input_layout.size.feature[0]));
     return result;
 }
@@ -63,6 +65,6 @@ lstm_elt_inst::typed_primitive_inst(network_impl& network, lstm_elt_node const&
     :parent(network, node)
 {
     auto input_size = node.input().get_output_layout();
-    CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_size.format.value, "expected format", format::bfyx);
+    CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_size.format.value, "expected format", format::bfyx, format::fyxb);
 }
 }