layout index_select_inst::calc_output_layout(index_select_node const& node)
{
- auto desc = node.get_primitive();
+ assert((bool)node.get_primitive()->output_data_type == false
+ && "Output data type forcing is not supported for "
+ "index_select_node!");
+ auto desc = node.get_primitive();
auto input_layout = node.input().get_output_layout();
- auto indices_layout = node.indices().get_output_layout();
- auto indices_size = indices_layout.size.spatial[0];
-
- auto axis = node.get_axis();
+
int32_t output_b = input_layout.size.batch[0];
int32_t output_f = input_layout.size.feature[0];
int32_t output_x = input_layout.size.spatial[0];
int32_t output_y = input_layout.size.spatial[1];
- switch (axis)
- {
- case index_select_axis_name::along_b:
- output_b = indices_size;
- break;
- case index_select_axis_name::along_f:
- output_f = indices_size;
- break;
- case index_select_axis_name::along_x:
- output_x = indices_size;
- break;
- case index_select_axis_name::along_y:
- output_y = indices_size;
- break;
- default:
- CLDNN_ERROR_MESSAGE(node.id(), "UNSPORTTED AXIS");
- break;
+ if (!node.get_reverse()) {
+ auto indices_layout = node.indices().get_output_layout();
+ auto indices_size = indices_layout.size.spatial[0];
+ auto axes = node.get_axes();
+ for (size_t i = 0; i < axes.size(); i++)
+ {
+ switch (axes[i])
+ {
+ case index_select_axis_name::along_b:
+ output_b = indices_size;
+ break;
+ case index_select_axis_name::along_f:
+ output_f = indices_size;
+ break;
+ case index_select_axis_name::along_x:
+ output_x = indices_size;
+ break;
+ case index_select_axis_name::along_y:
+ output_y = indices_size;
+ break;
+ default:
+ CLDNN_ERROR_MESSAGE(node.id(), "UNSUPPORTED AXIS");
+ break;
+ }
+ }
}
return layout{ input_layout.data_type, input_layout.format, { output_b, output_f, output_x, output_y } };
}
std::stringstream primitive_description;
std::string axis_str = "";
- switch (desc->axis)
+ for (size_t i = 0; i < desc->axis.size(); i++)
{
- case index_select_axis_name::along_b:
- axis_str = "along_b";
- break;
- case index_select_axis_name::along_f:
- axis_str = "along_f";
- break;
- case index_select_axis_name::along_y:
- axis_str = "along_y";
- break;
- case index_select_axis_name::along_x:
- axis_str = "along_x";
- break;
- default:
- axis_str = "not supported axis";
- break;
+ switch (desc->axis.at(i))
+ {
+ case index_select_axis_name::along_b:
+ axis_str += "along_b, ";
+ break;
+ case index_select_axis_name::along_f:
+ axis_str += "along_f, ";
+ break;
+ case index_select_axis_name::along_y:
+ axis_str += "along_y, ";
+ break;
+ case index_select_axis_name::along_x:
+ axis_str += "along_x, ";
+ break;
+ default:
+ axis_str += "not supported axis, ";
+ break;
+ }
}
json_composite index_select_info;
- index_select_info.add("axis", axis_str);
+ index_select_info.add("axes", axis_str);
node_info->add("index_select_info", index_select_info);
node_info->dump(primitive_description);
{
auto& input = node.input();
auto input_layout = input.get_output_layout();
- auto& indices = node.indices();
- auto indices_layout = indices.get_output_layout();
auto const node_id = node.id();
- CLDNN_ERROR_DATA_TYPES_MISMATCH(node_id, "indicies data_type", indices_layout.data_type, "i32 data_type ", data_types::i32, "");
CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", input_layout.format, "supported input format", format::bfyx, format::yxfb);
- CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", indices_layout.format, "supported indicies format", format::bfyx, format::yxfb);
- CLDNN_ERROR_NOT_EQUAL(node_id, "indicies batch_size", indices_layout.size.batch[0], "expected size", 1, "");
- CLDNN_ERROR_NOT_EQUAL(node_id, "indicies feature_size", indices_layout.size.feature[0], "expected size", 1, "");
- CLDNN_ERROR_NOT_EQUAL(node_id, "indicies y_size", indices_layout.size.spatial[1], "expected size", 1, "");
- CLDNN_ERROR_LESS_THAN(node_id, "indicies x_size", indices_layout.size.spatial[0], "expected size", 1, "");
+
+ if (!node.get_reverse())
+ {
+ auto& indices = node.indices();
+ auto indices_layout = indices.get_output_layout();
+ CLDNN_ERROR_DATA_TYPES_MISMATCH(node_id, "indicies data_type", indices_layout.data_type, "i32 data_type ", data_types::i32, "");
+ CLDNN_ERROR_NOT_EQUAL(node_id, "indicies batch_size", indices_layout.size.batch[0], "expected size", 1, "");
+ CLDNN_ERROR_NOT_EQUAL(node_id, "indicies feature_size", indices_layout.size.feature[0], "expected size", 1, "");
+ CLDNN_ERROR_NOT_EQUAL(node_id, "indicies y_size", indices_layout.size.spatial[1], "expected size", 1, "");
+ CLDNN_ERROR_LESS_THAN(node_id, "indicies x_size", indices_layout.size.spatial[0], "expected size", 1, "");
+ CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", indices_layout.format, "supported indicies format", format::bfyx, format::yxfb);
+ }
}
}