From c52ac787c6b716f5abbcebca2a57e3dc3f157200 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 8 Feb 2018 13:30:17 -0800 Subject: [PATCH] Internal functional _If and _While ops. PiperOrigin-RevId: 185042663 --- tensorflow/core/BUILD | 1 + tensorflow/core/kernels/BUILD | 11 + tensorflow/core/kernels/functional_ops.cc | 322 ++++++++++++++++++++++++++++++ tensorflow/core/ops/functional_ops.cc | 59 ++++++ 4 files changed, 393 insertions(+) create mode 100644 tensorflow/core/kernels/functional_ops.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index d0c9a72..a7f8533 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -838,6 +838,7 @@ cc_library( "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index e7192ec..523e395 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1942,6 +1942,17 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "functional_ops", + prefix = "functional_ops", + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + cc_library( name = "image", deps = [ diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc new file mode 100644 index 0000000..b687088 --- /dev/null +++ b/tensorflow/core/kernels/functional_ops.cc @@ -0,0 +1,322 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef FunctionLibraryRuntime::Handle FHandle; +typedef std::vector TensorVec; + +namespace { + +// Helper to instantiate function "func" in the library "lib". +Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, + FunctionLibraryRuntime::Handle* handle) { + return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle); +} + +// If "t" is a scalar of a supported type, returns t != 0 in "*v". +Status ToBool(gtl::ArraySlice t, bool* v) { + if (t.size() != 1) { + return errors::InvalidArgument( + "Expected a single scalar which can be converted to a boolean, got ", + t.size(), " tensors."); + } + if (TensorShapeUtils::IsScalar(t[0].shape())) { + switch (t[0].dtype()) { +#define CASE(T) \ + case DataTypeToEnum::value: \ + *v = t[0].scalar()() != 0; \ + break; + + CASE(float); + CASE(double); + CASE(int32); + CASE(uint8); + CASE(int16); + CASE(int8); + CASE(int64); +#undef CASE + case DT_BOOL: + *v = t[0].scalar()(); + break; + case DT_STRING: + *v = !t[0].scalar()().empty(); + break; + default: + return errors::InvalidArgument(DataTypeString(t[0].dtype()), + " cannot be converted to a boolean"); + } + } else { + *v = t[0].NumElements() > 0; + } + return Status::OK(); +} + +// Sets "rets" to be the output of "ctx". Validates rets' types based +// on "kernel". +Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, + gtl::ArraySlice rets) { + if (rets.size() != ctx->num_outputs()) { + return errors::Internal("Expect to produce ", ctx->num_outputs(), + " tensors, but only get ", rets.size()); + } + for (int i = 0; i < rets.size(); ++i) { + if (rets[i].dtype() != kernel->output_type(i)) { + return errors::Internal("Expect ", i, "-th output is of type ", + DataTypeString(kernel->output_type(i)), + " but get ", DataTypeString(rets[i].dtype())); + } + ctx->set_output(i, rets[i]); + } + return Status::OK(); +} + +void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, + bool always_collect_stats) { + opts->step_id = ctx->step_id(); + opts->rendezvous = ctx->rendezvous(); + opts->cancellation_manager = ctx->cancellation_manager(); + if (always_collect_stats) { + opts->stats_collector = ctx->stats_collector(); + } + opts->runner = ctx->runner(); +} + +} // end namespace + +class FunctionalIf : public AsyncOpKernel { + public: + explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + auto lib = ctx->function_library(); + OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func)); + OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func)); + OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_)); + } + + ~FunctionalIf() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + bool cond; + OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); + (new State(this, ctx, cond, done))->Start(); + } + + private: + FHandle then_handle_; + FHandle else_handle_; + + class State { + public: + State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond, + DoneCallback done) + : kernel_(kernel), + ctx_(ctx), + cond_(cond), + done_(done), + lib_(CHECK_NOTNULL(ctx_->function_library())) { + SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); + for (int i = 1; i < ctx_->num_inputs(); ++i) { + args_.push_back(ctx_->input(i)); + } + } + + ~State() {} + + void Start() { + FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_; + rets_.clear(); + lib_->Run( + // Evaluate one of the branch. + opts_, handle, args_, &rets_, + // Done callback + [this](Status s) { + if (s.ok()) { + s = SetOutputs(kernel_, ctx_, rets_); + } + ctx_->SetStatus(s); + auto done = done_; + delete this; + done(); + }); + } + + private: + FunctionalIf* const kernel_; + OpKernelContext* const ctx_; + const bool cond_; + const DoneCallback done_; + FunctionLibraryRuntime* const lib_; + FunctionLibraryRuntime::Options opts_; + TensorVec args_; + TensorVec rets_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf); +REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), + FunctionalIf); + +class FunctionalWhile : public AsyncOpKernel { + public: + explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_)); + } + + ~FunctionalWhile() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + auto lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library"), done); + + // TODO(b/37549631): Because this op has `SetIsStateful()` in its + // op registration, this kernel may be shared by multiple + // subgraphs, which have different associated + // `FunctionLibraryRuntime` objects and hence different `FHandle` + // namespaces. We currently work around this by caching the map + // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two + // functions this op uses. + FHandle cond_handle; + FHandle body_handle; + { + mutex_lock l(mu_); + const auto iter = handles_.find(lib); + if (iter == handles_.end()) { + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), + done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), + done); + handles_[lib] = {cond_handle, body_handle}; + } else { + cond_handle = iter->second.first; + body_handle = iter->second.second; + } + } + + (new State(this, ctx, cond_handle, body_handle, done))->Start(); + } + + private: + NameAttrList cond_func_; + NameAttrList body_func_; + + mutex mu_; + std::unordered_map> + handles_ GUARDED_BY(mu_); + + class State { + public: + State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle, + FHandle body_handle, DoneCallback done) + : kernel_(kernel), + ctx_(ctx), + cond_handle_(cond_handle), + body_handle_(body_handle), + done_(done), + lib_(CHECK_NOTNULL(ctx_->function_library())) { + SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); + for (int i = 0; i < ctx_->num_inputs(); ++i) { + args_.push_back(ctx_->input(i)); + } + } + + ~State() {} + + void Start() { EvalCond(); } + + private: + FunctionalWhile* const kernel_; + OpKernelContext* const ctx_; + const FHandle cond_handle_; + const FHandle body_handle_; + const DoneCallback done_; + FunctionLibraryRuntime* const lib_; + FunctionLibraryRuntime::Options opts_; + TensorVec args_; + TensorVec rets_; + + void EvalCond() { + lib_->Run( + // Evaluate the condition. + opts_, cond_handle_, args_, &rets_, + // Done cb. + [this](const Status& s) { + if (!s.ok()) { + return Finish(s); + } + StartBody(); + }); + } + + void StartBody() { + bool cond; + Status s = ToBool(rets_, &cond); + if (!s.ok()) { + return Finish(s); + } + if (!cond) { + return Finish(Status::OK()); + } + rets_.clear(); + lib_->Run( + // Evaluate the body. + opts_, body_handle_, args_, &rets_, + // Done callback + [this](const Status& s) { + if (!s.ok()) { + return Finish(s); + } + if (args_.size() != rets_.size()) { + return Finish(errors::InvalidArgument( + "While loop body returned ", rets_.size(), + " arguments. Expected: ", args_.size())); + } + args_.clear(); + using std::swap; + swap(args_, rets_); + EvalCond(); + }); + } + + void Finish(Status s) { + if (s.ok()) { + s = SetOutputs(kernel_, ctx_, args_); + } + ctx_->SetStatus(s); + done_(); + delete this; + } + }; +}; +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile); +REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 515b316..9e18d20 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -48,4 +48,63 @@ REGISTER_OP("RemoteCall") .Attr("Tout: list(type)") .Attr("f: func") .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("_If") + .Input("cond: Tcond") + .Input("input: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("Tin: list(type)") + .Attr("Tout: list(type)") + .Attr("then_branch: func") + .Attr("else_branch: func") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(input) : else_branch(input) + +cond: A Tensor. If the tensor is a scalar of non-boolean type, the + scalar is converted to a boolean according to the + following rule: if the scalar is a numerical value, non-zero means + True and zero means False; if the scalar is a string, non-empty + means True and empty means False. If the tensor is not a scalar, + being empty means False and being non-empty means True. +input: A list of input tensors. +then_branch: A function that takes 'inputs' and returns a list of + tensors, whose types are the same as what else_branch returns. +else_branch: A function that takes 'inputs' and returns a list of + tensors. whose types are the same as what then_branch returns. +)doc"); + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("_While") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(i)); + } + return Status::OK(); + }) + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified + by T. +)doc"); + } // end namespace tensorflow -- 2.7.4