Q: What is k-Means? what is
its pros and cons? Please provide a minimal PyTorch code implementation from
scratch. Please provide necessary comment for challenging parts.
A: Here’s a crisp overview + a from-scratch
PyTorch k-Means (with k-means++ init). I added comments exactly where
people usually get stuck (distance calc, assignments, centroid updates, empty
clusters).
What is k-Means (in one
breath)
Unsupervised clustering
that partitions data into k clusters by iteratively:
1.       assigning
each point to its nearest centroid (Euc. distance),
2.       recomputing
each centroid as the mean of its assigned points,
until assignments stabilize.
Pros / Cons (fast
soundbites)
Pros
·      
Simple, fast, scalable (vectorizes well,
mini-batchable).
·      
Works nicely when clusters are roughly
spherical/separated.
·      
Few parameters (k, max_iter).
Cons
·      
Must choose k.
·      
Sensitive to initialization; can
get stuck in local minima.
·      
Assumes Euclidean geometry; struggles with
non-spherical/imbalanced clusters.
·      
Sensitive to scale (standardize
features!).
Minimal PyTorch
Implementation (with comments)
import torch
def
kmeans_plus_plus_init(X, k, rng):
    """
    k-means++ initialization:
    1) pick one center uniformly at random
    2) pick remaining centers with prob ∝ distance^2 to nearest
chosen center
    """
    N, D = X.shape
    centers = torch.empty(k, D,
device=X.device, dtype=X.dtype)
    # pick first center, Start with one
random center.
    i0 = torch.randint(low=0, high=N, size=(1,), generator=rng,
device=X.device)
    centers[0] = X[i0]
    # Pre-allocate distances to nearest chosen
center
    # Start large; will be updated to min
squared distance to any chosen center
    d2 = torch.full((N,),
float("inf"), device=X.device, dtype=X.dtype) #[N]
    for c in range(1, k):
        # Update d2 to the nearest of chosen
centers so far
        # ||X - centers[c-1]||^2 using broadcasting (N,D) - (1,D)
-> (N,D)
        diff = X -
centers[c-1] # [N, D]
        d2 = torch.minimum(d2, (diff *
diff).sum(dim=1)) #Note2; #[N]
        # Sample next center proportional to
distance^2
        probs = d2 / (d2.sum() + 1e-12)
        next_idx = torch.multinomial(probs,
num_samples=1, generator=rng)
        centers[c] = X[next_idx]
    return centers
def kmeans(X, k,
max_iter=100, tol=1e-4, use_pp=True, seed=0):
    """
    X: [N, D] data tensor (float32/float64).
Recommend standardized features.
    k: number of clusters
    Returns:
      centers: [k, D]
      labels: 
[N] cluster assignments
    """
    assert X.ndim == 2
    device = X.device
    rng = torch.Generator(device=device).manual_seed(seed)
    # ----- init centers -----
    if use_pp:
        centers = kmeans_plus_plus_init(X, k,
rng)
    else:
        # random init from data
        idx = torch.randint(0, X.size(0), (k,),
generator=rng, device=device) #[k]
        centers = X[idx].clone() # [k, D]
    prev_inertia = float("inf")
    for it in range(max_iter):
        # ===== E-step: assign points to
nearest center =====
        #d2 = torch.cdist(X, centers, p=2) ** 2   # Ok
but bit slower  than following
        # Pairwise squared distances via (x^2 +
c^2 - 2 x·c)
        # X2: [N,1], C2: [k], X@C^T: [N,k]
        X2 = (X * X).sum(dim=1, keepdim=True)               #
[N,1]
        C2 =
(centers * centers).sum(dim=1)           
            # [k]
        d2 = X2 +
C2.unsqueeze(0) - 2.0 * (X @ centers.T) # [N,k]
        # finds the index of the
smallest value along each row.
        labels = d2.argmin(dim=1) 
# [N] 
        # Inertia =
sum of min squared distances (objective)
        inertia =
d2.gather(1, labels.unsqueeze(1)).sum().item() #Better
        #inertia = torch.gather(d2,
1, labels.unsqueeze(1)).sum().item() # OK
        # ===== M-step: recompute centers as
mean of assigned points =====
        # We’ll use scatter-add to accumulate
sums and then divide by counts.
        kD = centers.shape # torch.Size([k, D])
        centers_sum = torch.zeros_like(centers)
# [k, D]
        counts = torch.zeros(k, device=device,
dtype=X.dtype) # [k]
        # One-hot style accumulation without
making a dense one-hot matrix:
        # scatter_add accumulates row-wise sums
into centers_sum[labels]
        centers_sum.index_add_(0, labels, X)                     # sum per cluster
        counts.index_add_(0,
labels, torch.ones_like(labels, dtype=X.dtype)) #Note3
        # Handle empty clusters: re-seed them
to random points (rare but important)
        empty = counts
== 0
        if empty.any():
            # Reinitialize empty centers from
random data points (k-means++ alternative)
            repl_idx = torch.randint(0,
X.size(0), (int(empty.sum()),), generator=rng, device=device)
            #centers[empty] = X[repl_idx]
            counts[empty] = 1.0
            centers_sum[empty] = X[repl_idx]
        centers = centers_sum / counts.clamp_min(1.0).unsqueeze(1)
        # ===== stopping criterion =====
        # Relative improvement on inertia;
simple and effective
        rel_impr = (prev_inertia - inertia) /
(abs(prev_inertia) + 1e-12)
        if rel_impr >= 0 and rel_impr <
tol:
            break
        prev_inertia = inertia
    return centers, labels
#
-----------------------------
# Tiny demo
if __name__ ==
"__main__":
    torch.manual_seed(0)
    device = "cuda" if
torch.cuda.is_available() else "cpu"
    # Make three blobs in 2D
    N = 600
    X = torch.cat([
        torch.randn(N//3, 2) * 0.5 +
torch.tensor([0.0, 0.0]),
        torch.randn(N//3, 2) * 0.6 +
torch.tensor([4.0, 4.0]),
        torch.randn(N//3, 2) * 0.4 +
torch.tensor([-4.0, 3.5]),
    ], dim=0).to(device)
    centers, labels = kmeans(X, k=3,
max_iter=100, tol=1e-4, use_pp=True, seed=42)
    print("Centers:\n",
centers.cpu())
    print("Counts per cluster:",
torch.bincount(labels.cpu(), minlength=3).tolist())
Notes on the “tricky”
parts
·      
Pairwise distances:
we used the identity
(|x-c|^2 = |x|^2 + |c|^2 - 2x\cdot c)
to compute all distances in one shot (vectorized, GPU-friendly).
·      
k-means++
picks diverse initial centers using probabilities ∝ distance² — often dramatically
improves convergence and final quality.
·      
Empty clusters
can happen; we re-seed those centers to random points (simple, robust).
·      
Stopping:
we check relative improvement of inertia; tol=1e-4 is a good default.
·      
Scaling:
always standardize features (mean 0, var 1) before k-Means in practice.
Note3: counts.index_add_(0,
labels, torch.ones(X.size(0), device=X.device, dtype=X.dtype))