Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / primitive_type_base.h
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #pragma once
19 #include "meta_utils.h"
20 #include "primitive_type.h"
21 #include "program_node.h"
22 #include "primitive_inst.h"
23 #include "network_impl.h"
24 #include "engine_impl.h"
25 #include <memory>
26
27 namespace cldnn
28 {
29 template<class PType>
30 struct primitive_type_base : ::cldnn_primitive_type
31 {
32     static_assert(meta::is_api_primitive<PType>::value, "Primitive type passed to primitive_type_base should derive from cldnn::primitive");
33
34     std::shared_ptr<primitive> from_dto(const CLDNN_PRIMITIVE_DESC(primitive)* dto) const override
35     {
36         if (dto->type != this)
37             throw std::invalid_argument("primitive_type_base::from_dto: primitive type mismatch");
38
39         return std::make_shared<PType>(as_dto<PType>(dto));
40     }
41
42     std::shared_ptr<cldnn::program_node> create_node(program_impl& program, const std::shared_ptr<primitive> prim) const override
43     {
44         if (prim->type != this)
45             throw std::invalid_argument("primitive_type_base::create_node: primitive type mismatch");
46
47         return std::make_shared<typed_program_node<PType>>(std::static_pointer_cast<PType>(prim), program);
48     }
49
50     std::shared_ptr<cldnn::primitive_inst> create_instance(network_impl& network, const cldnn::program_node& node) const override
51     {
52         if (node.type() != this)
53             throw std::invalid_argument("primitive_type_base::create_instance: primitive type mismatch");
54
55         return std::make_shared<typed_primitive_inst<PType>>(network, node);
56     }
57
58     std::unique_ptr<primitive_impl> choose_impl(engine_impl& engine, const cldnn::program_node& node) const override
59     {
60         if (node.type() != this)
61             throw std::invalid_argument("primitive_type_base::choose_impl: primitive type mismatch");
62
63         return engine.create_primitive_impl(node.as<PType>());
64     }
65
66     bool does_an_implementation_exist(engine_impl& engine, const cldnn::program_node& node) const override
67     {
68         if (node.type() != this)
69             throw std::invalid_argument("primitive_type_base::choose_impl: primitive type mismatch");
70         return engine.does_an_implementation_exist(node.as<PType>());
71     }
72
73     bool does_possible_implementation_exist(engine_impl& engine, const cldnn::program_node& node) const override
74     {
75         if (node.type() != this)
76             throw std::invalid_argument("primitive_type_base::choose_impl: primitive type mismatch");
77         return engine.does_possible_implementation_exist(node.as<PType>());
78     }
79
80     cldnn::layout calc_output_layout(const cldnn::program_node& node) const override
81     {
82         if (node.type() != this)
83             throw std::invalid_argument("primitive_type_base::calc_output_layout: primitive type mismatch");
84
85         return typed_primitive_inst<PType>::calc_output_layout(node);
86     }
87
88     std::string to_string(const cldnn::program_node& node) const override
89     {
90         if (node.type() != this)
91             throw std::invalid_argument("primitive_type_base::to_string: primitive type mismatch");
92
93         return typed_primitive_inst<PType>::to_string(node);
94     }
95
96 };
97
98 }