-
Notifications
You must be signed in to change notification settings - Fork 34
/
kmeans.py
68 lines (52 loc) · 2.17 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
""" A class for K-Means clustering """
__author__ = 'Zhang Zhang'
__email__ = '[email protected]'
import daal.algorithms.kmeans as kmeans
from daal.algorithms.kmeans import init
from daal.data_management import HomogenNumericTable
import numpy as np
class KMeans:
def __init__(self, nclusters, randomseed = None):
"""Initialize class parameters
Args:
nclusters: Number of clusters
randomseed: An integer used to seed the random number generator
"""
self.nclusters_ = nclusters
self.seed_ = 1234 if randomseed is None else randomseed
self.centroids_ = None
self.assignments_ = None
self.goalfunction_ = None
self.niterations_ = None
def compute(self, data, centroids = None, maxiters = 100):
"""Compute K-Means clustering for the input data
Args:
data: Input data to be clustered
centroids: User defined input centroids. If None then initial
centroids will be randomly chosen
maxiters: The maximum number of iterations
"""
if centroids is not None:
# Create an algorithm object for centroids initialization
init_alg = init.Batch_Float64RandomDense(self.nclusters_)
# Set input
init_alg.input.set(init.data, data)
# Set parameters
init_alg.parameter.seed = self.seed_
# Compute initial centroids
self.centroids_ = init_alg.compute().get(init.centroids)
else:
self.centroids_ = centroids
# Create an algorithm object for clustering
clustering_alg = kmeans.Batch_Float64LloydDense(
self.nclusters_,
maxiters)
# Set input
clustering_alg.input.set(kmeans.data, data)
clustering_alg.input.set(kmeans.inputCentroids, self.centroids_)
# compute
result = clustering_alg.compute()
self.centroids_ = result.get(kmeans.centroids)
self.assignments_ = result.get(kmeans.assignments)
self.goalfunction_ = result.get(kmeans.goalFunction)
self.niterations_ = result.get(kmeans.nIterations)