def scaled_dot_attn(Q, K,
V, mask=None):
# Q, K, V: [B, T, d_k]
scores = torch.matmul(Q, K.transpose(-2,
-1) ) / math.sqrt(Q.size(-1)) # [B, T,
T]
if mask is not None: # mask: [B, 1, T] or [B, T, T]; True = keep,
False = block
scores = scores.masked_fill(~mask,
float('-inf'))
A = torch.softmax(scores, dim=-1) # [B, T, T]
return torch.matmul(A, V) #
[B, T, d_k]
def kmeans(X,
k,iterations=200, ratio=1e4):
assert X.ndim == 2
N, D = X.shape
device = X.device
rng =
torch.Generator(device=device).manual_seed(0)
centroids_idx = torch.randint(0, N, (k,0),
generator=rng, device=device)
center = X[centroids_idx].clone()
prev_inertia = 0
for i in ranger(iterations):
# E-step
d2 = torch.cdist(X, center, p = 2)**2
label
= d2.argmin(dim=1)
inertia = d2.gather(1, label.unsqueeze(1) ).sum().item()
inertia = d2.gather(1, d2.argmin(dim=1,keepdim=True )).sum().item()
# M-step
centers_sum = torch.zeros_like(centers)
# [k, D]
counts = torch.zeros(k, device=device,
dtype=X.dtype) # [k]
centers_sum.index_add_(0, labels,
X) # sum per cluster
counts.index_add_(0, labels,
torch.ones_like(labels, dtype=X.dtype)) #Note3
# Handle empty clusters: re-seed them
to random points (rare but important)
empty = counts == 0
if empty.any():
idxs=torch.randint(0,X.size(0),(int(empty.sum()),),generator=rng,
device=device)
counts[empty] = 1.0
centers_sum[empty] = X[idxs]
center = center_sums/
center_cnt.clamp_min(1.0).unsqueeze(1)
if prev_inertia - inertial < prev_inertia * ratio: break
prev_inertia = inertia
return centers, labels
class KNNClassifier:
def __init__(self, k=3):
self.k = k
self.x_train = None
self.y_train = None
def fit(self, x_train, y_train):
self.x_train = x_train
self.y_train = y_train
def predict(self, x_test):
dists = (
x_test.pow(2).sum(dim=1,
keepdim=True) #[N_test, 1]
+ self.x_train.pow(2).sum(dim=1).unsqueeze(0) #[1, N_train]
- 2 * x_test @ self.x_train.T #[N_test, N_train]
)
# [N_test, N_train]
#dists = torch.cdist(x_test,
x_train) # also ok
knn_idx = dists.topk(self.k,
largest=False).indices # [N_test, k]
knn_labels = self.y_train[knn_idx] # [N_test, k]
# Majority vote along k nearest
neighbors for each test point
preds = torch.mode(knn_labels,
dim=1).values #[N_test]
return preds
if __name__ ==
"__main__":
# Synthetic data: 6 training points, 2 test
points, 2D features
x_train = torch.tensor([
[1., 1.], [1., 2.], [2., 1.], # class 0
[5., 5.], [5., 6.], [6., 5.] # class 1
])
y_train = torch.tensor([0, 0, 0, 1, 1, 1])
x_test = torch.tensor([[1.5, 1.5], [5.5,
5.5]])
model = KNNClassifier(k=3)
model.fit(x_train, y_train)
preds = model.predict(x_test)
print("Predictions:",
preds.tolist())