Blacklist certain op types when doing bound shape inference (#18290)
authorYinghai Lu <yinghai@fb.com>
Thu, 21 Mar 2019 22:28:20 +0000 (15:28 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Mar 2019 22:43:05 +0000 (15:43 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18290

Some such as `Tile` will mess up our tracking of batch size and for now it makes sense to stop the shape inference on these ops so that we don't lower it and downstream ops without proper batch info.

Reviewed By: zrphercule

Differential Revision: D14463550

fbshipit-source-id: 2792481efa540f2a7dd310e677c213860c3053ca

caffe2/opt/bound_shape_inference_test.cc
caffe2/opt/bound_shape_inferencer.cc

index ee59fe7..d8f77cf 100644 (file)
@@ -244,8 +244,7 @@ TEST(BoundShapeInference, FC) {
       {spec.max_batch_size, 1024});
 }
 
-// We don't support inference input shape when Weight is not 2D
-TEST(BoundShapeInference, UnsupportedFC) {
+TEST(BoundShapeInference, FC3D) {
   NetDef net;
   net.add_op()->CopyFrom(
       CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
@@ -255,7 +254,12 @@ TEST(BoundShapeInference, UnsupportedFC) {
   shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
   BoundShapeSpec spec(20, 1000);
   BoundShapeInferencer eng(spec);
-  EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  verifyShapeInfo(
+      out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
+  verifyShapeInfo(
+      out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
 }
 
 TEST(BoundShapeInference, Combo0) {
index dea1eeb..717f843 100644 (file)
@@ -44,10 +44,15 @@ void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) {
 void BoundShapeInferencer::InferBoundShapeAndType(
     const NetDef& net,
     const std::unordered_map<std::string, ShapeInfo>& info) {
+  const static std::unordered_set<std::string> unsupported{"Tile"};
   shape_info_ = info;
 
   for (const auto& op : net.op()) {
     VLOG(1) << op.type();
+    if (unsupported.count(op.type())) {
+      continue;
+    }
+
     if (op.type() == "SparseLengthsSum" ||
         op.type() == "SparseLengthsSumFused8BitRowwise" ||
         op.type() == "SparseLengthsWeightedSum" ||
@@ -316,34 +321,22 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
     ArgumentHelper helper(op);
     auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
     auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
-    CAFFE_ENFORCE_EQ(
-        axis,
-        1,
-        "Don't know how to deduce input of FC with axis not equal to 1: ",
-        op.input(0));
-    CAFFE_ENFORCE_EQ(
-        axis_w,
-        1,
-        "Don't know how to deduce input of FC with axis_w not equal to 1: ",
-        op.input(0));
     const TensorShape w_shape = w_shape_info.shape;
-    CAFFE_ENFORCE_EQ(
-        w_shape.dims_size(),
-        2,
-        "Don't know how to deduce input of FC other than of dim size 2: ",
-        op.input(0));
     bool transposed = (op.type() == "FC") ? false : true;
     const int canonical_axis_w =
         canonical_axis_index_(axis_w, w_shape.dims().size());
     const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w)
                                  : SizeFromDim(w_shape, canonical_axis_w);
+    std::vector<int64_t> dims;
+    for (int i = 0; i < axis - 1; ++i) {
+      dims.push_back(1);
+    }
+    dims.push_back(spec_.max_batch_size);
+    dims.push_back(K);
     current_dim_type_ = ShapeInfo::DimType::BATCH;
     current_max_batch_size_ = spec_.max_batch_size;
     CheckAndSetTensorShapeAndType(
-        op.input(0),
-        ShapeInfo::DimType::BATCH,
-        {spec_.max_batch_size, K},
-        w_shape.data_type());
+        op.input(0), ShapeInfo::DimType::BATCH, dims, w_shape.data_type());
   } else {
     ShapeInfo& x_shape_info = x_it->second;
     if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) {