2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18 #include "deformable_convolution_inst.h"
19 #include "primitive_type_base.h"
20 #include "sliding_window_utils.h"
21 #include "error_handler.h"
22 #include "json_object.h"
26 primitive_type_id deformable_conv_type_id() {
27 static primitive_type_base<deformable_conv> instance;
31 layout deformable_conv_inst::calc_output_layout(deformable_conv_node const& node) {
32 auto desc = node.get_primitive();
34 auto input_layout = node.input().get_output_layout();
36 auto input_type = input_layout.data_type;
37 auto output_type = node.get_primitive()->output_data_type ? *node.get_primitive()->output_data_type : input_type;
39 tensor output_size(input_layout.size.batch[0],
40 desc->output_size.feature[0],
41 desc->output_size.spatial[0],
42 desc->output_size.spatial[1],
43 desc->output_size.spatial[2]);
45 return {output_type, input_layout.format, output_size};
48 std::string deformable_conv_inst::to_string(deformable_conv_node const& node) {
49 auto desc = node.get_primitive();
50 auto split = node.get_split();
51 auto node_info = node.desc_to_json();
53 std::stringstream primitive_description;
55 json_composite conv_info;
56 conv_info.add("split", split);
57 conv_info.add("groups", desc->groups);
59 json_composite ud_out_size_info;
60 ud_out_size_info.add("size", desc->output_size.to_string());
61 conv_info.add("with user defined output size", ud_out_size_info);
63 node_info->add("deformable_convolution info", conv_info);
64 node_info->dump(primitive_description);
66 return primitive_description.str();
69 deformable_conv_inst::typed_primitive_inst(network_impl& network, deformable_conv_node const& node) : parent(network, node) {
73 primitive_type_id deformable_interp_type_id() {
74 static primitive_type_base<deformable_interp> instance;
78 layout deformable_interp_inst::calc_output_layout(deformable_interp_node const& node) {
79 auto desc = node.get_primitive();
81 auto input_layout = node.input().get_output_layout();
83 auto kernel_size = desc->kernel_size;
84 auto input_type = input_layout.data_type;
85 auto output_type = node.get_primitive()->output_data_type ? *node.get_primitive()->output_data_type : input_type;
87 tensor output_size(input_layout.size.batch[0],
88 input_layout.size.feature[0]*kernel_size.spatial[0]*kernel_size.spatial[1],
89 desc->output_size.spatial[0],
90 desc->output_size.spatial[1],
91 desc->output_size.spatial[2]);
93 return {output_type, input_layout.format, output_size};
96 std::string deformable_interp_inst::to_string(deformable_interp_node const& node) {
97 auto desc = node.get_primitive();
98 auto strd = desc->stride;
99 auto split = node.get_split();
100 auto dilation = desc->dilation;
101 auto node_info = node.desc_to_json();
103 std::stringstream primitive_description;
105 json_composite interp_info;
106 interp_info.add("stride", strd.to_string());
107 interp_info.add("input offset", desc->input_offset.to_string());
108 interp_info.add("split", split);
109 interp_info.add("dilation", dilation.to_string());
110 interp_info.add("deformable_groups", desc->deformable_groups);
111 interp_info.add("groups", desc->groups);
113 json_composite ud_out_size_info;
114 ud_out_size_info.add("size", desc->output_size.to_string());
115 interp_info.add("with user defined output size", ud_out_size_info);
117 node_info->add("deformable_interpolation info", interp_info);
118 node_info->dump(primitive_description);
120 return primitive_description.str();
123 deformable_interp_inst::typed_primitive_inst(network_impl& network, deformable_interp_node const& node) : parent(network, node) {