Adding support for RandomUniform. Basic support for op import/export of RandomUniform...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 2 Apr 2018 15:00:03 +0000 (08:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 2 Apr 2018 15:03:31 +0000 (08:03 -0700)
PiperOrigin-RevId: 191293897

tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/export_tensorflow.cc
tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc [new file with mode: 0644]
tensorflow/contrib/lite/toco/import_tensorflow.cc
tensorflow/contrib/lite/toco/model.h
tensorflow/contrib/lite/toco/toco_tooling.cc
tensorflow/contrib/lite/toco/tooling_util.cc

index d552de3..2dd689a 100644 (file)
@@ -259,6 +259,7 @@ cc_library(
         "graph_transformations/resolve_constant_fake_quant.cc",
         "graph_transformations/resolve_constant_fill.cc",
         "graph_transformations/resolve_constant_gather.cc",
+        "graph_transformations/resolve_constant_random_uniform.cc",
         "graph_transformations/resolve_constant_range.cc",
         "graph_transformations/resolve_constant_shape_or_rank.cc",
         "graph_transformations/resolve_constant_stack.cc",
index 22a2335..e88357f 100644 (file)
@@ -1711,6 +1711,23 @@ void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
   (*topk_op->mutable_attr())["sorted"].set_b(true);
 }
 
+void ConvertRandomUniformOperator(const Model& model,
+                                  const RandomUniformOperator& src_op,
+                                  GraphDef* tensorflow_graph) {
+  CHECK(tensorflow_graph != nullptr);
+  auto* new_op = tensorflow_graph->add_node();
+  new_op->set_op("RandomUniform");
+  CHECK_EQ(src_op.inputs.size(), 1);
+  new_op->set_name(src_op.outputs[0]);
+  *new_op->add_input() = src_op.inputs[0];
+  const auto shape_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+  (*new_op->mutable_attr())["T"].set_type(shape_type);
+  (*new_op->mutable_attr())["dtype"].set_type(
+      GetTensorFlowDataType(src_op.dtype));
+  (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
+  (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
+}
+
 void ConvertOperator(const Model& model, const Operator& src_op,
                      GraphDef* tensorflow_graph) {
   if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1897,6 +1914,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
     ConvertTransposeConvOperator(
         model, static_cast<const TransposeConvOperator&>(src_op),
         tensorflow_graph);
+  } else if (src_op.type == OperatorType::kRandomUniform) {
+    ConvertRandomUniformOperator(
+        model, static_cast<const RandomUniformOperator&>(src_op),
+        tensorflow_graph);
   } else {
     LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
   }
index 640afc7..76ec02a 100644 (file)
@@ -173,6 +173,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack)
index 778da39..89ad58f 100644 (file)
@@ -50,78 +50,108 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
     old_output_data_types[output] = model->GetArray(output).data_type;
   }
   // Do the actual output data types propagation.
-  if (op->type == OperatorType::kDequantize ||
-      op->type == OperatorType::kResizeBilinear) {
-    // These operators unconditionally produce float outputs
-    SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
-  } else if (op->type == OperatorType::kTensorFlowLess ||
-             op->type == OperatorType::kTensorFlowLessEqual ||
-             op->type == OperatorType::kTensorFlowGreater ||
-             op->type == OperatorType::kTensorFlowGreaterEqual) {
-    // These operators unconditionally produce bool outputs
-    SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
-  } else if (op->type == OperatorType::kRank ||
-             op->type == OperatorType::kTensorFlowShape) {
-    // These operators only produce int32 outputs.
-    SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
-  } else if (op->type == OperatorType::kTensorFlowSplit ||
-             op->type == OperatorType::kTensorFlowConcat ||
-             op->type == OperatorType::kFill) {
-    // These operators produce an output with the same type as their 2nd input
-    CHECK_GE(op->inputs.size(), 2);
-    const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
-    SetDataTypeForAllOutputs(model, op, data_type);
-  } else if (op->type == OperatorType::kTransposeConv) {
-    // These operators produce an output with the same type as their 3rd input
-    CHECK_GE(op->inputs.size(), 3);
-    const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
-    SetDataTypeForAllOutputs(model, op, data_type);
-  } else if (op->type == OperatorType::kCast) {
-    // Data type of the Cast op is specified.
-    CHECK_EQ(op->outputs.size(), 1);
-    auto* cast_op = static_cast<CastOperator*>(op);
-    model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
-  } else if (op->type == OperatorType::kArgMax) {
-    // Data type of the ArgMax op is specified.
-    CHECK_EQ(op->outputs.size(), 1);
-    auto* argmax_op = static_cast<ArgMaxOperator*>(op);
-    model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
-  } else if (op->type == OperatorType::kRange) {
-    auto* range_op = static_cast<RangeOperator*>(op);
-    // Output type of the Range op can be set via an attribute
-    ArrayDataType data_type;
-    if (range_op->dtype != ArrayDataType::kNone) {
-      // Use the type if specified
-      data_type = range_op->dtype;
-    } else {
-      // Otherwise use the first input
-      CHECK_GE(op->inputs.size(), 1);
-      data_type = model->GetArray(op->inputs[0]).data_type;
+  switch (op->type) {
+    case OperatorType::kDequantize:
+    case OperatorType::kResizeBilinear:
+      // These operators unconditionally produce float outputs
+      SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
+      break;
+    case OperatorType::kTensorFlowLess:
+    case OperatorType::kTensorFlowLessEqual:
+    case OperatorType::kTensorFlowGreater:
+    case OperatorType::kTensorFlowGreaterEqual:
+      // These operators unconditionally produce bool outputs
+      SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
+      break;
+    case OperatorType::kRank:
+    case OperatorType::kTensorFlowShape:
+      // These operators only produce int32 outputs.
+      SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
+      break;
+    case OperatorType::kTensorFlowSplit:
+    case OperatorType::kTensorFlowConcat:
+    case OperatorType::kFill: {
+      // These operators produce an output with the same type as their 2nd input
+      CHECK_GE(op->inputs.size(), 2);
+      const ArrayDataType data_type = model->GetArray(op->inputs[1]).data_type;
+      SetDataTypeForAllOutputs(model, op, data_type);
+      break;
     }
-    CHECK_EQ(op->outputs.size(), 1);
-    SetDataTypeForAllOutputs(model, op, data_type);
-  } else if (op->type == OperatorType::kTensorFlowUnsupported) {
-    auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
-    // Some output tensors from the op could be eliminated by optimization.
-    // This can make unsupported_op->output_data_types have more elements than
-    // op->outputs.
-    if (unsupported_op->output_data_types.size() < op->outputs.size()) {
+    case OperatorType::kTransposeConv: {
+      // These operators produce an output with the same type as their 3rd input
+      CHECK_GE(op->inputs.size(), 3);
+      const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
+      SetDataTypeForAllOutputs(model, op, data_type);
+      break;
+    }
+    case OperatorType::kCast: {
+      // Data type of the Cast op is specified.
+      CHECK_EQ(op->outputs.size(), 1);
+      auto* cast_op = static_cast<CastOperator*>(op);
+      model->GetArray(op->outputs[0]).data_type = cast_op->dst_data_type;
+      break;
+    }
+    case OperatorType::kArgMax: {
+      // Data type of the ArgMax op is specified.
+      CHECK_EQ(op->outputs.size(), 1);
+      auto* argmax_op = static_cast<ArgMaxOperator*>(op);
+      model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
+      break;
+    }
+    case OperatorType::kRange: {
+      auto* range_op = static_cast<RangeOperator*>(op);
+      // Output type of the Range op can be set via an attribute
+      ArrayDataType data_type;
+      if (range_op->dtype != ArrayDataType::kNone) {
+        // Use the type if specified
+        data_type = range_op->dtype;
+      } else {
+        // Otherwise use the first input
+        CHECK_GE(op->inputs.size(), 1);
+        data_type = model->GetArray(op->inputs[0]).data_type;
+      }
+      CHECK_EQ(op->outputs.size(), 1);
+      SetDataTypeForAllOutputs(model, op, data_type);
+      break;
+    }
+    case OperatorType::kRandomUniform: {
+      auto* rand_op = static_cast<RandomUniformOperator*>(op);
+      // The output type of RandomUniform is specified with an attribute
+      if (rand_op->dtype == ArrayDataType::kNone) {
+        return false;
+      }
+      CHECK_EQ(op->outputs.size(), 1);
+      SetDataTypeForAllOutputs(model, op, rand_op->dtype);
+      break;
+    }
+    case OperatorType::kTensorFlowUnsupported: {
+      auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
+      // Some output tensors from the op could be eliminated by optimization.
+      // This can make unsupported_op->output_data_types have more elements than
+      // op->outputs.
+      if (unsupported_op->output_data_types.size() < op->outputs.size()) {
+        return false;
+      }
+      for (int i = 0; i < op->outputs.size(); ++i) {
+        auto output = op->outputs[i];
+        auto data_type = unsupported_op->output_data_types[i];
+        model->GetArray(output).data_type = data_type;
+      }
+      break;
+    }
+    case OperatorType::kExpandDims: {
+      // Yield on ExpandDim until it is converted to Reshape
       return false;
     }
-    for (int i = 0; i < op->outputs.size(); ++i) {
-      auto output = op->outputs[i];
-      auto data_type = unsupported_op->output_data_types[i];
-      model->GetArray(output).data_type = data_type;
+    default: {
+      // These operators produce outputs with the same type as their 1st input
+      CHECK_GT(op->inputs.size(), 0);
+      const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+      SetDataTypeForAllOutputs(model, op, data_type);
+      break;
     }
-  } else if (op->type == OperatorType::kExpandDims) {
-    // Yield on ExpandDim until it is converted to Reshape
-    return false;
-  } else {
-    // These operators produce outputs with the same type as their 1st input
-    CHECK_GT(op->inputs.size(), 0);
-    const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
-    SetDataTypeForAllOutputs(model, op, data_type);
   }
+
   // Return true if any output data type changed, false if none changed.
   for (const auto& output : op->outputs) {
     if (old_output_data_types[output] != model->GetArray(output).data_type) {
index 676736c..b96d698 100644 (file)
@@ -392,8 +392,7 @@ void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
                          depth * block_size * block_size}));
 }
 
-void ProcessFillOperator(Model* model, FillOperator* op) {
-  CHECK_EQ(op->inputs.size(), 2);
+void ProcessOpWithShapeInput(Model* model, Operator* op) {
   CHECK_EQ(op->outputs.size(), 1);
   auto& output_array = model->GetArray(op->outputs[0]);
   if (output_array.has_shape()) {
@@ -1529,7 +1528,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
                                   static_cast<SpaceToDepthOperator*>(op));
       break;
     case OperatorType::kFill:
-      ProcessFillOperator(model, static_cast<FillOperator*>(op));
+      CHECK_EQ(op->inputs.size(), 2);
+      ProcessOpWithShapeInput(model, op);
       break;
     case OperatorType::kFullyConnected:
       ProcessFullyConnectedOperator(model,
@@ -1659,6 +1659,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
       // transforms that remove them, so we avoid propagating shapes through
       // them and let things settle once they've been removed.
       break;
+    case OperatorType::kRandomUniform:
+      CHECK_EQ(op->inputs.size(), 1);
+      ProcessOpWithShapeInput(model, op);
+      break;
     default:
       // Unimplemented, another graph transformation should drop it.
       LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
new file mode 100644 (file)
index 0000000..88d06d7
--- /dev/null
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace toco {
+
+template <ArrayDataType Type>
+bool ComputeRandomUniformArray(Model* model, RandomUniformOperator* op) {
+  typedef tensorflow::random::UniformDistribution<
+      tensorflow::random::PhiloxRandom, DataType<Type>>
+      Distribution;
+
+  // Allocate output
+  auto& output_array = model->GetArray(op->outputs[0]);
+  CHECK(output_array.data_type == Type);
+  std::vector<DataType<Type>>& data =
+      output_array.GetMutableBuffer<Type>().data;
+  data.resize(RequiredBufferSizeForShape(output_array.shape()));
+
+  // We use the same random number generator and distribution as TensorFlow to
+  // produce the exact same values given the same seeds. See
+  // tensorflow::functor::FillPhiloxRandomTask<Distribution, false> in
+  // //third_party/tensorflow/core/kernels/random_op.cc for the implementation.
+  tensorflow::random::PhiloxRandom generator(op->seed, op->seed2);
+  Distribution dist;
+
+  // The generator creates Distribution::kResultElementCount samples at a time.
+  size_t offset = 0;
+  size_t num_samples = Distribution::kResultElementCount;
+  while (offset < data.size()) {
+    const typename Distribution::ResultType samples = dist(&generator);
+    std::copy(&samples[0],
+              &samples[0] + std::min(num_samples, data.size() - offset),
+              &data[0] + offset);
+    offset += num_samples;
+  }
+
+  return true;
+}
+
+bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
+  const auto it = model->operators.begin() + op_index;
+  auto* base_op = it->get();
+  if (base_op->type != OperatorType::kRandomUniform) {
+    return false;
+  }
+  auto* op = static_cast<RandomUniformOperator*>(base_op);
+
+  CHECK_EQ(op->inputs.size(), 1);
+  CHECK_EQ(op->outputs.size(), 1);
+
+  auto& output_array = model->GetArray(op->outputs[0]);
+  if (output_array.data_type == ArrayDataType::kNone) {
+    // Yield until the output type has been set by PropagateArrayDataTypes
+    return false;
+  }
+
+  if (!output_array.has_shape()) {
+    // Yield until the output shape has been set by PropagateFixedShapes
+    return false;
+  }
+
+  if ((op->seed == 0) && (op->seed2 == 0)) {
+    LOG(WARNING) << "RandomUniform op outputting \"" << op->outputs[0]
+                 << "\" is truly random (using /dev/random system entropy). "
+                    "Therefore, cannot resolve as constant. Set \"seed\" or "
+                    "\"seed2\" attr non-zero to fix this";
+    return false;
+  }
+
+  switch (output_array.data_type) {
+    case ArrayDataType::kFloat:
+      if (!ComputeRandomUniformArray<ArrayDataType::kFloat>(model, op)) {
+        return false;
+      }
+      break;
+    // For future support of double or half.
+    // case ArrayDataType::kDouble...
+    default:
+      LOG(FATAL)
+          << "Unsupported data type given to RandomUniform op with output \""
+          << op->outputs[0] << "\"";
+      break;
+  }
+
+  // Erase input arrays if no longer used
+  toco::DeleteArrayIfUsedOnce(op->inputs[0], model);
+
+  // Erase the operator
+  model->operators.erase(it);
+
+  return true;
+}
+
+}  // namespace toco
index c26e4bd..8764790 100644 (file)
@@ -74,7 +74,7 @@ const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
   return attr.s();
 }
 
-int GetIntAttr(const NodeDef& node, const string& attr_name) {
+int64 GetIntAttr(const NodeDef& node, const string& attr_name) {
   CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
                                   << node.DebugString();
   const auto& attr = node.attr().at(attr_name);
@@ -569,6 +569,23 @@ void ConvertBiasAddOperator(const NodeDef& node,
   model->operators.emplace_back(biasadd);
 }
 
+void ConvertRandomUniform(const NodeDef& node,
+                          const TensorFlowImportFlags& tf_import_flags,
+                          Model* model) {
+  CHECK_EQ(node.op(), "RandomUniform");
+  CheckInputsCount(node, tf_import_flags, 1);
+
+  CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
+  auto op = absl::make_unique<RandomUniformOperator>();
+  op->inputs.push_back(node.input(0));
+  op->outputs.push_back(node.name());
+  op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
+  op->seed = GetIntAttr(node, "seed");
+  op->seed2 = GetIntAttr(node, "seed2");
+  CHECK(model != nullptr);
+  model->operators.emplace_back(std::move(op));
+}
+
 void ConvertReluOperator(const NodeDef& node,
                          const TensorFlowImportFlags& tf_import_flags,
                          Model* model) {
@@ -1931,7 +1948,7 @@ void ConvertTopKV2Operator(const NodeDef& node,
   // K can be encoded as attr (TopK) convert it to a const.
   if (HasAttr(node, "k")) {
     string k_array = CreateConstArray<ArrayDataType::kInt32>(
-        model, node.name() + "k", {GetIntAttr(node, "k")});
+        model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
     op->inputs.push_back(k_array);
   } else {
     CheckInputsCount(node, tf_import_flags, 2);
@@ -2168,6 +2185,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
     } else if (node.op() == "DynamicStitch" ||
                node.op() == "ParallelDynamicStitch") {
       ConvertDynamicStitchOperator(node, tf_import_flags, model);
+    } else if (node.op() == "RandomUniform") {
+      ConvertRandomUniform(node, tf_import_flags, model);
     } else {
       ConvertUnsupportedOperator(node, tf_import_flags, model);
     }
index 5199e29..64269d3 100644 (file)
@@ -60,6 +60,7 @@ enum class OperatorType {
   kMaxPool,
   kFakeQuant,
   kMul,
+  kRandomUniform,
   kRange,
   kRank,
   kRelu,
@@ -946,6 +947,13 @@ struct FloorModOperator : Operator {
   FloorModOperator() : Operator(OperatorType::kFloorMod) {}
 };
 
+struct RandomUniformOperator : Operator {
+  RandomUniformOperator() : Operator(OperatorType::kRandomUniform) {}
+  ArrayDataType dtype = ArrayDataType::kNone;
+  int64 seed;
+  int64 seed2;
+};
+
 // Creates a sequence of numbers that begins at start and extends by increments
 // of delta up to but not including limit.
 //
index 30dd6fa..0c52f50 100644 (file)
@@ -79,6 +79,7 @@ void MakeGeneralGraphTransformationsSet(
   transformations->Add(new ResolveConstantBinaryOperator);
   transformations->Add(new ResolveConstantFill);
   transformations->Add(new ResolveConstantGather);
+  transformations->Add(new ResolveConstantRandomUniform);
   transformations->Add(new ResolveConstantRange);
   transformations->Add(new ResolveConstantStack);
   transformations->Add(new ResolveConstantStridedSlice);
index f3f5048..060c52e 100644 (file)
@@ -297,6 +297,7 @@ const char* OperatorTypeName(OperatorType type) {
     HANDLE_OPERATORTYPENAME_CASE(L2Pool)
     HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
     HANDLE_OPERATORTYPENAME_CASE(Mul)
+    HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
     HANDLE_OPERATORTYPENAME_CASE(Relu)
     HANDLE_OPERATORTYPENAME_CASE(Relu1)
     HANDLE_OPERATORTYPENAME_CASE(Relu6)