-
Notifications
You must be signed in to change notification settings - Fork 3
/
cluster.py
165 lines (124 loc) · 4.56 KB
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import matplotlib.pyplot as plt
from scipy.spatial import cKDTree
import sys
import time
class Point(object):
"""Point class for lat, lon, visited, and cluster"""
def __init__(self, date, time, timezone, lon, lat):
self.date, self.time, self.timezone = date, time, timezone
self.lon, self.lat = float(lon), float(lat)
self.visited = self.cluster = 0
def __str__(self):
return """
Date: """ + self.date + """
Time: """ + self.time + """
TZ: """ + self.timezone + """
Lat: """ + str(self.lat) + """
Lon: """ + str(self.lon) + """
Visit: """ + str(self.visited) + """
Clust: """ + str(self.cluster)
def coord(self):
return [self.lon, self.lat]
class DBSCAN(object):
"""DBSCAN algorithm implementation
- init with eps and minPts value
- use addPoints to add to the tree
- use solve to compute or re-compute clusters"""
def __init__(self, eps, minPts):
self.eps, self.minPts = eps, minPts
self.points = []
def addPoints(self, data):
if len([p for p in data if type(p) != Point]) > 0:
raise ValueError('Error in data passed in. Not array of points.')
for p in data:
self.points.append(p)
self.coords = [p.coord() for p in self.points]
self.tree = cKDTree(self.coords)
self.neighbors = self.tree.query_ball_point(self.coords, self.eps)
def solve(self):
visited = set()
self.clusters = []
numCluster = -1
for i in range(len(self.points)):
# if visited skip it
if i in visited:
continue
# add the index to visited list
visited.add(i)
# if the length of neighbors for i is greater than min points
if len(self.neighbors[i]) >= self.minPts:
# add new cluster
self.clusters.append({i})
# increment counter
numCluster += 1
# set cluster value of point to cluster number
self.points[i].cluster = numCluster
# points in the neighborhood
toMerge = set(self.neighbors[i])
while toMerge:
j = toMerge.pop()
# if j isn't visited visit and add cluster value
if j not in visited:
visited.add(j)
self.points[j].cluster = numCluster
# if minPts add them to the toMerge set
if len(self.neighbors[j]) >= self.minPts:
toMerge |= set(self.neighbors[j])
if not any([j in c for c in self.clusters]):
self.points[j].cluster = numCluster
self.clusters[-1].add(j)
# filter data
def filterPoints(x1, x2, y1, y2, Data):
subset = []
for P in Data:
if P.lon > x1 and P.lon < x2:
if P.lat > y1 and P.lat < y2:
subset.append(P)
return subset
# plot cluster data
def plot(DF):
points = [p.coord() for p in DF]
plt.scatter(*zip(*points), color=[.75] * 3, alpha=.5, s= 15)
colors = 'rbgycm' * 1000
for i, clust in enumerate(scan.clusters):
core = list(clust)
plt.scatter(*zip(*[points[i] for i in xrange(len(points)) if i in clust]), color=colors[i], alpha=1, s=15)
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Plotting Clusters')
plt.show()
# save data to csv for mapping
def save(DF, name):
with open(name + '.csv', 'w') as out:
out.write('lat,lon,cluster\n')
for P in DF:
out.write(str(P.lat) + ',' + str(P.lon) + ',' + str(P.cluster) + '\n')
# read in to multi dimensional array
HISTORY = [line.strip().split(',') for line in open('history-2014.csv')]
Data = []
# head of the file
for row in HISTORY[1:]:
Data.append(Point(row[0], row[1], row[2], row[3], row[4]))
filters = {
'NE': [-80, -65, 35, 50],
'SF': [-125, -119, 35, 40]
}
name = sys.argv[1]
eps = float(sys.argv[2])
minPts = int(sys.argv[3])
# filter data points
DF = filterPoints(filters[name][0],filters[name][1],filters[name][2],filters[name][3], Data)
scan = DBSCAN(eps, minPts)
t0 = time.time()
scan.addPoints(DF)
scan.solve()
t1 = time.time()
print 'Epsilon value: ' + str(eps)
print 'Minimum points: ' + str(minPts)
print 'Number of Points: ' + str(len(scan.points))
print 'Number of Clusters: ' + str(len(scan.clusters))
print 'Time taken: ' + str(t1 - t0)
# save data frame with clusters
save(DF, name)
# plot the points
plot(scan.points)