namespace cldnn
{
-/// @brief Axis which index_select primitive will index.
-enum class index_select_axis_name : int32_t
-{
- along_b,
- along_f,
- along_y,
- along_x
-};
-
/// @brief Select index, which will be copied to the output..
///
/// @details Applies index selecting along specified dimension. The indices, which will be copied are specifed by
/// @param input An identifier of primitive, which is an input for newly created
/// index_select primitive.
/// @param indicies An identifer of primitive, which have indices in memory distributed along x.
- /// @param type Axis of index selecting.
+ /// @param axis Axis of index selecting.
/// @param output_padding Optional padding for output from primitive.
index_select(
const primitive_id& id,
index_select_axis_name axis = index_select_axis_name::along_b,
const padding& output_padding = padding()
)
- : primitive_base(id, {input, indices}, output_padding)
+ : primitive_base(id, { input, indices }, output_padding)
+ , axis( { axis } )
+ , reverse(false)
+ {}
+
+ /// @brief Constructs index_select primitive / layer.
+ ///
+ /// @param id An identifier of new primitive.
+ /// @param input An identifier of primitive, which is an input for newly created
+ /// index_select primitive.
+ /// @param axis Axis of index selecting.
+ /// @param output_padding Optional padding for output from primitive.
+ index_select(
+ const primitive_id& id,
+ const primitive_id& input,
+ index_select_axis_name axis = index_select_axis_name::along_b,
+ const padding& output_padding = padding()
+ )
+ : primitive_base(id, { input }, output_padding)
+ , axis( { axis } )
+ , reverse(true)
+ {}
+
+ /// @brief Constructs index_select primitive / layer.
+ ///
+ /// @param id An identifier of new primitive.
+ /// @param input An identifier of primitive, which is an input for newly created
+ /// index_select primitive.
+ /// @param axis Vector of axes of index selecting.
+ /// @param output_padding Optional padding for output from primitive.
+ index_select(
+ const primitive_id& id,
+ const primitive_id& input,
+ const std::vector<index_select_axis_name>& axis = { index_select_axis_name::along_b },
+ const padding& output_padding = padding()
+ )
+ : primitive_base(id, { input }, output_padding)
, axis(axis)
+ , reverse(true)
{}
/// @brief Constructs a copy from C API @CLDNN_PRIMITIVE_DESC{broadcast}
index_select(const dto* dto)
: primitive_base(dto)
- , axis(static_cast<index_select_axis_name>(dto->axis))
+ , axis(dto->axis, dto->axis + dto->axis_num)
+ , reverse(dto->reverse)
{}
- /// @brief Axis of index selecting.
- index_select_axis_name axis;
+ /// @brief A list of axes of index selecting
+ std::vector<index_select_axis_name> axis;
+ /// @brief Do index_select in reverse order on axis/axes.
+ bool reverse;
protected:
void update_dto(dto& dto) const override
{
- dto.axis = static_cast<cldnn_index_select_axis>(axis);
+ dto.axis = axis.data();
+ dto.axis_num = (int)axis.size();
+ dto.reverse = reverse;
}
};
/// @}