add test case always predict [0, 1]
authorPeng Yu <peng.yu@shopify.com>
Thu, 12 Apr 2018 02:04:58 +0000 (22:04 -0400)
committerPeng Yu <peng.yu@shopify.com>
Mon, 21 May 2018 17:34:57 +0000 (13:34 -0400)
tensorflow/contrib/tensor_forest/python/tensor_forest_test.py

index bbe627b..b6e70d4 100644 (file)
@@ -18,10 +18,14 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from google.protobuf.json_format import ParseDict
+from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto
 from tensorflow.contrib.tensor_forest.python import tensor_forest
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 
 
@@ -110,6 +114,30 @@ class TensorForestTest(test_util.TensorFlowTestCase):
     self.assertTrue(isinstance(paths, ops.Tensor))
     self.assertTrue(isinstance(var, ops.Tensor))
 
+  def testInfrenceWithPreTrainedParams(self):
+    input_data = [[-1., 0.], [-1., 2.],  # node 1
+                  [1., 0.], [1., -2.]]  # node 2
+    expected_prediction = [[0.0, 1.0], [0.0, 1.0],
+                           [0.0, 1.0], [0.0, 1.0]]
+    hparams = tensor_forest.ForestHParams(
+        num_classes=2,
+        num_features=2,
+        num_trees=1,
+        max_nodes=1000,
+        split_after_samples=25).fill()
+    tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}}
+    tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
+    graph_builder = tensor_forest.RandomForestGraphs(hparams, [tree_param])
+    probs, paths, var = graph_builder.inference_graph(input_data)
+    self.assertTrue(isinstance(probs, ops.Tensor))
+    self.assertTrue(isinstance(paths, ops.Tensor))
+    self.assertTrue(isinstance(var, ops.Tensor))
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      resources.initialize_resources(resources.shared_resources()).run()
+      self.assertEquals(probs.eval().shape, (4, 2))
+      self.assertEquals(probs.eval().tolist(), expected_prediction)
+
   def testTrainingConstructionClassificationSparse(self):
     input_data = sparse_tensor.SparseTensor(
         indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],