if tree_stats is not None:
kwargs.update(dict(tree_stat=tree_stats[i]))
self.variables.append(tree_variables_class(
- params, i, training, **kwargs))
+ params, i, training, **kwargs))
def __setitem__(self, t, val):
self.variables[t] = val
max_nodes=1000,
split_after_samples=25).fill()
tree_weight = {'decisionTree':
- {'nodes':
+ {'nodes':
[{'binaryNode':
- {'rightChildId': 2,
- 'leftChildId': 1,
- 'inequalityLeftChildTest':
- {'featureId': {'id': '0'},
- 'threshold': {'floatValue': 0}}}},
+ {'rightChildId': 2,
+ 'leftChildId': 1,
+ 'inequalityLeftChildTest':
+ {'featureId': {'id': '0'},
+ 'threshold': {'floatValue': 0}}}},
{'leaf': {'vector':
- {'value': [{'floatValue': 0.0},
- {'floatValue': 1.0}]}},
- 'nodeId': 1},
- {'leaf': {'vector':
- {'value': [{'floatValue': 0.0},
- {'floatValue': 1.0}]}},
- 'nodeId': 2}]}}
- restored_tree_param = ParseDict(tree_weight, _tree_proto.Model()).SerializeToString()
- graph_builder = tensor_forest.RandomForestGraphs(hparams, [restored_tree_param])
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 1},
+ {'leaf': {'vector':
+ {'value': [{'floatValue': 0.0},
+ {'floatValue': 1.0}]}},
+ 'nodeId': 2}]}}
+ restored_tree_param = ParseDict(tree_weight,
+ _tree_proto.Model()).SerializeToString()
+ graph_builder = tensor_forest.RandomForestGraphs(hparams,
+ [restored_tree_param])
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))