diff --git a/Kohonen.java b/Kohonen.java index 5939cdc..d92ea4d 100644 --- a/Kohonen.java +++ b/Kohonen.java @@ -50,6 +50,7 @@ public class Kohonen extends JFrame private JTextField xDim = new JTextField(3); private JTextField yDim = new JTextField(3); private JTextField epochs = new JTextField(4); + private JComboBox colorBox = new JComboBox(new String[]{"Red", "Green", "Blue"}); /** * Read in the input csv data @@ -155,6 +156,10 @@ public void readFile(File input) xVal = Integer.parseInt(xDim.getText()); yVal = Integer.parseInt(yDim.getText()); epochVal = Integer.parseInt(epochs.getText()); + if(xVal <= 0 || yVal <= 0 || epochVal <= 0) + { + throw new NumberFormatException(); + } } catch(NumberFormatException nfe) { @@ -162,20 +167,91 @@ public void readFile(File input) return; } // Create the SOM object + long startTime = System.nanoTime(); + //JFrame status = new JFrame("Training the Kohonen network..."); + //status.setVisible(true); + //status.setSize(100, 100); + JOptionPane.showMessageDialog(null,"Training the Kohonen network. Please be patient.\nThis may take a while for large datasets, networks, and epochs."); SOM training = new SOM(validData, xVal, yVal, epochVal); - // Prepare the object for training (i.e. weights and distances) - //training.init(); - // Train the object and store the result + // Train the Kohonen network training.train(); - - - - // Scale will fail if a column is all the same value - // Ensure non-zero variance of columns - - // Now scale the grid + long endTime = System.nanoTime(); + //System.out.println("" + (endTime - startTime / 1000000)); + // Plot the network as a heatmap + //status.setVisible(false); + plot(xVal, yVal, training.getNodes()); } + /** + * Plot the trained Kohonen network. + * + * This method takes as inputs the x dimension + * of the map, the y dimension of the map, and + * an array for the node labels of each observation + * and produces a grid heatmap showing the count + * of observations assigned to each node. + * + * + * */ + public void plot(int xDim, int yDim, int [] nodes) + { + // Determine the number of observations assigned to each node + int [] counts = new int[xDim * yDim]; + for(int i = 0; i < nodes.length; i++) + { + counts[nodes[i]] += 1; + } + // Determine the maximum and minimum counts for shading + int maxCount = 0; + int minCount = Integer.MAX_VALUE; + for(int i = 0; i < counts.length; i++) + { + if(counts[i] < minCount) + { + minCount = counts[i]; + } + if(counts[i] > maxCount) + { + maxCount = counts[i]; + } + //System.out.println("Node " + i + ": " + counts[i]); + } + // Create the window + JFrame map = new JFrame ("Kohonen network"); + // Set the output resolution, don't let it exceed 1024x768 + //double aspectRatio = (double)xdim / yDim; + String col = colorBox.getSelectedItem().toString(); + map.setSize(800, 600); + map.setLayout (new GridLayout(xDim, yDim)); + + // Plot the counts + JButton [] jB = new JButton[counts.length]; + // Set the colors + float r = 0; + float g = 0; + float b = 0; + for(int i = 0; i < counts.length; i++) + { + if(col.equals("Red")) + { + r = (float)counts[i] / maxCount; + } + else if(col.equals("Green")) + { + g = (float)counts[i] / maxCount; + } + else + { + b = (float)counts[i] / maxCount; + } + jB[i] = new JButton(""); + jB[i].setBackground(new Color(r, g, b)); + jB[i].setOpaque(true); + map.add(jB[i]); + } + map.setVisible(true); + } + class ButtonListener implements ActionListener { /** @@ -214,8 +290,7 @@ public Kohonen() //Color JLabel colorLabel = new JLabel("Shading"); window.add(colorLabel); - String [] colors = new String[]{"Red", "Green", "Blue"}; - JComboBox colorBox = new JComboBox(colors); + //String [] colors = new String[]{"Red", "Green", "Blue"}; window.add(colorBox); //Epochs