Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / apply_adam.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "apply_adam_inst.h"
18 #include "primitive_type_base.h"
19 #include "error_handler.h"
20 #include "json_object.h"
21
22 namespace cldnn
23 {
24 primitive_type_id apply_adam_type_id()
25 {
26     static primitive_type_base<apply_adam> instance;
27     return &instance;
28 }
29
30 apply_adam_node::typed_program_node(const std::shared_ptr<apply_adam> prim, program_impl& prog)
31     : parent(prim, prog)
32 {
33     can_share_buffer(false); //apply adam's output initial val should be either 0 or use same buffer as mutable_data after it (no allocation needed)
34 }
35 layout apply_adam_inst::calc_output_layout(apply_adam_node const& node)
36 {
37     assert((bool)node.get_primitive()->output_data_type == false
38            && "Output data type forcing is not supported for apply_adam_node!");
39     return node.input().get_non_padded_output_layout();
40 }
41
42 std::string apply_adam_inst::to_string(apply_adam_node const& node)
43 {
44     auto desc      = node.get_primitive();
45     auto node_info = node.desc_to_json();
46     auto& m     = node.m();
47     auto& v     = node.v();
48     auto& beta1_power = node.beta1_power();
49     auto& beta2_power = node.beta2_power();
50
51     std::stringstream primitive_description;
52
53     json_composite apply_adam_info;
54     apply_adam_info.add("m_id", m.id());
55     apply_adam_info.add("v_id", v.id());
56     apply_adam_info.add("beta1_power_id", beta1_power.id());
57     apply_adam_info.add("beta2_power_id", beta2_power.id());
58     apply_adam_info.add("lr", desc->lr);
59     apply_adam_info.add("beta1", desc->beta1);
60     apply_adam_info.add("beta2", desc->beta2);
61     apply_adam_info.add("epsilon", desc->epsilon);
62
63     node_info->add("apply adam info", apply_adam_info);
64     node_info->dump(primitive_description);
65
66     return primitive_description.str();
67 }
68
69 apply_adam_inst::typed_primitive_inst(network_impl& network, apply_adam_node const& node)
70     :parent(network, node) 
71 {
72     auto m_format = node.m().get_output_layout().format;
73     auto v_format = node.v().get_output_layout().format;
74     auto beta1_power_format = node.beta1_power().get_output_layout().format;
75     auto beta2_power_format = node.beta2_power().get_output_layout().format;
76
77     CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "M format", m_format.value, "supported m formats", format::yxfb, format::bfyx );
78     CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "V format", v_format.value, "supported v formats", format::yxfb, format::bfyx );
79     CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "beta1_power format", beta1_power_format.value, "supported beta1_power formats", format::yxfb, format::bfyx);
80     CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "beta2_power format", beta2_power_format.value, "supported beta2_power formats", format::yxfb, format::bfyx);
81 }
82 }