return args;
}
- virtual bool validate(typed_primitive_inst<lstm_elt>& instance) const override
- {
- bool res = parent::validate(instance);
-
- return res;
- }
public:
static primitive_impl* create(const lstm_elt_node& arg)
{
const auto& cell_layout = arg.cell().get_output_layout();
lstm_elt_params.SetCell(convert_data_tensor(cell_layout));
+ // TODO: make a generic function to get the direction
+ if (cell_layout.size.spatial[1] > 1) {
+ lstm_elt_params.cell_direction = arg.direction();
+ }
}
lstm_elt_params.SetOffsetOrder(arg.offset_order());
lstm_elt_params.clip = arg.clip();
lstm_elt_params.input_forget = arg.input_forget();
+ lstm_elt_params.direction = arg.direction();
auto& kernel_selector = kernel_selector::lstm_elt_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(lstm_elt_params, lstm_elt_optional_params);
implementation_map<lstm_elt>::add({
{ std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw },
{ std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw },
+ { std::make_tuple(engine_types::ocl, data_types::f32, format::fyxb), val_fw },
+ { std::make_tuple(engine_types::ocl, data_types::f16, format::fyxb), val_fw },
});
}
~attach() {}