cbc302f7040590f08cc12764f052bd2c8df9c164
[platform/core/ml/nnfw.git] / compiler / luci / service / src / ShapeDescription.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/ShapeDescription.h"
18
19 #include <oops/InternalExn.h>
20
21 #include <cassert>
22
23 namespace luci
24 {
25
26 ShapeDescription to_shape_description(const loco::TensorShape &shape)
27 {
28   ShapeDescription res;
29
30   res._rank_known = true;
31
32   res._dims.resize(shape.rank());
33   for (uint32_t axis = 0; axis < shape.rank(); ++axis)
34   {
35     // All the dimensions SHOULD be known
36     assert(shape.dim(axis).known());
37     res._dims.at(axis) = shape.dim(axis).value();
38   }
39
40   return res;
41 }
42
43 ShapeDescription to_shape_description(const loco::FeatureShape &shape)
44 {
45   ShapeDescription res;
46
47   res._rank_known = true;
48
49   // T/F Lite encodes a feature map as a NHWC tensor
50   res._dims.resize(4);
51   res._dims.at(0) = shape.count().value();
52   res._dims.at(1) = shape.height().value();
53   res._dims.at(2) = shape.width().value();
54   res._dims.at(3) = shape.depth().value();
55
56   return res;
57 }
58
59 ShapeDescription to_shape_description(const loco::FilterShape &shape)
60 {
61   ShapeDescription res;
62
63   res._rank_known = true;
64
65   // T/F Lite encodes a convolution filter as a NHWC tensor
66   res._dims.resize(4);
67   res._dims.at(0) = shape.count().value();
68   res._dims.at(1) = shape.height().value();
69   res._dims.at(2) = shape.width().value();
70   res._dims.at(3) = shape.depth().value();
71
72   return res;
73 }
74
75 ShapeDescription to_shape_description(const loco::DepthwiseFilterShape &shape)
76 {
77   ShapeDescription res;
78
79   res._rank_known = true;
80
81   // T/F Lite encodes a depthwise convolution filter as a [1, H, W, C*M] tensor
82   res._dims.resize(4);
83   res._dims.at(0) = 1;
84   res._dims.at(1) = shape.height().value();
85   res._dims.at(2) = shape.width().value();
86   res._dims.at(3) = shape.depth().value() * shape.multiplier().value();
87
88   return res;
89 }
90
91 ShapeDescription to_shape_description(const loco::BiasShape &shape)
92 {
93   ShapeDescription res;
94
95   res._rank_known = true;
96
97   res._dims.resize(1);
98   res._dims.at(0) = shape.length().value();
99
100   return res;
101 }
102
103 ShapeDescription to_shape_description(const loco::MatrixShape &shape)
104 {
105   ShapeDescription res;
106
107   res._rank_known = true;
108
109   res._dims.resize(2);
110   res._dims.at(0) = shape.height().value();
111   res._dims.at(1) = shape.width().value();
112
113   return res;
114 }
115
116 ShapeDescription to_shape_description(const loco::NodeShape &shape)
117 {
118   switch (shape.domain())
119   {
120     case loco::Domain::Tensor:
121       return to_shape_description(shape.as<loco::TensorShape>());
122     case loco::Domain::Feature:
123       return to_shape_description(shape.as<loco::FeatureShape>());
124     case loco::Domain::Filter:
125       return to_shape_description(shape.as<loco::FilterShape>());
126     case loco::Domain::DepthwiseFilter:
127       return to_shape_description(shape.as<loco::DepthwiseFilterShape>());
128     case loco::Domain::Bias:
129       return to_shape_description(shape.as<loco::BiasShape>());
130     case loco::Domain::Matrix:
131       return to_shape_description(shape.as<loco::MatrixShape>());
132     default:
133       break;
134   }
135
136   INTERNAL_EXN_V("Unsupported loco domain", oops::to_uint32(shape.domain()));
137 }
138
139 } // namespace luci