from __future__ import division
from __future__ import print_function
+from google.protobuf.json_format import ParseDict
+from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2 as _tree_proto
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
+ def testInfrenceWithPreTrainedParams(self):
+ input_data = [[-1., 0.], [-1., 2.], # node 1
+ [1., 0.], [1., -2.]] # node 2
+ expected_prediction = [[0.0, 1.0], [0.0, 1.0],
+ [0.0, 1.0], [0.0, 1.0]]
+ hparams = tensor_forest.ForestHParams(
+ num_classes=2,
+ num_features=2,
+ num_trees=1,
+ max_nodes=1000,
+ split_after_samples=25).fill()
+ tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}}
+ tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
+ graph_builder = tensor_forest.RandomForestGraphs(hparams, [tree_param])
+ probs, paths, var = graph_builder.inference_graph(input_data)
+ self.assertTrue(isinstance(probs, ops.Tensor))
+ self.assertTrue(isinstance(paths, ops.Tensor))
+ self.assertTrue(isinstance(var, ops.Tensor))
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+ self.assertEquals(probs.eval().shape, (4, 2))
+ self.assertEquals(probs.eval().tolist(), expected_prediction)
+
def testTrainingConstructionClassificationSparse(self):
input_data = sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 3], [1, 0], [1, 7], [2, 1], [3, 9]],