case ONNXOpCode::opMul:
case ONNXOpCode::opRelu:
case ONNXOpCode::opReshape:
+ case ONNXOpCode::opUnsqueeze:
case ONNXOpCode::opSigmoid:
case ONNXOpCode::opScale:
case ONNXOpCode::opSoftmax:
case ONNXOpCode::opRelu:
outputs = _opCreator.convertRelu(input_nodes);
break;
+ case ONNXOpCode::opUnsqueeze:
+ outputs = _opCreator.convertUnsqueeze(input_nodes[0], onnx_node);
+ break;
case ONNXOpCode::opSigmoid:
outputs = _opCreator.convertSigmoid(input_nodes);
break;
return outputs;
}
+std::vector<mir::Operation*>
+ONNXOpCreator::convertUnsqueeze(Operation* input_data, const onnx::NodeProto& onnx_node) {
+ auto* axes = findAttribute(onnx_node, "axes");
+ assert(axes && axes->ints_size());
+ const int out_rank = input_data->getOutputShape(0).rank() + axes->ints_size();
+ Shape out_shape(out_rank);
+ const Shape& input_shape = input_data->getOutputShape(0);
+ auto ints_iterator = axes->ints().begin();
+ int j = 0;
+ for (int i = 0; i < out_rank; i++) {
+ if (ints_iterator < axes->ints().end() && i == *ints_iterator) {
+ out_shape.dim(i) = 1;
+ ints_iterator++;
+ } else {
+ out_shape.dim(i) = input_shape.dim(j);
+ j++;
+ }
+ }
+ auto outputs = createOp<ops::ReshapeOp>(input_data->getOutput(0), out_shape);
+ return outputs;
+}
+
std::vector<Operation*> ONNXOpCreator::convertRelu(InputOps& inputs) {
assert(inputs.size() == 1);
return createOp<ops::ReluOp>(inputs[0]->getOutput(0));
std::vector<mir::Operation*> convertReshape(mir::Operation* input_data, mir::Shape output_shape);
std::vector<mir::Operation*> convertRelu(InputOps& inputs);
std::vector<mir::Operation*> convertSigmoid(InputOps& inputs);
+
+ std::vector<mir::Operation*>
+ convertUnsqueeze(mir::Operation* inputs, const onnx::NodeProto& onnx_node);
std::vector<mir::Operation*> convertElementwise(InputOps& inputs,
mir::ops::ElementwiseOp::OpType op_type);
std::vector<mir::Operation*> convertScale(InputOps& inputs, const onnx::NodeProto& node);
ASSERT_EQ(result_shape_expand, op->getOutputShape(0));
}
+TEST(ShapeInferenceTest, ReshapeAutoDimensionUnsqueeze) {
+ Graph g;
+
+ Shape input_shape{10, 2, 10};
+ Shape result_shape_expand{1, 10, 2, 1, 10, 1};
+
+ auto input = g.create<ops::VariableOp>("input", input_shape);
+ auto op = g.create<ops::ReshapeOp>("reshape", input->getOutput(0),
+ Shape{1, Shape::autoDim, 2, 1, 10, 1});
+
+ ASSERT_EQ(result_shape_expand, op->getOutputShape(0));
+}
+
TEST(ShapeInferenceTest, SqueezeTestAllDims) {
Graph g;