std::unique_ptr<ClientGraph> cg,
const SessionOptions& session_opts,
const StatsPublisherFactory& stats_publisher_factory,
- GraphExecutionState* execution_state, bool is_partial,
- WorkerCacheInterface* worker_cache, bool should_deregister)
+ bool is_partial, WorkerCacheInterface* worker_cache,
+ bool should_deregister)
: session_handle_(handle),
client_graph_(std::move(cg)),
session_opts_(session_opts),
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
- // Initialize a name to node map for testing that fetches are reachable.
- for (Node* n : execution_state->full_graph()->nodes()) {
+ // Initialize a name to node map for processing device stats.
+ for (Node* n : client_graph_->graph.nodes()) {
name_to_node_.insert({n->name(), n});
}
}
// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
// on once at setup time to prevent us from computing the dependencies
// everytime.
-// TODO(suharshs,mrry): Consider removing the need for execution_state to reduce
-// contention.
Status MasterSession::ReffedClientGraph::CheckFetches(
const RunStepRequestWrapper& req, const RunState* run_state,
GraphExecutionState* execution_state) {
// Skip if already fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
- const auto it = name_to_node_.find(id.first);
- if (it == name_to_node_.end()) {
+ const Node* n = execution_state->get_node_by_name(id.first.ToString());
+ if (n == nullptr) {
return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
for (size_t i = 0; i < req.num_fetches(); ++i) {
const string& fetch = req.fetch_name(i);
const TensorId id(ParseTensorName(fetch));
- auto it = name_to_node_.find(id.first);
- if (it == name_to_node_.end()) {
+ const Node* n = execution_state->get_node_by_name(id.first.ToString());
+ if (n == nullptr) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
- stack.push_back(it->second);
+ stack.push_back(n);
}
// Any tensor needed for fetches can't be in pending_feeds.
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
- stats_publisher_factory_, execution_state_.get(), is_partial,
- worker_cache, !should_delete_worker_sessions_);
+ stats_publisher_factory_, is_partial, worker_cache,
+ !should_delete_worker_sessions_);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}