2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "apply_adam_inst.h"
18 #include "primitive_type_base.h"
19 #include "error_handler.h"
20 #include "json_object.h"
24 primitive_type_id apply_adam_type_id()
26 static primitive_type_base<apply_adam> instance;
30 apply_adam_node::typed_program_node(const std::shared_ptr<apply_adam> prim, program_impl& prog)
33 can_share_buffer(false); //apply adam's output initial val should be either 0 or use same buffer as mutable_data after it (no allocation needed)
35 layout apply_adam_inst::calc_output_layout(apply_adam_node const& node)
37 assert((bool)node.get_primitive()->output_data_type == false
38 && "Output data type forcing is not supported for apply_adam_node!");
39 return node.input().get_non_padded_output_layout();
42 std::string apply_adam_inst::to_string(apply_adam_node const& node)
44 auto desc = node.get_primitive();
45 auto node_info = node.desc_to_json();
48 auto& beta1_power = node.beta1_power();
49 auto& beta2_power = node.beta2_power();
51 std::stringstream primitive_description;
53 json_composite apply_adam_info;
54 apply_adam_info.add("m_id", m.id());
55 apply_adam_info.add("v_id", v.id());
56 apply_adam_info.add("beta1_power_id", beta1_power.id());
57 apply_adam_info.add("beta2_power_id", beta2_power.id());
58 apply_adam_info.add("lr", desc->lr);
59 apply_adam_info.add("beta1", desc->beta1);
60 apply_adam_info.add("beta2", desc->beta2);
61 apply_adam_info.add("epsilon", desc->epsilon);
63 node_info->add("apply adam info", apply_adam_info);
64 node_info->dump(primitive_description);
66 return primitive_description.str();
69 apply_adam_inst::typed_primitive_inst(network_impl& network, apply_adam_node const& node)
70 :parent(network, node)
72 auto m_format = node.m().get_output_layout().format;
73 auto v_format = node.v().get_output_layout().format;
74 auto beta1_power_format = node.beta1_power().get_output_layout().format;
75 auto beta2_power_format = node.beta2_power().get_output_layout().format;
77 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "M format", m_format.value, "supported m formats", format::yxfb, format::bfyx );
78 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "V format", v_format.value, "supported v formats", format::yxfb, format::bfyx );
79 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "beta1_power format", beta1_power_format.value, "supported beta1_power formats", format::yxfb, format::bfyx);
80 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "beta2_power format", beta2_power_format.value, "supported beta2_power formats", format::yxfb, format::bfyx);