boosted_trees: infer the output shapes of Quantiles Op from the input shapes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 18:19:08 +0000 (11:19 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 18:22:57 +0000 (11:22 -0700)
PiperOrigin-RevId: 188750079

tensorflow/contrib/boosted_trees/ops/quantile_ops.cc

index ae99d53..6aa5246 100644 (file)
@@ -272,6 +272,20 @@ REGISTER_OP("Quantiles")
     .Input("sparse_indices: num_sparse_features * int64")
     .Output("dense_quantiles: num_dense_features * int32")
     .Output("sparse_quantiles: num_sparse_features * int32")
+    .SetShapeFn([](InferenceContext* c) {
+      int num_dense_features;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_dense_features", &num_dense_features));
+      int num_sparse_features;
+      TF_RETURN_IF_ERROR(
+          c->GetAttr("num_sparse_features", &num_sparse_features));
+      // Set output shapes (dense_quantiles and sparse_quantiles) by the
+      // relevant inputs (dense_values and sparse_values). Note that the output
+      // has an additional dimension for dimension_ids.
+      for (int i = 0; i < num_dense_features + num_sparse_features; ++i) {
+        c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 2}));
+      }
+      return Status::OK();
+    })
     .Doc(R"doc(
 Computes quantile for each a given list of dense and sparse feature values using
 the given buckets.