using namespace caffe2;
namespace {
-using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
-ShapeInfo MakeTensorInfo(
+ShapeInfo makeTensorInfo(
ShapeInfo::DimType t,
const std::vector<int64_t>& dims,
TensorProto::DataType dtype = TensorProto_DataType_FLOAT) {
return info;
}
-void PrintShape(const ShapeInfoMap& map) {
- for (const auto& kv : map) {
- const auto& s = kv.second;
- std::stringstream ss;
- ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
- for (const auto d : s.shape.dims()) {
- ss << d << ", ";
- }
- ss << "], dtype: " << s.shape.data_type();
- LOG(INFO) << ss.str();
- }
-}
-
-void VerifyShapeInfo(
+void verifyShapeInfo(
const ShapeInfoMap& info,
const std::string& name,
ShapeInfo::DimType t,
"SparseLengthsSum", "", {"Weights", "Data", "Lengths"}, {"Out"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
- "Weights", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
+ "Weights", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1000, 16}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map);
const auto& out_shape = eng.shape_info();
- VerifyShapeInfo(
- out_shape, "Weights", ShapeInfo::DimType::CONSTANT, {16, 1000});
- VerifyShapeInfo(
+ verifyShapeInfo(
+ out_shape, "Weights", ShapeInfo::DimType::CONSTANT, {1000, 16});
+ verifyShapeInfo(
out_shape,
"Data",
ShapeInfo::DimType::SEQ,
{spec.max_seq_size},
- TensorProto_DataType_INT32);
- VerifyShapeInfo(
+ TensorProto_DataType_INT64);
+ verifyShapeInfo(
out_shape,
"Lengths",
ShapeInfo::DimType::BATCH,
{spec.max_batch_size},
TensorProto_DataType_INT32);
- VerifyShapeInfo(
+ verifyShapeInfo(
out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
}
+TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) {
+ NetDef net;
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "SparseLengthsSumFused8BitRowwise",
+ "",
+ {"Weights", "Data", "Lengths"},
+ {"Out"},
+ {}));
+ ShapeInfoMap shape_map;
+ shape_map.emplace(
+ "Weights",
+ makeTensorInfo(
+ ShapeInfo::DimType::CONSTANT, {1000, 58}, TensorProto_DataType_INT8));
+ BoundShapeSpec spec(20, 1000);
+ BoundShapeInferencer eng(spec);
+ eng.InferBoundShapeAndType(net, shape_map);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape,
+ "Weights",
+ ShapeInfo::DimType::CONSTANT,
+ {1000, 58},
+ TensorProto_DataType_INT8);
+ verifyShapeInfo(
+ out_shape,
+ "Data",
+ ShapeInfo::DimType::SEQ,
+ {spec.max_seq_size},
+ TensorProto_DataType_INT64);
+ verifyShapeInfo(
+ out_shape,
+ "Lengths",
+ ShapeInfo::DimType::BATCH,
+ {spec.max_batch_size},
+ TensorProto_DataType_INT32);
+ verifyShapeInfo(
+ out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 50});
+}
+
+TEST(BoundShapeInference, ConcatMissingInput) {
+ NetDef net;
+ net.add_op()->CopyFrom(CreateOperatorDef(
+ "Concat",
+ "",
+ {"I0", "I1"},
+ {"Cout", "split_info"},
+ {MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)}));
+ BoundShapeSpec spec(20, 1000);
+ ShapeInfoMap shape_map;
+ shape_map.emplace(
+ "I0",
+ makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60}));
+ BoundShapeInferencer eng(spec);
+ eng.InferBoundShapeAndType(net, shape_map);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape, "I0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 60});
+ verifyShapeInfo(
+ out_shape,
+ "Cout",
+ ShapeInfo::DimType::BATCH,
+ {spec.max_batch_size, 2, 60});
+}
+
TEST(BoundShapeInference, FC) {
NetDef net;
net.add_op()->CopyFrom(
CreateOperatorDef("FCTransposed", "", {"X1", "W1", "B1"}, {"Out1"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
- "W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
- shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+ "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+ shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
shape_map.emplace(
- "W1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
- shape_map.emplace("B1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {1024}));
+ "W1", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1024}));
+ shape_map.emplace("B1", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1024}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map);
const auto& out_shape = eng.shape_info();
- VerifyShapeInfo(
+ verifyShapeInfo(
out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
- VerifyShapeInfo(
+ verifyShapeInfo(
out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
- VerifyShapeInfo(
+ verifyShapeInfo(
out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
- VerifyShapeInfo(
+ verifyShapeInfo(
out_shape,
"Out1",
ShapeInfo::DimType::BATCH,
CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
- "W0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1, 1024}));
- shape_map.emplace("B0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
+ "W0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1, 1024}));
+ shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet);
CreateOperatorDef("BatchGather", "", {"Fout", "Indices"}, {"Gout"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
- "Weights0", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 1000}));
+ "Weights0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {1000, 16}));
shape_map.emplace(
- "Weights1", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {16, 20000}));
+ "Weights1", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {20000, 16}));
shape_map.emplace(
- "Indices", MakeTensorInfo(ShapeInfo::DimType::CONSTANT, {2}));
+ "Indices", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {2}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map);
const auto& out_shape = eng.shape_info();
- PrintShape(out_shape);
- VerifyShapeInfo(
+ LOG(INFO) << eng.PrintShapeInfo();
+ verifyShapeInfo(
out_shape, "Gout", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2});
}
visited_tensors_.clear();
for (const auto& op : net.op()) {
+ LOG(INFO) << op.type();
if (op.type() == "SparseLengthsSum" ||
op.type() == "SparseLengthsSumFused8BitRowwise") {
InferSparseLengthsSum(op);
} else if (op.type() == "FC" || op.type() == "FCTransposed") {
InferFC(op);
+ } else if (op.type() == "Concat") {
+ InferConcat(op);
} else {
InferCommonOp(op);
}
"Shape of DATA input of SparseLengthsSum ",
op.input(0),
" needs to be presented");
+ CAFFE_ENFORCE_EQ(
+ it->second.shape.dims().size(),
+ 2,
+ "DATA input ",
+ op.input(0),
+ "needs to be 2D");
// Bound inputs
CheckAndSetTensorShapeAndType(
op.input(1),
ShapeInfo::DimType::SEQ,
{spec_.max_seq_size},
- TensorProto_DataType_INT32);
+ TensorProto_DataType_INT64);
CheckAndSetTensorShapeAndType(
op.input(2),
ShapeInfo::DimType::BATCH,
CAFFE_ENFORCE_EQ(it->second.shape.dims_size(), 2);
current_dim_type_ = ShapeInfo::DimType::BATCH;
current_max_batch_size_ = spec_.max_batch_size;
+ auto output_dim1 = it->second.shape.dims(1);
+ // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 for
+ // scale and 4 byte for bias (https://fburl.com/t6dp9tsc)
+ if (op.type() == "SparseLengthsSumFused8BitRowwise") {
+ output_dim1 -= 8;
+ }
CheckAndSetTensorShapeAndType(
op.output(0),
ShapeInfo::DimType::BATCH,
- {spec_.max_batch_size, it->second.shape.dims(0)},
- it->second.shape.data_type());
+ {spec_.max_batch_size, output_dim1},
+ TensorProto_DataType_FLOAT);
+}
+
+// For concat net, if some inputs are missing and we have add_axis argument, it
+// means that all the inputs should be of the same dimension. In this case, we
+// can infer the shape of the missing inputs
+void BoundShapeInferencer::InferConcat(const OperatorDef& op) {
+ ArgumentHelper helper(op);
+ auto add_axis = helper.GetSingleArgument<int32_t>("add_axis", 0);
+ if (add_axis) {
+ ShapeInfo* ref_input_shape = nullptr;
+ std::string ref_name;
+ std::unordered_set<std::string> missing_shape_inputs;
+ for (const auto& i : op.input()) {
+ const auto it = shape_info_.find(i);
+ if (it != shape_info_.end()) {
+ const auto& current_input_shape = it->second;
+ if (ref_input_shape) {
+ CAFFE_ENFORCE(
+ ref_input_shape->shape.dims_size(),
+ current_input_shape.shape.dims_size());
+ for (int j = 0; j < ref_input_shape->shape.dims_size(); ++j) {
+ CAFFE_ENFORCE_EQ(
+ ref_input_shape->shape.dims(j),
+ current_input_shape.shape.dims(j),
+ "Mismatched size on dim ",
+ j,
+ " between ",
+ ref_name,
+ " and ",
+ i,
+ " (",
+ ref_input_shape->shape.dims(j),
+ " vs ",
+ current_input_shape.shape.dims(j),
+ ")");
+ }
+ } else {
+ ref_input_shape = &it->second;
+ ref_name = i;
+ }
+ } else {
+ missing_shape_inputs.emplace(i);
+ }
+ }
+
+ if (ref_input_shape) {
+ current_dim_type_ = ref_input_shape->dim_type;
+ for (const auto& i : missing_shape_inputs) {
+ shape_info_.emplace(i, *ref_input_shape);
+ }
+ }
+ }
+ InferCommonOp(op);
}
void BoundShapeInferencer::InferFC(const OperatorDef& op) {
std::vector<TensorShape> input_shapes;
for (const auto& input : op.input()) {
const auto it = shape_info_.find(input);
- CAFFE_ENFORCE(it != shape_info_.end());
+ CAFFE_ENFORCE(
+ it != shape_info_.end(), "Cannot find shape info for ", input);
input_shapes.emplace_back(it->second.shape);
}