An introduction to k-means clustering


Clustering is a class of unsupervised techniques that attempts to define categories to uncategorized data. This is often a preliminary step for data-naiive clients (clients who did not develop a plan before collecting data, and may not know what they want out of their data), but can also be helpful in getting a better idea of the structure of your data in general. Here we will define k-means clustering (KMC) in a technical and qualitative manner, review the advantages and disadvantages of KMC, and cover an example with synthetic data.

              K-means clustering is an iterative process that uses a user-defined number of categories to place all data in a dataset into one of K categories. As with most machine learning algorithms, the goal is to minimize the error. Here, we define the error as the Within-Cluster Sum of Squares. This is, if we have n observations of d-dimensional data, we want to split the data into k groups that minimize

Where Xi is the ith of the k data sets

Qualitatively, this means that we want to pick k points in such a way that they are all as close as possible to one of k subsets of the data. 

Iteratively, this is performed by initially choosing k points (we will call them anchor points) randomly in the range of the data to represent a class, assigning each of the points in the data into one of the k-sets by calculating the distance for each point to each class (that is k calculations per point) and choosing the class that minimizes that distance; quantitatively

 and lastly calculating the mean of each class and assigning that as the new anchor-point (and then going back to the assignment step). When the points no longer change classes, the algorithm is considered converged.


Visually, this can be demonstrated as follows:

From wikimedia commons:

From wikimedia commons:

K-means clustering has the following advantages and disadvantages. Typically, KMC is fast- on the order of O(K*n*d).  It is also easy to explain, and implement on low-level hardware (that is, with faster but less robust programming languages).  The disadvantages arise from assumptions about the data structure. Primarily that the data are balanced, each cluster has approximately the same variance, and most importantly that k is known.  The first disadvantage can be overcome with weighting, or replicating data (with perturbations).  The second disadvantage is more difficult to overcome, as we will see in the upcoming examples.  For sufficiently small data, the disadvantage of unknown k is best dealt with through a sweep of k-values.


We will now cover an example of k-means clustering.


Let’s start by generating some data

import numpy as np
import matplotlib.plt


mu = 0
sigma = 1

X_1_1 = np.random.normal(loc = mu, scale=sigma,size = 100)
X_1_2 = np.random.normal(loc = mu, scale=sigma,size = 100)

mu = 10
sigma = .8
X_2_1 = np.random.normal(loc = mu, scale=sigma,size = 100)
X_2_2 = np.random.normal(loc = mu, scale=sigma,size = 100) ``

Now, we choose two centroids, and then randomly assign each point to one of two classes

A = np.hstack((X_1_1.reshape(-1,1),X_1_2.reshape(-1,1)))
B = np.hstack((X_2_1.reshape(-1,1),X_2_2.reshape(-1,1)))
X_zip = np.vstack((A,B))

Y= [-1 for x in range(len(X_zip))]
#number of clusters
k = 2

#define initial centroids
(a,b,c,d) = (X_zip[int(np.random.uniform(0,100))][0],
C1 = [a,b]
C2 = [c,d]

#assign each point to a cluster
while i < len(X_zip):
    pt = X_zip[i]
    Y[i] = np.argmin([np.linalg.norm(pt - C1),np.linalg.norm(pt - C2)])

plt.plot(a,b,'x',markersize = 15,c='#3366aa')
plt.plot(c,d,'x',markersize = 15,c='#ee7722')    
colorDict = {0:'#3366aa',1:'#ee7722'}
while i< len(X_zip):
    plt.plot(X_zip[i,0],X_zip[i,1],'o',c =colorDict[Y[i]])

This is k-means clustering! 

Notice where our centroids are (the X symbols); they were randomly chosen from among all the data.  After one step, our data are obviously not well clustered; but we can repeat the 2 step process of centroid finding and cluster assigning

The following code just repeats 5 iterations of this process (5 is all we need for such nicely separated data)

numReps = 5
j =0
while j < numReps:
    while i < len(X_zip):
        pt = X_zip[i]
        Y[i] = np.argmin([np.linalg.norm(pt - C1),np.linalg.norm(pt - C2)])

    #recalculate new centroid
    allZip = np.hstack((X_zip,np.array(Y).reshape(-1,1)))
    C1_arr = [x[0:2] for x in allZip if x[2] ==0]
    C2_arr = [x[0:2] for x in allZip if x[2] ==1]
    C1 = np.mean(C1_arr,0)
    C2 = np.mean(C2_arr,0)

plt.plot(C1[0],C1[1],'x',markersize = 15,c='#3366aa')
plt.plot(C2[0],C2[0],'x',markersize = 15,c='#ee7722')    
colorDict = {0:'#3366aa',1:'#ee7722'}
while i< len(X_zip):
    plt.plot(X_zip[i,0],X_zip[i,1],'o',c =colorDict[Y[i]])

And our final result looks like

So that's k-means clustering in a nutshell.  We worked with really nice data in 2D, so these algorithms are quick and accurate.  In reality, data may have dozens of clusters and dimensions.  They may even have an unknown number of clusters!  Advanced techniques can include sweeping across a variety of clusters, or continuous subsampling of data points to improve speed.