From 049a8364211ca91d73e10b2002c18f10fe89b8b2 Mon Sep 17 00:00:00 2001 From: Peng Yu Date: Fri, 16 Feb 2018 10:59:14 -0500 Subject: [PATCH] add inference support for tree and forest variables --- .../contrib/tensor_forest/python/tensor_forest.py | 31 +++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 7a35a70..0feca52 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -295,7 +295,7 @@ def get_epoch_variable(): # A simple container to hold the training variables for a single tree. -class TreeTrainingVariables(object): +class TreeVariables(object): """Stores tf.Variables for training a single random tree. Uses tf.get_variable to get tree-specific names so that this can be used @@ -303,7 +303,7 @@ class TreeTrainingVariables(object): then relies on restoring that model to evaluate). """ - def __init__(self, params, tree_num, training): + def __init__(self, params, tree_num, training, tree_config='', tree_stat=''): if (not hasattr(params, 'params_proto') or not isinstance(params.params_proto, _params_proto.TensorForestParams)): @@ -315,27 +315,27 @@ class TreeTrainingVariables(object): # TODO(gilberth): Manually shard this to be able to fit it on # multiple machines. self.stats = stats_ops.fertile_stats_variable( - params, '', self.get_tree_name('stats', tree_num)) + params, tree_stat, self.get_tree_name('stats', tree_num)) self.tree = model_ops.tree_variable( - params, '', self.stats, self.get_tree_name('tree', tree_num)) + params, tree_config, self.stats, self.get_tree_name('tree', tree_num)) def get_tree_name(self, name, num): return '{0}-{1}'.format(name, num) -class ForestTrainingVariables(object): +class ForestVariables(object): """A container for a forests training data, consisting of multiple trees. - Instantiates a TreeTrainingVariables object for each tree. We override the + Instantiates a TreeVariables object for each tree. We override the __getitem__ and __setitem__ function so that usage looks like this: - forest_variables = ForestTrainingVariables(params) + forest_variables = ForestVariables(params) ... forest_variables.tree ... """ def __init__(self, params, device_assigner, training=True, - tree_variables_class=TreeTrainingVariables): + tree_variables_class=TreeVariables, tree_configs=None, tree_stats=None): self.variables = [] # Set up some scalar variables to run through the device assigner, then # we can use those to colocate everything related to a tree. @@ -347,7 +347,12 @@ class ForestTrainingVariables(object): for i in range(params.num_trees): with ops.device(self.device_dummies[i].device): - self.variables.append(tree_variables_class(params, i, training)) + kwargs = {} + if tree_configs is not None: + kwargs.update(dict(tree_config=tree_configs[i])) + if tree_stats is not None: + kwargs.update(dict(tree_stat=tree_stats[i])) + self.variables.append(tree_variables_class(params, i, training, **kwargs)) def __setitem__(self, t, val): self.variables[t] = val @@ -361,9 +366,11 @@ class RandomForestGraphs(object): def __init__(self, params, + tree_configs=None, + tree_stats=None, device_assigner=None, variables=None, - tree_variables_class=TreeTrainingVariables, + tree_variables_class=TreeVariables, tree_graphs=None, training=True): self.params = params @@ -371,9 +378,9 @@ class RandomForestGraphs(object): device_assigner or framework_variables.VariableDeviceChooser()) logging.info('Constructing forest with params = ') logging.info(self.params.__dict__) - self.variables = variables or ForestTrainingVariables( + self.variables = variables or ForestVariables( self.params, device_assigner=self.device_assigner, training=training, - tree_variables_class=tree_variables_class) + tree_variables_class=tree_variables_class, tree_configs=tree_configs, tree_stats=tree_stats) tree_graph_class = tree_graphs or RandomTreeGraphs self.trees = [ tree_graph_class(self.variables[i], self.params, i) -- 2.7.4