Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: create pipeline from dict & support pass easy_rec_config object … #383

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion easy_rec/python/compat/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import os
import threading
import time
from distutils.version import LooseVersion

import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
Expand Down
1 change: 0 additions & 1 deletion easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def train_and_evaluate(pipeline_config_path, continue_train=False):
Returns:
None, the model will be saved into pipeline_config.model_dir
"""
assert gfile.Exists(pipeline_config_path), 'pipeline_config_path not exists'
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path)

Expand Down
2 changes: 1 addition & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down
9 changes: 2 additions & 7 deletions easy_rec/python/utils/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,9 @@ def create_pipeline_proto_from_configs(configs):
Returns:
A fully populated pipeline_pb2.EasyRecConfig.
"""
config_json_str = json.dumps(configs)
pipeline_config = pipeline_pb2.EasyRecConfig()
pipeline_config.model.CopyFrom(configs['model'])
pipeline_config.train_config.CopyFrom(configs['train_config'])
pipeline_config.train_input_reader.CopyFrom(configs['train_input_config'])
pipeline_config.eval_config.CopyFrom(configs['eval_config'])
pipeline_config.eval_input_reader.CopyFrom(configs['eval_input_config'])
if 'graph_rewriter_config' in configs:
pipeline_config.graph_rewriter.CopyFrom(configs['graph_rewriter_config'])
json_format.Parse(config_json_str, pipeline_config)
return pipeline_config


Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ multi_line_output = 7
force_single_line = true
known_standard_library = setuptools
known_first_party = easy_rec
known_third_party = absl,common_io,distutils,docutils,eas_prediction,easyrec_request,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
known_third_party = absl,common_io,distutils,docutils,eas_prediction,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml
no_lines_before = LOCALFOLDER
default_section = THIRDPARTY
skip = easy_rec/python/protos
Expand Down