Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / index_select.hpp
index 11ff25a..5897533 100644 (file)
 
 namespace cldnn
 {
-/// @brief Axis which index_select primitive will index.
-enum class index_select_axis_name : int32_t
-{
-    along_b,
-    along_f,
-    along_y,
-    along_x
-};
-
 /// @brief Select index, which will be copied to the output..
 ///
 /// @details Applies index selecting along specified dimension. The indices, which will be copied are specifed by 
@@ -63,7 +54,7 @@ struct index_select : public primitive_base<index_select, CLDNN_PRIMITIVE_DESC(i
     /// @param input              An identifier of primitive, which is an input for newly created
     ///                           index_select primitive.
     /// @param indicies           An identifer of primitive, which have indices in memory distributed along x. 
-    /// @param type               Axis of index selecting.
+    /// @param axis               Axis of index selecting.
     /// @param output_padding     Optional padding for output from primitive.
     index_select(
         const primitive_id& id,
@@ -72,23 +63,65 @@ struct index_select : public primitive_base<index_select, CLDNN_PRIMITIVE_DESC(i
         index_select_axis_name axis = index_select_axis_name::along_b,
         const padding& output_padding = padding()
     )
-        : primitive_base(id, {input, indices}, output_padding)
+        : primitive_base(id, { input, indices }, output_padding)
+        , axis( { axis } )
+        , reverse(false)
+    {}
+
+    /// @brief Constructs index_select primitive / layer.
+    ///
+    /// @param id                 An identifier of new primitive.
+    /// @param input              An identifier of primitive, which is an input for newly created
+    ///                           index_select primitive.
+    /// @param axis               Axis of index selecting.
+    /// @param output_padding     Optional padding for output from primitive.
+    index_select(
+        const primitive_id& id,
+        const primitive_id& input,
+        index_select_axis_name axis = index_select_axis_name::along_b,
+        const padding& output_padding = padding()
+    )
+        : primitive_base(id, { input }, output_padding)
+        , axis( { axis } )
+        , reverse(true)
+    {}
+
+    /// @brief Constructs index_select primitive / layer.
+    ///
+    /// @param id                 An identifier of new primitive.
+    /// @param input              An identifier of primitive, which is an input for newly created
+    ///                           index_select primitive.
+    /// @param axis               Vector of axes of index selecting.
+    /// @param output_padding     Optional padding for output from primitive.
+    index_select(
+        const primitive_id& id,
+        const primitive_id& input,
+        const std::vector<index_select_axis_name>& axis = { index_select_axis_name::along_b },
+        const padding& output_padding = padding()
+    )
+        : primitive_base(id, { input }, output_padding)
         , axis(axis)
+        , reverse(true)
     {}
 
     /// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{broadcast}
     index_select(const dto* dto)
         : primitive_base(dto)
-        , axis(static_cast<index_select_axis_name>(dto->axis))
+        , axis(dto->axis, dto->axis + dto->axis_num)
+        , reverse(dto->reverse)
     {}
 
-    /// @brief Axis of index selecting.
-    index_select_axis_name axis;
+    /// @brief A list of axes of index selecting
+    std::vector<index_select_axis_name> axis;
+    /// @brief Do index_select in reverse order on axis/axes.
+    bool reverse;
 
 protected:
     void update_dto(dto& dto) const override
     {
-        dto.axis = static_cast<cldnn_index_select_axis>(axis);
+        dto.axis = axis.data();
+        dto.axis_num = (int)axis.size();
+        dto.reverse = reverse;
     }
 };
 /// @}