From 9ba9cf259b38af8425f4ee3b8967b811575fd149 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Feb 2018 19:46:27 -0800 Subject: [PATCH] Make sure rounding and handling of denormals in Grappler is the same as in TensorFlow. Enable constant folding for more types, particularly on GPUs. PiperOrigin-RevId: 187120456 --- tensorflow/core/grappler/op_types.cc | 6 +- .../core/grappler/optimizers/constant_folding.cc | 96 +++++++++++++--------- tensorflow/core/kernels/constant_op.cc | 11 +++ 3 files changed, 74 insertions(+), 39 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index e225e99..9b3755d 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -354,7 +354,8 @@ bool IsFreeOfSideEffect(const NodeDef& node) { return false; } const OpDef* op_def = nullptr; - Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + const string& op_name = node.op(); + Status status = OpRegistry::Global()->LookUpOpDef(op_name, &op_def); if (!status.ok()) { return false; } @@ -368,7 +369,8 @@ bool IsFreeOfSideEffect(const NodeDef& node) { } } // Some nodes do in-place updates on regular tensor inputs. - if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace")) { + if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace") || + StringPiece(op_name).starts_with("Inplace")) { return false; } return true; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 10ca7dc..a5417aa 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -35,7 +35,9 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/setround.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/bcast.h" @@ -51,7 +53,14 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {} ~EigenThreadPoolWrapper() override {} void Schedule(std::function fn) override { - pool_->Schedule(std::move(fn)); + auto wrapped = [=]() { + // TensorFlow flushes denormals to zero and rounds to nearest, so we do + // the same here. + port::ScopedFlushDenormal flush; + port::ScopedSetRound round(FE_TONEAREST); + fn(); + }; + pool_->Schedule(std::move(wrapped)); } int NumThreads() const override { return pool_->NumThreads(); } int CurrentThreadId() const override { return pool_->CurrentThreadId(); } @@ -292,16 +301,16 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // graph. const int node_count = graph_->node_size(); for (int i = 0; i < node_count; ++i) { - NodeDef& node = *graph_->mutable_node(i); - const string op = node.op(); + NodeDef* node = graph_->mutable_node(i); + const string op = node->op(); if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") { continue; } const std::vector& output = - properties.GetOutputProperties(node.name()); + properties.GetOutputProperties(node->name()); const std::vector& input = - properties.GetInputProperties(node.name()); + properties.GetInputProperties(node->name()); if (input.empty() || output.empty()) { continue; } @@ -328,35 +337,35 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // could have multiple outputs). if (op == "Shape" || op == "Size" || op == "Rank") { // Replace the node with the corresponding constant. - node.set_op("Const"); - node.clear_attr(); - (*node.mutable_attr())["dtype"].set_type(type); + node->set_op("Const"); + node->clear_attr(); + (*node->mutable_attr())["dtype"].set_type(type); value.AsProtoTensorContent( - (*node.mutable_attr())["value"].mutable_tensor()); + (*node->mutable_attr())["value"].mutable_tensor()); // Turn the data input into a control dependency: this is needed to // ensure that the constant value will only be run in the // cases where the shape/rank/size would have been run in // the original graph. Additional inputs are extra control string ctrl_dep = - AddControlDependency(node.input(0), graph_, node_map_.get()); - node.set_input(0, ctrl_dep); - node_map_->AddOutput(NodeName(ctrl_dep), node.name()); + AddControlDependency(node->input(0), graph_, node_map_.get()); + node->set_input(0, ctrl_dep); + node_map_->AddOutput(NodeName(ctrl_dep), node->name()); } else { - auto outputs = node_map_->GetOutputs(node.name()); + auto outputs = node_map_->GetOutputs(node->name()); for (const auto& output : outputs) { for (int k = 0; k < output->input_size(); ++k) { int port; string node_name = ParseNodeName(output->input(k), &port); - if (node_name == node.name() && port == j) { + if (node_name == node->name() && port == j) { // Create a const node as ShapeN's output if not already. const string const_name = - OptimizedNodeName(node, strings::StrCat("-matshapes-", j)); + OptimizedNodeName(*node, strings::StrCat("-matshapes-", j)); if (node_map_->GetNode(const_name) == nullptr) { NodeDef* added_node = graph_->add_node(); added_node->set_name(const_name); added_node->set_op("Const"); - added_node->set_device(node.device()); + added_node->set_device(node->device()); node_map_->AddNode(added_node->name(), added_node); (*added_node->mutable_attr())["dtype"].set_type(type); value.AsProtoTensorContent( @@ -364,7 +373,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // We add a control dependency to the original ShapeN node, // so that the node will only be run if all inputs of the // original ShapeN node are run. - string ctrl_dep = AddControlDependency(node.name(), graph_, + string ctrl_dep = AddControlDependency(node->name(), graph_, node_map_.get()); *added_node->add_input() = ctrl_dep; node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); @@ -679,7 +688,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) { return false; } - // Skip control flow nodes, they can't be folded + // Skip control flow nodes, they can't be folded. if (ModifiesFrameInfo(node)) { return false; } @@ -688,12 +697,16 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } - // Skips ops that don't benefit from folding. - const string& op = node.op(); + // Don't fold stateful ops such as TruncatedNormal. + if (!IsFreeOfSideEffect(node)) { + return false; + } - if (op.find("Placeholder") == 0) { + // Skips ops that don't benefit from folding. + if (IsPlaceholder(node)) { return false; } + const string& op = node.op(); if (op.find("Save") != string::npos || op.find("Restore") != string::npos || op.find("Reader") != string::npos) { return false; @@ -705,16 +718,12 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } - // Don't fold stateful ops such as TruncatedNormal. const OpDef* op_def = nullptr; Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok()) { return false; } - if (op_def->is_stateful()) { - return false; - } - + // Don't fold ops without outputs. if (op_def->output_arg_size() == 0) { return false; } @@ -779,8 +788,11 @@ Status CreateConstantTensorAttrValue(DataType type, double value, SET_TENSOR_VAL_CASE(DT_FLOAT, float, float); SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double); SET_TENSOR_VAL_CASE(DT_INT64, int64, int64); + SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64); SET_TENSOR_VAL_CASE(DT_INT32, int32, int); + SET_TENSOR_VAL_CASE(DT_UINT32, int32, int); SET_TENSOR_VAL_CASE(DT_INT16, int32, int); + SET_TENSOR_VAL_CASE(DT_UINT16, int32, int); SET_TENSOR_VAL_CASE(DT_INT8, int32, int); SET_TENSOR_VAL_CASE(DT_UINT8, int32, int); SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool); @@ -843,10 +855,16 @@ Status ConstantFolding::CreateNodeDef(const string& name, POPULATE_TENSOR_PROTO(tensor, t, double, double); case DT_INT64: POPULATE_TENSOR_PROTO(tensor, t, int64, int64); + case DT_UINT64: + POPULATE_TENSOR_PROTO(tensor, t, uint64, int64); case DT_INT32: POPULATE_TENSOR_PROTO(tensor, t, int32, int); + case DT_UINT32: + POPULATE_TENSOR_PROTO(tensor, t, uint32, int); case DT_INT16: POPULATE_TENSOR_PROTO(tensor, t, int16, int); + case DT_UINT16: + POPULATE_TENSOR_PROTO(tensor, t, uint16, int); case DT_INT8: POPULATE_TENSOR_PROTO(tensor, t, int8, int); case DT_UINT8: @@ -1166,9 +1184,8 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { std::unordered_set processed_nodes; std::deque queue; for (int i = 0; i < graph_->node_size(); i++) { - auto node = graph_->mutable_node(i); - if (IsFoldable(*node)) { - queue.push_back(node); + if (IsFoldable(graph_->node(i))) { + queue.push_back(graph_->mutable_node(i)); } } while (!queue.empty()) { @@ -1203,8 +1220,8 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { int last = output->node_size() - 1; for (int i = output->node_size() - 1; i >= 0; --i) { const NodeDef& node = output->node(i); - auto outputs = node_map_->GetOutputs(node.name()); - if (outputs.empty()) { + auto fanout = node_map_->GetOutputs(node.name()); + if (fanout.empty()) { output->mutable_node()->SwapElements(i, last); last--; } @@ -1216,8 +1233,8 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { // If no fetch nodes is provided, we conservatively // keep all nodes in the original graph in case users need to fetch // their values. - auto outputs = node_map_->GetOutputs(node.name()); - if (!outputs.empty() || !has_fetch_ || + auto fanout = node_map_->GetOutputs(node.name()); + if (!fanout.empty() || !has_fetch_ || nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { auto added_node = output->add_node(); *added_node = node; @@ -1331,14 +1348,14 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const { // IS_ONES_CASE(DT_HALF); IS_ONES_CASE(DT_FLOAT); IS_ONES_CASE(DT_DOUBLE); + IS_ONES_CASE(DT_COMPLEX64); + IS_ONES_CASE(DT_COMPLEX128); IS_ONES_CASE(DT_UINT8); IS_ONES_CASE(DT_INT8); IS_ONES_CASE(DT_UINT16); IS_ONES_CASE(DT_INT16); IS_ONES_CASE(DT_INT32); IS_ONES_CASE(DT_INT64); - IS_ONES_CASE(DT_COMPLEX64); - IS_ONES_CASE(DT_COMPLEX128); default: VLOG(1) << "Unsupported type " << DataTypeString(dtype); return false; @@ -1362,14 +1379,14 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { // IS_ZEROS_CASE(DT_HALF); IS_ZEROS_CASE(DT_FLOAT); IS_ZEROS_CASE(DT_DOUBLE); + IS_ZEROS_CASE(DT_COMPLEX64); + IS_ZEROS_CASE(DT_COMPLEX128); IS_ZEROS_CASE(DT_UINT8); IS_ZEROS_CASE(DT_INT8); IS_ZEROS_CASE(DT_UINT16); IS_ZEROS_CASE(DT_INT16); IS_ZEROS_CASE(DT_INT32); IS_ZEROS_CASE(DT_INT64); - IS_ZEROS_CASE(DT_COMPLEX64); - IS_ZEROS_CASE(DT_COMPLEX128); default: VLOG(1) << "Unsupported type " << DataTypeString(dtype); return false; @@ -1869,6 +1886,11 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { + // TensorFlow flushes denormals to zero and rounds to nearest, so we do + // the same here. + port::ScopedFlushDenormal flush; + port::ScopedSetRound round(FE_TONEAREST); + nodes_to_preserve_ = item.NodesToPreserve(); for (const auto& feed : item.feed) { feed_nodes_.insert(NodeName(feed.first)); diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index fdb03a5..312c1a4 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -105,7 +105,12 @@ REGISTER_KERNEL(GPU, int8); REGISTER_KERNEL(GPU, qint8); REGISTER_KERNEL(GPU, uint16); REGISTER_KERNEL(GPU, int16); +REGISTER_KERNEL(GPU, qint16); +REGISTER_KERNEL(GPU, quint16); +REGISTER_KERNEL(GPU, uint32); +REGISTER_KERNEL(GPU, qint32); REGISTER_KERNEL(GPU, int64); +REGISTER_KERNEL(GPU, uint64); REGISTER_KERNEL(GPU, complex64); REGISTER_KERNEL(GPU, complex128); REGISTER_KERNEL(GPU, bool); @@ -122,9 +127,15 @@ REGISTER_SYCL_KERNEL(SYCL, float); REGISTER_SYCL_KERNEL(SYCL, double); REGISTER_SYCL_KERNEL(SYCL, uint8); REGISTER_SYCL_KERNEL(SYCL, int8); +REGISTER_SYCL_KERNEL(SYCL, qint8); REGISTER_SYCL_KERNEL(SYCL, uint16); REGISTER_SYCL_KERNEL(SYCL, int16); +REGISTER_SYCL_KERNEL(SYCL, qint16); +REGISTER_SYCL_KERNEL(SYCL, quint16); +REGISTER_SYCL_KERNEL(SYCL, uint32); +REGISTER_SYCL_KERNEL(SYCL, qint32); REGISTER_SYCL_KERNEL(SYCL, int64); +REGISTER_SYCL_KERNEL(SYCL, uint64); REGISTER_SYCL_KERNEL(SYCL, bool); #undef REGISTER_SYCL_KERNEL #endif -- 2.7.4