Updating a test in constant_folding_test.cc that uses a graph with placeholder nodes...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 29 Mar 2018 20:24:38 +0000 (13:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 20:27:56 +0000 (13:27 -0700)
PiperOrigin-RevId: 190976595

tensorflow/core/grappler/optimizers/constant_folding_test.cc

index e0ff9b1..16a19ba 100644 (file)
@@ -82,6 +82,14 @@ class ConstantFoldingTest : public GrapplerTest {
   }
 };
 
+template <DataType DTYPE>
+Tensor GetRandomTensor(const TensorShape& shape) {
+  typedef typename EnumToDataType<DTYPE>::Type T;
+  Tensor tensor(DTYPE, shape);
+  tensor.flat<T>() = tensor.flat<T>().random();
+  return tensor;
+}
+
 TEST_F(ConstantFoldingTest, SimpleFolding) {
   // Build a simple graph with a few trivially prunable ops.
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -371,6 +379,23 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
       }
     }
+    auto a_t = GetRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
+    auto b_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
+    auto x_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+    auto y_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+    auto bias_t = GetRandomTensor<DT_FLOAT>(TensorShape({2}));
+
+    auto tensors_expected = EvaluateNodes(
+        item.graph, item.fetch,
+        {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
+    EXPECT_EQ(item.fetch.size(), tensors_expected.size());
+    auto tensors = EvaluateNodes(
+        output, item.fetch,
+        {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
+    EXPECT_EQ(item.fetch.size(), tensors.size());
+    for (int i = 0; i < item.fetch.size(); ++i) {
+      test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
+    }
   }
 }