Basic Concepts

The \(k\)-means clustering algorithm attempts to solve the following optimization problem:

\[ \underset{C, \, \{m_k\}_{k=1}^K}\min \sum_{k=1}^K \sum_{C(i) = k} \lVert x_i - m_k \rVert^2, \] where \(C(\cdot): \{1, \ldots, n\} \rightarrow \{1, \ldots, K\}\) is a cluster assignment function, and \(m_k\)’s are the cluster means. To solve this problem, \(k\)-means uses an iterative approach that updates \(C(\cdot)\) and \(m_k\)’s alternatively. Suppose we have a set of six observations.

We first randomly assign them into two clusters (initiate a random \(C\) function). Based on this cluster assignment, we can calculate the corresponding cluster mean \(m_k\)’s.

Then we will assign each observation to the closest cluster mean. In this example, only the blue point on the top will be moved to a new cluster. Then the cluster means can then be recalculated.

When there is nothing to move anymore, the algorithm stops. Keep in mind that we started with a random cluster assignment, and this objective function is not convex. Hence we may obtain different results if started with different values. The solution is to try different starting points and use the best final results. This can be tuned using the nstart parameter in the kmeans() function.

    # some random data
    set.seed(1)
    mat = matrix(rnorm(1000), 50, 20)
    
    # if we use only one starting point
    kmeans(mat, centers = 3, nstart = 1)$tot.withinss
## [1] 885.8913
    # if we use multiple starting point and pick the best one
    kmeans(mat, centers = 3, nstart = 100)$tot.withinss
## [1] 883.8241

Example 1: iris data

We use the classical iris data as an example. This dataset contains three different classes, but the goal here is to learn the clusters without knowing the class labels.

    # plot the original data using two varaibles
    head(iris)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1          5.1         3.5          1.4         0.2  setosa
## 2          4.9         3.0          1.4         0.2  setosa
## 3          4.7         3.2          1.3         0.2  setosa
## 4          4.6         3.1          1.5         0.2  setosa
## 5          5.0         3.6          1.4         0.2  setosa
## 6          5.4         3.9          1.7         0.4  setosa
    library(ggplot2)
    ggplot(iris, aes(Petal.Length, Petal.Width, color = Species)) + geom_point()

The last two variables in the iris data carry more information on separating the three classes. Hence we will only use the Petal.Length and Petal.Width.

    library(colorspace)
    par(mar = c(3, 2, 4, 2), xpd = TRUE)
    MASS::parcoord(iris[, -5], col = rainbow_hcl(3)[iris$Species], 
                   var.label = TRUE, lwd = 2)
    legend(x = 1.2, y = 1.3, cex = 1,
       legend = as.character(levels(iris$Species)),
        fill = rainbow_hcl(3), horiz = TRUE)

Let’s perfrom the \(k\)-means clustering

    set.seed(1)

    # k mean clustering
    iris.kmean <- kmeans(iris[, 3:4], centers = 3, nstart = 20)
    
    # the center of each class
    iris.kmean$centers
##   Petal.Length Petal.Width
## 1     1.462000    0.246000
## 2     5.595833    2.037500
## 3     4.269231    1.342308
    # the within cluster variation 
    iris.kmean$withinss
## [1]  2.02200 16.29167 13.05769
    # the between cluster variation 
    iris.kmean$betweenss
## [1] 519.524
    # plot the fitted clusters vs. the truth
    iris.kmean$cluster <- as.factor(iris.kmean$cluster)
    
    ggplot(iris, aes(Petal.Length, Petal.Width, color = Species)) + # true cluster
            geom_point(alpha = 0.3, size = 3.5) + 
            scale_color_manual(values = c('red', 'green', 'blue')) +
            geom_point(col = c("blue", "green", "red")[iris.kmean$cluster]) # fitted cluster 

Example 2: clustering of image pixels

Let’s first load and plot an image of Leo.

    library(jpeg)
    img<-readJPEG("../data/leo.jpg")
    
    # generate a blank image
    par(mar=rep(0.2, 4))
    plot(c(0, 400), c(0, 500), xaxt = 'n', yaxt = 'n', bty = 'n', pch = '', ylab = '', xlab = '')

    rasterImage(img, 0, 0, 400, 500)

For a jpg file, each pixel is stored as a vector with 3 elements — representing red, green and blue intensities. However, by the way, that this objective img being constructed, it is stored as a 3d array. The first two dimensions are the height and width of the figure. We need to vectorize them and treat each pixel as an observation.

    dim(img)
## [1] 500 400   3
    # this apply function applies vecterization to each layer (r/g/b) of the image. 
    img_expand = apply(img, 3, c)

    # and now we have the desired data matrix
    dim(img_expand)
## [1] 200000      3

Before performing the \(k\)-mean clustering, let’s have a quick peek at the data in a 3d view. Since there are too many observations, we randomly sample a few.

    library(scatterplot3d)
    set.seed(1)
    sub_pixels = sample(1:nrow(img_expand), 1000)
    sub_img_expand = img_expand[sub_pixels, ]
    
    scatterplot3d(sub_img_expand, pch = 19, xlab = "red", ylab = "green", zlab = "blue", 
                  color = rgb(sub_img_expand[,1], sub_img_expand[,2], sub_img_expand[,3]))

The next step is to perform the \(k\)-mean and obtain the cluster label. For example, let’s try 5 clusters.

    kmeanfit <- kmeans(img_expand, 5)

    # to produce the new graph, we simply replicate the cluster mean for all observations in the same cluster
    new_img_expand = kmeanfit$centers[kmeanfit$cluster, ]
    
    # now we need to convert this back to the array that can be plotted as an image. 
    # this is a lazy way to do it, but get the job done
    new_img = img
    new_img[, , 1] = matrix(new_img_expand[,1], 500, 400)
    new_img[, , 2] = matrix(new_img_expand[,2], 500, 400)
    new_img[, , 3] = matrix(new_img_expand[,3], 500, 400)

    # plot the new image
    par(mar=rep(0.2, 4))
    plot(c(0, 400), c(0, 500), xaxt = 'n', yaxt = 'n', bty = 'n', pch = '', ylab = '', xlab = '')

    rasterImage(new_img, 0, 0, 400, 500)

With this technique, we can easily reproduce results with different \(k\) values. Apparently, as \(k\) increases, we get better resolution. \(k = 30\) seems to recover the original image fairly well.

## Warning: did not converge in 10 iterations