From 1893d25a795e29d276ae3484cfe0727eb657e4ad Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 7 Feb 2018 12:56:06 -0800 Subject: [PATCH] Fix bug and speed up Grappler constant folding Fix bug in and speed up ConstantFolding::CreateNodeDef(): * Fix bug trying to store more than kintmax32 values in a repeated proto field. * Speed up populating compressed format. Example: tensorflow/python/kernel_tests/large_concat_op_test with size = 2**29+6 goes from ~30 seconds to ~15 seconds. The fraction of time spent in ConstantFolding::CreateNodeDef() goes down from about 35% to about 12%. --- .../core/grappler/optimizers/constant_folding.cc | 34 +++++++++++++--------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 37a4759..1e6f11c 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -808,20 +808,26 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, // Use the packed representation whenever possible to avoid generating large // graphdefs. Moreover, avoid repeating the last values if they're equal. if (tensor->NumElements() > 4) { -#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME) \ - optimized = true; \ - TYPE last = tensor->flat()(0); \ - int last_index = 0; \ - for (int i = 0; i < tensor->NumElements(); ++i) { \ - TYPE cur = tensor->flat()(i); \ - t->add_##NAME##_val(cur); \ - if (cur != last) { \ - last = cur; \ - last_index = i; \ - } \ - } \ - /* Remove all identical trailing values to save memory. */ \ - t->mutable_##NAME##_val()->Truncate(last_index + 1); +#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME) \ + const TYPE* val_ptr = tensor->flat().data(); \ + TYPE last = *val_ptr; \ + int64 last_index = 0; \ + for (int64 i = 0; i < tensor->NumElements(); ++i) { \ + TYPE cur = *val_ptr++; \ + if (cur != last) { \ + last = cur; \ + last_index = i; \ + } \ + } \ + if (last_index < kint32max) { \ + optimized = true; \ + t->mutable_##NAME##_val()->Reserve(last_index + 1); \ + t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \ + val_ptr = tensor->flat().data(); \ + for (int64 i = 0; i <= last_index; ++i) { \ + t->set_##NAME##_val(i, *val_ptr++); \ + } \ + } if (tensor->dtype() == DT_FLOAT) { POPULATE_TENSOR_PROTO(tensor, t, float, float) -- 2.7.4