Fix warnings in tf.contrib.tensor_forest
authorYong Tang <yong.tang.github@outlook.com>
Sun, 22 Apr 2018 17:55:06 +0000 (17:55 +0000)
committerYong Tang <yong.tang.github@outlook.com>
Sun, 22 Apr 2018 17:55:06 +0000 (17:55 +0000)
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/contrib/tensor_forest/client/eval_metrics.py
tensorflow/contrib/tensor_forest/hybrid/python/layers/fully_connected.py
tensorflow/contrib/tensor_forest/python/tensor_forest.py

index 9003301..e893e1d 100644 (file)
@@ -37,7 +37,7 @@ def _top_k_generator(k):
   def _top_k(probabilities, targets):
     targets = math_ops.to_int32(targets)
     if targets.get_shape().ndims > 1:
-      targets = array_ops.squeeze(targets, squeeze_dims=[1])
+      targets = array_ops.squeeze(targets, axis=[1])
     return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k))
   return _top_k
 
@@ -57,7 +57,7 @@ def _r2(probabilities, targets, weights=None):
 
 
 def _squeeze_and_onehot(targets, depth):
-  targets = array_ops.squeeze(targets, squeeze_dims=[1])
+  targets = array_ops.squeeze(targets, axis=[1])
   return array_ops.one_hot(math_ops.to_int32(targets), depth)
 
 
index ff3ab21..745a5b1 100644 (file)
@@ -55,7 +55,7 @@ class ManyToOneLayer(hybrid_layer.HybridLayer):
 
       # There is always one activation per instance by definition, so squeeze
       # away the extra dimension.
-      return array_ops.squeeze(nn_activations, squeeze_dims=[1])
+      return array_ops.squeeze(nn_activations, axis=[1])
 
 
 class FlattenedFullyConnectedLayer(hybrid_layer.HybridLayer):
index b9bcbb1..7a35a70 100644 (file)
@@ -445,7 +445,7 @@ class RandomForestGraphs(object):
           mask = math_ops.less(
               r, array_ops.ones_like(r) * self.params.bagging_fraction)
           gather_indices = array_ops.squeeze(
-              array_ops.where(mask), squeeze_dims=[1])
+              array_ops.where(mask), axis=[1])
           # TODO(thomaswc): Calculate out-of-bag data and labels, and store
           # them for use in calculating statistics later.
           tree_data = array_ops.gather(processed_dense_features, gather_indices)