const primitive_type_id& type,
const primitive_id& id,
const std::vector<primitive_id>& input,
- const padding& output_padding = padding()
+ const padding& output_padding = padding(),
+ const optional_data_type output_data_type = optional_data_type()
)
- :type(type), id(id), input(_input.cpp_ids), output_padding(output_padding), _input(input)
+ : type(type)
+ , id(id)
+ , input(_input.cpp_ids)
+ , output_padding(output_padding)
+ , output_data_type(output_data_type)
+ , _input(input)
{}
/// @brief Constructs a copy from basic C API @CLDNN_PRIMITIVE_DESC{primitive}
- primitive(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
- :type(dto->type), id(dto->id), input(_input.cpp_ids), output_padding(dto->output_padding), _input(dto->input)
- {}
+ primitive(const CLDNN_PRIMITIVE_DESC(primitive) * dto)
+ : type(dto->type)
+ , id(dto->id)
+ , input(_input.cpp_ids)
+ , output_padding(dto->output_padding)
+ , output_data_type(dto->output_data_type.enabled
+ ? optional_data_type{static_cast<data_types>(
+ dto->output_data_type.data_type)}
+ : optional_data_type{})
+ , _input(dto->input)
+ {
+ }
virtual ~primitive() = default;
{
std::vector<std::reference_wrapper<primitive_id>> result;
auto&& deps = get_dependencies();
-
+
result.reserve(_input.size() + deps.size());
for (auto& pid : _input.cpp_ids)
result.push_back(std::ref(pid));
/// @brief Requested output padding.
padding output_padding;
+ /// @brief Requested output precision, if any.
+ optional_data_type output_data_type;
+
protected:
struct primitive_id_arr
{
_dto.type = type;
_dto.input = _input.ref();
_dto.output_padding = output_padding;
+ _dto.output_data_type.enabled = (bool)output_data_type;
+ _dto.output_data_type.data_type =
+ static_cast<cldnn_data_type>(*output_data_type);
//call abstract method to update primitive-specific fields
update_dto(_dto);
explicit primitive_base(
const primitive_id& id,
const std::vector<primitive_id>& input,
- const padding& output_padding = padding())
- : primitive(PType::type_id(), id, input, output_padding)
+ const padding& output_padding = padding(),
+ optional_data_type output_data_type = optional_data_type())
+ : primitive(PType::type_id(), id, input, output_padding, output_data_type)
{}
primitive_base(const DTO* dto)
: primitive(reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(dto))
{
- if (dto->type != PType::type_id())
+ if (dto->type != PType::type_id())
throw std::invalid_argument("DTO type mismatch");
}