2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Service/ShapeDescription.h"
19 #include <oops/InternalExn.h>
26 ShapeDescription to_shape_description(const loco::TensorShape &shape)
30 res._rank_known = true;
32 res._dims.resize(shape.rank());
33 for (uint32_t axis = 0; axis < shape.rank(); ++axis)
35 // All the dimensions SHOULD be known
36 assert(shape.dim(axis).known());
37 res._dims.at(axis) = shape.dim(axis).value();
43 ShapeDescription to_shape_description(const loco::FeatureShape &shape)
47 res._rank_known = true;
49 // T/F Lite encodes a feature map as a NHWC tensor
51 res._dims.at(0) = shape.count().value();
52 res._dims.at(1) = shape.height().value();
53 res._dims.at(2) = shape.width().value();
54 res._dims.at(3) = shape.depth().value();
59 ShapeDescription to_shape_description(const loco::FilterShape &shape)
63 res._rank_known = true;
65 // T/F Lite encodes a convolution filter as a NHWC tensor
67 res._dims.at(0) = shape.count().value();
68 res._dims.at(1) = shape.height().value();
69 res._dims.at(2) = shape.width().value();
70 res._dims.at(3) = shape.depth().value();
75 ShapeDescription to_shape_description(const loco::DepthwiseFilterShape &shape)
79 res._rank_known = true;
81 // T/F Lite encodes a depthwise convolution filter as a [1, H, W, C*M] tensor
84 res._dims.at(1) = shape.height().value();
85 res._dims.at(2) = shape.width().value();
86 res._dims.at(3) = shape.depth().value() * shape.multiplier().value();
91 ShapeDescription to_shape_description(const loco::BiasShape &shape)
95 res._rank_known = true;
98 res._dims.at(0) = shape.length().value();
103 ShapeDescription to_shape_description(const loco::MatrixShape &shape)
105 ShapeDescription res;
107 res._rank_known = true;
110 res._dims.at(0) = shape.height().value();
111 res._dims.at(1) = shape.width().value();
116 ShapeDescription to_shape_description(const loco::NodeShape &shape)
118 switch (shape.domain())
120 case loco::Domain::Tensor:
121 return to_shape_description(shape.as<loco::TensorShape>());
122 case loco::Domain::Feature:
123 return to_shape_description(shape.as<loco::FeatureShape>());
124 case loco::Domain::Filter:
125 return to_shape_description(shape.as<loco::FilterShape>());
126 case loco::Domain::DepthwiseFilter:
127 return to_shape_description(shape.as<loco::DepthwiseFilterShape>());
128 case loco::Domain::Bias:
129 return to_shape_description(shape.as<loco::BiasShape>());
130 case loco::Domain::Matrix:
131 return to_shape_description(shape.as<loco::MatrixShape>());
136 INTERNAL_EXN_V("Unsupported loco domain", oops::to_uint32(shape.domain()));