.Attr("dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices;
ShapeHandle unused;
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
- return shape_inference::UnknownShape(c);
+ auto shapes = c->input_handle_shapes_and_types(0);
+ if (shapes != nullptr && !shapes->empty()) {
+ ShapeHandle tensor_shape = shapes->at(0).shape;
+ ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(indices, tensor_shape, &output_shape));
+ c->set_output(0, output_shape);
+ return Status::OK();
+ } else {
+ PartialTensorShape p;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &p));
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
+ ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(c->Concatenate(indices, s, &output_shape));
+ c->set_output(0, output_shape);
+ return Status::OK();
+ }
});
REGISTER_OP("TensorArrayScatterV3")
.Output("flow_out: float")
.Attr("T: type")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices;
ShapeHandle unused;
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ ShapeHandle value_shape;
+ // Assert that the length of the indices tensor is equal to the first
+ // dimension of the value tensor.
+ TF_RETURN_IF_ERROR(
+ c->MergePrefix(c->input(2), indices, &value_shape, &indices));
+ auto shapes = c->input_handle_shapes_and_types(0);
+ if (shapes != nullptr && !shapes->empty()) {
+ ShapeHandle tensor_shape = shapes->at(0).shape;
+ ShapeHandle fed_shape;
+ TF_RETURN_IF_ERROR(c->Subshape(value_shape, 1, &fed_shape));
+ TF_RETURN_IF_ERROR(c->Merge(tensor_shape, fed_shape, &fed_shape));
+ }
return shape_inference::ScalarShape(c);
});