Blog Archive

Monday, November 3, 2025

What is k-Means? what is its pros and cons?

 

Q: What is k-Means? what is its pros and cons? Please provide a minimal PyTorch code implementation from scratch. Please provide necessary comment for challenging parts.


A: Here’s a crisp overview + a from-scratch PyTorch k-Means (with k-means++ init). I added comments exactly where people usually get stuck (distance calc, assignments, centroid updates, empty clusters).


What is k-Means (in one breath)

Unsupervised clustering that partitions data into k clusters by iteratively:

1.       assigning each point to its nearest centroid (Euc. distance),

2.       recomputing each centroid as the mean of its assigned points,
until assignments stabilize.


Pros / Cons (fast soundbites)

Pros

·       Simple, fast, scalable (vectorizes well, mini-batchable).

·       Works nicely when clusters are roughly spherical/separated.

·       Few parameters (k, max_iter).

Cons

·       Must choose k.

·       Sensitive to initialization; can get stuck in local minima.

·       Assumes Euclidean geometry; struggles with non-spherical/imbalanced clusters.

·       Sensitive to scale (standardize features!).


Minimal PyTorch Implementation (with comments)



import torch

 

def kmeans_plus_plus_init(X, k, rng):

    """

    k-means++ initialization:

    1) pick one center uniformly at random

    2) pick remaining centers with prob distance^2 to nearest chosen center

    """

    N, D = X.shape

    centers = torch.empty(k, D, device=X.device, dtype=X.dtype)

 

    # pick first center, Start with one random center.

    i0 = torch.randint(low=0, high=N, size=(1,), generator=rng, device=X.device)

    centers[0] = X[i0]

 

    # Pre-allocate distances to nearest chosen center

    # Start large; will be updated to min squared distance to any chosen center

    d2 = torch.full((N,), float("inf"), device=X.device, dtype=X.dtype) #[N]

 

    for c in range(1, k):

        # Update d2 to the nearest of chosen centers so far

        # ||X - centers[c-1]||^2 using broadcasting (N,D) - (1,D) -> (N,D)

        diff = X - centers[c-1] # [N, D]

        d2 = torch.minimum(d2, (diff * diff).sum(dim=1)) #Note2; #[N]

 

        # Sample next center proportional to distance^2

        probs = d2 / (d2.sum() + 1e-12)

        next_idx = torch.multinomial(probs, num_samples=1, generator=rng)

        centers[c] = X[next_idx]

 

    return centers

 

 

def kmeans(X, k, max_iter=100, tol=1e-4, use_pp=True, seed=0):

    """

    X: [N, D] data tensor (float32/float64). Recommend standardized features.

    k: number of clusters

    Returns:

      centers: [k, D]

      labels:  [N] cluster assignments

    """

    assert X.ndim == 2

    device = X.device

    rng = torch.Generator(device=device).manual_seed(seed)

 

    # ----- init centers -----

    if use_pp:

        centers = kmeans_plus_plus_init(X, k, rng)

    else:

        # random init from data

        idx = torch.randint(0, X.size(0), (k,), generator=rng, device=device) #[k]

        centers = X[idx].clone() # [k, D]

 

    prev_inertia = float("inf")

 

    for it in range(max_iter):

        # ===== E-step: assign points to nearest center =====

        #d2 = torch.cdist(X, centers, p=2) ** 2   # Ok but bit slower  than following

        # Pairwise squared distances via (x^2 + c^2 - 2 x·c)

        # X2: [N,1], C2: [k], X@C^T: [N,k]

        X2 = (X * X).sum(dim=1, keepdim=True)               # [N,1]

        C2 = (centers * centers).sum(dim=1)                        # [k]

        d2 = X2 + C2.unsqueeze(0) - 2.0 * (X @ centers.T) # [N,k]

        # finds the index of the smallest value along each row.

        labels = d2.argmin(dim=1)  # [N]

        # Inertia = sum of min squared distances (objective)

        inertia = d2.gather(1, labels.unsqueeze(1)).sum().item() #Better

        #inertia = torch.gather(d2, 1, labels.unsqueeze(1)).sum().item() # OK

 

        # ===== M-step: recompute centers as mean of assigned points =====

        # We’ll use scatter-add to accumulate sums and then divide by counts.

        kD = centers.shape # torch.Size([k, D])

        centers_sum = torch.zeros_like(centers) # [k, D]

        counts = torch.zeros(k, device=device, dtype=X.dtype) # [k]

 

        # One-hot style accumulation without making a dense one-hot matrix:

        # scatter_add accumulates row-wise sums into centers_sum[labels]

        centers_sum.index_add_(0, labels, X)                     # sum per cluster

        counts.index_add_(0, labels, torch.ones_like(labels, dtype=X.dtype)) #Note3

        # Handle empty clusters: re-seed them to random points (rare but important)

        empty = counts == 0

        if empty.any():

            # Reinitialize empty centers from random data points (k-means++ alternative)

            repl_idx = torch.randint(0, X.size(0), (int(empty.sum()),), generator=rng, device=device)

            #centers[empty] = X[repl_idx]

            counts[empty] = 1.0

            centers_sum[empty] = X[repl_idx]

 

        centers = centers_sum / counts.clamp_min(1.0).unsqueeze(1)

 

        # ===== stopping criterion =====

        # Relative improvement on inertia; simple and effective

        rel_impr = (prev_inertia - inertia) / (abs(prev_inertia) + 1e-12)

        if rel_impr >= 0 and rel_impr < tol:

            break

        prev_inertia = inertia

 

    return centers, labels

 

 

# -----------------------------

# Tiny demo

if __name__ == "__main__":

    torch.manual_seed(0)

    device = "cuda" if torch.cuda.is_available() else "cpu"

 

    # Make three blobs in 2D

    N = 600

    X = torch.cat([

        torch.randn(N//3, 2) * 0.5 + torch.tensor([0.0, 0.0]),

        torch.randn(N//3, 2) * 0.6 + torch.tensor([4.0, 4.0]),

        torch.randn(N//3, 2) * 0.4 + torch.tensor([-4.0, 3.5]),

    ], dim=0).to(device)

 

    centers, labels = kmeans(X, k=3, max_iter=100, tol=1e-4, use_pp=True, seed=42)

    print("Centers:\n", centers.cpu())

    print("Counts per cluster:", torch.bincount(labels.cpu(), minlength=3).tolist())

Notes on the “tricky” parts

·       Pairwise distances: we used the identity
(|x-c|^2 = |x|^2 + |c|^2 - 2x\cdot c)
to compute all distances in one shot (vectorized, GPU-friendly).

·       k-means++ picks diverse initial centers using probabilities distance² — often dramatically improves convergence and final quality.

·       Empty clusters can happen; we re-seed those centers to random points (simple, robust).

·       Stopping: we check relative improvement of inertia; tol=1e-4 is a good default.

·       Scaling: always standardize features (mean 0, var 1) before k-Means in practice.

Note3: counts.index_add_(0, labels, torch.ones(X.size(0), device=X.device, dtype=X.dtype))