Remove second output of Reshape during ONNXIFI transform (#17027)
authorYinghai Lu <yinghai@fb.com>
Wed, 13 Feb 2019 02:21:09 +0000 (18:21 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Feb 2019 02:31:53 +0000 (18:31 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17027

Glow doesn't support second output of Reshape right now and it's useless. For correctness, we do make sure that the second output of Reshape is of Constant type during bound shape inference.

Reviewed By: ipiszy

Differential Revision: D14056555

fbshipit-source-id: f39cca7ba941bf5a5cc3adc96e2b1f943cc0be93

caffe2/opt/bound_shape_inference_test.cc
caffe2/opt/bound_shape_inferencer.cc
caffe2/opt/bound_shape_inferencer.h
caffe2/opt/onnxifi_transformer.cc

index b07dbca..40e99e5 100644 (file)
@@ -110,6 +110,46 @@ TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) {
       out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 50});
 }
 
+TEST(BoundShapeInference, Reshape) {
+  NetDef net;
+  std::vector<int> new_shape{-1, 8};
+  std::vector<int> new_shape2{2, 8};
+  net.add_op()->CopyFrom(
+      CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"X1"}, {}));
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Reshape",
+      "",
+      {"X1"},
+      {"Y1", "old_shape"},
+      {MakeArgument<std::vector<int>>("shape", new_shape)}));
+
+  // Cannot infer shape for this one because input/output shape doesn't match
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Reshape",
+      "",
+      {"X1"},
+      {"Y2", "old_shape2"},
+      {MakeArgument<std::vector<int>>("shape", new_shape2)}));
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+  shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+  BoundShapeSpec spec(20, 1000);
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  verifyShapeInfo(
+      out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
+  verifyShapeInfo(
+      out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
+  verifyShapeInfo(
+      out_shape,
+      "Y1",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size * 16 / 8, 8});
+  EXPECT_TRUE(out_shape.find("Y2") == out_shape.end());
+}
+
 TEST(BoundShapeInference, ConcatMissingInput) {
   NetDef net;
   net.add_op()->CopyFrom(CreateOperatorDef(
index adfff97..cb3693a 100644 (file)
@@ -56,6 +56,8 @@ void BoundShapeInferencer::InferBoundShapeAndType(
       InferFC(op);
     } else if (op.type() == "Concat") {
       InferConcat(op);
+    } else if (op.type() == "Reshape") {
+      InferReshape(op);
     } else if (op.type() == "LengthsRangeFill") {
       InferLengthsRangeFill(op);
     } else {
@@ -198,6 +200,13 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
       TensorProto_DataType_FLOAT);
 }
 
+void BoundShapeInferencer::InferReshape(const OperatorDef& op) {
+  InferCommonOp(op);
+  // old_shape should be a constant
+  if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
+    shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
+  }
+}
 // For concat net, if some inputs are missing and we have add_axis argument, it
 // means that all the inputs should be of the same dimension. In this case, we
 // can infer the shape of the missing inputs
@@ -253,7 +262,7 @@ void BoundShapeInferencer::InferConcat(const OperatorDef& op) {
   }
   InferCommonOp(op);
   // split_info should be a constant
-  if (op.output_size() > 1) {
+  if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
     shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
   }
 }
@@ -345,7 +354,12 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
 
   const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
   CAFFE_ENFORCE(schema);
-  auto output_shapes = schema->InferTensor(op, input_shapes);
+  std::vector<TensorShape> output_shapes;
+  try {
+    output_shapes = schema->InferTensor(op, input_shapes);
+  } catch (const std::exception& e) {
+    LOG(WARNING) << "Caught exception while inferring shapes for " << op.type();
+  }
   int i = 0;
   for (const auto& shape : output_shapes) {
     if (shape.unknown_shape()) {
index e42472b..dafac5b 100644 (file)
@@ -66,6 +66,7 @@ class CAFFE2_API BoundShapeInferencer {
   void InferSparseLengthsSum(const OperatorDef& op);
   void InferFC(const OperatorDef& op);
   void InferConcat(const OperatorDef& op);
+  void InferReshape(const OperatorDef& op);
   void InferLengthsRangeFill(const OperatorDef& op);
 
   // Standard shape/type inference using op schema registered shape inference
index 344f58a..3079d26 100644 (file)
@@ -489,15 +489,16 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOpViaC2(
   // We already have all the ops and external inputs and outputs!
   NetDef onnxifi_net(net);
 
-  // Remove the second output of Concat from external_output. In addition, we
-  // remove those outputs from the Onnxifi op too.
+  // Remove the second output of Concat/Reshape from external_output. In
+  // addition, we remove those outputs from the Onnxifi op too.
   // TODO: This approach is a bit hacky as we assume that the second output is
   // never used. A more appropriate approach can be learned from the ONNX path,
   // where we statically computes the split_info given input shape and insert a
   // GivenTensorIntFill op
   std::unordered_set<std::string> split_infos;
   for (auto& op : *onnxifi_net.mutable_op()) {
-    if (op.type() == "Concat" && op.output_size() == 2) {
+    if ((op.type() == "Concat" || op.type() == "Reshape") &&
+        op.output_size() == 2) {
       split_infos.emplace(op.output(1));
     }
   }
@@ -802,8 +803,9 @@ NetDef OnnxifiTransformer::TransformViaC2(
       for (const auto& o : op.output()) {
         net.add_external_output(o);
       }
-      // Remove the second output of Concat from the external_output
-      if (op.type() == "Concat" && op.output_size() == 2) {
+      // Remove the second output of Concat/Reshape from the external_output
+      if ((op.type() == "Concat" || op.type() == "Reshape") &&
+          op.output_size() == 2) {
         net.mutable_external_output()->RemoveLast();
       }