layout lstm_gemm_inst::calc_output_layout(lstm_gemm_node const& node)
{
+ assert((bool)node.get_primitive()->output_data_type == false
+ && "Output data type forcing is not supported for lstm_gemm_node!");
auto desc = node.get_primitive();
auto input_layout = node.input().get_output_layout();
auto weights_layout = node.weights().get_output_layout();
// biases{bfyx} = [b: 1, f:1 , x: direction, y: 4 * hidden_size ]
// hidden{bfyx} = [b: batch, f: direction, x: 1 , y: hidden_size ] optional
// tempGEMM{bfyx} = [b: batch, f: direction, x: 4*hidden_size, y: 1] output
-
- auto result = layout(input_layout.data_type, format::bfyx, tensor(input_layout.size.batch[0], weights_layout.size.feature[0], weights_layout.size.spatial[1], 1));
+ auto result = layout(input_layout.data_type, input_layout.format, tensor(input_layout.size.batch[0], weights_layout.size.feature[0], weights_layout.size.spatial[1], 1));
return result;
}
lstm_gemm_inst::typed_primitive_inst(network_impl& network, lstm_gemm_node const& node)
:parent(network, node)
{
- auto input_size = node.input().get_output_layout();
- CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_size.format.value, "expected format", format::bfyx);
+ auto input_layout = node.input().get_output_layout();
+ CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_layout.format.value, "expected format", format::bfyx, format::fyxb);
}
}