Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / embed.cpp
index b2087b0..b1c6199 100644 (file)
@@ -31,11 +31,13 @@ namespace cldnn
 
        layout embed_inst::calc_output_layout(embed_node const& node)
        {
-               auto input_layout = node.input().get_output_layout();
+        assert((bool)node.get_primitive()->output_data_type == false
+               && "Output data type forcing is not supported for embed_node!");
+        auto input_layout = node.input().get_output_layout();
                auto desc = node.get_primitive();
                auto weights_layout = node.weights().get_output_layout();
 
-               auto result = layout(input_layout.data_type, format::bfyx, tensor(input_layout.size.batch[0], input_layout.size.spatial[0] * input_layout.size.spatial[1], weights_layout.size.batch[0], 1));
+               auto result = layout(input_layout.data_type, format::bfyx, tensor(input_layout.size.batch[0], input_layout.size.spatial[0], weights_layout.size.batch[0], 1));
                return result;
                
        }
@@ -66,5 +68,8 @@ namespace cldnn
                auto output_size = output_memory().get_layout();
                CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "input format", input_size.format.value, "expected format", format::yxfb, format::bfyx);
                CLDNN_ERROR_NOT_EQUAL(node.id(), "Input size", input_size.size.raw.size(), "output size", output_size.size.raw.size(), "");
+        CLDNN_ERROR_NOT_EQUAL(node.id(), "Input batch", input_size.size.batch[0], "output batch", output_size.size.batch[0], "");
+        CLDNN_ERROR_NOT_EQUAL(node.id(), "Input feature", input_size.size.feature[0], "size 1", 1, "");
+        CLDNN_ERROR_NOT_EQUAL(node.id(), "Input y size ", input_size.size.spatial[1], "size 1", 1, "");
        }
 }