.Input("sparse_indices: num_sparse_features * int64")
.Output("dense_quantiles: num_dense_features * int32")
.Output("sparse_quantiles: num_sparse_features * int32")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_dense_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_dense_features", &num_dense_features));
+ int num_sparse_features;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_sparse_features", &num_sparse_features));
+ // Set output shapes (dense_quantiles and sparse_quantiles) by the
+ // relevant inputs (dense_values and sparse_values). Note that the output
+ // has an additional dimension for dimension_ids.
+ for (int i = 0; i < num_dense_features + num_sparse_features; ++i) {
+ c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 2}));
+ }
+ return Status::OK();
+ })
.Doc(R"doc(
Computes quantile for each a given list of dense and sparse feature values using
the given buckets.