return args;
}
- virtual bool validate(typed_primitive_inst<lstm_gemm>& instance) const override
- {
- bool res = parent::validate(instance);
-
- return res;
- }
public:
static primitive_impl* create(const lstm_gemm_node& arg)
const auto& hidden_layout = arg.hidden().get_output_layout();
lstm_gemm_params.SetHidden(convert_data_tensor(hidden_layout));
+ // TODO: make a generic function to get the direction
+ if (hidden_layout.size.spatial[1] > 1) {
+ lstm_gemm_params.hidden_direction = arg.direction();
+ }
}
lstm_gemm_params.direction = arg.direction();
+
+ // Update the direction of the input for the gemm kernel
+ const auto& input_layout = arg.input().get_output_layout();
+ size_t input_directions = input_layout.size.spatial[1];
+
+ if (input_directions > 1) // For bidirection input, input direction can be 1 or 0
+ {
+ lstm_gemm_params.input_direction = arg.direction();
+ }
+ else // For unidirectional input
+ {
+ lstm_gemm_params.input_direction = 0;
+ }
auto lstm_gemm_optional_params = get_default_optional_params<kernel_selector::lstm_gemm_optional_params>(arg.get_program());
implementation_map<lstm_gemm>::add({
{ std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw },
{ std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw },
+ { std::make_tuple(engine_types::ocl, data_types::f32, format::fyxb), val_fw },
+ { std::make_tuple(engine_types::ocl, data_types::f16, format::fyxb), val_fw },
});
}
~attach() {}