}
};
+template <DataType DTYPE>
+Tensor GetRandomTensor(const TensorShape& shape) {
+ typedef typename EnumToDataType<DTYPE>::Type T;
+ Tensor tensor(DTYPE, shape);
+ tensor.flat<T>() = tensor.flat<T>().random();
+ return tensor;
+}
+
TEST_F(ConstantFoldingTest, SimpleFolding) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
EXPECT_EQ(2, t.tensor_shape().dim(1).size());
}
}
+ auto a_t = GetRandomTensor<DT_FLOAT>(TensorShape({3, 2}));
+ auto b_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
+ auto x_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto y_t = GetRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto bias_t = GetRandomTensor<DT_FLOAT>(TensorShape({2}));
+
+ auto tensors_expected = EvaluateNodes(
+ item.graph, item.fetch,
+ {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
+ EXPECT_EQ(item.fetch.size(), tensors_expected.size());
+ auto tensors = EvaluateNodes(
+ output, item.fetch,
+ {{"x", x_t}, {"y", y_t}, {"a", a_t}, {"b", b_t}, {"bias", bias_t}});
+ EXPECT_EQ(item.fetch.size(), tensors.size());
+ for (int i = 0; i < item.fetch.size(); ++i) {
+ test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
+ }
}
}