From 99aeab88ebddcb97221a3e94337b543fb0a3c574 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Sun, 20 May 2018 21:20:46 -0400 Subject: [PATCH] address lint again --- .../contrib/tensor_forest/python/tensor_forest.py | 2 +- .../tensor_forest/python/tensor_forest_test.py | 32 ++++++++++++---------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index ba1755e..6f62cd1 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -354,7 +354,7 @@ class ForestVariables(object): if tree_stats is not None: kwargs.update(dict(tree_stat=tree_stats[i])) self.variables.append(tree_variables_class( - params, i, training, **kwargs)) + params, i, training, **kwargs)) def __setitem__(self, t, val): self.variables[t] = val diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index 7c5883d..1c9c818 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -126,23 +126,25 @@ class TensorForestTest(test_util.TensorFlowTestCase): max_nodes=1000, split_after_samples=25).fill() tree_weight = {'decisionTree': - {'nodes': + {'nodes': [{'binaryNode': - {'rightChildId': 2, - 'leftChildId': 1, - 'inequalityLeftChildTest': - {'featureId': {'id': '0'}, - 'threshold': {'floatValue': 0}}}}, + {'rightChildId': 2, + 'leftChildId': 1, + 'inequalityLeftChildTest': + {'featureId': {'id': '0'}, + 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': - {'value': [{'floatValue': 0.0}, - {'floatValue': 1.0}]}}, - 'nodeId': 1}, - {'leaf': {'vector': - {'value': [{'floatValue': 0.0}, - {'floatValue': 1.0}]}}, - 'nodeId': 2}]}} - restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString() - graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param]) + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 1}, + {'leaf': {'vector': + {'value': [{'floatValue': 0.0}, + {'floatValue': 1.0}]}}, + 'nodeId': 2}]}} + restored_tree_param = ParseDict(tree_weight, + _tree_proto.Model()).SerializeToString() + graph_builder = tensor_forest.RandomForestGraphs(hparams, + [restored_tree_param]) probs, paths, var = graph_builder.inference_graph(input_data) self.assertTrue(isinstance(probs, ops.Tensor)) self.assertTrue(isinstance(paths, ops.Tensor)) -- 2.7.4