address lint
authorPeng Yu <peng.yu@shopify.com>
Fri, 18 May 2018 21:05:45 +0000 (17:05 -0400)
committerPeng Yu <peng.yu@shopify.com>
Mon, 21 May 2018 17:34:57 +0000 (13:34 -0400)
tensorflow/contrib/tensor_forest/python/tensor_forest.py
tensorflow/contrib/tensor_forest/python/tensor_forest_test.py

index 0feca52..ba1755e 100644 (file)
@@ -335,7 +335,8 @@ class ForestVariables(object):
   """
 
   def __init__(self, params, device_assigner, training=True,
-               tree_variables_class=TreeVariables, tree_configs=None, tree_stats=None):
+               tree_variables_class=TreeVariables,
+               tree_configs=None, tree_stats=None):
     self.variables = []
     # Set up some scalar variables to run through the device assigner, then
     # we can use those to colocate everything related to a tree.
@@ -349,10 +350,11 @@ class ForestVariables(object):
       with ops.device(self.device_dummies[i].device):
         kwargs = {}
         if tree_configs is not None:
-            kwargs.update(dict(tree_config=tree_configs[i]))
+          kwargs.update(dict(tree_config=tree_configs[i]))
         if tree_stats is not None:
-            kwargs.update(dict(tree_stat=tree_stats[i]))
-        self.variables.append(tree_variables_class(params, i, training, **kwargs))
+          kwargs.update(dict(tree_stat=tree_stats[i]))
+        self.variables.append(tree_variables_class(
+                params, i, training, **kwargs))
 
   def __setitem__(self, t, val):
     self.variables[t] = val
@@ -380,7 +382,8 @@ class RandomForestGraphs(object):
     logging.info(self.params.__dict__)
     self.variables = variables or ForestVariables(
         self.params, device_assigner=self.device_assigner, training=training,
-        tree_variables_class=tree_variables_class, tree_configs=tree_configs, tree_stats=tree_stats)
+        tree_variables_class=tree_variables_class,
+        tree_configs=tree_configs, tree_stats=tree_stats)
     tree_graph_class = tree_graphs or RandomTreeGraphs
     self.trees = [
         tree_graph_class(self.variables[i], self.params, i)
index cf50ba2..7c5883d 100644 (file)
@@ -125,7 +125,22 @@ class TensorForestTest(test_util.TensorFlowTestCase):
         num_trees=1,
         max_nodes=1000,
         split_after_samples=25).fill()
-    tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}}
+    tree_weight = {'decisionTree':
+                    {'nodes':
+                        [{'binaryNode':
+                            {'rightChildId': 2,
+                             'leftChildId': 1,
+                             'inequalityLeftChildTest':
+                             {'featureId': {'id': '0'},
+                             'threshold': {'floatValue': 0}}}},
+                         {'leaf': {'vector':
+                            {'value': [{'floatValue': 0.0},
+                                       {'floatValue': 1.0}]}},
+                            'nodeId': 1},
+                        {'leaf': {'vector':
+                            {'value': [{'floatValue': 0.0},
+                                       {'floatValue': 1.0}]}},
+                            'nodeId': 2}]}}
     restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
     graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param])
     probs, paths, var = graph_builder.inference_graph(input_data)