continue;
}
- // Remove Shuffle or Reverse op over scalar values.
- if (use_shape_info &&
- !properties->GetInputProperties(node->name()).empty() &&
- (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
+ // Remove Shuffle or Transpose op over dimensions of size 1.
+ if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
+ properties->GetInputProperties(node->name()).size() >= 2) {
const auto& shape =
properties->GetInputProperties(node->name())[0].shape();
- // The node is replaceable iff
- // unknown_rank == false && (dim_size == 0 || all dims have size 1)
- bool replaceable = !shape.unknown_rank();
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() == 1;
+ if (shape.unknown_rank()) {
+ // Not optimizable.
+ continue;
}
- if (replaceable) {
+ const auto& p = properties->GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor perm(p.dtype(), p.shape());
+ if (!perm.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+ std::vector<int> permutation;
+ for (int j = 0; j < perm.NumElements(); ++j) {
+ if (perm.dtype() == DT_INT64) {
+ permutation.push_back(perm.vec<int64>()(j));
+ } else {
+ permutation.push_back(perm.vec<int>()(j));
+ }
+ }
+ if (permutation.size() != shape.dim_size()) {
+ // Number of elements in perm should be same as dim_size. Skip if not.
+ continue;
+ }
+ // The node is replaceable iff
+ // dim_size == 0 || all dims have size 1 ||
+ // all dims with > 1 size are not permuted.
+ bool replaceable = true;
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, node, optimized_graph);
+ continue;
+ }
+ }
+ }
+
+ // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
+ if (use_shape_info && IsRandomShuffle(*node) &&
+ !properties->GetInputProperties(node->name()).empty()) {
+ const auto& shape =
+ properties->GetInputProperties(node->name())[0].shape();
+ // The node is replaceable iff
+ // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
+ if (!shape.unknown_rank() &&
+ (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
continue;
}
}
+ // Remove Reverse op over dimensions with size 1.
+ if (use_shape_info && node->op() == "ReverseV2" &&
+ properties->GetInputProperties(node->name()).size() >= 2) {
+ const auto& shape =
+ properties->GetInputProperties(node->name())[0].shape();
+ if (shape.unknown_rank()) {
+ // Not optimizable.
+ continue;
+ }
+ const auto& a = properties->GetInputProperties(node->name())[1];
+ if (TensorShape::IsValid(a.shape()) && a.has_value()) {
+ Tensor axis(a.dtype(), a.shape());
+ if (!axis.FromProto(a.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ a.value().DebugString());
+ }
+ std::set<int> target_axes;
+ for (int j = 0; j < axis.NumElements(); ++j) {
+ // value of axis can be negative.
+ if (axis.dtype() == DT_INT64) {
+ target_axes.insert(
+ (axis.vec<int64>()(j) + shape.dim_size()) % shape.dim_size());
+ } else {
+ target_axes.insert(
+ (axis.vec<int>()(j) + shape.dim_size()) % shape.dim_size());
+ }
+ }
+
+ // The node is replaceable iff
+ // unknown_rank == false &&
+ // (dim_size == 0 || all dims have size 1 ||
+ // all dims with > 1 size are not in target_axes)
+ bool replaceable = !shape.unknown_rank();
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() == 1 ||
+ target_axes.find(j) == target_axes.end();
+ }
+ if (replaceable) {
+ ReplaceOperationWithIdentity(0, node, optimized_graph);
+ continue;
+ }
+ }
+ }
+
if (use_shape_info && IsSlice(*node) &&
properties->GetInputProperties(node->name()).size() == 3) {
const auto& input = properties->GetInputProperties(node->name())[0];
ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
- LOG(INFO) << s1.output.size();
- LOG(INFO) << s2.output.size();
ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
GrapplerItem item;
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
-TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
+TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
+ DT_FLOAT);
+ Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
+ Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
+ DT_FLOAT);
+ Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
+ ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
+ ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
+ p2);
+
+ ops::Add out1(scope.WithOpName("out1"), t1, t2);
+
+ GrapplerItem item;
+ item.fetch = {"out1"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("p1", "Const", {}, {}, &want);
+ AddNode("p2", "Const", {}, {}, &want);
+ AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
+ AddNode("t2", "Identity",
+ {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
+ &want);
+ AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output in1 =
test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
}
+TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
+ DT_FLOAT);
+ Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
+ Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
+ DT_FLOAT);
+ Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
+ ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
+ ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
+ a2);
+
+ ops::Add out1(scope.WithOpName("out1"), r1, r2);
+
+ GrapplerItem item;
+ item.fetch = {"out1"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("a1", "Const", {}, {}, &want);
+ AddNode("a2", "Const", {}, {}, &want);
+ AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
+ AddNode("r2", "Identity",
+ {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
+ &want);
+ AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
{ // size = {3, 5}
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();