Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / lstm.cpp
index 7c80782..fae374a 100644 (file)
@@ -31,18 +31,21 @@ primitive_type_id lstm_type_id()
 
 layout lstm_inst::calc_output_layout(lstm_node const& node)
 {
+    assert((bool)node.get_primitive()->output_data_type == false
+           && "Output data type forcing is not supported for lstm_node!");
     auto input_layout = node.input().get_output_layout();
     auto hidden_layout = node.inital_hidden().get_output_layout();
 
-    // input     = [        1,  sequence,           batch,      input_size ]
-    // weights   = [        1, direction, 4 * hidden_size,      input_size ]
-    // recurrent = [        1, direction, 4 * hidden_size,     hidden_size ]
-    // biases    = [        1,         1,       direction, 4 * hidden_size ]
-    // hidden    = [        1, direction,           batch,     hidden_size ]
-    // cell      = [        1, direction,           batch,     hidden_size ]
-    // output    = [ sequence, direction,           batch,     hidden_size ]    
+    // input     = [ batch,  sequence,       direction,      input_size ]
+    // weights   = [     1, direction, 4 * hidden_size,      input_size ]
+    // recurrent = [     1, direction, 4 * hidden_size,     hidden_size ]
+    // biases    = [     1,         1,       direction, 4 * hidden_size ]
+    // hidden    = [ batch,         1,       direction,     hidden_size ]
+    // cell      = [ batch,         1,       direction,     hidden_size ]
+    // output    = [ batch,  sequence,       direction,     hidden_size ]
        auto result = layout(input_layout.data_type, format::bfyx,
-                  tensor(hidden_layout.size.feature[0], input_layout.size.feature[0], hidden_layout.size.spatial[0], hidden_layout.size.spatial[1]));
+                  tensor(hidden_layout.size.feature[0], input_layout.size.feature[0],
+                         hidden_layout.size.spatial[0], hidden_layout.size.spatial[1]));
     return result;
 }
 
@@ -75,10 +78,8 @@ std::string lstm_inst::to_string(lstm_node const& node)
 lstm_inst::typed_primitive_inst(network_impl& network, lstm_node const& node)
     :parent(network, node)
 {
-    // [ARIEL] TODO: That do we need to check here??
-    auto input_size = node.input().get_output_layout();
-    // auto output_size = output_memory().get_layout();
-    CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_size.format.value, "expected format", format::bfyx);
-    //CLDNN_ERROR_NOT_EQUAL(node.id(), "Input size", input_size.size.raw.size(), "output size", output_size.size.raw.size(), "");
+    auto input_layout = node.input().get_output_layout();
+    CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_layout.format.value, "expected format", format::bfyx);
 }
+
 }