Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / lstm_gemm.cpp
index 31d36fa..e39a271 100644 (file)
@@ -31,6 +31,8 @@ primitive_type_id lstm_gemm_type_id()
 
 layout lstm_gemm_inst::calc_output_layout(lstm_gemm_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for lstm_gemm_node!");
     auto desc = node.get_primitive();
     auto input_layout = node.input().get_output_layout();
     auto weights_layout = node.weights().get_output_layout();
@@ -41,8 +43,7 @@ layout lstm_gemm_inst::calc_output_layout(lstm_gemm_node const& node)
     //   biases{bfyx}    = [b: 1,     f:1 ,          x: direction,       y:  4 * hidden_size ]
     //   hidden{bfyx}    = [b: batch, f:  direction, x: 1 ,              y: hidden_size ] optional
     //   tempGEMM{bfyx}  = [b: batch, f: direction,  x: 4*hidden_size,   y: 1] output
-
-    auto result = layout(input_layout.data_type, format::bfyx, tensor(input_layout.size.batch[0], weights_layout.size.feature[0], weights_layout.size.spatial[1], 1));
+    auto result = layout(input_layout.data_type, input_layout.format, tensor(input_layout.size.batch[0], weights_layout.size.feature[0], weights_layout.size.spatial[1], 1));
     return result;
 }
 
@@ -71,7 +72,7 @@ std::string lstm_gemm_inst::to_string(lstm_gemm_node const& node)
 lstm_gemm_inst::typed_primitive_inst(network_impl& network, lstm_gemm_node const& node)
     :parent(network, node)
 {
-    auto input_size = node.input().get_output_layout();
-    CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_size.format.value, "expected format", format::bfyx);
+    auto input_layout = node.input().get_output_layout();
+    CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_layout.format.value, "expected format", format::bfyx, format::fyxb);
 }
 }