2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/MigrateLegacyShapeDtypePass.h"
19 #include <loco/Service/ShapeInference.h>
20 #include <loco/Service/TypeInference.h>
22 #include <luci/IR/CircleNodes.h>
29 bool has_same_shape(luci::CircleNode *node, loco::TensorShape shape)
31 if (node->rank() != shape.rank())
34 for (uint32_t i = 0; i < shape.rank(); ++i)
35 if (!(node->dim(i) == shape.dim(i)))
46 bool MigrateLegacyShapeDtypePass::run(luci::Module *m)
50 for (size_t g = 0; g < m->size(); ++g)
59 bool MigrateLegacyShapeDtypePass::run(loco::Graph *g)
63 for (auto node : loco::all_nodes(g))
65 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
66 if (loco::shape_known(node))
68 auto loco_shape = loco::shape_get(node).as<loco::TensorShape>();
70 assert(circle_node->shape_signature().rank() == 0 ||
71 circle_node->shape_signature().rank() == loco_shape.rank());
73 // When shape of loco is copied to circle node, ShapeSignature should be applied.
74 loco::TensorShape new_shape;
75 new_shape.rank(loco_shape.rank());
76 for (uint32_t i = 0; i < loco_shape.rank(); ++i)
78 if (circle_node->shape_signature().rank() > 0 &&
79 circle_node->shape_signature().dim(i) == -1)
82 new_shape.dim(i) = loco_shape.dim(i);
85 if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED ||
86 !has_same_shape(circle_node, new_shape))
88 circle_node->rank(new_shape.rank());
89 for (uint32_t i = 0; i < new_shape.rank(); ++i)
90 circle_node->dim(i) = new_shape.dim(i);
92 if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED)
93 circle_node->shape_status(luci::ShapeStatus::VALID);
99 if (loco::dtype_known(node))
101 if (loco::dtype_get(node) != circle_node->dtype())
103 circle_node->dtype(loco::dtype_get(node));