};
template <typename T>
-bool AllValuesAre(const TensorProto& tensor, const T& value) {
- // TensorProto represents the content of the tensor in either <type>_val or
- // tensor_content.
- typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
- checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
- if (!tensor_values->empty()) {
- for (const T& tensor_value : *tensor_values) {
- if (tensor_value != value) {
- return false;
- }
- }
- return true;
+bool AllValuesAre(const TensorProto& proto, const T& value) {
+ Tensor tensor;
+ if (!tensor.FromProto(proto)) {
+ return false;
}
- const auto tensor_content_size = tensor.tensor_content().size();
- if (tensor_content_size > 0) {
- CHECK_EQ(0, tensor_content_size % sizeof(T));
- std::vector<T> raw_values(tensor_content_size / sizeof(T));
- port::CopyToArray(tensor.tensor_content(),
- reinterpret_cast<char*>(raw_values.data()));
- for (int i = 0; i < tensor_content_size / sizeof(T); ++i) {
- if (raw_values[i] != value) {
- return false;
- }
+ auto values = tensor.flat<T>();
+ for (int i = 0; i < tensor.NumElements(); ++i) {
+ if (values(i) != value) {
+ return false;
}
- return true;
}
- return false;
+ return true;
}
// Add new_input as a control input to node if it does not already depend on it.
t->set_dtype(type);
*t->mutable_tensor_shape() = shape;
switch (type) {
- SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
- SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
- SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
- SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
- SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
- SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
- SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
- SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
- SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
- SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
- SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
+ case DT_HALF:
+ t->add_half_val(static_cast<Eigen::half>(value).x);
+ break;
+ case DT_BFLOAT16:
+ t->add_half_val(static_cast<bfloat16>(value).value);
+ break;
+ SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
+ SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
+ SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
+ SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
+ SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
default:
return errors::InvalidArgument("Unsupported type: ", type);
}
}
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
- // TODO(rmlarsen): Make DT_HALF case compile.
- // IS_ONES_CASE(DT_HALF);
+ IS_ONES_CASE(DT_HALF);
+ IS_ONES_CASE(DT_BFLOAT16);
IS_ONES_CASE(DT_FLOAT);
IS_ONES_CASE(DT_DOUBLE);
IS_ONES_CASE(DT_COMPLEX64);
}
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
- // TODO(rmlarsen): Make DT_HALF case compile.
- // IS_ZEROS_CASE(DT_HALF);
+ IS_ZEROS_CASE(DT_HALF);
+ IS_ZEROS_CASE(DT_BFLOAT16);
IS_ZEROS_CASE(DT_FLOAT);
IS_ZEROS_CASE(DT_DOUBLE);
IS_ZEROS_CASE(DT_COMPLEX64);
}
Status ConstantFolding::ReplaceOperationWithConstant(
- double value, const TensorShapeProto& shape, NodeDef* node,
- GraphDef* graph) {
- AttrValue dtype_attr = node->attr().at("T");
+ double value, const AttrValue& dtype_attr, const TensorShapeProto& shape,
+ NodeDef* node, GraphDef* graph) {
AttrValue tensor_attr;
TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value,
shape, &tensor_attr));
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined()) {
- TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(0, output_shape, node,
- optimized_graph));
+ AttrValue dtype_attr;
+ if (node->op() == "SparseMatMul") {
+ dtype_attr.set_type(DT_FLOAT);
+ } else {
+ dtype_attr = node->attr().at("T");
+ }
+ TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+ 0, dtype_attr, output_shape, node, optimized_graph));
continue;
}
// Even if an input shape is only partially known, we may known that it
namespace grappler {
namespace {
-class ConstantFoldingTest : public GrapplerTest {};
+class ConstantFoldingTest : public GrapplerTest {
+ protected:
+ template <DataType DTYPE>
+ void SimpleNeutralElementTest() {
+ typedef typename EnumToDataType<DTYPE>::Type T;
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DTYPE,
+ ops::Placeholder::Shape(TensorShape({2, 2})));
+ Tensor zeros_t(DTYPE, TensorShape({2, 2}));
+ Tensor ones_t(DTYPE, TensorShape({2, 2}));
+ Tensor x_t(DTYPE, TensorShape({2, 2}));
+ for (int i = 0; i < 4; ++i) {
+ zeros_t.flat<T>()(i) = T(0);
+ ones_t.flat<T>()(i) = T(1);
+ x_t.flat<T>()(i) = T(i + 1);
+ }
+ Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t);
+ Output ones = ops::Const(s.WithOpName("ones"), ones_t);
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"mul1", "mul2"};
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+ EXPECT_EQ(5, output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ const string& name = node.name();
+ if (name == "mul1") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "mul2") {
+ EXPECT_EQ("Snapshot", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ }
+ }
+ auto tensors_expected =
+ EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}});
+ auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}});
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors.size());
+ for (int i = 0; i < 2; ++i) {
+ test::ExpectTensorEqual<T>(tensors_expected[i], tensors[i]);
+ }
+ }
+};
TEST_F(ConstantFoldingTest, SimpleFolding) {
// Build a simple graph with a few trivially prunable ops.
EXPECT_EQ("Const", node_d.op());
std::vector<string> fetch = {"d"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
// Check that the result nodes have the expected value.
std::vector<string> fetch = {"c3", "c20"};
- auto tensor_expected = EvaluateNodes(item.graph, fetch);
+ auto tensor_expected = EvaluateNodes(item.graph, fetch, {});
EXPECT_EQ(fetch.size(), tensor_expected.size());
fetch = {"add_child", "mul_child"};
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
}
}
+TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) {
+ SimpleNeutralElementTest<DT_HALF>();
+ SimpleNeutralElementTest<DT_BFLOAT16>();
+}
+
TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
// Check that the reciprocals have the expected value.
std::vector<string> fetch = {"cf_half"};
- auto tensor_expected = EvaluateNodes(item.graph, fetch);
+ auto tensor_expected = EvaluateNodes(item.graph, fetch, {});
EXPECT_EQ(fetch.size(), tensor_expected.size());
fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
EXPECT_EQ("Const", new_d.op());
std::vector<string> fetch = {"e", "f"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(fetch.size(), tensors_expected.size());
EXPECT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
GrapplerItem item;
item.fetch.push_back("e");
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {});
EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
if (node.name() == "e") {
EXPECT_EQ("Const", node.op());
++found;
- auto folded = EvaluateNodes(output, {"e"});
- auto expected = EvaluateNodes(item.graph, {"e"});
+ auto folded = EvaluateNodes(output, {"e"}, {});
+ auto expected = EvaluateNodes(item.graph, {"e"}, {});
EXPECT_EQ(1, expected.size());
EXPECT_EQ(1, folded.size());
test::ExpectTensorEqual<int>(folded[0], expected[0]);
}
}
EXPECT_EQ(1, found);
- auto tensors = EvaluateNodes(output, item.fetch);
+ auto tensors = EvaluateNodes(output, item.fetch, {});
EXPECT_EQ(1, tensors.size());
test::ExpectTensorEqual<int>(tensors_expected[0], tensors[0]);
}
if (node.name() == "i1") {
EXPECT_EQ("Const", node.op());
++found;
- auto folded = EvaluateNodes(output, {"i1"});
- auto expected = EvaluateNodes(item.graph, {"i1"});
+ auto folded = EvaluateNodes(output, {"i1"}, {});
+ auto expected = EvaluateNodes(item.graph, {"i1"}, {});
EXPECT_EQ(1, expected.size());
EXPECT_EQ(1, folded.size());
test::ExpectTensorEqual<int>(folded[0], expected[0]);
if (node.name() == "i2") {
EXPECT_EQ("Const", node.op());
++found;
- auto folded = EvaluateNodes(output, {"i2"});
- auto expected = EvaluateNodes(item.graph, {"i2"});
+ auto folded = EvaluateNodes(output, {"i2"}, {});
+ auto expected = EvaluateNodes(item.graph, {"i2"}, {});
EXPECT_EQ(1, expected.size());
EXPECT_EQ(1, folded.size());
test::ExpectTensorEqual<int>(folded[0], expected[0]);
}
EXPECT_EQ(8, constant_folded);
- auto expected = EvaluateNodes(item.graph, outputs);
- auto optimized = EvaluateNodes(output, outputs);
+ auto expected = EvaluateNodes(item.graph, outputs, {});
+ auto optimized = EvaluateNodes(output, outputs, {});
ASSERT_EQ(expected.size(), optimized.size());
for (int i = 0; i < expected.size(); ++i) {
test::ExpectTensorEqual<int>(expected[i], optimized[i]);
EXPECT_EQ(6, found_nodes);
std::vector<string> fetch = {"out1", "idx1"};
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(2, tensors.size());
const Tensor& out_value = tensors[0];
EXPECT_EQ(3 * 5, out_value.NumElements());
}
std::vector<string> fetch = {"acc0"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4",
"concat5", "concat6", "concat7", "concat8", "concat9"};
- auto tensors_expected = EvaluateNodes(item.graph, {"concat0"});
+ auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}, {});
EXPECT_EQ(1, tensors_expected.size());
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
}
}
- auto tensors = EvaluateNodes(output, {"concat0"});
+ auto tensors = EvaluateNodes(output, {"concat0"}, {});
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
}
std::vector<string> fetch = {"stack"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, fetch, {});
+ auto tensors = EvaluateNodes(output, fetch, {});
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());