From 2d2f0455d215c7a3f4353ab879cbb513b7946a78 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Thu, 12 Apr 2018 15:38:48 -0400 Subject: [PATCH] address comments --- tensorflow/contrib/tensor_forest/python/tensor_forest_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index b6e70d4..cf50ba2 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -114,7 +114,7 @@ class TensorForestTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(paths, ops.Tensor)) self.assertTrue(isinstance(var, ops.Tensor)) - def testInfrenceWithPreTrainedParams(self): + def testInfrenceFromRestoredModel(self): input_data = [[-1., 0.], [-1., 2.], # node 1 [1., 0.], [1., -2.]] # node 2 expected_prediction = [[0.0, 1.0], [0.0, 1.0], @@ -126,8 +126,8 @@ class TensorForestTest(test_util.TensorFlowTestCase): max_nodes=1000, split_after_samples=25).fill() tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}} - tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString() - graph_builder = tensor_forest.RandomForestGraphs(hparams, [tree_param]) + restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString() + graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param]) probs, paths, var = graph_builder.inference_graph(input_data) self.assertTrue(isinstance(probs, ops.Tensor)) self.assertTrue(isinstance(paths, ops.Tensor)) -- 2.7.4