Also add quantization step node to MODEL_VARIABLES collection.
authorSuharsh Sivakumar <suharshs@google.com>
Mon, 12 Feb 2018 20:50:07 +0000 (12:50 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 20:54:19 +0000 (12:54 -0800)
PiperOrigin-RevId: 185420228

tensorflow/contrib/quantize/python/common.py

index 3a1fa61..9e76549 100644 (file)
@@ -114,7 +114,9 @@ def CreateOrGetQuantizationStep():
           dtype=dtypes.int64,
           initializer=init_ops.zeros_initializer(),
           trainable=False,
-          collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+          collections=[
+              ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES
+          ])
       with g.name_scope(quantization_step_tensor.op.name + '/'):
         # We return the incremented variable tensor. Since this is used in conds
         # for quant_delay and freeze_bn_delay, it will run once per graph