-
Notifications
You must be signed in to change notification settings - Fork 4
/
test_human.py
141 lines (130 loc) · 4.76 KB
/
test_human.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import PIL
def invert_image(input_image):
"""inverts the input image color values. Returns image
of same dim."""
if not type(input_image) == type(np.array([0])):
input_image = np.array(input_image)/256
return (input_image - 256) * -1
def normalize_image(input_image):
"""Takes a PIL image and "normalizes" its pixel values;
returns a numpy array of same shape, with min value 0 and
max value 256.
"""
if not type(input_image) == type(np.array([0])):
array = np.array(input_image)/256
else:
array = input_image/256
min_ = np.min(np.min(array))
array = array - min_
max_ = np.max(np.max(array))
array = array * 1/max_ * 256
return array
def get_image(id_number):
"""Takes id number and returns an image."""
path = './tp_images/' + id_number + '.bmp'
try:
img = PIL.Image.open(path)
except:
try:
path = './fp_images/' + id_number + '.bmp'
img = PIL.Image.open(path)
except:
raise Exception('Error: No file associated with ', id_number)
return np.array(img)
def remove_ticks(ax_obj):
"""takes an ax object from matplotlib and removes ticks."""
ax_obj.tick_params(
axis='both',
which='both',
bottom='off',
top='off',
labelbottom='off',
right='off',
left='off',
labelleft='off'
)
return ax_obj
def show_example(id_number):
"""Takes example id number and shows it for user inspection."""
img = get_image(id_number)
#Plot "zoomed in"
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(2, 56/80))
ax[0].set_title('natural')
ax[0].imshow(img, cmap='Greys')
ax[0] = remove_ticks(ax[0])
ax[1].set_title('inverted')
ax[1].imshow(invert_image(img), cmap='Greys')
ax[1] = remove_ticks(ax[1])
plt.tight_layout()
plt.show();
#Plot "actual size"
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(2, 28/80))
ax[0].set_title('natural')
ax[0].imshow(img, cmap='Greys')
ax[0] = remove_ticks(ax[0])
ax[1].set_title('inverted')
ax[1].imshow(invert_image(img), cmap='Greys')
ax[1] = remove_ticks(ax[1])
plt.show();
return
def train(proposals):
"""Takes all proposals (df) as input and displays information to "train" user
in an interactive session.
"""
sample = proposals.sample(1)
idx = sample.index
proposals = proposals.drop(idx, axis=0)
if sample.iloc[0].crater==1:
print('This is an example of a positive candidate (true crater).')
elif sample.iloc[0].crater==0:
print('This is an example of a negative candidate (false crater).')
print('press q to cycle through images')
show_example(sample.iloc[0]['id'])
response = None
while response not in ['y', 'n']:
response = input('Would you like to see another example? (y/n)')
if response == 'y':
return train(proposals)
elif response == 'n':
return proposals
def get_result(id_number):
"""Takes id number, shows user, accepts user
input, and returns 0, 1 for guess.
"""
print('showing example... (press q to cycle through images)')
show_example(id_number)
result = input('Is it a crater? (y/n) (q to quit, a to see again)')
if result not in ['y', 'n', 'q', 'a']:
print('sorry, invalid input. must be: y, n, q, or a.')
return(get_result(id_number))
elif result == 'y':
return 1
elif result == 'n':
return 0
elif result == 'q':
quit()
elif result == 'a':
return(get_result(id_number))
if __name__ in '__main__':
print('Welcome to the crater identification test!')
print('The purpose of this program is to test the human ability to classify craters from non-crater proposals.')
print('You will first be shown as many examples as you wish. You will then be prompted to begin classifying crater candidates.')
input('OK (press enter)')
proposals = pd.read_csv('proposals.csv')
results = pd.DataFrame(columns=list(proposals.columns)+['prediction'])
train(proposals)
save_path = input('What filename would you like to save the results with? (exclude extension)')
while True:
sample = proposals.sample(1)
idx = sample.index
proposals = proposals.drop(idx, axis=0)
next_result = pd.DataFrame(columns=proposals.columns, index=[len(results)], data=sample.values)
id_ = sample['id'].iloc[0]
prediction = get_result(id_)
next_result['prediction'] = prediction
results = pd.concat([results, next_result], axis=0)
print('{} results recorded so far.'.format(len(results)))
results.to_csv('{}.csv'.format(save_path), index=False)