"""
def __init__(self, params, device_assigner, training=True,
- tree_variables_class=TreeVariables, tree_configs=None, tree_stats=None):
+ tree_variables_class=TreeVariables,
+ tree_configs=None, tree_stats=None):
self.variables = []
# Set up some scalar variables to run through the device assigner, then
# we can use those to colocate everything related to a tree.
with ops.device(self.device_dummies[i].device):
kwargs = {}
if tree_configs is not None:
- kwargs.update(dict(tree_config=tree_configs[i]))
+ kwargs.update(dict(tree_config=tree_configs[i]))
if tree_stats is not None:
- kwargs.update(dict(tree_stat=tree_stats[i]))
- self.variables.append(tree_variables_class(params, i, training, **kwargs))
+ kwargs.update(dict(tree_stat=tree_stats[i]))
+ self.variables.append(tree_variables_class(
+ params, i, training, **kwargs))
def __setitem__(self, t, val):
self.variables[t] = val
logging.info(self.params.__dict__)
self.variables = variables or ForestVariables(
self.params, device_assigner=self.device_assigner, training=training,
- tree_variables_class=tree_variables_class, tree_configs=tree_configs, tree_stats=tree_stats)
+ tree_variables_class=tree_variables_class,
+ tree_configs=tree_configs, tree_stats=tree_stats)
tree_graph_class = tree_graphs or RandomTreeGraphs
self.trees = [
tree_graph_class(self.variables[i], self.params, i)
num_trees=1,
max_nodes=1000,
split_after_samples=25).fill()
- tree_weight = {'decisionTree': {'nodes': [{'binaryNode': {'rightChildId': 2, 'leftChildId': 1, 'inequalityLeftChildTest': {'featureId': {'id': '0'}, 'threshold': {'floatValue': 0}}}}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 1}, {'leaf': {'vector': {'value': [{'floatValue': 0.0}, {'floatValue': 1.0}]}}, 'nodeId': 2}]}}
+ tree_weight = {'decisionTree':
+ {'nodes':
+ [{'binaryNode':
+ {'rightChildId': 2,
+ 'leftChildId': 1,
+ 'inequalityLeftChildTest':
+ {'featureId': {'id': '0'},
+ 'threshold': {'floatValue': 0}}}},
+ {'leaf': {'vector':
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 1},
+ {'leaf': {'vector':
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 2}]}}
restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param])
probs, paths, var = graph_builder.inference_graph(input_data)