out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 50});
}
+TEST(BoundShapeInference, Reshape) {
+ NetDef net;
+ std::vector<int> new_shape{-1, 8};
+ std::vector<int> new_shape2{2, 8};
+ net.add_op()->CopyFrom(
+ CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"X1"}, {}));
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "Reshape",
+ "",
+ {"X1"},
+ {"Y1", "old_shape"},
+ {MakeArgument<std::vector<int>>("shape", new_shape)}));
+
+ // Cannot infer shape for this one because input/output shape doesn't match
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "Reshape",
+ "",
+ {"X1"},
+ {"Y2", "old_shape2"},
+ {MakeArgument<std::vector<int>>("shape", new_shape2)}));
+ ShapeInfoMap shape_map;
+ shape_map.emplace(
+ "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+ shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+ BoundShapeSpec spec(20, 1000);
+ BoundShapeInferencer eng(spec);
+ eng.InferBoundShapeAndType(net, shape_map);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
+ verifyShapeInfo(
+ out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
+ verifyShapeInfo(
+ out_shape,
+ "Y1",
+ ShapeInfo::DimType::BATCH,
+ {spec.max_batch_size * 16 / 8, 8});
+ EXPECT_TRUE(out_shape.find("Y2") == out_shape.end());
+}
+
TEST(BoundShapeInference, ConcatMissingInput) {
NetDef net;
net.add_op()->CopyFrom(CreateOperatorDef(
InferFC(op);
} else if (op.type() == "Concat") {
InferConcat(op);
+ } else if (op.type() == "Reshape") {
+ InferReshape(op);
} else if (op.type() == "LengthsRangeFill") {
InferLengthsRangeFill(op);
} else {
TensorProto_DataType_FLOAT);
}
+void BoundShapeInferencer::InferReshape(const OperatorDef& op) {
+ InferCommonOp(op);
+ // old_shape should be a constant
+ if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
+ shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
+ }
+}
// For concat net, if some inputs are missing and we have add_axis argument, it
// means that all the inputs should be of the same dimension. In this case, we
// can infer the shape of the missing inputs
}
InferCommonOp(op);
// split_info should be a constant
- if (op.output_size() > 1) {
+ if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
}
}
const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
CAFFE_ENFORCE(schema);
- auto output_shapes = schema->InferTensor(op, input_shapes);
+ std::vector<TensorShape> output_shapes;
+ try {
+ output_shapes = schema->InferTensor(op, input_shapes);
+ } catch (const std::exception& e) {
+ LOG(WARNING) << "Caught exception while inferring shapes for " << op.type();
+ }
int i = 0;
for (const auto& shape : output_shapes) {
if (shape.unknown_shape()) {
void InferSparseLengthsSum(const OperatorDef& op);
void InferFC(const OperatorDef& op);
void InferConcat(const OperatorDef& op);
+ void InferReshape(const OperatorDef& op);
void InferLengthsRangeFill(const OperatorDef& op);
// Standard shape/type inference using op schema registered shape inference
// We already have all the ops and external inputs and outputs!
NetDef onnxifi_net(net);
- // Remove the second output of Concat from external_output. In addition, we
- // remove those outputs from the Onnxifi op too.
+ // Remove the second output of Concat/Reshape from external_output. In
+ // addition, we remove those outputs from the Onnxifi op too.
// TODO: This approach is a bit hacky as we assume that the second output is
// never used. A more appropriate approach can be learned from the ONNX path,
// where we statically computes the split_info given input shape and insert a
// GivenTensorIntFill op
std::unordered_set<std::string> split_infos;
for (auto& op : *onnxifi_net.mutable_op()) {
- if (op.type() == "Concat" && op.output_size() == 2) {
+ if ((op.type() == "Concat" || op.type() == "Reshape") &&
+ op.output_size() == 2) {
split_infos.emplace(op.output(1));
}
}
for (const auto& o : op.output()) {
net.add_external_output(o);
}
- // Remove the second output of Concat from the external_output
- if (op.type() == "Concat" && op.output_size() == 2) {
+ // Remove the second output of Concat/Reshape from the external_output
+ if ((op.type() == "Concat" || op.type() == "Reshape") &&
+ op.output_size() == 2) {
net.mutable_external_output()->RemoveLast();
}