if self._saver:
return self._saver
- savers = ops.get_collection(ops.GraphKeys.SAVERS)[0]
+ savers = ops.get_collection(ops.GraphKeys.SAVERS)
if not savers:
return None
if not isinstance(savers, list):
return savers
- assert len(savers) == 1, 'Only one saver supported.'
+ if len(savers) > 1:
+ logging.error(
+ 'Multiple savers in the SAVERS collection. On-demand checkpointing '
+ 'will be disabled. Pass an explicit `saver` to the constructor to '
+ 'override this behavior.'
+ )
+ return None
+
return savers[0]
def after_run(self, run_context, run_values):