Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / mir / src / mir_caffe2_importer / caffe2_op_creator.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "caffe2_op_creator.h"
18 #include "caffe2_proto_helper.h"
19
20 #include "mir/ops/AddOp.h"
21 #include "mir/ops/AvgPool2DOp.h"
22 #include "mir/ops/CappedReluOp.h"
23 #include "mir/ops/ConcatOp.h"
24 #include "mir/ops/ConstantOp.h"
25 #include "mir/ops/Conv2DOp.h"
26 #include "mir/ops/FullyConnectedOp.h"
27 #include "mir/ops/MaxPool2DOp.h"
28 #include "mir/ops/MulOp.h"
29 #include "mir/ops/ReluOp.h"
30 #include "mir/ops/ReshapeOp.h"
31 #include "mir/ops/ResizeOp.h"
32 #include "mir/ops/SigmoidOp.h"
33 #include "mir/ops/SoftmaxOp.h"
34 #include "mir/ops/TransposeOp.h"
35
36 #include "mir/Index.h"
37 #include "mir/Shape.h"
38 #include "mir/ShapeRange.h"
39 #include "mir/Tensor.h"
40 #include "mir/TensorUtil.h"
41
42 #include <cmath>
43 #include <stdexcept>
44 #include <vector>
45
46 namespace mir_caffe2
47 {
48
49 using namespace ::caffe2;
50 using namespace mir;
51
52 //
53 // Helper functions
54 //
55
56 static std::pair<std::vector<int32_t>, std::vector<int32_t>>
57 getPadding(const ::caffe2::OperatorDef &op)
58 {
59
60   if (hasArgument(op.arg(), "pads"))
61   {
62     // pads order: t l b r
63     auto pads_arg = findArgumentByName(op.arg(), "pads");
64
65     std::vector<int32_t> paddings;
66     for (const auto &pad : pads_arg.ints())
67       paddings.push_back(static_cast<int32_t>(pad));
68
69     assert(paddings.size() == 4);
70
71     int32_t pad_t = paddings[0];
72     int32_t pad_l = paddings[1];
73     int32_t pad_b = paddings[2];
74     int32_t pad_r = paddings[3];
75
76     std::vector<int32_t> padding_before{pad_t, pad_l};
77     std::vector<int32_t> padding_after{pad_b, pad_r};
78     return {padding_before, padding_after};
79   }
80
81   bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r") ||
82                         hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
83
84   if (has_custom_pad)
85   {
86     int32_t pad_l = getSingleArgument(op, "pad_l", 0);
87     int32_t pad_t = getSingleArgument(op, "pad_t", 0);
88     int32_t pad_r = getSingleArgument(op, "pad_r", 0);
89     int32_t pad_b = getSingleArgument(op, "pad_b", 0);
90
91     std::vector<int32_t> padding_before{pad_t, pad_l};
92     std::vector<int32_t> padding_after{pad_b, pad_r};
93     return {padding_before, padding_after};
94   }
95
96   int32_t pad = getSingleArgument(op, "pad", 0);
97   return {{pad, pad}, {pad, pad}};
98 }
99
100 static std::vector<std::int32_t> getStrides(const ::caffe2::OperatorDef &op)
101 {
102   std::vector<std::int32_t> strides;
103
104   if (hasArgument(op.arg(), "stride"))
105   {
106     std::int32_t stride = getSingleArgument(op, "stride", 1);
107     strides = {stride, stride};
108   }
109
110   if (hasArgument(op.arg(), "strides"))
111   {
112     // strides order: h w
113     auto strides_arg = findArgumentByName(op.arg(), "strides");
114     for (const auto &s : strides_arg.ints())
115       strides.push_back(s);
116   }
117
118   assert(!strides.empty() && "Strides not found");
119
120   return strides;
121 }
122
123 static std::vector<std::int32_t> getWindowSize(const ::caffe2::OperatorDef &op,
124                                                const std::vector<mir::Operation::Output *> &inputs)
125 {
126   int is_global_pooling = getSingleArgument(op, "global_pooling", 0);
127   bool has_custom_kernel_size =
128       hasArgument(op.arg(), "kernel_h") || hasArgument(op.arg(), "kernel_w");
129   bool has_custom_kernels_size = hasArgument(op.arg(), "kernels");
130
131   int kernel_h(0), kernel_w(0);
132   if (is_global_pooling)
133   {
134     const auto &input_shape = inputs[0]->getShape();
135     assert(input_shape.rank() == 4 && "getWindowSize() inputs must be of rank 4");
136     kernel_h = input_shape.dim(2);
137     kernel_w = input_shape.dim(3);
138   }
139   else
140   {
141     if (has_custom_kernel_size)
142     {
143       kernel_h = getSingleArgument(op, "kernel_h", 0);
144       kernel_w = getSingleArgument(op, "kernel_w", 0);
145     }
146     else
147     {
148       if (has_custom_kernels_size)
149       {
150         // kernels order: h w
151         std::vector<int32_t> kernels;
152         auto kernels_arg = findArgumentByName(op.arg(), "kernels");
153         for (const auto &ker : kernels_arg.ints())
154           kernels.push_back(static_cast<int32_t>(ker));
155         assert(kernels.size() == 2);
156         kernel_h = kernels[0];
157         kernel_w = kernels[1];
158       }
159       else
160       {
161         kernel_h = kernel_w = getSingleArgument(op, "kernel", 0);
162       }
163     }
164   }
165   return {kernel_h, kernel_w};
166 }
167
168 //
169 // Check functions
170 //
171
172 static void checkLayout(const OperatorDef &op)
173 {
174   if (getSingleArgument(op, "order", "NCHW") != "NCHW")
175     throw std::runtime_error(op.type() + ": only 'NCHW' axis order is supported");
176 }
177
178 static void checkConvLikeOp(const ::caffe2::OperatorDef &op)
179 {
180   checkLayout(op);
181
182   // Padding
183   bool has_custom_pad = hasArgument(op.arg(), "pad_l") || hasArgument(op.arg(), "pad_r") ||
184                         hasArgument(op.arg(), "pad_t") || hasArgument(op.arg(), "pad_b");
185
186   if (has_custom_pad && hasArgument(op.arg(), "pad"))
187     throw std::runtime_error("Custom pad can't be combined with overall pad");
188
189   if (has_custom_pad &&
190       !(hasArgument(op.arg(), "pad_l") && hasArgument(op.arg(), "pad_r") &&
191         hasArgument(op.arg(), "pad_t") && hasArgument(op.arg(), "pad_b")))
192     throw std::runtime_error("If one custom pad specified - all custom pads must be specified");
193
194   // Kernel size
195   bool has_custom_kernel_size =
196       hasArgument(op.arg(), "kernel_h") || hasArgument(op.arg(), "kernel_w");
197
198   if (has_custom_kernel_size && hasArgument(op.arg(), "kernel"))
199     throw std::runtime_error("Custom kernel size can't be combined with overall kernel size");
200
201   if (has_custom_kernel_size &&
202       !(hasArgument(op.arg(), "kernel_h") && hasArgument(op.arg(), "kernel_w")))
203     throw std::runtime_error(
204         "If one custom kernel size specified - all custom kernel sizes must be specified");
205 }
206
207 static mir::TensorVariant createTensor(const OperatorDef &op)
208 {
209   assert(hasArgument(op.arg(), "shape") && hasArgument(op.arg(), "values"));
210
211   const auto &shape = findArgumentByName(op.arg(), "shape");
212   const auto &values = findArgumentByName(op.arg(), "values");
213
214   mir::DataType element_type;
215   const void *src_data;
216   // if values on floats
217   if (!values.floats().empty())
218   {
219     element_type = mir::DataType::FLOAT32;
220     src_data = values.floats().data();
221   }
222   else
223   {
224     assert(!values.ints().empty());
225     if (op.type() == "GivenTensorInt64Fill")
226     {
227       element_type = mir::DataType::INT64;
228     }
229     else
230     {
231       element_type = mir::DataType::INT32;
232     }
233     src_data = values.ints().data();
234   }
235
236   mir::Shape tensor_shape(shape.ints_size());
237
238   for (int i = 0; i < shape.ints_size(); ++i)
239   {
240     tensor_shape.dim(i) = shape.ints(i);
241   }
242
243   return mir::TensorVariant({element_type, tensor_shape}, src_data);
244 }
245
246 //
247 // Convert functions
248 //
249
250 std::vector<mir::Operation::Output *>
251 Caffe2OpCreator::convertConstant(const std::vector<mir::Operation::Output *> &,
252                                  const ::caffe2::OperatorDef &op)
253 {
254   // Constant may not contain any data if it is a fake input.
255   if (!hasArgument(op.arg(), "values"))
256     return {};
257
258   return {createOp<ops::ConstantOp>(createTensor(op))->getOutput(0)};
259 }
260
261 std::vector<mir::Operation::Output *>
262 Caffe2OpCreator::convertAdd(const std::vector<mir::Operation::Output *> &inputs,
263                             const ::caffe2::OperatorDef &op)
264 {
265   assert(inputs.size() == 2);
266   auto lhs = inputs[0];
267   auto rhs = inputs[1];
268
269   if (getSingleArgument(op, "broadcast", 0) != 0)
270   {
271     // FIXME This only works when 'axis' == 1 and the second input is 1-D.
272     rhs = createOp<ops::ReshapeOp>(rhs, Shape{1, rhs->getShape().dim(0), 1, 1})->getOutput(0);
273     auto result = createOp<ops::AddOp>(lhs, rhs)->getOutput(0);
274     return {result};
275   }
276
277   auto result = createOp<ops::AddOp>(lhs, rhs)->getOutput(0);
278   return {result};
279 }
280
281 std::vector<mir::Operation::Output *>
282 Caffe2OpCreator::convertAveragePool(const std::vector<mir::Operation::Output *> &inputs,
283                                     const OperatorDef &op)
284 {
285   checkConvLikeOp(op);
286
287   assert(inputs.size() == 1);
288   auto input = inputs[0];
289
290   AvgPool2DOpAttributes attributes;
291   std::tie(attributes.padding_before, attributes.padding_after) = getPadding(op);
292   attributes.window = getWindowSize(op, inputs);
293   attributes.strides = getStrides(op);
294   attributes.include_pad = false;
295   attributes.data_format = DataFormat::NCHW;
296   auto result = createOp<ops::AvgPool2DOp>(input, attributes)->getOutput(0);
297   return {result};
298 }
299
300 std::vector<mir::Operation::Output *>
301 Caffe2OpCreator::convertConv(const std::vector<mir::Operation::Output *> &inputs,
302                              const ::caffe2::OperatorDef &op)
303 {
304   // dilation order: h w (not used)
305   mir::Conv2DOpAttributes attributes;
306   attributes.strides = getStrides(op);
307   std::tie(attributes.padding_before, attributes.padding_after) = getPadding(op);
308   attributes.num_groups = getSingleArgument(op, "group", 1);
309   attributes.data_format = DataFormat::NCHW;
310
311   std::vector<std::size_t> perm{0, 2, 3, 1}; // OIHW -> OHWI
312   auto kernel = createOp<ops::TransposeOp>(inputs[1], perm)->getOutput(0);
313   auto result = createOp<ops::Conv2DOp>(inputs[0], kernel, attributes)->getOutput(0);
314
315   if (op.input_size() > 2)
316   {
317     auto bias = inputs[2];
318     bias = createOp<ops::ReshapeOp>(bias, Shape{1, bias->getShape().dim(0), 1, 1})->getOutput(0);
319     result = createOp<ops::AddOp>(result, bias)->getOutput(0);
320   }
321
322   return {result};
323 }
324
325 std::vector<mir::Operation::Output *>
326 Caffe2OpCreator::convertConcat(const std::vector<mir::Operation::Output *> &inputs,
327                                const ::caffe2::OperatorDef &op)
328 {
329   checkLayout(op);
330
331   // `1` corresponds to the default (channels) axis.
332   int axis = getSingleArgument(op, "axis", 1);
333   auto result = createOp<ops::ConcatOp>(inputs, axis);
334   return {result->getOutput(0)};
335 }
336
337 std::vector<mir::Operation::Output *>
338 Caffe2OpCreator::convertDropout(const std::vector<mir::Operation::Output *> &inputs,
339                                 const ::caffe2::OperatorDef &)
340 {
341   // This is a no-op in inference mode.
342   return {inputs[0]};
343 }
344
345 std::vector<mir::Operation::Output *>
346 Caffe2OpCreator::convertFC(const std::vector<mir::Operation::Output *> &inputs,
347                            const ::caffe2::OperatorDef &op)
348 {
349   for (auto &s : {"axis", "axis_w", "float16_compute"})
350     if (hasArgument(op.arg(), s))
351       throw std::runtime_error(std::string("FC: only default '") + s + "' value is supported");
352
353   const auto &input_shape = inputs[0]->getShape();
354   // Transform input into 2-D tensor by flattening axes
355   Shape shape{input_shape.dim(0), input_shape.numElements() / input_shape.dim(0)};
356
357   auto reshape = createOp<ops::ReshapeOp>(inputs[0], shape)->getOutput(0);
358   auto weights =
359       createOp<ops::TransposeOp>(inputs[1], std::vector<std::size_t>{1, 0})->getOutput(0);
360   auto result = createOp<ops::FullyConnectedOp>(reshape, weights)->getOutput(0);
361   result = createOp<ops::AddOp>(result, inputs[2])->getOutput(0);
362
363   return {result};
364 }
365
366 std::vector<mir::Operation::Output *>
367 Caffe2OpCreator::convertMaxPool(const std::vector<mir::Operation::Output *> &inputs,
368                                 const OperatorDef &op)
369 {
370   checkConvLikeOp(op);
371
372   assert(inputs.size() == 1);
373   auto input = inputs[0];
374
375   MaxPool2DOpAttributes attributes;
376   std::tie(attributes.padding_before, attributes.padding_after) = getPadding(op);
377   attributes.window = getWindowSize(op, inputs);
378   attributes.strides = getStrides(op);
379   attributes.data_format = DataFormat::NCHW;
380   auto result = createOp<ops::MaxPool2DOp>(input, attributes)->getOutput(0);
381   return {result};
382 }
383
384 std::vector<mir::Operation::Output *>
385 Caffe2OpCreator::convertMul(const std::vector<mir::Operation::Output *> &inputs,
386                             const ::caffe2::OperatorDef &op)
387 {
388   assert(inputs.size() == 2);
389   auto lhs = inputs[0];
390   auto rhs = inputs[1];
391
392   if (getSingleArgument(op, "broadcast", 0) != 0)
393   {
394     // FIXME This only works when `axis` == 1 and the second input is 1-D.
395     rhs = createOp<ops::ReshapeOp>(rhs, Shape{1, rhs->getShape().dim(0), 1, 1})->getOutput(0);
396     auto result = createOp<ops::MulOp>(lhs, rhs)->getOutput(0);
397     return {result};
398   }
399
400   auto result = createOp<ops::MulOp>(lhs, rhs)->getOutput(0);
401   return {result};
402 }
403
404 std::vector<mir::Operation::Output *>
405 Caffe2OpCreator::convertRelu(const std::vector<mir::Operation::Output *> &inputs)
406 {
407   auto relu = createOp<ops::ReluOp>(inputs[0]);
408   return {relu->getOutput(0)};
409 }
410
411 std::vector<mir::Operation::Output *>
412 Caffe2OpCreator::convertResizeNearest(const std::vector<mir::Operation::Output *> &inputs,
413                                       const ::caffe2::OperatorDef &op)
414 {
415   std::vector<float> scales(4);
416   assert(inputs[0]->getShape().rank() == 4 && "only 4d tensors is supported");
417   // Assuming NCHW format.
418   scales[0] = 1.0f;
419   scales[1] = 1.0f;
420   scales[2] = getSingleArgument(op, "height_scale", 1.0f);
421   scales[3] = getSingleArgument(op, "width_scale", 1.0f);
422   auto result =
423       createOp<ops::ResizeOp>(inputs[0], ops::ResizeOp::ResizeMethod::nearestNeighbor, scales)
424           ->getOutput(0);
425   return {result};
426 }
427
428 std::vector<mir::Operation::Output *>
429 Caffe2OpCreator::convertSigmoid(const std::vector<mir::Operation::Output *> &inputs)
430 {
431   auto result = createOp<ops::SigmoidOp>(inputs[0]);
432   return {result->getOutput(0)};
433 }
434
435 std::vector<mir::Operation::Output *>
436 Caffe2OpCreator::convertSoftmax(const std::vector<mir::Operation::Output *> &inputs,
437                                 const ::caffe2::OperatorDef &op)
438 {
439   int axis = getSingleArgument(op, "axis", 1);
440   auto softmax = createOp<ops::SoftmaxOp>(inputs[0], axis);
441   return {softmax->getOutput(0)};
442 }
443
444 std::vector<mir::Operation::Output *>
445 Caffe2OpCreator::convertSpatialBN(const std::vector<mir::Operation::Output *> &inputs,
446                                   const ::caffe2::OperatorDef &op)
447 {
448   checkLayout(op);
449
450   // Sanity checks
451   if (op.input_size() != 5)
452     throw std::runtime_error(
453         "SpatialBN must have exactly 5 inputs ('sums' and 'sumsq' are not supported yet)");
454   if (getSingleArgument(op, "is_test", 1) != 1)
455     throw std::runtime_error("SpatialBN: only test mode supported");
456
457   // overall_res = (X - mean) / sqrt(var + epsilon) * scale + bias
458
459   auto scale_op = dynamic_cast<mir::ops::ConstantOp *>(inputs[1]->getNode());
460   auto bias_op = dynamic_cast<mir::ops::ConstantOp *>(inputs[2]->getNode());
461   auto mean_op = dynamic_cast<mir::ops::ConstantOp *>(inputs[3]->getNode());
462   auto var_op = dynamic_cast<mir::ops::ConstantOp *>(inputs[4]->getNode());
463   if (scale_op == nullptr || bias_op == nullptr || mean_op == nullptr || var_op == nullptr)
464     throw std::runtime_error(
465         "SpatialBN: non-constant 'scale', 'bias', 'mean' and 'var' inputs are not supported yet.");
466
467   const auto &scale_tensor = scale_op->getValue();
468   const auto &bias_tensor = bias_op->getValue();
469   const auto &mean_tensor = mean_op->getValue();
470   const auto &var_tensor = var_op->getValue();
471   float eps = getSingleArgument(op, "epsilon", 1e-5f);
472
473   // res1 = X - mean
474   Tensor<float> bias_data(mean_tensor);
475   for (auto &idx : ShapeRange(bias_data.getShape()))
476     bias_data.at(idx) *= -1;
477
478   auto mean = createOp<ops::ConstantOp>(mean_tensor)->getOutput(0);
479   mean = createOp<ops::ReshapeOp>(mean, Shape{1, mean->getShape().dim(0), 1, 1})->getOutput(0);
480   auto result = createOp<ops::AddOp>(inputs[0], mean)->getOutput(0);
481
482   // res2 = res1 * scale / (var + epsilon)
483   Tensor<float> multiplier(scale_tensor);
484   for (auto &idx : ShapeRange(scale_tensor.getShape()))
485     multiplier.at(idx) /= std::sqrt(*reinterpret_cast<float *>(var_tensor.at(idx)) + eps);
486   auto scale = createOp<ops::ConstantOp>(scale_tensor)->getOutput(0);
487   scale = createOp<ops::ReshapeOp>(scale, Shape{1, scale->getShape().dim(0), 1, 1})->getOutput(0);
488   result = createOp<ops::MulOp>(result, scale)->getOutput(0);
489
490   // overall_res = res2 + bias
491   auto bias = createOp<ops::ConstantOp>(bias_tensor)->getOutput(0);
492   bias = createOp<ops::ReshapeOp>(bias, Shape{1, bias->getShape().dim(0), 1, 1})->getOutput(0);
493   result = createOp<ops::AddOp>(result, bias)->getOutput(0);
494
495   return {result};
496 }
497
498 std::vector<mir::Operation::Output *>
499 Caffe2OpCreator::convertSum(const std::vector<mir::Operation::Output *> &inputs)
500 {
501   auto result = createOp<ops::AddOp>(inputs[0], inputs[1])->getOutput(0);
502   for (int i = 2; i < static_cast<int>(inputs.size()); ++i)
503   {
504     result = createOp<ops::AddOp>(result, inputs[i])->getOutput(0);
505   }
506   return {result};
507 }
508
509 std::vector<mir::Operation::Output *>
510 Caffe2OpCreator::convertClip(const std::vector<mir::Operation::Output *> &inputs,
511                              const ::caffe2::OperatorDef &op)
512 {
513
514   float max = getSingleArgument(op, "max", float(0));
515   float min = getSingleArgument(op, "min", float(0));
516
517   if (min != 0.0f)
518     throw std::runtime_error("Clip: min != 0 is not supported.");
519   if (max <= min)
520     throw std::runtime_error("Clip: max <= min is not supported.");
521   auto cap_relu = createOp<ops::CappedReluOp>(inputs[0], max);
522
523   return {cap_relu->getOutput(0)};
524 }
525
526 std::vector<mir::Operation::Output *>
527 Caffe2OpCreator::convertReshape(const std::vector<mir::Operation::Output *> &inputs,
528                                 const ::caffe2::OperatorDef &)
529 {
530   auto shape_op = dynamic_cast<mir::ops::ConstantOp *>(inputs[1]->getNode());
531   if (shape_op == nullptr)
532     throw std::runtime_error("Reshape: non-constant shape is not supported yet.");
533
534   const auto &shape_tensor = shape_op->getValue();
535
536   Tensor<int64_t> out_shape_tensor(shape_tensor);
537
538   ShapeRange range(out_shape_tensor.getShape());
539   std::vector<int32_t> shape_vec;
540   for (const auto &index : range)
541   {
542     shape_vec.push_back(static_cast<int32_t>(out_shape_tensor.at(index)));
543   }
544   Shape out_shape(shape_vec);
545
546   auto reshape = createOp<ops::ReshapeOp>(inputs[0], out_shape);
547
548   return {reshape->getOutput(0)};
549 }
550
551 } // namespace mir_caffe2