Bound shape inference for c2 (#16081)
authorYinghai Lu <yinghai@fb.com>
Thu, 17 Jan 2019 02:58:08 +0000 (18:58 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 03:02:56 +0000 (19:02 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16081

A simple version of bound shape inference, conditioned on batch size. In addition to doing normal shape inference, it will change the batch size (1st dim of the shape) of the inputs as well as batch size modulating ops such as `SparseLengthsSum`. Probably support to more ops is needed, such as `SparseToDense`. We can build on this.

Reviewed By: jackm321, rdzhabarov

Differential Revision: D13661968

fbshipit-source-id: 6a724a647e109757c26e3e26e15a49725ecc75cc

caffe2/opt/bound_shape_inference_test.cc [new file with mode: 0644]
caffe2/opt/bound_shape_inferencer.cc [new file with mode: 0644]
caffe2/opt/bound_shape_inferencer.h [new file with mode: 0644]

diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc
new file mode 100644 (file)
index 0000000..38ec576
--- /dev/null
@@ -0,0 +1,168 @@
+#include <gtest/gtest.h>
+#include "caffe2/core/common.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/opt/bound_shape_inferencer.h"
+#include "caffe2/utils/proto_utils.h"
+
+using namespace caffe2;
+namespace {
+using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
+
+ShapeInfo MakeTensorInfo(
+    ShapeInfo::DimType t,
+    const std::vector<int64_t>& dims,
+    TensorProto::DataType dtype = TensorProto_DataType_FLOAT) {
+  ShapeInfo info;
+  info.dim_type = t;
+  TensorShape& shape = info.shape;
+  for (const auto d : dims) {
+    shape.add_dims(d);
+  }
+  shape.set_data_type(dtype);
+  return info;
+}
+
+void PrintShape(const ShapeInfoMap& map) {
+  for (const auto& kv : map) {
+    const auto& s = kv.second;
+    std::stringstream ss;
+    ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
+    for (const auto d : s.shape.dims()) {
+      ss << d << ", ";
+    }
+    ss << "], dtype: " << s.shape.data_type();
+    LOG(INFO) << ss.str();
+  }
+}
+
+void VerifyShapeInfo(
+    const ShapeInfoMap& info,
+    const std::string& name,
+    ShapeInfo::DimType t,
+    const std::vector<int64_t>& dims,
+    TensorProto::DataType dtype = TensorProto_DataType_FLOAT) {
+  LOG(INFO) << "Checking " << name;
+  const auto it = info.find(name);
+  ASSERT_TRUE(it != info.end());
+  const auto& shape_info = it->second;
+  EXPECT_EQ(shape_info.dim_type, t);
+  const auto& shape = shape_info.shape;
+  ASSERT_EQ(shape.dims_size(), dims.size());
+  for (int i = 0; i < dims.size(); ++i) {
+    EXPECT_EQ(shape.dims(i), dims[i]);
+  }
+  EXPECT_EQ(shape.data_type(), dtype);
+}
+
+} // namespace
+
+TEST(BoundShapeInference, SparseLengthsSum) {
+  NetDef net;
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "SparseLengthsSum", "", {"Weights", "Data", "Lengths"}, {"Out"}, {}));
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "Weights", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
+  BoundShapeSpec spec(20, 1000);
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  VerifyShapeInfo(
+      out_shape, "Weights", ShapeInfo::DimType::CONSTANT, {16, 1000});
+  VerifyShapeInfo(
+      out_shape,
+      "Data",
+      ShapeInfo::DimType::SEQ,
+      {spec.max_seq_size},
+      TensorProto_DataType_INT32);
+  VerifyShapeInfo(
+      out_shape,
+      "Lengths",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size},
+      TensorProto_DataType_INT32);
+  VerifyShapeInfo(
+      out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
+}
+
+TEST(BoundShapeInference, FC) {
+  NetDef net;
+  net.add_op()->CopyFrom(
+      CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
+  net.add_op()->CopyFrom(
+      CreateOperatorDef("FCTransposed", "", {"X1", "W1", "B1"}, {"Out1"}, {}));
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+  shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+  shape_map.emplace(
+      "W1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+  shape_map.emplace("B1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {1024}));
+  BoundShapeSpec spec(20, 1000);
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  VerifyShapeInfo(
+      out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
+  VerifyShapeInfo(
+      out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
+  VerifyShapeInfo(
+      out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
+  VerifyShapeInfo(
+      out_shape,
+      "Out1",
+      ShapeInfo::DimType::BATCH,
+      {spec.max_batch_size, 1024});
+}
+
+// We don't support inference input shape when Weight is not 2D
+TEST(BoundShapeInference, UnsupportedFC) {
+  NetDef net;
+  net.add_op()->CopyFrom(
+      CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1, 1024}));
+  shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+  BoundShapeSpec spec(20, 1000);
+  BoundShapeInferencer eng(spec);
+  EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet);
+}
+
+TEST(BoundShapeInference, Combo0) {
+  NetDef net;
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "SparseLengthsSum", "", {"Weights0", "Data0", "Lengths0"}, {"EB0"}, {}));
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "SparseLengthsSum", "", {"Weights1", "Data1", "Lengths1"}, {"EB1"}, {}));
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "Concat",
+      "",
+      {"EB0", "EB1"},
+      {"Cout", "split_info"},
+      {MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)}));
+  net.add_op()->CopyFrom(CreateOperatorDef(
+      "BatchMatMul",
+      "",
+      {"Cout", "Cout"},
+      {"Bout"},
+      {MakeArgument<int>("trans_b", 1)}));
+  net.add_op()->CopyFrom(
+      CreateOperatorDef("Flatten", "", {"Bout"}, {"Fout"}, {}));
+  net.add_op()->CopyFrom(
+      CreateOperatorDef("BatchGather", "", {"Fout", "Indices"}, {"Gout"}, {}));
+  ShapeInfoMap shape_map;
+  shape_map.emplace(
+      "Weights0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
+  shape_map.emplace(
+      "Weights1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 20000}));
+  shape_map.emplace(
+      "Indices", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {2}));
+  BoundShapeSpec spec(20, 1000);
+  BoundShapeInferencer eng(spec);
+  eng.InferBoundShapeAndType(net, shape_map);
+  const auto& out_shape = eng.shape_info();
+  PrintShape(out_shape);
+  VerifyShapeInfo(
+      out_shape, "Gout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2});
+}
diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc
new file mode 100644 (file)
index 0000000..a96ae02
--- /dev/null
@@ -0,0 +1,247 @@
+#include "bound_shape_inferencer.h"
+#include "caffe2/core/operator_schema.h"
+#include "caffe2/core/tensor_impl.h"
+#include "caffe2/utils/proto_utils.h"
+
+namespace caffe2 {
+
+namespace {
+std::vector<int64_t> ConvertToVec(
+    const ::google::protobuf::RepeatedField<::google::protobuf::int64>& in) {
+  std::vector<int64_t> out;
+  out.reserve(in.size());
+  for (const auto d : in) {
+    out.push_back(d);
+  }
+  return out;
+}
+
+int64_t SizeFromDim(const TensorShape& shape, int axis) {
+  int64_t r = 1;
+  for (int i = axis; i < shape.dims_size(); ++i) {
+    r *= shape.dims(i);
+  }
+  return r;
+}
+
+int64_t SizeToDim(const TensorShape& shape, int axis) {
+  CAFFE_ENFORCE_LE(axis, shape.dims_size());
+  int64_t r = 1;
+  for (int i = 0; i < axis; ++i) {
+    r *= shape.dims(i);
+  }
+  return r;
+}
+
+void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) {
+  for (auto& kv : *info) {
+    kv.second.shape.set_name(kv.first);
+  }
+}
+} // namespace
+
+void BoundShapeInferencer::InferBoundShapeAndType(
+    const NetDef& net,
+    const std::unordered_map<std::string, ShapeInfo>& info) {
+  shape_info_ = info;
+  visited_tensors_.clear();
+
+  for (const auto& op : net.op()) {
+    if (op.type() == "SparseLengthsSum" ||
+        op.type() == "SparseLengthsSumFused8BitRowwise") {
+      InferSparseLengthsSum(op);
+    } else if (op.type() == "FC" || op.type() == "FCTransposed") {
+      InferFC(op);
+    } else {
+      InferCommonOp(op);
+    }
+  }
+
+  // Make sure shape has name
+  EnsureShapeNames(&shape_info_);
+}
+
+TensorShape& BoundShapeInferencer::CheckAndSetTensorShapeAndType(
+    const std::string& name,
+    ShapeInfo::DimType t,
+    std::vector<int64_t> bound_dims,
+    TensorProto::DataType type) {
+  if (!visited_tensors_.emplace(name).second) {
+    return shape_info_.at(name).shape;
+  }
+  auto rt = shape_info_.emplace(name, ShapeInfo());
+  ShapeInfo& shape_info = rt.first->second;
+  shape_info.dim_type = t;
+  TensorShape& shape = shape_info.shape;
+  if (!rt.second) {
+    // Check shape consistency
+    CAFFE_ENFORCE_EQ(shape.dims_size(), bound_dims.size());
+    // For shapes that was provided as a hint at the input of the net, fix the
+    // batch size first.
+    if (shape.dims_size() > 0 &&
+        shape_info.dim_type == ShapeInfo::DimType::UNKNOWN &&
+        t > ShapeInfo::DimType::CONSTANT) {
+      shape_info.dim_type = t;
+      shape.set_dims(0, bound_dims.front());
+    }
+    for (int i = 0; i < shape.dims_size(); ++i) {
+      CAFFE_ENFORCE_EQ(
+          shape.dims(i),
+          bound_dims[i],
+          "Shape inconsistency found in tensor ",
+          name,
+          " on dim ",
+          i,
+          " (",
+          shape.dims(i),
+          " vs ",
+          bound_dims[i],
+          ")");
+    }
+    return shape;
+  }
+
+  shape.mutable_dims()->Clear();
+  for (const auto d : bound_dims) {
+    shape.add_dims(d);
+  }
+  shape.set_data_type(type);
+  return shape;
+}
+
+std::vector<TensorShape> InferOutput(
+    const OperatorDef& op,
+    const std::vector<TensorShape>& input_shapes) {
+  const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
+  CAFFE_ENFORCE(schema);
+  return schema->InferTensor(op, input_shapes);
+}
+
+void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
+  CAFFE_ENFORCE_EQ(op.input_size(), 3, "SparseLengthsSum has to have 3 inputs");
+  const auto it = shape_info_.find(op.input(0));
+  CAFFE_ENFORCE(
+      it != shape_info_.end(),
+      "Shape of DATA input of SparseLengthsSum ",
+      op.input(0),
+      " needs to be presented");
+
+  // Bound inputs
+  CheckAndSetTensorShapeAndType(
+      op.input(1),
+      ShapeInfo::DimType::SEQ,
+      {spec_.max_seq_size},
+      TensorProto_DataType_INT32);
+  CheckAndSetTensorShapeAndType(
+      op.input(2),
+      ShapeInfo::DimType::BATCH,
+      {spec_.max_batch_size},
+      TensorProto_DataType_INT32);
+
+  // Infer output
+  CAFFE_ENFORCE_EQ(it->second.shape.dims_size(), 2);
+  current_dim_type_ = ShapeInfo::DimType::BATCH;
+  current_max_batch_size_ = spec_.max_batch_size;
+  CheckAndSetTensorShapeAndType(
+      op.output(0),
+      ShapeInfo::DimType::BATCH,
+      {spec_.max_batch_size, it->second.shape.dims(0)},
+      it->second.shape.data_type());
+}
+
+void BoundShapeInferencer::InferFC(const OperatorDef& op) {
+  CAFFE_ENFORCE_EQ(op.input_size(), 3, "FC has to have 3 inputs");
+  const auto w_it = shape_info_.find(op.input(1));
+  CAFFE_ENFORCE(
+      w_it != shape_info_.end(),
+      "Shape of WEIGHT input of FC ",
+      op.input(1),
+      " needs to be presented");
+  const ShapeInfo& w_shape_info = w_it->second;
+  const auto b_it = shape_info_.find(op.input(2));
+  CAFFE_ENFORCE(
+      w_it != shape_info_.end(),
+      "Shape of BIAS input of FC ",
+      op.input(2),
+      " needs to be presented");
+  const ShapeInfo& b_shape_info = b_it->second;
+  auto x_it = shape_info_.find(op.input(0));
+  if (x_it == shape_info_.end()) {
+    // We don't have a hint at the x input we try to deduce it from weight shape
+    ArgumentHelper helper(op);
+    auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
+    auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
+    CAFFE_ENFORCE_EQ(
+        axis,
+        1,
+        "Don't know how to deduce input of FC with axis not equal to 1: ",
+        op.input(0));
+    CAFFE_ENFORCE_EQ(
+        axis_w,
+        1,
+        "Don't know how to deduce input of FC with axis_w not equal to 1: ",
+        op.input(0));
+    const TensorShape w_shape = w_shape_info.shape;
+    CAFFE_ENFORCE_EQ(
+        w_shape.dims_size(),
+        2,
+        "Don't know how to deduce input of FC other than of dim size 2: ",
+        op.input(0));
+    bool transposed = (op.type() == "FC") ? false : true;
+    const int canonical_axis_w =
+        canonical_axis_index_(axis_w, w_shape.dims().size());
+    const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w)
+                                 : SizeFromDim(w_shape, canonical_axis_w);
+    current_dim_type_ = ShapeInfo::DimType::BATCH;
+    current_max_batch_size_ = spec_.max_batch_size;
+    CheckAndSetTensorShapeAndType(
+        op.input(0),
+        ShapeInfo::DimType::BATCH,
+        {spec_.max_batch_size, K},
+        w_shape.data_type());
+  } else {
+    ShapeInfo& x_shape_info = x_it->second;
+    if (x_shape_info.dim_type == ShapeInfo::DimType::UNKNOWN) {
+      CAFFE_ENFORCE_GE(x_shape_info.shape.dims_size(), 1);
+      x_shape_info.shape.set_dims(0, spec_.max_batch_size);
+      x_shape_info.dim_type = ShapeInfo::DimType::BATCH;
+    }
+  }
+
+  // Standard shape inference for outputs
+  std::vector<TensorShape> input_shapes{
+      shape_info_[op.input(0)].shape, w_shape_info.shape, b_shape_info.shape};
+  std::vector<TensorShape> output_shapes = InferOutput(op, input_shapes);
+  CAFFE_ENFORCE_EQ(output_shapes.size(), 1);
+  CheckAndSetTensorShapeAndType(
+      op.output(0),
+      ShapeInfo::DimType::BATCH,
+      ConvertToVec(output_shapes[0].dims()),
+      output_shapes[0].data_type());
+}
+
+void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
+  // First, we need to check that all the input shape/types are already
+  // presented
+  std::vector<TensorShape> input_shapes;
+  for (const auto& input : op.input()) {
+    const auto it = shape_info_.find(input);
+    CAFFE_ENFORCE(it != shape_info_.end());
+    input_shapes.emplace_back(it->second.shape);
+  }
+
+  const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
+  CAFFE_ENFORCE(schema);
+  auto output_shapes = schema->InferTensor(op, input_shapes);
+  CAFFE_ENFORCE_EQ(output_shapes.size(), op.output_size());
+  int i = 0;
+  for (const auto& shape : output_shapes) {
+    CheckAndSetTensorShapeAndType(
+        op.output(i++),
+        current_dim_type_,
+        ConvertToVec(shape.dims()),
+        shape.data_type());
+  }
+}
+
+} // namespace caffe2
diff --git a/caffe2/opt/bound_shape_inferencer.h b/caffe2/opt/bound_shape_inferencer.h
new file mode 100644 (file)
index 0000000..f4b66b1
--- /dev/null
@@ -0,0 +1,72 @@
+#pragma once
+
+#include "caffe2/core/logging.h"
+#include "caffe2/proto/caffe2_pb.h"
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+
+namespace caffe2 {
+
+struct CAFFE2_API ShapeInfo {
+  enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
+  // type of the shape according its first dim
+  DimType dim_type{DimType::UNKNOWN};
+  TensorShape shape;
+};
+
+// This struct stores the max bound size for batch in the general sense. We have
+// the conventioal batch size and the look-up sequence, which is also batch in a
+// sense.
+struct CAFFE2_API BoundShapeSpec {
+  explicit BoundShapeSpec(int64_t b, int64_t q)
+      : max_batch_size(b), max_seq_size(q) {}
+  int64_t max_batch_size;
+  int64_t max_seq_size;
+};
+
+/// \class A class that does bound shape inference given a C2 net. Depending on
+/// its type, each op have a maximum shape that it accepts. We define some
+/// initial bound for certain dimension, for example max batch size or max
+/// sequnce lookup size. And the inference will first infer the input size and
+/// then propagates the bound shape down the network. For now the variable part
+/// (bound part) is the first dimension of the shape, which usually corresponds
+/// to the batch size or sequence lookup size.
+class CAFFE2_API BoundShapeInferencer {
+ public:
+  explicit BoundShapeInferencer(const BoundShapeSpec& spec) : spec_(spec) {
+    CAFFE_ENFORCE_GT(spec_.max_batch_size, 0);
+    CAFFE_ENFORCE_GT(spec_.max_seq_size, 0);
+  }
+
+  void InferBoundShapeAndType(
+      const NetDef& net,
+      const std::unordered_map<std::string, ShapeInfo>& info);
+
+  const std::unordered_map<std::string, ShapeInfo>& shape_info() const {
+    return shape_info_;
+  }
+
+ private:
+  TensorShape& CheckAndSetTensorShapeAndType(
+      const std::string& name,
+      ShapeInfo::DimType t,
+      std::vector<int64_t> bound_dims,
+      TensorProto::DataType type);
+
+  void InferSparseLengthsSum(const OperatorDef& op);
+  void InferFC(const OperatorDef& op);
+
+  // Standard shape/type inference using op schema registered shape inference
+  // function
+  void InferCommonOp(const OperatorDef& op);
+
+  const BoundShapeSpec spec_;
+  ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::UNKNOWN};
+  int64_t current_max_batch_size_{0};
+  std::unordered_map<std::string, ShapeInfo> shape_info_;
+  std::unordered_set<std::string> visited_tensors_;
+};
+
+} // namespace caffe2