Fix/Improve bound shape inference with real net tests (#16597)
authorYinghai Lu <yinghai@fb.com>
Wed, 6 Feb 2019 18:23:01 +0000 (10:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Feb 2019 18:41:07 +0000 (10:41 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16597

This diff fixes some bugs in shape inference for `SparseLengthsSumFused8BitRowwise`. And added input shape inference for `Concat` when `add_axis=1`.

Reviewed By: bertmaher

Differential Revision: D13892452

fbshipit-source-id: 6cd95697a6fabe6d78a5ce3cb749a3a1e51c68e7

caffe2/opt/bound_shape_inference_test.cc
caffe2/opt/bound_shape_inferencer.cc
caffe2/opt/bound_shape_inferencer.h
caffe2/predictor/emulator/data_filler.h

index 38ec576..b07dbca 100644 (file)
@@ -6,9 +6,8 @@
 
 using namespace caffe2;
 namespace {
-using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
 
-ShapeInfo MakeTensorInfo(
+ShapeInfo makeTensorInfo(
     ShapeInfo::DimType t,
     const std::vector<int64_t>& dims,
     TensorProto::DataType dtype = TensorProto_DataType_FLOAT) {
@@ -22,20 +21,7 @@ ShapeInfo MakeTensorInfo(
   return info;
 }
 
-void PrintShape(const ShapeInfoMap& map) {
-  for (const auto& kv : map) {
-    const auto& s = kv.second;
-    std::stringstream ss;
-    ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
-    for (const auto d : s.shape.dims()) {
-      ss << d << ", ";
-    }
-    ss << "], dtype: " << s.shape.data_type();
-    LOG(INFO) << ss.str();
-  }
-}
-
-void VerifyShapeInfo(
+void verifyShapeInfo(
     const ShapeInfoMap& info,
     const std::string& name,
     ShapeInfo::DimType t,
@@ -62,29 +48,93 @@ TEST(BoundShapeInference, SparseLengthsSum) {
       "SparseLengthsSum", "", {"Weights", "Data", "Lengths"}, {"Out"}, {}));
   ShapeInfoMap shape_map;
   shape_map.emplace(
-      "Weights", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
+      "Weights", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1000, 16}));
   BoundShapeSpec spec(20, 1000);
   BoundShapeInferencer eng(spec);
   eng.InferBoundShapeAndType(net, shape_map);
   const auto& out_shape = eng.shape_info();
-  VerifyShapeInfo(
-      out_shape, "Weights", ShapeInfo::DimType::CONSTANT, {16, 1000});
-  VerifyShapeInfo(
+  verifyShapeInfo(
+      out_shape, "Weights", ShapeInfo::DimType::CONSTANT, {1000, 16});
+  verifyShapeInfo(
       out_shape,
       "Data",
       ShapeInfo::DimType::SEQ,
       {spec.max_seq_size},
-      TensorProto_DataType_INT32);
-  VerifyShapeInfo(
+      TensorProto_DataType_INT64);
+  verifyShapeInfo(
       out_shape,
       "Lengths",
       ShapeInfo::DimType::BATCH,
       {spec.max_batch_size},
       TensorProto_DataType_INT32);
-  VerifyShapeInfo(
+  verifyShapeInfo(
       out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
 }
 
+TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) {
+  NetDef net;
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "SparseLengthsSumFused8BitRowwise",
+      "",
+      {"Weights", "Data", "Lengths"},
+      {"Out"},
+      {}));
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "Weights",
+      makeTensorInfo(
+          ShapeInfo::DimType::CONSTANT, {1000, 58}, TensorProto_DataType_INT8));
+  BoundShapeSpec spec(20, 1000);
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  verifyShapeInfo(
+      out_shape,
+      "Weights",
+      ShapeInfo::DimType::CONSTANT,
+      {1000, 58},
+      TensorProto_DataType_INT8);
+  verifyShapeInfo(
+      out_shape,
+      "Data",
+      ShapeInfo::DimType::SEQ,
+      {spec.max_seq_size},
+      TensorProto_DataType_INT64);
+  verifyShapeInfo(
+      out_shape,
+      "Lengths",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size},
+      TensorProto_DataType_INT32);
+  verifyShapeInfo(
+      out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 50});
+}
+
+TEST(BoundShapeInference, ConcatMissingInput) {
+  NetDef net;
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Concat",
+      "",
+      {"I0", "I1"},
+      {"Cout", "split_info"},
+      {MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)}));
+  BoundShapeSpec spec(20, 1000);
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "I0",
+      makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60}));
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  verifyShapeInfo(
+      out_shape, "I0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60});
+  verifyShapeInfo(
+      out_shape,
+      "Cout",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size, 2, 60});
+}
+
 TEST(BoundShapeInference, FC) {
   NetDef net;
   net.add_op()->CopyFrom(
@@ -93,22 +143,22 @@ TEST(BoundShapeInference, FC) {
       CreateOperatorDef("FCTransposed", "", {"X1", "W1", "B1"}, {"Out1"}, {}));
   ShapeInfoMap shape_map;
   shape_map.emplace(
-      "W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
-  shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+      "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+  shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
   shape_map.emplace(
-      "W1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
-  shape_map.emplace("B1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {1024}));
+      "W1", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+  shape_map.emplace("B1", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1024}));
   BoundShapeSpec spec(20, 1000);
   BoundShapeInferencer eng(spec);
   eng.InferBoundShapeAndType(net, shape_map);
   const auto& out_shape = eng.shape_info();
-  VerifyShapeInfo(
+  verifyShapeInfo(
       out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
-  VerifyShapeInfo(
+  verifyShapeInfo(
       out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
-  VerifyShapeInfo(
+  verifyShapeInfo(
       out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
-  VerifyShapeInfo(
+  verifyShapeInfo(
       out_shape,
       "Out1",
       ShapeInfo::DimType::BATCH,
@@ -122,8 +172,8 @@ TEST(BoundShapeInference, UnsupportedFC) {
       CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
   ShapeInfoMap shape_map;
   shape_map.emplace(
-      "W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1, 1024}));
-  shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+      "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1, 1024}));
+  shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
   BoundShapeSpec spec(20, 1000);
   BoundShapeInferencer eng(spec);
   EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet);
@@ -153,16 +203,16 @@ TEST(BoundShapeInference, Combo0) {
       CreateOperatorDef("BatchGather", "", {"Fout", "Indices"}, {"Gout"}, {}));
   ShapeInfoMap shape_map;
   shape_map.emplace(
-      "Weights0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
+      "Weights0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1000, 16}));
   shape_map.emplace(
-      "Weights1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 20000}));
+      "Weights1", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {20000, 16}));
   shape_map.emplace(
-      "Indices", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {2}));
+      "Indices", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {2}));
   BoundShapeSpec spec(20, 1000);
   BoundShapeInferencer eng(spec);
   eng.InferBoundShapeAndType(net, shape_map);
   const auto& out_shape = eng.shape_info();
-  PrintShape(out_shape);
-  VerifyShapeInfo(
+  LOG(INFO) << eng.PrintShapeInfo();
+  verifyShapeInfo(
       out_shape, "Gout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2});
 }
index a96ae02..1ea1609 100644 (file)
@@ -47,11 +47,14 @@ void BoundShapeInferencer::InferBoundShapeAndType(
   visited_tensors_.clear();
 
   for (const auto& op : net.op()) {
+    LOG(INFO) << op.type();
     if (op.type() == "SparseLengthsSum" ||
         op.type() == "SparseLengthsSumFused8BitRowwise") {
       InferSparseLengthsSum(op);
     } else if (op.type() == "FC" || op.type() == "FCTransposed") {
       InferFC(op);
+    } else if (op.type() == "Concat") {
+      InferConcat(op);
     } else {
       InferCommonOp(op);
     }
@@ -125,13 +128,19 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
       "Shape of DATA input of SparseLengthsSum ",
       op.input(0),
       " needs to be presented");
+  CAFFE_ENFORCE_EQ(
+      it->second.shape.dims().size(),
+      2,
+      "DATA input ",
+      op.input(0),
+      "needs to be 2D");
 
   // Bound inputs
   CheckAndSetTensorShapeAndType(
       op.input(1),
       ShapeInfo::DimType::SEQ,
       {spec_.max_seq_size},
-      TensorProto_DataType_INT32);
+      TensorProto_DataType_INT64);
   CheckAndSetTensorShapeAndType(
       op.input(2),
       ShapeInfo::DimType::BATCH,
@@ -142,11 +151,70 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
   CAFFE_ENFORCE_EQ(it->second.shape.dims_size(), 2);
   current_dim_type_ = ShapeInfo::DimType::BATCH;
   current_max_batch_size_ = spec_.max_batch_size;
+  auto output_dim1 = it->second.shape.dims(1);
+  // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 for
+  // scale and 4 byte for bias (https://fburl.com/t6dp9tsc)
+  if (op.type() == "SparseLengthsSumFused8BitRowwise") {
+    output_dim1 -= 8;
+  }
   CheckAndSetTensorShapeAndType(
       op.output(0),
       ShapeInfo::DimType::BATCH,
-      {spec_.max_batch_size, it->second.shape.dims(0)},
-      it->second.shape.data_type());
+      {spec_.max_batch_size, output_dim1},
+      TensorProto_DataType_FLOAT);
+}
+
+// For concat net, if some inputs are missing and we have add_axis argument, it
+// means that all the inputs should be of the same dimension. In this case, we
+// can infer the shape of the missing inputs
+void BoundShapeInferencer::InferConcat(const OperatorDef& op) {
+  ArgumentHelper helper(op);
+  auto add_axis = helper.GetSingleArgument<int32_t>("add_axis", 0);
+  if (add_axis) {
+    ShapeInfo* ref_input_shape = nullptr;
+    std::string ref_name;
+    std::unordered_set<std::string> missing_shape_inputs;
+    for (const auto& i : op.input()) {
+      const auto it = shape_info_.find(i);
+      if (it != shape_info_.end()) {
+        const auto& current_input_shape = it->second;
+        if (ref_input_shape) {
+          CAFFE_ENFORCE(
+              ref_input_shape->shape.dims_size(),
+              current_input_shape.shape.dims_size());
+          for (int j = 0; j < ref_input_shape->shape.dims_size(); ++j) {
+            CAFFE_ENFORCE_EQ(
+                ref_input_shape->shape.dims(j),
+                current_input_shape.shape.dims(j),
+                "Mismatched size on dim ",
+                j,
+                " between ",
+                ref_name,
+                " and ",
+                i,
+                " (",
+                ref_input_shape->shape.dims(j),
+                " vs ",
+                current_input_shape.shape.dims(j),
+                ")");
+          }
+        } else {
+          ref_input_shape = &it->second;
+          ref_name = i;
+        }
+      } else {
+        missing_shape_inputs.emplace(i);
+      }
+    }
+
+    if (ref_input_shape) {
+      current_dim_type_ = ref_input_shape->dim_type;
+      for (const auto& i : missing_shape_inputs) {
+        shape_info_.emplace(i, *ref_input_shape);
+      }
+    }
+  }
+  InferCommonOp(op);
 }
 
 void BoundShapeInferencer::InferFC(const OperatorDef& op) {
@@ -226,7 +294,8 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
   std::vector<TensorShape> input_shapes;
   for (const auto& input : op.input()) {
     const auto it = shape_info_.find(input);
-    CAFFE_ENFORCE(it != shape_info_.end());
+    CAFFE_ENFORCE(
+        it != shape_info_.end(), "Cannot find shape info for ", input);
     input_shapes.emplace_back(it->second.shape);
   }
 
index f4b66b1..775c531 100644 (file)
@@ -3,6 +3,7 @@
 #include "caffe2/core/logging.h"
 #include "caffe2/proto/caffe2_pb.h"
 
+#include <sstream>
 #include <string>
 #include <unordered_map>
 #include <unordered_set>
@@ -11,11 +12,16 @@ namespace caffe2 {
 
 struct CAFFE2_API ShapeInfo {
   enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
+  ShapeInfo() {}
+  ShapeInfo(DimType t, TensorShape&& s) : dim_type(t), shape(std::move(s)) {}
+
   // type of the shape according its first dim
   DimType dim_type{DimType::UNKNOWN};
   TensorShape shape;
 };
 
+using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
+
 // This struct stores the max bound size for batch in the general sense. We have
 // the conventioal batch size and the look-up sequence, which is also batch in a
 // sense.
@@ -48,6 +54,20 @@ class CAFFE2_API BoundShapeInferencer {
     return shape_info_;
   }
 
+  /// Print out all the shape info
+  std::string PrintShapeInfo() const {
+    std::stringstream ss;
+    for (const auto& kv : shape_info_) {
+      const auto& s = kv.second;
+      ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
+      for (const auto d : s.shape.dims()) {
+        ss << d << ", ";
+      }
+      ss << "], dtype: " << s.shape.data_type() << "\n";
+    }
+    return ss.str();
+  }
+
  private:
   TensorShape& CheckAndSetTensorShapeAndType(
       const std::string& name,
@@ -57,6 +77,7 @@ class CAFFE2_API BoundShapeInferencer {
 
   void InferSparseLengthsSum(const OperatorDef& op);
   void InferFC(const OperatorDef& op);
+  void InferConcat(const OperatorDef& op);
 
   // Standard shape/type inference using op schema registered shape inference
   // function
index e574ba5..fa8f584 100644 (file)
@@ -36,7 +36,7 @@ class Filler {
     return bytes;
   }
 
-  std::vector<std::string> get_input_names() const {
+  const std::vector<std::string>& get_input_names() const {
     CAFFE_ENFORCE(!input_names_.empty(), "input names is not initialized");
     return input_names_;
   }