Status Sync() override { return Status::OK(); }
};
-Status InlineSymbolicGradient(const NodeDef& node,
- const FunctionDefLibrary& library,
- GraphDef* inlined_graph) {
- Env* env = Env::Default();
- DeviceAttributes attr;
- attr.set_name("/device:CPU:0");
- attr.set_device_type("CPU");
- FakeCPUDevice* dev = new FakeCPUDevice(env, attr);
- std::vector<Device*> devices;
- devices.push_back(dev);
- DeviceMgr dvc_mgr(devices);
- FunctionLibraryDefinition function_library(OpRegistry::Global(), library);
+class SymbolicGradientEnv {
+ public:
+ SymbolicGradientEnv(int graph_version, const FunctionDefLibrary& library)
+ : graph_version_(graph_version), library_(library) {}
+
+ FunctionLibraryDefinition* function_library() {
+ InitializeIfNeeded();
+ return fld_.get();
+ }
+ FunctionLibraryRuntime* function_library_runtime() {
+ InitializeIfNeeded();
+ return flr_;
+ }
+
+ private:
+ // This initialization is expensive. Do it lazily to avoid paying for it
+ // unless it's needed.
+ void InitializeIfNeeded() {
+ if (flr_) {
+ return;
+ }
+ Env* env = Env::Default();
+ DeviceAttributes attr;
+ attr.set_name("/device:CPU:0");
+ attr.set_device_type("CPU");
+ FakeCPUDevice* dev = new FakeCPUDevice(env, attr);
+ std::vector<Device*> devices;
+ devices.push_back(dev);
+ dvc_mgr_.reset(new DeviceMgr(devices));
+ fld_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), library_));
+ OptimizerOptions optimizer_opts;
+ optimizer_opts.set_do_function_inlining(true);
+ pflr_.reset(new ProcessFunctionLibraryRuntime(
+ dvc_mgr_.get(), env, graph_version_, fld_.get(), optimizer_opts));
+ flr_ = pflr_->GetFLR(dev->name());
+ }
+
+ const int graph_version_;
+ const FunctionDefLibrary& library_;
+ std::unique_ptr<DeviceMgr> dvc_mgr_;
+ std::unique_ptr<FunctionLibraryDefinition> fld_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* flr_ = nullptr;
+};
+Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
+ GraphDef* inlined_graph) {
GraphDef graph_def;
// Create a node to anchor the gradient inputs
}
// Convert the graphdef to a graph
- OptimizerOptions optimizer_opts;
- optimizer_opts.set_do_function_inlining(true);
- ProcessFunctionLibraryRuntime pflr(&dvc_mgr, env,
- inlined_graph->versions().producer(),
- &function_library, optimizer_opts);
- FunctionLibraryRuntime* flr = pflr.GetFLR(dev->name());
- CHECK(flr);
GraphConstructorOptions graph_ctor_opts;
graph_ctor_opts.allow_internal_ops = true;
graph_ctor_opts.expect_device_spec = false;
- Graph graph(function_library);
+ Graph graph(env->function_library());
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(graph_ctor_opts, graph_def, &graph));
// Recursively inline the functions until there is nothing more to inline. We
// should at least expand one function.
int counter = 0;
- while (counter < 50 && ExpandInlineFunctions(flr, &graph)) {
+ while (counter < 50 &&
+ ExpandInlineFunctions(env->function_library_runtime(), &graph)) {
++counter;
}
return Status::OK();
}
- *optimized_graph->mutable_versions() = item.graph.versions();
+ SymbolicGradientEnv env(item.graph.versions().producer(),
+ item.graph.library());
+
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "SymbolicGradient") {
- TF_RETURN_IF_ERROR(
- InlineSymbolicGradient(node, item.graph.library(), optimized_graph));
+ TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &env, optimized_graph));
continue;
}
auto it = functions.find(node.op());
// inlined based on the context in which they're instantiated.
// TODO(bsteiner): trim the library to remove unused function definitions
+ *optimized_graph->mutable_versions() = item.graph.versions();
*optimized_graph->mutable_library() = item.graph.library();
return Status::OK();