-
Notifications
You must be signed in to change notification settings - Fork 0
/
MyClassifier.java
395 lines (344 loc) · 11.5 KB
/
MyClassifier.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
//if standard deviation is 0, set prob dens fctn to 1!!!
import java.io.*;
import java.util.*;
public class MyClassifier {
public static void main(String[] args) {
File training_file;
File testing_file;
Scanner scan_tr;
Scanner scan_test;
ArrayList<double[]> training_list;
ArrayList<double[]> testing_list;
// check num args
if (args.length != 3) {
// do sth
System.out.println("*Too many/little args" + args.length);
return;
}
// parse args
// assumes correct input format
// think about other parsing corners & edges
String training_path = args[0];
String testing_path = args[1];
String algo = args[2];
// open files and scanners
try {
training_file = new File(training_path);
testing_file = new File(testing_path);
scan_tr = new Scanner(training_file);
scan_test = new Scanner(testing_file);
}
catch (FileNotFoundException e) {
e.printStackTrace();
return;
}
training_list = makeTrainingList(scan_tr);
testing_list = makeTestingList(scan_test);
//implement one of the 2 classifiers
if (algo.equals("NB")) {
naiveBayes(training_list, testing_list);
}
else if (get_k(algo) != 0) {
kNearestNeighbour(get_k(algo), training_list, testing_list);
}
else {
System.out.println("*Something went wrong: kNN format*");
}
//System.out.println("*End of program*");//!
}
//end of Main method
//supporting methods:
private static void kNearestNeighbour(int k, ArrayList<double[]> training, ArrayList<double[]> testing) {
//!System.out.println("Starting kNN with k = " + k);//!
// vars: Num, lists
// for each example in testing list
// calculate Euclidean distance to each training example
// save N shortest distances
// look at class of N selected training examples
// majority vote on class of testing example
// go through all TESTING examples
for(double[] testing_ex : testing) {
//this arraylist stores arrays of doubles, of the form [distance, class], with class 0 or 1
ArrayList<double[]> k_nearest = new ArrayList<double[]>();
//!System.out.println("New test");//!
//for each test line, go through all TRAINING examples
int n = 0;
for (double[] training_ex : training) {
//make new [,] instance
double distance = dist_Euclidean(testing_ex, training_ex);
double classs = training_ex[training_ex.length - 1]; //[training_ex.length - 1] should be 8
double[] combined = {distance, classs};
if (n < k) {
k_nearest.add(combined);
//!System.out.println("Adding to nearest: " + Arrays.toString(combined));//!
}
else {
int furthest = find_furthest_index(k_nearest);
if (combined[0] < k_nearest.get(furthest)[0]) {
//replaces furthest element in k_nearest by current var
//!System.out.println("Adding to nearest: " + Arrays.toString(combined));//!
//!System.out.println("Removing from nearest: " + Arrays.toString(k_nearest.get(furthest)));//!
k_nearest.set(furthest, combined);
//display length of AList?
}
}
n++;
}
//at this point we have the k nearest training examples to the given test
//count num of yes & no
int num_yes = 0;
int num_no = 0;
for (double[] neighbour : k_nearest) {
if (neighbour[1] == 1) {
num_yes++;
}
else if (neighbour[1] == 0) {
num_no++;
}
else {
System.out.println("Something is wrong with yes/no counting");
}
}
//display class yes or no, decided by majority vote (tiebreaker is yes)
if (num_yes >= num_no) {
System.out.println("yes");
}
else {
System.out.println("no");
}
}
//!System.out.println("*End of kNN*");//!
}
private static void naiveBayes(ArrayList<double[]> training, ArrayList<double[]> testing) {
//find mean and stdev for each attribute (1 to n --i.e. 0 to n-1), for yes & no respectively
int num_attr = testing.get(0).length;
ArrayList<double[]> stats = getStats(training, num_attr);
double[] meanyes = stats.get(0);
double[] meanno = stats.get(1);
double[] sdyes = stats.get(2);
double[] sdno = stats.get(3);
double n_yes = stats.get(4)[0];
double n_no = stats.get(4)[1];
double n_total = n_yes + n_no;
//check if these 2 are right!
double probyes = n_yes/n_total;
double probno = n_no/n_total;
double[] pdfyes = new double[num_attr];
double[] pdfno = new double[num_attr];
// System.out.println("Mean and stdev at start is: ");
// System.out.println(Arrays.toString(meanyes));
// System.out.println(Arrays.toString(meanno));
// System.out.println(Arrays.toString(sdyes));
// System.out.println(Arrays.toString(sdno));
// double var_test = 0;
// double meanyes_test = 0;
// double sdyes_test = 0;
//for each test, compute Bayes method maths with prob dens function
for (double[] testing_ex : testing) {
//!!there should be a diff prob function for each index (attribute)
for (int i = 0; i < num_attr; i++) {
if (sdyes[i] == 0 || Double.isNaN(sdyes[i])) {
pdfyes[i] = 1;
//System.out.println("pdf of yes is 1");
}
else {
pdfyes[i] = probFunction(testing_ex[i], meanyes[i], sdyes[i]);
//System.out.println("pdf of yes is good!");
}
if (sdno[i] == 0 || Double.isNaN(sdno[i])) {
pdfno[i] = 1;
}
else {
pdfno[i] = probFunction(testing_ex[i], meanno[i], sdno[i]);
}
}
//multiply the pdfs of the different attributes together
double pdfyes_all = 1;
double pdfno_all = 1;
for (int i = 0; i < num_attr; i++) {
pdfyes_all = pdfyes_all*pdfyes[i];
pdfno_all = pdfno_all*pdfno[i];
}
double yes_func = pdfyes_all*probyes;
double no_func = pdfno_all*probno;
if (yes_func >= no_func) {
System.out.println("yes");
}
else if (no_func > yes_func) {
System.out.println("no");
}
else {
System.out.println("Something is wrong with yes/no choosing");
}
}
//System.out.println("End of NB. Starting test...");
//System.out.println(var_test + meanyes_test + sdyes_test);
//System.out.println(probFunction(var_test, meanyes_test, sdyes_test));
}
public static double probFunction(double val, double mean, double std) {
double function = 0;
double sub_fctn = -Math.pow(val-mean,2)/(2*Math.pow(std, 2));
function = 1/(std*Math.sqrt(2*Math.PI))*Math.pow(Math.E,sub_fctn);
//System.out.println("Calculated function is: " + function);
return function;
}
public static ArrayList<double[]> getStats(ArrayList<double[]> training, int num_attr) {
ArrayList<double[]> stats = new ArrayList<double[]>();
double[] sumyes = new double[num_attr];
double[] sumno = new double[num_attr];
double[] meanyes = new double[num_attr];
double[] meanno = new double[num_attr];
double[] diffyes = new double[num_attr];
double[] diffno = new double[num_attr];
double[] sdyes = new double[num_attr];
double[] sdno = new double[num_attr];
double n_yes = 0;
double n_no = 0;
double[] nums = {n_yes, n_no};
//calc mean
for (double[] training_ex : training) {
//if example is class yes
if (training_ex[num_attr] == 1) {
for (int i = 0; i < num_attr; i++) {
//do I have to initialise this to all 0 for sum?
sumyes[i] += training_ex[i];
}
n_yes++;
}
//if class no
else if (training_ex[num_attr] == 0) {
for (int i = 0; i < num_attr; i++) {
sumno[i] += training_ex[i];
}
n_no++;
}
else {
System.out.println("Sometihng went wrong: Not class yes or no");
}
}
//divide each index by n
for (int i = 0; i < num_attr; i++) {
meanyes[i] = sumyes[i] / n_yes;
meanno[i] = sumno[i] / n_no;
//do some testing
}
//Calculate standard deviation
for (double[] training_ex : training) {
//if example is class yes
if (training_ex[num_attr] == 1) {
for (int i = 0; i < num_attr; i++) {
diffyes[i] += Math.pow((training_ex[i] - meanyes[i]),2);
}
}
//if class no
else if (training_ex[num_attr] == 0) {
for (int i = 0; i < num_attr; i++) {
diffno[i] = diffno[i] + Math.pow((training_ex[i] - meanno[i]),2);
}
}
else {
System.out.println("Something went wrong: Not class yes or no");
}
}
for (int i = 0; i < num_attr; i++) {
sdyes[i] = Math.sqrt(diffyes[i] / (n_yes - 1));
sdno[i] = Math.sqrt(diffno[i] / (n_no - 1));
//do some testing
}
nums[0] = n_yes;
nums[1] = n_no;
stats.add(meanyes);
stats.add(meanno);
stats.add(sdyes);
stats.add(sdno);
stats.add(nums);
return stats;
}
public static ArrayList<double[]> makeTrainingList(Scanner scan_tr) {
ArrayList<double[]> training_data = new ArrayList<double[]>();
// add all examples from file into TRAINING_data list
while (scan_tr.hasNextLine()) {
String line = scan_tr.nextLine();
String[] entry_str = line.replaceAll("","").split(",");
// change class vars to non-numeric
if (entry_str[entry_str.length - 1].contentEquals("yes")) {
entry_str[entry_str.length - 1] = Integer.toString(1);
} else {
// test for case where isn't no?
entry_str[entry_str.length - 1] = Integer.toString(0);
}
//testing
//System.out.println("String form: " + Arrays.toString(entry_str));
// change everything to number
double[] entry = new double[entry_str.length];
for (int i = 0; i < entry_str.length; i++) {
double sth = Double.parseDouble(entry_str[i]);
entry[i] = sth;
}
//!System.out.println("Digit form (training file): " + Arrays.toString(entry));//!
training_data.add(entry);
}
return training_data;
}
public static ArrayList<double[]> makeTestingList(Scanner scan_test) {
ArrayList<double[]> testing_data = new ArrayList<double[]>();
// add all examples from file into TRAINING_data list
while (scan_test.hasNextLine()) {
//read in line, convert to String array
String line = scan_test.nextLine();
String[] entry_str = line.replaceAll("","").split(",");
//System.out.println("String form: " + Arrays.toString(entry_str));
// convert to number
double[] entry = new double[entry_str.length];
for (int i = 0; i < entry_str.length; i++) {
double sth = Double.parseDouble(entry_str[i]);
entry[i] = sth;
}
//!System.out.println("Digit form (testing file): " + Arrays.toString(entry));//!
testing_data.add(entry);
}
return testing_data;
}
private static double dist_Euclidean(double[] a, double[] b) {
// for first n attributes:
double dist = 0;
// assumes correct num of attributes
//!System.out.println("Calculating Euclidean distance...");//!
//!System.out.println(Arrays.toString(a));//!
//!System.out.println(Arrays.toString(b));//!
for (int i = 0; i < 8; i++) {
//difference squared of each dimension
dist += Math.pow(Math.abs(a[i] - b[i]),2);
}
dist = Math.sqrt(dist);
//!System.out.println("Distance is " + dist);//!
return dist;
}
// check if fits kNN format
// have accounted for wrong String length?
private static int get_k(String algo) {
try {
//System.out.println(algo);
String[] split_str = algo.split("");
int num_neighbours = Integer.parseInt(split_str[0]);
if (split_str[1].equals("N") && split_str[2].equals("N")) {
return num_neighbours;
}
} catch (Exception exc) {
exc.printStackTrace();
}
//doesn't fit format
return 0;
}
private static int find_furthest_index(ArrayList<double[]> k_nearest) {
int furthest_index = 0;
int curr_index = 0;
for (double[] neighbour : k_nearest) {
if(neighbour[0] > k_nearest.get(furthest_index)[0]) {
furthest_index = curr_index;
}
curr_index++;
}
return furthest_index;
}
}