Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / lstm_elt_gpu.cpp
index 69baa64..b9f8eef 100644 (file)
@@ -47,12 +47,6 @@ protected:
         return args;
     }
 
-    virtual bool validate(typed_primitive_inst<lstm_elt>& instance) const override
-    {
-        bool res = parent::validate(instance);
-
-        return res;
-    }
 public:
 
     static primitive_impl* create(const lstm_elt_node& arg)
@@ -64,11 +58,16 @@ public:
         {
             const auto& cell_layout = arg.cell().get_output_layout();
             lstm_elt_params.SetCell(convert_data_tensor(cell_layout));
+            // TODO: make a generic function to get the direction
+            if (cell_layout.size.spatial[1] > 1) {
+                lstm_elt_params.cell_direction = arg.direction();
+            }
         }
 
         lstm_elt_params.SetOffsetOrder(arg.offset_order());
         lstm_elt_params.clip = arg.clip();
         lstm_elt_params.input_forget = arg.input_forget();
+        lstm_elt_params.direction = arg.direction();
 
         auto& kernel_selector = kernel_selector::lstm_elt_kernel_selector::Instance();
         auto best_kernels = kernel_selector.GetBestKernels(lstm_elt_params, lstm_elt_optional_params);
@@ -90,6 +89,8 @@ namespace {
             implementation_map<lstm_elt>::add({
                 { std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw },
                 { std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw },
+                { std::make_tuple(engine_types::ocl, data_types::f32, format::fyxb), val_fw },
+                { std::make_tuple(engine_types::ocl, data_types::f16, format::fyxb), val_fw },
             });
         }
         ~attach() {}