From cea25b276d3a975a1ae5152aa4809d2c73a5f9d4 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Fri, 18 May 2018 17:05:45 -0400 Subject: [PATCH] address lint --- .../contrib/tensor_forest/python/tensor_forest.py | 13 ++++++++----- .../contrib/tensor_forest/python/tensor_forest_test.py | 17 ++++++++++++++++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 0feca52..ba1755e 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -335,7 +335,8 @@ class ForestVariables(object): """ def __init__(self, params, device_assigner, training=True, - tree_variables_class=TreeVariables, tree_configs=None, tree_stats=None): + tree_variables_class=TreeVariables, + tree_configs=None, tree_stats=None): self.variables = [] # Set up some scalar variables to run through the device assigner, then # we can use those to colocate everything related to a tree. @@ -349,10 +350,11 @@ class ForestVariables(object): with ops.device(self.device_dummies[i].device): kwargs = {} if tree_configs is not None: - kwargs.update(dict(tree_config=tree_configs[i])) + kwargs.update(dict(tree_config=tree_configs[i])) if tree_stats is not None: - kwargs.update(dict(tree_stat=tree_stats[i])) - self.variables.append(tree_variables_class(params, i, training, **kwargs)) + kwargs.update(dict(tree_stat=tree_stats[i])) + self.variables.append(tree_variables_class( + params, i, training, **kwargs)) def __setitem__(self, t, val): self.variables[t] = val @@ -380,7 +382,8 @@ class RandomForestGraphs(object): logging.info(self.params.__dict__) self.variables = variables or ForestVariables( self.params, device_assigner=self.device_assigner, training=training, - tree_variables_class=tree_variables_class, tree_configs=tree_configs, tree_stats=tree_stats) + tree_variables_class=tree_variables_class, + tree_configs=tree_configs, tree_stats=tree_stats) tree_graph_class = tree_graphs or RandomTreeGraphs self.trees = [ tree_graph_class(self.variables[i], self.params, i) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index cf50ba2..7c5883d 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -125,7 +125,22 @@ class TensorForestTest(test_util.TensorFlowTestCase): num_trees=1, max_nodes=1000, split_after_samples=25).fill() - tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}} + tree_weight = {'decisionTree': + {'nodes': + [{'binaryNode': + {'rightChildId': 2, + 'leftChildId': 1, + 'inequalityLeftChildTest': + {'featureId': {'id': '0'}, + 'threshold': {'floatValue': 0}}}}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 1}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 2}]}} restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString() graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param]) probs, paths, var = graph_builder.inference_graph(input_data) -- 2.7.4