Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / lstm_gemm_gpu.cpp
index 7cb6b11..40d601a 100644 (file)
@@ -50,12 +50,6 @@ protected:
         return args;
     }
 
-    virtual bool validate(typed_primitive_inst<lstm_gemm>& instance) const override
-    {
-        bool res = parent::validate(instance);
-
-        return res;
-    }
 public:
 
     static primitive_impl* create(const lstm_gemm_node& arg)
@@ -78,8 +72,25 @@ public:
 
             const auto& hidden_layout = arg.hidden().get_output_layout();
             lstm_gemm_params.SetHidden(convert_data_tensor(hidden_layout));
+            // TODO: make a generic function to get the direction
+            if (hidden_layout.size.spatial[1] > 1) {
+                lstm_gemm_params.hidden_direction = arg.direction();
+            }
         }
         lstm_gemm_params.direction = arg.direction();
+        
+        // Update the direction of the input for the gemm kernel
+        const auto& input_layout = arg.input().get_output_layout();
+        size_t input_directions = input_layout.size.spatial[1];
+
+        if (input_directions > 1)  // For bidirection input, input direction can be 1 or 0
+        {
+            lstm_gemm_params.input_direction = arg.direction();
+        }
+        else  // For unidirectional input
+        {
+            lstm_gemm_params.input_direction = 0;
+        }
 
         auto lstm_gemm_optional_params = get_default_optional_params<kernel_selector::lstm_gemm_optional_params>(arg.get_program());
 
@@ -103,6 +114,8 @@ namespace {
             implementation_map<lstm_gemm>::add({
                 { std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw },
                 { std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw },
+                { std::make_tuple(engine_types::ocl, data_types::f32, format::fyxb), val_fw },
+                { std::make_tuple(engine_types::ocl, data_types::f16, format::fyxb), val_fw },
             });
         }
         ~attach() {}