-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
make_sample.py
124 lines (107 loc) · 3.82 KB
/
make_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
验证图片尺寸和分离测试集(5%)和训练集(95%)
初始化的时候使用,有新的图片后,可以把图片放在new目录里面使用。
"""
import os
import time
import datetime
import json
import random
import os.path
import shutil
import pymysql
from PIL import Image
def convertjpg(jpgfile, outdir, width=227, height=227):
'''转换图片分辨率'''
img=Image.open(jpgfile)
try:
new_img=img.resize((width,height),Image.BILINEAR)
if img.mode == "P" or img.mode == "RGBA":
new_img = new_img.convert('RGB')
new_img.save(outdir)
except Exception as e:
print("图片转换失败",e)
def spilt_train_test(origin_dir,train_dir,test_dir):
'''将样本集分成9:1'''
img_list = os.listdir(origin_dir)
random.seed(time.time())
random.shuffle(img_list)
R = int(len(img_list)*0.1)
for file_name in img_list[:R]:
src = os.path.join(origin_dir, file_name)
dst = os.path.join(test_dir, file_name)
shutil.move(src, dst)
for file_name in img_list[R+1:]:
src = os.path.join(origin_dir, file_name)
dst = os.path.join(train_dir, file_name)
shutil.move(src, dst)
def get_date_list(start=None, end=None):
'''获取两日期间日期列表'''
data_list = []
datestart=datetime.datetime.strptime(start,'%Y-%m-%d')
dateend=datetime.datetime.strptime(end,'%Y-%m-%d')
while datestart<dateend:
datestart+=datetime.timedelta(days=1)
data_list.append(datestart.strftime('%Y-%m-%d'))
return data_list
def get_label(date,typeid):
'''从获取数据库获取标签'''
conn = pymysql.connect(host='ip',
port=3306,
user='**',
password='**',
db='**',
charset='utf8')
cursor = conn.cursor()
# sql = "SELECT result,savedir FROM new_ocr_dir WHERE typeid = '{0}' AND time LIKE '{1}%'".format(typeid,date)
sql = "SELECT result,savedir FROM new_ocr_dir WHERE typeid = {0} AND time LIKE '{1}%'".format(typeid, date)
result = ()
try:
cursor.execute(sql)
result = cursor.fetchall()
except Exception as e:
print("查询数据库失败:{0}".format(e))
return result
def set_label(label, dir, typid, id, date):
'''
设置标签并修改图片分辨率
:param label: 标签
:param dir: 图片原地址
:param outdir: 图片新地址
'''
outdir = "data/{1}_{2}_{3}_{4}.jpg".format(typid, date, id, label)
convertjpg(dir, outdir, 227, 227)
# try:
# with open(dir, 'rb') as f:
# img = f.read()
# with open("data/{0}/{1}_{2}_{3}_{4}.jpg".format(typid, typid, date, id, label), 'wb') as f:
# f.write(img)
# except Exception:
# print(label, dir, typid)
def solve_lable_dir(label_dir):
'''提取标签和地址'''
labels = []
dirs = []
for per in label_dir:
label = json.loads(per[0]).get('result','')
dir = per[1][13:]
labels.append(label)
dirs.append(dir)
return labels, dirs
def make_sample():
'''选定日期和类别,制作带标签的样本集'''
dates = get_date_list('2019-10-12', '2019-12-08')
typeids = ['3200','3060','3050','3040','3000','2050','2040','2000','1050','1040']
for typeid in typeids:
for date in dates:
print(date)
label_dir = get_label(date, typeid)
labels, dirs = solve_lable_dir(label_dir)
for i in range(len(labels)):
label = labels[i]
label = label.replace('|','#')
dir = dirs[i]
set_label(label, dir, typeid, i, date)
if __name__ == '__main__':
make_sample()
# spilt_train_test('data','train','test')