K-Means clustering is one of the most powerful clustering algorithms in the Data Science and Machine Learning world. It is very simple, yet it delivers wonderful results. And because clustering is a very important step for understand a dataset, in this article we are going to discuss what is clustering, why do we need it and what is k-means clustering going to help us with in data science.

Light Work
Photo by Émile Perron / Unsplash

Interested in more stories like this? Follow me on Twitter at @b_dmarius and I'll post there every new article.

Article overview:

  • What is Clustering
  • Supervised and Unsupervised Machine Learning - What is Unsupervised Machine Learning
  • Clustering applications
  • K-Means Clustering explained
  • K-Means Clustering Algorithm
  • K-Means Clustering Implementation using Scikit-Learn and Python

What is Clustering

Clustering is the task of grouping data into two or more groups based on the properties of the data, and more exactly based on certain patterns which are more or less obvious in the data. The goal is to find those patterns in the data that help us be sure that, given a certain item in our dataset, we will be able to correctly place the item in a correct group, so that it is similar to other items in that group, but different from items in other groups.

That means the clustering actually consists of two parts: one is to identify the groups and the other one is to try as much as possible to place every item in the correct group.

The ideal result for a clustering algorithm is that two items in the same group are as similar to each other, while two items from different groups are as different as possible.

Cluster example
Cluster example - Source: Wikipedia

A real-world example would be customer segmentation. As a business selling various type of products/services, it would be very difficult to find the perfect business strategy for each and every customer. But we can be smart about it and try to group our customers into a few subgroups, understand what those customers all have in common and adapt our business strategy for every group. Coming up with the wrong business strategy to a customer would mean perhaps losing that customer, so it's important that we've achieved a good clustering of our market.

Supervised and Unsupervised Machine Learning - What is Unsupervised Machine Learning

Unsupervised Machine Learning is a type of Machine Learning Algorithm that tries to infer patterns in the data without any prior knowledge. The opposite is Supervised Machine Learning, where we have a training set and the algorithm will try to find the patterns in the data by matching inputs to predefined outputs.

The reason I am writing about this is because clustering an Unsupervised Machine Learning Task. When applying a clustering algorithm, we don't know the categories a priori(although we can set the number of categories that we want to be identified).

The categories will emerge from the algorithm analyzing the data. Because of that, we may call clustering an exploratory machine learning task, because we only know the number of categories, but not their properties. Then we can try playing around with different numbers of categories and see if our data is better clustered or not.

And then we have to understand our clusters, which may actually be the most different task. Let's reuse the example with customer segmentation. Let's say we have run a clustering algorithm and we get our customers clustered into 3 groups. But what are those groups? Why has the algorithm decided that these customers fit into this group, and those customers fit into that group? This is the part where you need very skilled data scientists along with people who understand your business very well. They will look at the data, try to analyze a few items in each category and try to guess a few criteria. Then they will extrapolate from there once they find a valid pattern.

What happens when we get a new customer? We have to put this customer into one of the clusters we already have, so we can run the data about this customer through our algorithm and the algorithm will fit our customer into one of our clusters. Also, in the future, after we acquire a large number of new customers, we might need to rebuild our clusters – maybe new clusters will appear or old clusters will disappear.

Clustering applications

What are some common clustering applications? Before we fall in love with clustering algorithms, we need to understand when we can use them and when not.

The most common use case is the one we've already discussed: customer/market segmentation. Companies run these types of analysis all the time so they can understand their customers and markets and tailor their business strategies, services and products for a better fit.

Another common use case is represented by information extraction tasks. In information extraction tasks we often need to find relations between entities, words, documents and so on. Now, if your intuition tells you we have a higher chance of finding relations between items which are more similar to each other, then you're right, because clustering our data points might help us figure out where to look for relations. (Note: if you want to read more about information extraction, you can also try this article: Python NLP Tutorial: Information Extraction and Knowledge Graphs).

Another very popular use cases is to use clustering for image segmentation. Image segmentation is the task of looking at an image and trying to identify different items in that image. We can use clustering to analyze the pixels of the image and to identify which item in the image contains which pixel.

K-Means Clustering explained

The K-Means clustering algorithm is an iterative clustering algorithm which tries to asssign data points to exactly one cluster of the K number of clusters we predefine.

As with any other clustering algorithm, it tries to make the items in one cluster as similar as possible, while also making the clusters as different from each other as possible. It does so by making sure that the sum of squared distance between the data points in a cluster and the centroid of that cluster is minimum. The centroid of the cluster is the mean value of all the values in the cluster. You also get from this paragraph where the name K-Means comes from.

In more technical terms, we try to make the data into one cluster as homogenuous as possible, while making the cluster as heterogenuous as possible. The K number is the number of clusters we try to obtain. We can play around with K until we are satisfied with our results.

K-Means Clustering algorithm

The K-Means Clustering algorithm works with a few simple steps.

  1. Assign the K number of clusters
  2. Shuffle the data and randomly assign each data point to one of the K clusters and assign initial random centroids.
  3. Calculate the squared sum between each data point and all centroids.
  4. Reassign each data point to the closest centroid based on the computation for step 3.
  5. Reassign the centroid by calculating the mean value for every cluster
  6. Repeat steps 3, 4, 5 until we no longer have to change anything in the clusters

The time needed to run the K-Means Clustering algorithm depends on the size of the dataset, the K number we define and the patterns in the data.

K-Means Clustering Implementation using Scikit-Learn and Python

We are going to use the Sckikit-Learn Python library to run a K-Means Clustering algorithm on a small dataset.

Dataset for K-Means Clustering algorithm

The data consists of 3 texts about London, Paris and Berlin. We are going to extract the summary sections of the Wikipedia articles about these 3 cities and run them throught our clustering algorithm.

We will then provide 3 new sentences of our own and check if they are correctly assigned to individual clusters. If that happens, then we will know our clustering algorithm worked.

K-Means Clustering implementation

First let's install our dependencies.

# Sklearn library for our cluster
pip3 install scikit-learn
# We will use nltk(Natural Language Toolkit) to remove stopwords from the text
pip3 install nltk
# We will use the wikipedia library to download our texts from the Wikipedia pages
pip3 install wikipedia

Now let's define a small class to help use gather the texts from the Wikipedia pages. We will store the text into 3 files on our local so that we don't download the texts again everytime we run the algorithm. Use class as it is right now for your first run of the algorithm and for a second run you can comment lines 8-12 and uncomment lines 13-15.

import wikipedia

class TextFetcher:

    def __init__(self, title):
        self.title = title
        page = wikipedia.page(title) # 8
        f = open(title + ".txt", "w") # 9
        f.write(page.summary) # 10
        f.close() # 11
        self.text = page.summary # 12
        #f = open(title + ".txt", "r")
        #self.text = f.read()

    def getText(self):
        return self.text

Now let's build the dataset. We will take the text about each city and remove stopwords. Stopwords are words we usually filter out before each text processing task.  They are very common words in the English language which do not bring any value, any meaning to a text. Because most of them are used everywhere, they will prevent us from clustering our texts correctly.

from text_fetcher import TextFetcher
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
import nltk

def preprocessor(text):
    tokens = word_tokenize(text)
    return (" ").join([word for word in tokens if word not in stopwords.words()])

if __name__ == "__main__":
    textFetcher = TextFetcher("London")
    text1 = preprocessor(textFetcher.getText())
    textFetcher = TextFetcher("Paris")
    text2 = preprocessor(textFetcher.getText())
    textFetcher = TextFetcher("Berlin")
    text3 = preprocessor(textFetcher.getText())

    docs = [text1, text2, text3]

Word vectorization techniques

It's a known fact that computers are tipically very bad at understanding text, but they are perform way better at working with numbers. Because our dataset is made out of words, we need to transform the words into numbers.

Word embeddings or word vectorization represent a collection of techniques used to assign a word to a vector of real numbers that can be used by Machine Learning for certain purposes, one of which is text clustering.

The Scikit-Learn library contains a few word vectorizers, but for this article we are going to choose the TfidfVectorizer.

    tfidf_vectorizer = TfidfVectorizer()
    tfidf = tfidf_vectorizer.fit_transform(docs)

Now it's time to apply our K-Means cluster algorithm. We are lucky that the Scikit-Learn has a very good implementation of the K-Means algorithm and we are going to use that. Because we know that we want to classify our texts into 3 categories(one for each city) we will define the K value to be 3.

kmeans = KMeans(n_clusters=3).fit(tfidf)
print (kmeans)

# Output: [0 1 2]

I know, it's that simple! Now what does our output mean? Simply put, those 3 values are our 3 clusters.

To test them, we can now provide 3 texts about which we know for sure they should be in different clusters and see if they are assigned correctly. We have to make sure we don't forget to also vectorize these 3 texts so that our algorithm can understand them.

    test = ["This is one is about London.", "London is a beautiful city", "I love London"]
    results = kmeans.predict(tfidf_vectorizer.transform(test))
    print (results)
    # Prints [0, 0, 0]

    test = ["This is one is about Paris.", "Paris is a beautiful city", "I love Paris"]
    results = kmeans.predict(tfidf_vectorizer.transform(test))
    print (results)
    # Prints [2, 2, 2]

    test = ["This is one is about Berlin.", "Berlin is a beautiful city", "I love Berlin"]
    results = kmeans.predict(tfidf_vectorizer.transform(test))
    # Prints [1, 1, 1]

    test = ["This is about London", "This is about Paris", "This is about Vienna"]
    results = kmeans.predict(tfidf_vectorizer.transform(test))
    print (results)
    # Prints [0, 2, 1]

And it seems our clustering worked! Now let's suppose we would get another text about which we don't know anything. We can pass that text through our classifier and see in which category it fits. I see this as a very good and efficient text classifier.


Today we discussed the K-Means Clustering algorithm. We first went through a general overview about Clustering algorithms and Unsupervised Machine Learning techniques, then we discussed the K-Means Algorithm and we implemented it using the Scikit-Learn Python library.

Thank you so much for reading this! Interested in more stories like this? Follow me on Twitter at @b_dmarius and I'll post there every new article.