2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "meta_utils.h"
20 #include "primitive_type.h"
21 #include "program_node.h"
22 #include "primitive_inst.h"
23 #include "network_impl.h"
24 #include "engine_impl.h"
30 struct primitive_type_base : ::cldnn_primitive_type
32 static_assert(meta::is_api_primitive<PType>::value, "Primitive type passed to primitive_type_base should derive from cldnn::primitive");
34 std::shared_ptr<primitive> from_dto(const CLDNN_PRIMITIVE_DESC(primitive)* dto) const override
36 if (dto->type != this)
37 throw std::invalid_argument("primitive_type_base::from_dto: primitive type mismatch");
39 return std::make_shared<PType>(as_dto<PType>(dto));
42 std::shared_ptr<cldnn::program_node> create_node(program_impl& program, const std::shared_ptr<primitive> prim) const override
44 if (prim->type != this)
45 throw std::invalid_argument("primitive_type_base::create_node: primitive type mismatch");
47 return std::make_shared<typed_program_node<PType>>(std::static_pointer_cast<PType>(prim), program);
50 std::shared_ptr<cldnn::primitive_inst> create_instance(network_impl& network, const cldnn::program_node& node) const override
52 if (node.type() != this)
53 throw std::invalid_argument("primitive_type_base::create_instance: primitive type mismatch");
55 return std::make_shared<typed_primitive_inst<PType>>(network, node);
58 std::unique_ptr<primitive_impl> choose_impl(engine_impl& engine, const cldnn::program_node& node) const override
60 if (node.type() != this)
61 throw std::invalid_argument("primitive_type_base::choose_impl: primitive type mismatch");
63 return engine.create_primitive_impl(node.as<PType>());
66 bool does_an_implementation_exist(engine_impl& engine, const cldnn::program_node& node) const override
68 if (node.type() != this)
69 throw std::invalid_argument("primitive_type_base::choose_impl: primitive type mismatch");
70 return engine.does_an_implementation_exist(node.as<PType>());
73 bool does_possible_implementation_exist(engine_impl& engine, const cldnn::program_node& node) const override
75 if (node.type() != this)
76 throw std::invalid_argument("primitive_type_base::choose_impl: primitive type mismatch");
77 return engine.does_possible_implementation_exist(node.as<PType>());
80 cldnn::layout calc_output_layout(const cldnn::program_node& node) const override
82 if (node.type() != this)
83 throw std::invalid_argument("primitive_type_base::calc_output_layout: primitive type mismatch");
85 return typed_primitive_inst<PType>::calc_output_layout(node);
88 std::string to_string(const cldnn::program_node& node) const override
90 if (node.type() != this)
91 throw std::invalid_argument("primitive_type_base::to_string: primitive type mismatch");
93 return typed_primitive_inst<PType>::to_string(node);