diff --git a/README.md b/README.md index 192721d..5943c4b 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ kmeans.clusterize(vectors, {k: 4}, (err,res) => { - **options** object: - **k** : number of clusters - **distance** (optional) : custom distance function returning the distance between two points `(a,b) => number`, *default* Euclidian Distance + - **seed** (optional) : value that can be provided to get repeatable cluster generation - **callback** node-style callback taking error and result argument ## Outputs diff --git a/lib/kmeans.js b/lib/kmeans.js index e93c558..bdce711 100644 --- a/lib/kmeans.js +++ b/lib/kmeans.js @@ -18,6 +18,7 @@ The kmeans will return an error if: - The number of different input vectors is smaller than k */ +const seedrandom = require('seedRandom'); const _ = require('underscore'); @@ -60,7 +61,8 @@ class Group { if (this.centroid && this.cluster.length > 0) { this.calculateCentroid(); } else { // random selection - const i = Math.floor(Math.random() * self.indexes.length); + const rand = self.seed == null ? Math.random() : seedrandom(self.seed)(); + const i = Math.floor(rand * self.indexes.length); this.centroidIndex = self.indexes[i]; self.indexes.splice(i, 1); this.centroid = []; @@ -126,6 +128,7 @@ class Clusterize { this.options = options; this.v = this.checkV(vector); this.k = this.options.k; + this.seed = this.options.seed; this.distanceFunction = this.options.distance || euclidianDistance; if (this.v.length < this.k) { const errMessage = `The number of points must be greater than diff --git a/package-lock.json b/package-lock.json index 649db58..14ed099 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1962,6 +1962,11 @@ "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", "dev": true }, + "seedrandom": { + "version": "3.0.5", + "resolved": "https://registry.npmjs.org/seedrandom/-/seedrandom-3.0.5.tgz", + "integrity": "sha512-8OwmbklUNzwezjGInmZ+2clQmExQPvomqjL7LFqOYqtmuxRgQYqOD3mHaU+MvZn5FLUeVxVfQjwLZW/n/JFuqg==" + }, "semver": { "version": "5.7.0", "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.0.tgz", diff --git a/package.json b/package.json index 738c6df..1a72374 100644 --- a/package.json +++ b/package.json @@ -12,6 +12,7 @@ "node": ">= v0.6.0" }, "dependencies": { + "seedrandom": "^3.0.5", "underscore": "^1.9.1" }, "devDependencies": { diff --git a/test/index.js b/test/index.js index 9fde4b2..0f95fb8 100644 --- a/test/index.js +++ b/test/index.js @@ -139,5 +139,17 @@ describe('kmeans', () => { done(); }); }); + + it('should produce consistent output when a seed is provided', (done) => { + kmeans.clusterize(data3D, { k: 3, seed: 42 }, (err, res) => { + + // Verify first value of each centroid is always the same + const cs = res.map(r => r.centroid[0]); + cs[0].should.equal(202.6); + cs[1].should.equal(-10.15); + cs[2].should.equal(39.75); + done(); + }); + }); }); });