{spec.max_batch_size, 1024});
}
-// We don't support inference input shape when Weight is not 2D
-TEST(BoundShapeInference, UnsupportedFC) {
+TEST(BoundShapeInference, FC3D) {
NetDef net;
net.add_op()->CopyFrom(
CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
shape_map.emplace("B0", makeTensorInfo(ShapeInfo::DimType::CONSTANT, {16}));
BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
- EXPECT_THROW(eng.InferBoundShapeAndType(net, shape_map), EnforceNotMet);
+ eng.InferBoundShapeAndType(net, shape_map);
+ const auto& out_shape = eng.shape_info();
+ verifyShapeInfo(
+ out_shape, "X0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 1024});
+ verifyShapeInfo(
+ out_shape, "Out0", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 16});
}
TEST(BoundShapeInference, Combo0) {
void BoundShapeInferencer::InferBoundShapeAndType(
const NetDef& net,
const std::unordered_map<std::string, ShapeInfo>& info) {
+ const static std::unordered_set<std::string> unsupported{"Tile"};
shape_info_ = info;
for (const auto& op : net.op()) {
VLOG(1) << op.type();
+ if (unsupported.count(op.type())) {
+ continue;
+ }
+
if (op.type() == "SparseLengthsSum" ||
op.type() == "SparseLengthsSumFused8BitRowwise" ||
op.type() == "SparseLengthsWeightedSum" ||
ArgumentHelper helper(op);
auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
- CAFFE_ENFORCE_EQ(
- axis,
- 1,
- "Don't know how to deduce input of FC with axis not equal to 1: ",
- op.input(0));
- CAFFE_ENFORCE_EQ(
- axis_w,
- 1,
- "Don't know how to deduce input of FC with axis_w not equal to 1: ",
- op.input(0));
const TensorShape w_shape = w_shape_info.shape;
- CAFFE_ENFORCE_EQ(
- w_shape.dims_size(),
- 2,
- "Don't know how to deduce input of FC other than of dim size 2: ",
- op.input(0));
bool transposed = (op.type() == "FC") ? false : true;
const int canonical_axis_w =
canonical_axis_index_(axis_w, w_shape.dims().size());
const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w)
: SizeFromDim(w_shape, canonical_axis_w);
+ std::vector<int64_t> dims;
+ for (int i = 0; i < axis - 1; ++i) {
+ dims.push_back(1);
+ }
+ dims.push_back(spec_.max_batch_size);
+ dims.push_back(K);
current_dim_type_ = ShapeInfo::DimType::BATCH;
current_max_batch_size_ = spec_.max_batch_size;
CheckAndSetTensorShapeAndType(
- op.input(0),
- ShapeInfo::DimType::BATCH,
- {spec_.max_batch_size, K},
- w_shape.data_type());
+ op.input(0), ShapeInfo::DimType::BATCH, dims, w_shape.data_type());
} else {
ShapeInfo& x_shape_info = x_it->second;
if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) {