-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
r1.15.5-deeprec2302 incr ev 在 restore过程中不能正确加载 #999
Comments
A temporary fix based deeprec2302 |
This issue is already fixed in release deeprec2402. |
使用了partitioner后,问题仍然存在,用下面的代码可以复现 (2302版本) 测试代码说明: 测试模型经过设计使得它具有以下几个特点(具体实现方法参见model_fn函数)
训练的key依次为: 20步0, 40步1, 30步2, 10步3 训练命令: 评测命令: 修改checkpoint_dir/checkpoint文件,可以分别评测增量checkpoint和全量checkpoint 代码中的其它内容主要是为了让DeepRec能够在使用estimator api时也能正确生成、加载增量checkpoint #!/usr/bin/env python3
import argparse
import functools
import os.path
import time
import tensorflow as tf
import numpy
global _incr_ckpt_secs
global _incr_ckpt_steps
class DelayHook(tf.train.SessionRunHook):
def after_run(self, run_context, run_values):
time.sleep(0.02)
def get_ev_option():
init_opt = tf.InitializerOption(initializer=tf.constant_initializer(1))
return tf.EmbeddingVariableOption(
init_option=init_opt, filter_option=None, evict_option=None)
def model_fn(features, labels, mode, params):
id_ = features['x']
weights = tf.get_embedding_variable(
name='embedding_table', embedding_dim=1,
value_dtype=tf.float32,
ev_option=get_ev_option(), key_dtype=tf.int64,
partitioner=tf.fixed_size_partitioner(num_shards=1))
x = tf.nn.embedding_lookup(weights, id_)
y = tf.reduce_mean(x, axis=1)
loss = tf.reduce_mean(y - labels)
saver = tf.train.Saver(
sharded=True, incremental_save_restore=True,
save_relative_paths=True)
scaffold = tf.train.Scaffold(saver=saver, incremental_save_restore=True)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss, scaffold=scaffold)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
saver_hook = tf.train.CheckpointSaverHook(
incremental_save_secs=1, checkpoint_dir='checkpoint_dir',
save_steps=50, scaffold=scaffold, listeners=[])
log_hook = tf.train.LoggingTensorHook(
{'loss': loss, 'step': tf.train.get_or_create_global_step()},
every_n_iter=1)
minimize = optimizer.minimize(
loss, global_step=tf.train.get_or_create_global_step())
return tf.estimator.EstimatorSpec(
mode, loss=loss, train_op=minimize,
training_chief_hooks=[saver_hook], scaffold=scaffold,
training_hooks=[log_hook, DelayHook()])
def train_input_fn():
def generator():
for i in range(20):
features = {
'x': numpy.array([0], dtype=numpy.int64)
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
for i in range(40):
features = {
'x': numpy.array([1], dtype=numpy.int64),
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
for i in range(30):
features = {
'x': numpy.array([2], dtype=numpy.int64)
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
for i in range(10):
features = {
'x': numpy.array([3], dtype=numpy.int64)
}
labels = numpy.zeros([1], dtype=numpy.float32)
yield features, labels
return tf.data.Dataset.from_generator(
generator, output_types=({'x': tf.int64}, tf.float32),
output_shapes=({'x': tf.TensorShape([None])}, tf.TensorShape([None])))
def eval_input_fn(value):
def generator():
for i in range(10):
features = {
'x': numpy.array([value, value], dtype=numpy.int64)
}
labels = numpy.zeros([2], dtype=numpy.float32)
yield features, labels
return tf.data.Dataset.from_generator(
generator, output_types=({'x': tf.int64}, tf.float32),
output_shapes=({'x': tf.TensorShape([None])}, tf.TensorShape([None])))
def _patch_session_creator(checkpoint_dir):
tf.logging.info('Patching monitored_session.ChiefSessionCreator')
from tensorflow.python.training import monitored_session, checkpoint_management
monitored_session.ChiefSessionCreator__ = monitored_session.ChiefSessionCreator
monitored_session.ChiefSessionCreator = functools.partial(
_session_creator, checkpoint_dir=checkpoint_dir)
_patch_evaluate_and_export()
_patch_evaluate_recover_session()
def _create_session(*args, **kwargs):
if _incr_ckpt_secs is not None:
if 'save_incremental_checkpoint_secs' not in kwargs \
or kwargs['save_incremental_checkpoint_secs'] is None:
kwargs['save_incremental_checkpoint_secs'] = _incr_ckpt_secs
if _incr_ckpt_steps is not None:
if 'save_incremental_checkpoint_steps' not in kwargs \
or kwargs['save_incremental_checkpoint_steps'] is None:
kwargs['save_incremental_checkpoint_steps'] = _incr_ckpt_steps
tf.logging.info("Creating MonitoredTrainingSession, %s, %s", args, kwargs)
return tf.train.MonitoredTrainingSession__(*args, **kwargs)
def _session_creator(**kwargs):
from tensorflow.python.training import monitored_session
kwargs['checkpoint_filename_with_path'] = None
tf.logging.info('Creating ChiefSessionCreator: %s', kwargs)
return monitored_session.ChiefSessionCreator__(**kwargs)
def patch_incr_ckpt(secs=0, steps=0):
global _incr_ckpt_secs
global _incr_ckpt_steps
_incr_ckpt_secs = secs if secs > 0 else None
_incr_ckpt_steps = steps if steps > 0 else None
tf.logging.info("Patching MonitoredTrainingSession.")
from tensorflow.python.training import training
training.MonitoredTrainingSession__ = training.MonitoredTrainingSession
training.MonitoredTrainingSession = _create_session
tf.train.MonitoredTrainingSession__ = tf.train.MonitoredTrainingSession
tf.train.MonitoredTrainingSession = _create_session
def _patch_evaluate_and_export():
from tensorflow.python.training import checkpoint_management
from tensorflow_estimator.python.estimator.training import _EvalResult, _EvalStatus, _TrainingExecutor
from tensorflow.python.framework import ops
from tensorflow.python.eager import context
def evaluate_and_export(self):
tf.logging.info('custom evaluate_and_export')
latest_ckpt_path = self._estimator.latest_checkpoint()
if not latest_ckpt_path:
self._log_err_msg('Estimator is not trained yet. Will start an '
'evaluation when a checkpoint is ready.')
return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), []
# .incremental_checkpoint
with context.graph_mode():
incremental_dir = os.path.join(self._estimator.model_dir, '.incremental_checkpoint')
incremental_ckpt = checkpoint_management.latest_checkpoint(incremental_dir)
base_version = int(latest_ckpt_path.split('-')[-1])
incremental_version = int(incremental_ckpt.split('-')[-1]) if incremental_ckpt else None
previous_version = int(self._previous_ckpt_path.split('-')[-1]) if self._previous_ckpt_path else None
tf.logging.info(f'now version: {base_version} {incremental_version} <- {previous_version}')
if previous_version and incremental_version and incremental_version == previous_version:
self._log_err_msg(
'No new checkpoint ready for evaluation. Skip the current '
'evaluation pass as evaluation results are expected to be same '
'for the same checkpoint.')
return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), []
metrics = self._estimator.evaluate(
input_fn=self._eval_spec.input_fn,
steps=self._eval_spec.steps,
name=self._eval_spec.name,
checkpoint_path=latest_ckpt_path,
hooks=self._eval_spec.hooks)
# _EvalResult validates the metrics.
eval_result = _EvalResult(
status=_EvalStatus.EVALUATED,
metrics=metrics,
checkpoint_path=latest_ckpt_path)
is_the_final_export = (
eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
self._max_training_steps if self._max_training_steps else False)
export_results = self._export_eval_result(eval_result,
is_the_final_export)
if is_the_final_export:
tf.logging.debug('Calling exporter with the `is_the_final_export=True`.')
self._is_final_export_triggered = True
self._last_warning_time = 0
self._previous_ckpt_path = incremental_ckpt if incremental_ckpt else latest_ckpt_path
return eval_result, export_results
_TrainingExecutor._Evaluator.evaluate_and_export__ = _TrainingExecutor._Evaluator.evaluate_and_export
_TrainingExecutor._Evaluator.evaluate_and_export = evaluate_and_export
def _patch_evaluate_recover_session():
def recover_session(self,
master,
saver=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
config=None):
from tensorflow.python.training import incremental_saver
incr_saver = incremental_saver._get_incremental_saver(self._incremental_save_restore, self._saver)
tf.logging.info("custom recover_session")
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
incr_saver,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
# Always try to run local_init_op
local_init_success, msg = self._try_run_local_init_op(sess)
if not is_loaded_from_checkpoint:
# Do not need to run checks for readiness
return sess, False
restoring_file = checkpoint_dir or checkpoint_filename_with_path
if not local_init_success:
tf.logging.info(
"Restoring model from %s did not make model ready for local init:"
" %s", restoring_file, msg)
return sess, False
is_ready, msg = self._model_ready(sess)
if not is_ready:
tf.logging.info("Restoring model from %s did not make model ready: %s",
restoring_file, msg)
return sess, False
tf.logging.info("Restored model from %s", restoring_file)
return sess, is_loaded_from_checkpoint
tf.train.SessionManager.recover_session__ = tf.train.SessionManager.recover_session
tf.train.SessionManager.recover_session = recover_session
def parse_cmdline():
p = argparse.ArgumentParser()
p.add_argument('mode', choices=('train', 'eval'))
p.add_argument('--value', type=int, default=0)
return p.parse_args()
def main():
cmdline = parse_cmdline()
tf.logging.set_verbosity(tf.logging.INFO)
patch_incr_ckpt(secs=1)
_patch_session_creator('checkpoint_dir')
eval_input = functools.partial(eval_input_fn, value=cmdline.value)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input)
config = tf.estimator.RunConfig(
model_dir='checkpoint_dir',
tf_random_seed=2020,
save_summary_steps=1,
save_checkpoints_steps=50,
keep_checkpoint_max=20,
experimental_max_worker_delay_secs=2000)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
if cmdline.mode == 'eval':
estimator.evaluate(eval_input)
else:
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
if __name__ == '__main__':
main() |
经验证,candyzone@261ccfb 能解决这个问题 |
System information
Describe the current behavior
restore的时候加载 incremental_ckpt ev变量不能正确加载覆盖base里的ev变量
Describe the expected behavior
正确加载incr ev 覆盖对应的变量
Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
The text was updated successfully, but these errors were encountered: