[IE CLDNN] Fix unsupported dims number error (#2453)
authorSergey Shlyapnikov <sergey.shlyapnikov@intel.com>
Fri, 2 Oct 2020 13:51:09 +0000 (16:51 +0300)
committerGitHub <noreply@github.com>
Fri, 2 Oct 2020 13:51:09 +0000 (16:51 +0300)
inference-engine/src/cldnn_engine/cldnn_common_utils.h
inference-engine/thirdparty/clDNN/src/gpu/scale_gpu.cpp
inference-engine/thirdparty/clDNN/src/select.cpp

index 9f4898c..6423163 100644 (file)
@@ -130,6 +130,7 @@ inline cldnn::format ImageFormatFromLayout(InferenceEngine::Layout l) {
 
 inline cldnn::format defaultFormatForDims(size_t dimensions) {
     switch (dimensions) {
+    case 0:
     case 1:
     case 2:
     case 3:
index f0165f6..9645def 100644 (file)
@@ -85,53 +85,66 @@ attach_scale_gpu::attach_scale_gpu() {
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::yxfb), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::byxf), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfzyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfzyx), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfwzyx), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfwzyx), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv16), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_yx_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv16), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_zyx_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_zyx_fsv16), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_zyx_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_zyx_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_zyx_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bs_fs_zyx_bsv16_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bs_fs_zyx_bsv16_fsv16), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bs_fs_zyx_bsv16_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::fs_b_yx_fsv32), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::fs_b_yx_fsv32), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bs_fs_yx_bsv16_fsv16), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bs_fs_yx_bsv16_fsv16), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bs_fs_yx_bsv16_fsv16), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv4), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv4), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv4), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv4), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_yx_fsv4), val_fw);
 
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv32), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv32), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv32), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv32), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_yx_fsv32), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_zyx_fsv32), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_zyx_fsv32), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_zyx_fsv32), val_fw);
     implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_zyx_fsv32), val_fw);
+    implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_zyx_fsv32), val_fw);
 }
 
 }  // namespace detail
index 3d94213..d01c1f8 100644 (file)
@@ -67,12 +67,21 @@ select_inst::typed_primitive_inst(network_impl& network, select_node const& node
                                 3,
                                 "");
 
-    CLDNN_ERROR_NOT_EQUAL(node.id(),
-                                "Mask format",
-                                deps[0]->get_output_layout().format,
-                                "Positive input format",
-                                deps[1]->get_output_layout().format,
-                                "");
+    if (deps[1]->get_output_layout().size != cldnn::tensor(1))
+        CLDNN_ERROR_NOT_EQUAL(node.id(),
+                              "Mask format",
+                              deps[0]->get_output_layout().format,
+                              "Positive input format",
+                              deps[1]->get_output_layout().format,
+                              "");
+             
+    if (deps[2]->get_output_layout().size != cldnn::tensor(1))
+        CLDNN_ERROR_NOT_EQUAL(node.id(),
+                              "Mask format",
+                              deps[0]->get_output_layout().format,
+                              "Positive input format",
+                              deps[2]->get_output_layout().format,
+                              "");
 
     if (node.get_primitive()->broadcast_type == "none") {
         CLDNN_ERROR_LAYOUT_MISMATCH(node.id(),
@@ -89,12 +98,13 @@ select_inst::typed_primitive_inst(network_impl& network, select_node const& node
                                 deps[1]->get_output_layout().size,
                                 "");
     } else if (node.get_primitive()->broadcast_type == "numpy") {
-        CLDNN_ERROR_NOT_EQUAL(node.id(),
-                                "Positive input format",
-                                deps[1]->get_output_layout().format,
-                                "Negative input format",
-                                deps[2]->get_output_layout().format,
-                                "");
+        if (deps[1]->get_output_layout().size != cldnn::tensor(1) && deps[2]->get_output_layout().size != cldnn::tensor(1))
+            CLDNN_ERROR_NOT_EQUAL(node.id(),
+                                  "Positive input format",
+                                  deps[1]->get_output_layout().format,
+                                  "Negative input format",
+                                  deps[2]->get_output_layout().format,
+                                  "");
 
         CLDNN_ERROR_DATA_TYPES_MISMATCH(node.id(),
                                 "Positive input data type",