From d52862ca81e419ea718447adec7e407734bf4ec8 Mon Sep 17 00:00:00 2001 From: Yinghai Lu Date: Thu, 14 Feb 2019 13:33:52 -0800 Subject: [PATCH] Moderate the dim type after LengthsRangeFill (#17096) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17096 LengthsRangeFill will take a batch size of lengths input and expand it into sequence. Later op should follow this type until it hits another batch type moderating op, e.g. SparseLengthsSum. Reviewed By: ipiszy Differential Revision: D14079422 fbshipit-source-id: 1a26925d502c32875ea95c160268bf6a256cc955 --- caffe2/opt/bound_shape_inference_test.cc | 39 ++++++++++++++++++++++++++++++++ caffe2/opt/bound_shape_inferencer.cc | 1 + 2 files changed, 40 insertions(+) diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index 40e99e5..ee59fe7 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -110,6 +110,45 @@ TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) { out_shape, "Out", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 50}); } +TEST(BoundShapeInference, LengthsRangeFill) { + NetDef net; + net.add_op()->CopyFrom(CreateOperatorDef( + "LengthsRangeFill", + "", + {"X"}, + {"Y"}, + {})); + net.add_op()->CopyFrom(CreateOperatorDef( + "Copy", + "", + {"Y"}, + {"Z"}, + {})); + ShapeInfoMap shape_map; + BoundShapeSpec spec(20, 1000); + BoundShapeInferencer eng(spec); + eng.InferBoundShapeAndType(net, shape_map); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, + "X", + ShapeInfo::DimType::BATCH, + {spec.max_batch_size}, + TensorProto_DataType_INT32); + verifyShapeInfo( + out_shape, + "Y", + ShapeInfo::DimType::SEQ, + {spec.max_seq_size}, + TensorProto_DataType_INT32); + verifyShapeInfo( + out_shape, + "Z", + ShapeInfo::DimType::SEQ, + {spec.max_seq_size}, + TensorProto_DataType_INT32); +} + TEST(BoundShapeInference, Reshape) { NetDef net; std::vector new_shape{-1, 8}; diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index cb3693a..f0b7205 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -137,6 +137,7 @@ void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) { ShapeInfo::DimType::SEQ, {spec_.max_seq_size}, TensorProto_DataType_INT32); + current_dim_type_ = ShapeInfo::DimType::SEQ; } void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { -- 2.7.4