Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / api / CPP / primitive.hpp
index 8314afc..41fa27d 100644 (file)
@@ -92,15 +92,30 @@ public:
         const primitive_type_id& type,
         const primitive_id& id,
         const std::vector<primitive_id>& input,
-        const padding& output_padding = padding()
+        const padding& output_padding = padding(),
+        const optional_data_type output_data_type = optional_data_type()
     )
-        :type(type), id(id), input(_input.cpp_ids), output_padding(output_padding), _input(input)
+        : type(type)
+        , id(id)
+        , input(_input.cpp_ids)
+        , output_padding(output_padding)
+        , output_data_type(output_data_type)
+        , _input(input)
     {}
 
     /// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{primitive}
-    primitive(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
-        :type(dto->type), id(dto->id), input(_input.cpp_ids), output_padding(dto->output_padding), _input(dto->input)
-    {}
+    primitive(const CLDNN_PRIMITIVE_DESC(primitive) * dto)
+        : type(dto->type)
+        , id(dto->id)
+        , input(_input.cpp_ids)
+        , output_padding(dto->output_padding)
+        , output_data_type(dto->output_data_type.enabled
+                               ? optional_data_type{static_cast<data_types>(
+                                     dto->output_data_type.data_type)}
+                               : optional_data_type{})
+        , _input(dto->input)
+    {
+    }
 
     virtual ~primitive() = default;
 
@@ -114,7 +129,7 @@ public:
     {
         std::vector<std::reference_wrapper<primitive_id>> result;
         auto&& deps = get_dependencies();
-
+        
         result.reserve(_input.size() + deps.size());
         for (auto& pid : _input.cpp_ids)
             result.push_back(std::ref(pid));
@@ -148,6 +163,9 @@ public:
     /// @brief Requested output padding.
     padding output_padding;
 
+    /// @brief Requested output precision, if any.
+    optional_data_type output_data_type;
+
 protected:
     struct primitive_id_arr
     {
@@ -198,6 +216,9 @@ public:
         _dto.type = type;
         _dto.input = _input.ref();
         _dto.output_padding = output_padding;
+        _dto.output_data_type.enabled = (bool)output_data_type;
+        _dto.output_data_type.data_type =
+            static_cast<cldnn_data_type>(*output_data_type);
 
         //call abstract method to update primitive-specific fields
         update_dto(_dto);
@@ -208,14 +229,15 @@ protected:
     explicit primitive_base(
         const primitive_id& id,
         const std::vector<primitive_id>& input,
-        const padding& output_padding = padding())
-        : primitive(PType::type_id(), id, input, output_padding)
+        const padding& output_padding = padding(),
+        optional_data_type output_data_type = optional_data_type())
+        : primitive(PType::type_id(), id, input, output_padding, output_data_type)
     {}
 
     primitive_base(const DTO* dto)
         : primitive(reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(dto))
     {
-        if (dto->type != PType::type_id())
+        if (dto->type != PType::type_id()) 
             throw std::invalid_argument("DTO type mismatch");
     }