2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "scale_grad_weights_inst.h"
18 #include "primitive_type_base.h"
19 #include "error_handler.h"
20 #include "json_object.h"
24 primitive_type_id scale_grad_weights_type_id()
26 static primitive_type_base<scale_grad_weights> instance;
30 layout scale_grad_weights_inst::calc_output_layout(scale_grad_weights_node const& node)
32 assert((bool)node.get_primitive()->output_data_type == false
33 && "Output data type forcing is not supported for "
34 "scale_grad_weights_node!");
35 //output buffer will not be used in this primitive
36 auto input_grad_layout_size = node.input().get_output_layout();
37 return{ input_grad_layout_size.data_type, input_grad_layout_size.format,{ 1, 1, 1, 1 } };
40 std::string scale_grad_weights_inst::to_string(scale_grad_weights_node const& node)
42 auto desc = node.get_primitive();
43 auto node_info = node.desc_to_json();
44 auto& input = node.input();
45 auto& scale_input = node.weights();
46 auto& input_grad = node.input_grad();
48 std::stringstream primitive_description;
50 json_composite scale_grad_weights_info;
51 scale_grad_weights_info.add("input", input.id());
52 scale_grad_weights_info.add("scale input", scale_input.id());
53 scale_grad_weights_info.add("input grad", input_grad.id());
55 scale_grad_weights_info.add("bias", node.bias().id());
57 node_info->add("scale_grad_weights info", scale_grad_weights_info);
58 node_info->dump(primitive_description);
60 return primitive_description.str();
63 scale_grad_weights_inst::typed_primitive_inst(network_impl& network, scale_grad_weights_node const& node)
64 :parent(network, node)
66 auto scale_layout = node.weights().get_output_layout();
67 auto scale_format = scale_layout.format;
69 auto scale_sizes = scale_layout.size;
70 auto scale_feature_size = scale_layout.size.feature[0];
72 auto input_layout = node.input().get_output_layout();
73 auto input_feature_size = input_layout.size.feature[0];
75 CLDNN_ERROR_NOT_EQUAL(node.id(), "Scale feature size", scale_feature_size, "input feature size", input_feature_size, "");
77 if (scale_sizes.spatial[0] != 1 || scale_sizes.spatial[1] != 1 || scale_sizes.batch[0] != 1) //Remove if support for other scale sizes will be added.
79 CLDNN_ERROR_MESSAGE(node.id(), "All sizes in scale_input except feature should be 1.");
82 if (node.use_momentum())
84 CLDNN_ERROR_LAYOUT_MISMATCH(node.id(), "Scale memory", node.weights().get_output_layout(), "previous scale grad memory", node.prev_scale_grad().get_output_layout(), "");
85 CLDNN_ERROR_LAYOUT_MISMATCH(node.id(), "Bias memory", node.bias().get_output_layout(), "previous bias grad memory", node.prev_bias_grad().get_output_layout(), "");
90 auto bias_layout = node.bias().get_output_layout();
91 auto bias_format = bias_layout.format;
92 auto bias_raw_sizes = bias_layout.size.raw;
94 CLDNN_ERROR_NOT_PROPER_FORMAT(node.id(), "Scale format", scale_format.value, "bias format", bias_format);
96 for (size_t i = 0; i < bias_layout.size.raw.size(); ++i)
98 if (scale_layout.size.raw[i] != bias_raw_sizes[i])
99 CLDNN_ERROR_MESSAGE(node.id(), "Scale input size do not match bias size! Size index:" + std::to_string(i));