Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / index_select.cpp
index 9c14470..88acded 100644 (file)
@@ -30,36 +30,44 @@ namespace cldnn
 
        layout index_select_inst::calc_output_layout(index_select_node const& node)
        {
-               auto desc = node.get_primitive();
+        assert((bool)node.get_primitive()->output_data_type == false
+               && "Output data type forcing is not supported for "
+                  "index_select_node!");
+        auto desc = node.get_primitive();
 
         auto input_layout = node.input().get_output_layout();
-        auto indices_layout = node.indices().get_output_layout();
-        auto indices_size = indices_layout.size.spatial[0];
-
-        auto axis = node.get_axis();
+        
         
         int32_t output_b = input_layout.size.batch[0];
         int32_t output_f = input_layout.size.feature[0];
         int32_t output_x = input_layout.size.spatial[0];
         int32_t output_y = input_layout.size.spatial[1];
 
-        switch (axis)
-        {
-        case index_select_axis_name::along_b:
-            output_b = indices_size;
-            break;
-        case index_select_axis_name::along_f:
-            output_f = indices_size;
-            break;
-        case index_select_axis_name::along_x:
-            output_x = indices_size;
-            break;
-        case index_select_axis_name::along_y:
-            output_y = indices_size;
-            break;
-        default:
-            CLDNN_ERROR_MESSAGE(node.id(), "UNSPORTTED AXIS");
-            break;
+        if (!node.get_reverse()) {
+            auto indices_layout = node.indices().get_output_layout();
+            auto indices_size = indices_layout.size.spatial[0];
+            auto axes = node.get_axes();
+            for (size_t i = 0; i < axes.size(); i++)
+            {
+                switch (axes[i])
+                {
+                case index_select_axis_name::along_b:
+                    output_b = indices_size;
+                    break;
+                case index_select_axis_name::along_f:
+                    output_f = indices_size;
+                    break;
+                case index_select_axis_name::along_x:
+                    output_x = indices_size;
+                    break;
+                case index_select_axis_name::along_y:
+                    output_y = indices_size;
+                    break;
+                default:
+                    CLDNN_ERROR_MESSAGE(node.id(), "UNSUPPORTED AXIS");
+                    break;
+                }
+            }
         }
         return layout{ input_layout.data_type, input_layout.format, { output_b, output_f, output_x, output_y } };
        }
@@ -71,27 +79,30 @@ namespace cldnn
                std::stringstream primitive_description;
 
         std::string axis_str = "";
-        switch (desc->axis)
+        for (size_t i = 0; i < desc->axis.size(); i++)
         {
-        case index_select_axis_name::along_b:
-            axis_str = "along_b";
-            break;
-        case index_select_axis_name::along_f:
-            axis_str = "along_f";
-            break;
-        case index_select_axis_name::along_y:
-            axis_str = "along_y";
-            break;
-        case index_select_axis_name::along_x:
-            axis_str = "along_x";
-            break;
-        default:
-            axis_str = "not supported axis";
-            break;
+            switch (desc->axis.at(i))
+            {
+            case index_select_axis_name::along_b:
+                axis_str += "along_b, ";
+                break;
+            case index_select_axis_name::along_f:
+                axis_str += "along_f, ";
+                break;
+            case index_select_axis_name::along_y:
+                axis_str += "along_y, ";
+                break;
+            case index_select_axis_name::along_x:
+                axis_str += "along_x, ";
+                break;
+            default:
+                axis_str += "not supported axis, ";
+                break;
+            }
         }
 
         json_composite index_select_info;
-        index_select_info.add("axis", axis_str);
+        index_select_info.add("axes", axis_str);
 
         node_info->add("index_select_info", index_select_info);
                node_info->dump(primitive_description);
@@ -104,17 +115,21 @@ namespace cldnn
        {
         auto& input = node.input();
         auto input_layout = input.get_output_layout();
-        auto& indices = node.indices();
-        auto indices_layout = indices.get_output_layout();
         auto const node_id = node.id();
 
-        CLDNN_ERROR_DATA_TYPES_MISMATCH(node_id, "indicies data_type", indices_layout.data_type, "i32 data_type ", data_types::i32, "");
         CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", input_layout.format, "supported input format", format::bfyx, format::yxfb);
-        CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", indices_layout.format, "supported indicies format", format::bfyx, format::yxfb);
-        CLDNN_ERROR_NOT_EQUAL(node_id, "indicies batch_size", indices_layout.size.batch[0], "expected size", 1, "");
-        CLDNN_ERROR_NOT_EQUAL(node_id, "indicies feature_size", indices_layout.size.feature[0], "expected size", 1, "");
-        CLDNN_ERROR_NOT_EQUAL(node_id, "indicies y_size", indices_layout.size.spatial[1], "expected size", 1, "");
-        CLDNN_ERROR_LESS_THAN(node_id, "indicies x_size", indices_layout.size.spatial[0], "expected size", 1, "");
+        
+        if (!node.get_reverse())
+        {
+            auto& indices = node.indices();
+            auto indices_layout = indices.get_output_layout();
 
+            CLDNN_ERROR_DATA_TYPES_MISMATCH(node_id, "indicies data_type", indices_layout.data_type, "i32 data_type ", data_types::i32, "");
+            CLDNN_ERROR_NOT_EQUAL(node_id, "indicies batch_size", indices_layout.size.batch[0], "expected size", 1, "");
+            CLDNN_ERROR_NOT_EQUAL(node_id, "indicies feature_size", indices_layout.size.feature[0], "expected size", 1, "");
+            CLDNN_ERROR_NOT_EQUAL(node_id, "indicies y_size", indices_layout.size.spatial[1], "expected size", 1, "");
+            CLDNN_ERROR_LESS_THAN(node_id, "indicies x_size", indices_layout.size.spatial[0], "expected size", 1, "");
+            CLDNN_ERROR_NOT_PROPER_FORMAT(node_id, "input_format", indices_layout.format, "supported indicies format", format::bfyx, format::yxfb);
+        }
        }
 }