Q: What is k-Means? what is
its pros and cons? Please provide a minimal PyTorch code implementation from
scratch. Please provide necessary comment for challenging parts.
A: Here’s a crisp overview + a from-scratch
PyTorch k-Means (with k-means++ init). I added comments exactly where
people usually get stuck (distance calc, assignments, centroid updates, empty
clusters).
What is k-Means (in one
breath)
Unsupervised clustering
that partitions data into k clusters by iteratively:
1. assigning
each point to its nearest centroid (Euc. distance),
2. recomputing
each centroid as the mean of its assigned points,
until assignments stabilize.
Pros / Cons (fast
soundbites)
Pros
·
Simple, fast, scalable (vectorizes well,
mini-batchable).
·
Works nicely when clusters are roughly
spherical/separated.
·
Few parameters (k, max_iter).
Cons
·
Must choose k.
·
Sensitive to initialization; can
get stuck in local minima.
·
Assumes Euclidean geometry; struggles with
non-spherical/imbalanced clusters.
·
Sensitive to scale (standardize
features!).
Minimal PyTorch
Implementation (with comments)
import torch
def
kmeans_plus_plus_init(X, k, rng):
"""
k-means++ initialization:
1) pick one center uniformly at random
2) pick remaining centers with prob ∝ distance^2 to nearest
chosen center
"""
N, D = X.shape
centers = torch.empty(k, D,
device=X.device, dtype=X.dtype)
# pick first center, Start with one
random center.
i0 = torch.randint(low=0, high=N, size=(1,), generator=rng,
device=X.device)
centers[0] = X[i0]
# Pre-allocate distances to nearest chosen
center
# Start large; will be updated to min
squared distance to any chosen center
d2 = torch.full((N,),
float("inf"), device=X.device, dtype=X.dtype) #[N]
for c in range(1, k):
# Update d2 to the nearest of chosen
centers so far
# ||X - centers[c-1]||^2 using broadcasting (N,D) - (1,D)
-> (N,D)
diff = X -
centers[c-1] # [N, D]
d2 = torch.minimum(d2, (diff *
diff).sum(dim=1)) #Note2; #[N]
# Sample next center proportional to
distance^2
probs = d2 / (d2.sum() + 1e-12)
next_idx = torch.multinomial(probs,
num_samples=1, generator=rng)
centers[c] = X[next_idx]
return centers
def kmeans(X, k,
max_iter=100, tol=1e-4, use_pp=True, seed=0):
"""
X: [N, D] data tensor (float32/float64).
Recommend standardized features.
k: number of clusters
Returns:
centers: [k, D]
labels:
[N] cluster assignments
"""
assert X.ndim == 2
device = X.device
rng = torch.Generator(device=device).manual_seed(seed)
# ----- init centers -----
if use_pp:
centers = kmeans_plus_plus_init(X, k,
rng)
else:
# random init from data
idx = torch.randint(0, X.size(0), (k,),
generator=rng, device=device) #[k]
centers = X[idx].clone() # [k, D]
prev_inertia = float("inf")
for it in range(max_iter):
# ===== E-step: assign points to
nearest center =====
#d2 = torch.cdist(X, centers, p=2) ** 2 # Ok
but bit slower than following
# Pairwise squared distances via (x^2 +
c^2 - 2 x·c)
# X2: [N,1], C2: [k], X@C^T: [N,k]
X2 = (X * X).sum(dim=1, keepdim=True) #
[N,1]
C2 =
(centers * centers).sum(dim=1)
# [k]
d2 = X2 +
C2.unsqueeze(0) - 2.0 * (X @ centers.T) # [N,k]
# finds the index of the
smallest value along each row.
labels = d2.argmin(dim=1)
# [N]
# Inertia =
sum of min squared distances (objective)
inertia =
d2.gather(1, labels.unsqueeze(1)).sum().item() #Better
#inertia = torch.gather(d2,
1, labels.unsqueeze(1)).sum().item() # OK
# ===== M-step: recompute centers as
mean of assigned points =====
# We’ll use scatter-add to accumulate
sums and then divide by counts.
kD = centers.shape # torch.Size([k, D])
centers_sum = torch.zeros_like(centers)
# [k, D]
counts = torch.zeros(k, device=device,
dtype=X.dtype) # [k]
# One-hot style accumulation without
making a dense one-hot matrix:
# scatter_add accumulates row-wise sums
into centers_sum[labels]
centers_sum.index_add_(0, labels, X) # sum per cluster
counts.index_add_(0,
labels, torch.ones_like(labels, dtype=X.dtype)) #Note3
# Handle empty clusters: re-seed them
to random points (rare but important)
empty = counts
== 0
if empty.any():
# Reinitialize empty centers from
random data points (k-means++ alternative)
repl_idx = torch.randint(0,
X.size(0), (int(empty.sum()),), generator=rng, device=device)
#centers[empty] = X[repl_idx]
counts[empty] = 1.0
centers_sum[empty] = X[repl_idx]
centers = centers_sum / counts.clamp_min(1.0).unsqueeze(1)
# ===== stopping criterion =====
# Relative improvement on inertia;
simple and effective
rel_impr = (prev_inertia - inertia) /
(abs(prev_inertia) + 1e-12)
if rel_impr >= 0 and rel_impr <
tol:
break
prev_inertia = inertia
return centers, labels
#
-----------------------------
# Tiny demo
if __name__ ==
"__main__":
torch.manual_seed(0)
device = "cuda" if
torch.cuda.is_available() else "cpu"
# Make three blobs in 2D
N = 600
X = torch.cat([
torch.randn(N//3, 2) * 0.5 +
torch.tensor([0.0, 0.0]),
torch.randn(N//3, 2) * 0.6 +
torch.tensor([4.0, 4.0]),
torch.randn(N//3, 2) * 0.4 +
torch.tensor([-4.0, 3.5]),
], dim=0).to(device)
centers, labels = kmeans(X, k=3,
max_iter=100, tol=1e-4, use_pp=True, seed=42)
print("Centers:\n",
centers.cpu())
print("Counts per cluster:",
torch.bincount(labels.cpu(), minlength=3).tolist())
Notes on the “tricky”
parts
·
Pairwise distances:
we used the identity
(|x-c|^2 = |x|^2 + |c|^2 - 2x\cdot c)
to compute all distances in one shot (vectorized, GPU-friendly).
·
k-means++
picks diverse initial centers using probabilities ∝ distance² — often dramatically
improves convergence and final quality.
·
Empty clusters
can happen; we re-seed those centers to random points (simple, robust).
·
Stopping:
we check relative improvement of inertia; tol=1e-4 is a good default.
·
Scaling:
always standardize features (mean 0, var 1) before k-Means in practice.
Note3: counts.index_add_(0,
labels, torch.ones(X.size(0), device=X.device, dtype=X.dtype))
No comments:
Post a Comment