Skip to content

Commit

Permalink
[feat]add udf for label (#303)
Browse files Browse the repository at this point in the history
* add udf for label
  • Loading branch information
dawn310826 authored Nov 25, 2022
1 parent 426cfef commit 56b6f93
Show file tree
Hide file tree
Showing 7 changed files with 397 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .git_bin_path
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@
{"leaf_name": "data/test/movielens_1m", "leaf_file": ["data/test/movielens_1m/ml_test_data", "data/test/movielens_1m/ml_train_data"]}
{"leaf_name": "data/test/mt_ckpt", "leaf_file": ["data/test/mt_ckpt/model.ckpt-100.data-00000-of-00001", "data/test/mt_ckpt/model.ckpt-100.index", "data/test/mt_ckpt/model.ckpt-100.meta"]}
{"leaf_name": "data/test/rtp", "leaf_file": ["data/test/rtp/taobao_fg_pred.out", "data/test/rtp/taobao_test_bucketize_feature.txt", "data/test/rtp/taobao_test_feature.txt", "data/test/rtp/taobao_test_input.txt", "data/test/rtp/taobao_train_bucketize_feature.txt", "data/test/rtp/taobao_train_feature.txt", "data/test/rtp/taobao_train_input.txt", "data/test/rtp/taobao_valid.csv", "data/test/rtp/taobao_valid_feature.txt"]}
{"leaf_name": "data/test/tb_data", "leaf_file": ["data/test/tb_data/taobao_ad_feature_gl", "data/test/tb_data/taobao_clk_edge_gl", "data/test/tb_data/taobao_multi_seq_test_data", "data/test/tb_data/taobao_multi_seq_train_data", "data/test/tb_data/taobao_noclk_edge_gl", "data/test/tb_data/taobao_test_data", "data/test/tb_data/taobao_test_data_compress.gz", "data/test/tb_data/taobao_test_data_for_expr", "data/test/tb_data/taobao_test_data_kd", "data/test/tb_data/taobao_train_data", "data/test/tb_data/taobao_train_data_for_expr", "data/test/tb_data/taobao_train_data_kd", "data/test/tb_data/taobao_user_profile_gl"]}
{"leaf_name": "data/test/tb_data", "leaf_file": ["data/test/tb_data/taobao_ad_feature_gl", "data/test/tb_data/taobao_clk_edge_gl", "data/test/tb_data/taobao_multi_seq_test_data", "data/test/tb_data/taobao_multi_seq_train_data", "data/test/tb_data/taobao_noclk_edge_gl", "data/test/tb_data/taobao_test_data", "data/test/tb_data/taobao_test_data_compress.gz", "data/test/tb_data/taobao_test_data_for_expr", "data/test/tb_data/taobao_test_data_kd", "data/test/tb_data/taobao_test_data_remap_label", "data/test/tb_data/taobao_train_data", "data/test/tb_data/taobao_train_data_for_expr", "data/test/tb_data/taobao_train_data_kd", "data/test/tb_data/taobao_train_data_remap_label", "data/test/tb_data/taobao_user_profile_gl"]}
{"leaf_name": "data/test/tb_data_with_time", "leaf_file": ["data/test/tb_data_with_time/taobao_test_data_with_time", "data/test/tb_data_with_time/taobao_train_data_with_time"]}
2 changes: 1 addition & 1 deletion .git_bin_url
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@
{"leaf_path": "data/test/movielens_1m", "sig": "99badbeec64f2fcabe0dfa1d2bfd8fb5", "remote_path": "data/git_oss_sample_data/data_test_movielens_1m_99badbeec64f2fcabe0dfa1d2bfd8fb5"}
{"leaf_path": "data/test/mt_ckpt", "sig": "803499f48e2df5e51ce5606e9649c6d4", "remote_path": "data/git_oss_sample_data/data_test_mt_ckpt_803499f48e2df5e51ce5606e9649c6d4"}
{"leaf_path": "data/test/rtp", "sig": "76cda60582617ddbb7cd5a49eb68a4b9", "remote_path": "data/git_oss_sample_data/data_test_rtp_76cda60582617ddbb7cd5a49eb68a4b9"}
{"leaf_path": "data/test/tb_data", "sig": "126c375d6aa666633fb3084aa27ff9f7", "remote_path": "data/git_oss_sample_data/data_test_tb_data_126c375d6aa666633fb3084aa27ff9f7"}
{"leaf_path": "data/test/tb_data", "sig": "f1279ca42de1734be321e88f85775d5f", "remote_path": "data/git_oss_sample_data/data_test_tb_data_f1279ca42de1734be321e88f85775d5f"}
{"leaf_path": "data/test/tb_data_with_time", "sig": "1a7648f4ae55faf37855762bccbb70cc", "remote_path": "data/git_oss_sample_data/data_test_tb_data_with_time_1a7648f4ae55faf37855762bccbb70cc"}
30 changes: 30 additions & 0 deletions docs/source/feature/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ input_fields字段:
- 如果input是INT32类型,并且默认值是6,那么default_val是"6";
- 如果input是DOUBLE类型,并且默认值是0.5,那么default_val是"0.5";
- input_dim, 目前仅适用于RawFeature类型,可以指定多维数据,如一个图片的embedding vector.
- user_define_fn, 目前仅适用于label,指定用户自定义函数名,以对label进行处理.
- user_define_fn_path, 如需引入oss/hdfs上的用户自定义函数,需指定函数路径.
- user_define_fn_res_type, 指定用户自定义函数的输出值类型.

```protobuf
input_fields: {
Expand All @@ -54,6 +57,33 @@ input_fields字段:
}
```

```protobuf
input_fields {
input_name:'clk'
input_type: DOUBLE
user_define_fn: 'tf.math.log1p'
}
```

```protobuf
input_fields {
input_name:'clk'
input_type: INT64
user_define_fn: 'remap_lbl'
user_define_fn_path: 'samples/demo_script/process_lbl.py'
user_define_fn_res_type: INT64
}
```

process_lbl.py:

```python
import numpy as np
def remap_lbl(labels):
res = np.where(labels<5, 0, 1)
return res
```

- **注意:**
- input_fields的顺序和odps table里面字段的顺序不需要保证一一对应的
- input_fields和csv文件里面字段的顺序必须是一一对应的(csv文件没有header)
Expand Down
47 changes: 47 additions & 0 deletions easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import six
import tensorflow as tf
from tensorflow.python.platform import gfile

from easy_rec.python.core import sampler as sampler_lib
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
Expand Down Expand Up @@ -72,6 +73,17 @@ def __init__(self,
for x in range(len(self._label_fields) - len(self._label_dim)):
self._label_dim.append(1)

self._label_udf_map = {}
for config in self._data_config.input_fields:
if config.HasField('user_define_fn'):
user_define_fn_path = config.user_define_fn_path if config.HasField(
'user_define_fn_path') else None
user_define_fn_res_type = config.user_define_fn_res_type if config.HasField(
'user_define_fn_res_type') else None
self._label_udf_map[config.input_name] = (config.user_define_fn,
user_define_fn_path,
user_define_fn_res_type)

self._batch_size = data_config.batch_size
self._prefetch_size = data_config.prefetch_size
self._feature_configs = list(feature_configs)
Expand Down Expand Up @@ -722,6 +734,41 @@ def _preprocess(self, field_dict):
for input_id, input_name in enumerate(self._label_fields):
if input_name not in field_dict:
continue
if input_name in self._label_udf_map:
udf_class, udf_path, dtype = self._label_udf_map[input_name]
if udf_path:
assert dtype is not None, 'must set user_define_fn_res_type'
if udf_path.startswith('oss://') or udf_path.startswith('hdfs://'):
with gfile.GFile(udf_path, 'r') as fin:
udf_content = fin.read()
final_udf_tmp_path = '/udf/'
final_udf_path = final_udf_tmp_path + udf_path.split('/')[-1]
logging.info('final udf path %s' % final_udf_path)
logging.info('udf content: %s' % udf_content)
if not gfile.Exists(final_udf_tmp_path):
gfile.MkDir(final_udf_tmp_path)
with gfile.GFile(final_udf_path, 'w') as fin:
fin.write(udf_content)
else:
final_udf_path = udf_path
final_udf_path = final_udf_path[:-3].replace('/', '.')
final_udf_path = final_udf_path + '.' + udf_class
logging.info('apply udf %s' % final_udf_path)
udf = load_by_path(final_udf_path)
field_dict[input_name] = tf.py_func(
udf, [field_dict[input_name]], Tout=get_tf_type(dtype))
field_dict[input_name].set_shape(tf.TensorShape([None]))
else:
logging.info('apply udf %s' % udf_class)
udf = load_by_path(udf_class)
if udf_class.split('.')[0] in ['tf', 'tensorflow']:
field_dict[input_name] = udf(field_dict[input_name])
else:
assert dtype is not None, 'must set user_define_fn_res_type'
field_dict[input_name] = tf.py_func(
udf, [field_dict[input_name]], Tout=get_tf_type(dtype))
field_dict[input_name].set_shape(tf.TensorShape([None]))

if field_dict[input_name].dtype == tf.string:
if self._label_dim[input_id] > 1:
logging.info('will split labels[%d]=%s' % (input_id, input_name))
Expand Down
6 changes: 6 additions & 0 deletions easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ message DatasetConfig {
optional string default_val = 3;
optional uint32 input_dim = 4 [default=1];
optional uint32 input_shape = 5 [default = 1];
// user-defined function for label. eg: tf.math.log1p, remap_lbl
optional string user_define_fn = 6;
// user-defined function path. eg: /samples/demo_script/process_lbl.py
optional string user_define_fn_path = 7;
// output field type of user-defined function.
optional FieldType user_define_fn_res_type = 8;
}

// set auto_expand_input_fields to true to
Expand Down
6 changes: 6 additions & 0 deletions samples/demo_script/process_lbl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import numpy as np


def remap_lbl(labels):
res = np.where(labels < 5, 0, 1)
return res
Loading

0 comments on commit 56b6f93

Please sign in to comment.