Improve support for DT_HALF and DT_BFLOAT16 in Grappler graph optimizations.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Mar 2018 23:48:31 +0000 (16:48 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 23:51:01 +0000 (16:51 -0700)
Update GrapplerTest::EvaluateNodes to take feeds as an argument, to make it easier to write tests with placeholders.

PiperOrigin-RevId: 190696386

tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h
tensorflow/core/grappler/optimizers/constant_folding_test.cc
tensorflow/core/grappler/optimizers/function_optimizer_test.cc
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
tensorflow/core/grappler/utils.cc
tensorflow/core/grappler/utils/grappler_test.cc
tensorflow/core/grappler/utils/grappler_test.h

index 792f675..ad3edc1 100644 (file)
@@ -158,7 +158,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
-  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
   EXPECT_EQ(1, tensors_expected.size());
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -176,7 +176,7 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
   EXPECT_EQ("c1", new_div.input(0));
   EXPECT_EQ("c1", new_div.input(1));
 
-  auto tensors = EvaluateNodes(output, item.fetch);
+  auto tensors = EvaluateNodes(output, item.fetch, {});
   EXPECT_EQ(1, tensors.size());
   test::ExpectTensorNear<double>(tensors_expected[0], tensors[0], 1e-6);
 }
index bdec73e..22ede19 100644 (file)
@@ -109,33 +109,18 @@ class DeviceSimple : public DeviceBase {
 };
 
 template <typename T>
-bool AllValuesAre(const TensorProto& tensor, const T& value) {
-  // TensorProto represents the content of the tensor in either <type>_val or
-  // tensor_content.
-  typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
-      checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
-  if (!tensor_values->empty()) {
-    for (const T& tensor_value : *tensor_values) {
-      if (tensor_value != value) {
-        return false;
-      }
-    }
-    return true;
+bool AllValuesAre(const TensorProto& proto, const T& value) {
+  Tensor tensor;
+  if (!tensor.FromProto(proto)) {
+    return false;
   }
-  const auto tensor_content_size = tensor.tensor_content().size();
-  if (tensor_content_size > 0) {
-    CHECK_EQ(0, tensor_content_size % sizeof(T));
-    std::vector<T> raw_values(tensor_content_size / sizeof(T));
-    port::CopyToArray(tensor.tensor_content(),
-                      reinterpret_cast<char*>(raw_values.data()));
-    for (int i = 0; i < tensor_content_size / sizeof(T); ++i) {
-      if (raw_values[i] != value) {
-        return false;
-      }
+  auto values = tensor.flat<T>();
+  for (int i = 0; i < tensor.NumElements(); ++i) {
+    if (values(i) != value) {
+      return false;
     }
-    return true;
   }
-  return false;
+  return true;
 }
 
 // Add new_input as a control input to node if it does not already depend on it.
@@ -825,17 +810,23 @@ Status CreateConstantTensorAttrValue(DataType type, double value,
   t->set_dtype(type);
   *t->mutable_tensor_shape() = shape;
   switch (type) {
-    SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
-    SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
-    SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
-    SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
-    SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
-    SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
-    SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
-    SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
-    SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
-    SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
-    SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
+    case DT_HALF:
+      t->add_half_val(static_cast<Eigen::half>(value).x);
+      break;
+    case DT_BFLOAT16:
+      t->add_half_val(static_cast<bfloat16>(value).value);
+      break;
+      SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
+      SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
+      SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
+      SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
+      SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
+      SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
+      SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
+      SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
+      SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
+      SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
+      SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
     default:
       return errors::InvalidArgument("Unsupported type: ", type);
   }
@@ -1388,8 +1379,8 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
   }
   const auto dtype = node.attr().at("dtype").type();
   switch (dtype) {
-    // TODO(rmlarsen): Make DT_HALF case compile.
-    //    IS_ONES_CASE(DT_HALF);
+    IS_ONES_CASE(DT_HALF);
+    IS_ONES_CASE(DT_BFLOAT16);
     IS_ONES_CASE(DT_FLOAT);
     IS_ONES_CASE(DT_DOUBLE);
     IS_ONES_CASE(DT_COMPLEX64);
@@ -1423,8 +1414,8 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
   }
   const auto dtype = node.attr().at("dtype").type();
   switch (dtype) {
-    // TODO(rmlarsen): Make DT_HALF case compile.
-    //    IS_ZEROS_CASE(DT_HALF);
+    IS_ZEROS_CASE(DT_HALF);
+    IS_ZEROS_CASE(DT_BFLOAT16);
     IS_ZEROS_CASE(DT_FLOAT);
     IS_ZEROS_CASE(DT_DOUBLE);
     IS_ZEROS_CASE(DT_COMPLEX64);
@@ -1511,9 +1502,8 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
 }
 
 Status ConstantFolding::ReplaceOperationWithConstant(
-    double value, const TensorShapeProto& shape, NodeDef* node,
-    GraphDef* graph) {
-  AttrValue dtype_attr = node->attr().at("T");
+    double value, const AttrValue& dtype_attr, const TensorShapeProto& shape,
+    NodeDef* node, GraphDef* graph) {
   AttrValue tensor_attr;
   TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value,
                                                    shape, &tensor_attr));
@@ -1947,8 +1937,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
           (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
         const PartialTensorShape shp(output_shape);
         if (shp.IsFullyDefined()) {
-          TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(0, output_shape, node,
-                                                          optimized_graph));
+          AttrValue dtype_attr;
+          if (node->op() == "SparseMatMul") {
+            dtype_attr.set_type(DT_FLOAT);
+          } else {
+            dtype_attr = node->attr().at("T");
+          }
+          TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+              0, dtype_attr, output_shape, node, optimized_graph));
           continue;
         }
         // Even if an input shape is only partially known, we may known that it
index b6645d3..f8a9e90 100644 (file)
@@ -83,7 +83,7 @@ class ConstantFolding : public GraphOptimizer {
   void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node,
                                     GraphDef* graph);
   void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
-  Status ReplaceOperationWithConstant(double value,
+  Status ReplaceOperationWithConstant(double value, const AttrValue& dtype_attr,
                                       const TensorShapeProto& shape,
                                       NodeDef* node, GraphDef* graph);
   void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
index dc9c105..85f8778 100644 (file)
@@ -28,7 +28,59 @@ namespace tensorflow {
 namespace grappler {
 namespace {
 
-class ConstantFoldingTest : public GrapplerTest {};
+class ConstantFoldingTest : public GrapplerTest {
+ protected:
+  template <DataType DTYPE>
+  void SimpleNeutralElementTest() {
+    typedef typename EnumToDataType<DTYPE>::Type T;
+    tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+    Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
+                                ops::Placeholder::Shape(TensorShape({2, 2})));
+    Tensor zeros_t(DTYPE, TensorShape({2, 2}));
+    Tensor ones_t(DTYPE, TensorShape({2, 2}));
+    Tensor x_t(DTYPE, TensorShape({2, 2}));
+    for (int i = 0; i < 4; ++i) {
+      zeros_t.flat<T>()(i) = T(0);
+      ones_t.flat<T>()(i) = T(1);
+      x_t.flat<T>()(i) = T(i + 1);
+    }
+    Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
+    Output ones = ops::Const(s.WithOpName("ones"), ones_t);
+    Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+    Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
+
+    GrapplerItem item;
+    TF_CHECK_OK(s.ToGraphDef(&item.graph));
+    item.fetch = {"mul1", "mul2"};
+    ConstantFolding optimizer(nullptr /* cpu_device */);
+    GraphDef output;
+    Status status = optimizer.Optimize(nullptr, item, &output);
+    TF_EXPECT_OK(status);
+    LOG(INFO) << output.DebugString();
+    EXPECT_EQ(5, output.node_size());
+    for (int i = 0; i < output.node_size(); ++i) {
+      const NodeDef& node = output.node(i);
+      const string& name = node.name();
+      if (name == "mul1") {
+        EXPECT_EQ("Const", node.op());
+        EXPECT_EQ("^x", node.input(0));
+        EXPECT_EQ("^zeros", node.input(1));
+      } else if (name == "mul2") {
+        EXPECT_EQ("Snapshot", node.op());
+        EXPECT_EQ("x", node.input(0));
+        EXPECT_EQ("^ones", node.input(1));
+      }
+    }
+    auto tensors_expected =
+        EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}});
+    auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}});
+    EXPECT_EQ(2, tensors_expected.size());
+    EXPECT_EQ(2, tensors.size());
+    for (int i = 0; i < 2; ++i) {
+      test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+    }
+  }
+};
 
 TEST_F(ConstantFoldingTest, SimpleFolding) {
   // Build a simple graph with a few trivially prunable ops.
@@ -55,8 +107,8 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
   EXPECT_EQ("Const", node_d.op());
 
   std::vector<string> fetch = {"d"};
-  auto tensors_expected = EvaluateNodes(item.graph, fetch);
-  auto tensors = EvaluateNodes(output, fetch);
+  auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+  auto tensors = EvaluateNodes(output, fetch, {});
   EXPECT_EQ(1, tensors_expected.size());
   EXPECT_EQ(1, tensors.size());
   test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
@@ -141,10 +193,10 @@ TEST_F(ConstantFoldingTest, AddTree) {
 
   // Check that the result nodes have the expected value.
   std::vector<string> fetch = {"c3", "c20"};
-  auto tensor_expected = EvaluateNodes(item.graph, fetch);
+  auto tensor_expected = EvaluateNodes(item.graph, fetch, {});
   EXPECT_EQ(fetch.size(), tensor_expected.size());
   fetch = {"add_child", "mul_child"};
-  auto tensors = EvaluateNodes(output, fetch);
+  auto tensors = EvaluateNodes(output, fetch, {});
   EXPECT_EQ(fetch.size(), tensors.size());
   for (int i = 0; i < fetch.size(); i++) {
     test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
@@ -322,6 +374,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
   }
 }
 
+TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
+  SimpleNeutralElementTest<DT_HALF>();
+  SimpleNeutralElementTest<DT_BFLOAT16>();
+}
+
 TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
@@ -379,10 +436,10 @@ TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
 
   // Check that the reciprocals have the expected value.
   std::vector<string> fetch = {"cf_half"};
-  auto tensor_expected = EvaluateNodes(item.graph, fetch);
+  auto tensor_expected = EvaluateNodes(item.graph, fetch, {});
   EXPECT_EQ(fetch.size(), tensor_expected.size());
   fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
-  auto tensors = EvaluateNodes(output, fetch);
+  auto tensors = EvaluateNodes(output, fetch, {});
   EXPECT_EQ(fetch.size(), tensors.size());
   for (int i = 0; i < fetch.size(); i++) {
     test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
@@ -590,8 +647,8 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
   EXPECT_EQ("Const", new_d.op());
 
   std::vector<string> fetch = {"e", "f"};
-  auto tensors_expected = EvaluateNodes(item.graph, fetch);
-  auto tensors = EvaluateNodes(output, fetch);
+  auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+  auto tensors = EvaluateNodes(output, fetch, {});
   EXPECT_EQ(fetch.size(), tensors_expected.size());
   EXPECT_EQ(fetch.size(), tensors.size());
   for (int i = 0; i < fetch.size(); i++) {
@@ -614,7 +671,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
   GrapplerItem item;
   item.fetch.push_back("e");
   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
   EXPECT_EQ(1, tensors_expected.size());
   ConstantFolding optimizer(nullptr /* cpu_device */);
   GraphDef output;
@@ -631,8 +688,8 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
     if (node.name() == "e") {
       EXPECT_EQ("Const", node.op());
       ++found;
-      auto folded = EvaluateNodes(output, {"e"});
-      auto expected = EvaluateNodes(item.graph, {"e"});
+      auto folded = EvaluateNodes(output, {"e"}, {});
+      auto expected = EvaluateNodes(item.graph, {"e"}, {});
       EXPECT_EQ(1, expected.size());
       EXPECT_EQ(1, folded.size());
       test::ExpectTensorEqual<int>(folded[0], expected[0]);
@@ -642,7 +699,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) {
     }
   }
   EXPECT_EQ(1, found);
-  auto tensors = EvaluateNodes(output, item.fetch);
+  auto tensors = EvaluateNodes(output, item.fetch, {});
   EXPECT_EQ(1, tensors.size());
   test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
 }
@@ -678,8 +735,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
     if (node.name() == "i1") {
       EXPECT_EQ("Const", node.op());
       ++found;
-      auto folded = EvaluateNodes(output, {"i1"});
-      auto expected = EvaluateNodes(item.graph, {"i1"});
+      auto folded = EvaluateNodes(output, {"i1"}, {});
+      auto expected = EvaluateNodes(item.graph, {"i1"}, {});
       EXPECT_EQ(1, expected.size());
       EXPECT_EQ(1, folded.size());
       test::ExpectTensorEqual<int>(folded[0], expected[0]);
@@ -689,8 +746,8 @@ TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) {
     if (node.name() == "i2") {
       EXPECT_EQ("Const", node.op());
       ++found;
-      auto folded = EvaluateNodes(output, {"i2"});
-      auto expected = EvaluateNodes(item.graph, {"i2"});
+      auto folded = EvaluateNodes(output, {"i2"}, {});
+      auto expected = EvaluateNodes(item.graph, {"i2"}, {});
       EXPECT_EQ(1, expected.size());
       EXPECT_EQ(1, folded.size());
       test::ExpectTensorEqual<int>(folded[0], expected[0]);
@@ -808,8 +865,8 @@ TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) {
   }
   EXPECT_EQ(8, constant_folded);
 
-  auto expected = EvaluateNodes(item.graph, outputs);
-  auto optimized = EvaluateNodes(output, outputs);
+  auto expected = EvaluateNodes(item.graph, outputs, {});
+  auto optimized = EvaluateNodes(output, outputs, {});
   ASSERT_EQ(expected.size(), optimized.size());
   for (int i = 0; i < expected.size(); ++i) {
     test::ExpectTensorEqual<int>(expected[i], optimized[i]);
@@ -1236,7 +1293,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
   EXPECT_EQ(6, found_nodes);
 
   std::vector<string> fetch = {"out1", "idx1"};
-  auto tensors = EvaluateNodes(output, fetch);
+  auto tensors = EvaluateNodes(output, fetch, {});
   EXPECT_EQ(2, tensors.size());
   const Tensor& out_value = tensors[0];
   EXPECT_EQ(3 * 5, out_value.NumElements());
@@ -1891,8 +1948,8 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) {
     }
 
     std::vector<string> fetch = {"acc0"};
-    auto tensors_expected = EvaluateNodes(item.graph, fetch);
-    auto tensors = EvaluateNodes(output, fetch);
+    auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+    auto tensors = EvaluateNodes(output, fetch, {});
     EXPECT_EQ(1, tensors_expected.size());
     EXPECT_EQ(1, tensors.size());
     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
@@ -1926,7 +1983,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
   item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
                 "concat5", "concat6", "concat7", "concat8", "concat9"};
 
-  auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
+  auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}, {});
   EXPECT_EQ(1, tensors_expected.size());
   ConstantFolding optimizer(nullptr /* cpu_device */);
   GraphDef output;
@@ -1977,7 +2034,7 @@ TEST_F(ConstantFoldingTest, PartialFolding_Concat) {
     }
   }
 
-  auto tensors = EvaluateNodes(output, {"concat0"});
+  auto tensors = EvaluateNodes(output, {"concat0"}, {});
   EXPECT_EQ(1, tensors.size());
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
@@ -2075,8 +2132,8 @@ TEST_F(ConstantFoldingTest, TrivialPack) {
   }
 
   std::vector<string> fetch = {"stack"};
-  auto tensors_expected = EvaluateNodes(item.graph, fetch);
-  auto tensors = EvaluateNodes(output, fetch);
+  auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+  auto tensors = EvaluateNodes(output, fetch, {});
   EXPECT_EQ(1, tensors_expected.size());
   EXPECT_EQ(1, tensors.size());
   EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
index 52a1118..deb2fab 100644 (file)
@@ -414,8 +414,9 @@ TEST_F(FunctionOptimizerTest, SymbolicGradients) {
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
 
-  std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out1", "out2"});
-  std::vector<Tensor> optimized = EvaluateNodes(output, {"out1", "out2"});
+  std::vector<Tensor> expected =
+      EvaluateNodes(item.graph, {"out1", "out2"}, {});
+  std::vector<Tensor> optimized = EvaluateNodes(output, {"out1", "out2"}, {});
   test::ExpectTensorEqual<float>(expected[0], optimized[0]);
   test::ExpectTensorEqual<float>(expected[1], optimized[1]);
 }
@@ -478,8 +479,8 @@ TEST_F(FunctionOptimizerTest, SymbolicGradientsIdentity) {
     EXPECT_EQ("Identity", output.node(i).op());
   }
 
-  std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out"});
-  std::vector<Tensor> optimized = EvaluateNodes(output, {"out"});
+  std::vector<Tensor> expected = EvaluateNodes(item.graph, {"out"}, {});
+  std::vector<Tensor> optimized = EvaluateNodes(output, {"out"}, {});
   test::ExpectTensorEqual<float>(expected[0], optimized[0]);
 }
 
index 9595936..a1f8080 100644 (file)
@@ -426,7 +426,7 @@ TEST_F(MemoryOptimizerTest, AccumulationRewrites) {
   EXPECT_EQ(4, count);
 
   std::vector<string> fetch = {"a", "b", "c", "e"};
-  auto tensors = EvaluateNodes(output, fetch);
+  auto tensors = EvaluateNodes(output, fetch, {});
   EXPECT_EQ(4, tensors.size());
 
   for (int i = 0; i < tensors[0].NumElements(); ++i) {
index 829bfe9..86a6d50 100644 (file)
@@ -33,8 +33,8 @@ namespace {
 template <typename T>
 bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
   using RealType = typename Eigen::NumTraits<T>::Real;
-  if (value > std::numeric_limits<RealType>::max() ||
-      value < std::numeric_limits<RealType>::min()) {
+  if (value > static_cast<double>(std::numeric_limits<RealType>::max()) ||
+      value < static_cast<double>(std::numeric_limits<RealType>::min())) {
     return false;
   }
   tensor->flat<T>()(0) = static_cast<T>(value);
@@ -473,8 +473,8 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
         "Expected scalar tensor, got num_elements = ", tensor->NumElements());
   }
   switch (dtype) {
-    // TODO(rmlarsen): Handle DT_HALF.
-    //    HANDLE_CASE(DT_HALF);
+    HANDLE_CASE(DT_HALF);
+    HANDLE_CASE(DT_BFLOAT16);
     HANDLE_CASE(DT_BOOL);
     HANDLE_CASE(DT_FLOAT);
     HANDLE_CASE(DT_DOUBLE);
index ee126f4..5c96359 100644 (file)
@@ -40,12 +40,13 @@ GrapplerTest::GrapplerTest() {
 }
 
 std::vector<Tensor> GrapplerTest::EvaluateNodes(
-    const GraphDef& graph, const std::vector<string>& node_names) const {
+    const GraphDef& graph, const std::vector<string>& node_names,
+    const std::vector<std::pair<string, Tensor>>& inputs) const {
   std::unique_ptr<tensorflow::Session> session(NewSession(options_));
   TF_CHECK_OK(session->Create(graph));
   RunOptions run_options;
   std::vector<Tensor> output_tensors;
-  TF_CHECK_OK(session->Run(run_options, {}, node_names, node_names,
+  TF_CHECK_OK(session->Run(run_options, inputs, node_names, node_names,
                            &output_tensors, nullptr));
   TF_CHECK_OK(session->Close());
   return output_tensors;
index e0c6738..4b160e7 100644 (file)
@@ -35,7 +35,8 @@ class GrapplerTest : public ::testing::Test {
 
  protected:
   std::vector<Tensor> EvaluateNodes(
-      const GraphDef& graph, const std::vector<string>& node_names) const;
+      const GraphDef& graph, const std::vector<string>& node_names,
+      const std::vector<std::pair<string, Tensor>>& inputs) const;
 
   std::vector<Tensor> EvaluateFetchNodes(const GrapplerItem& item) const;