Skip to content

Commit c212ac0

Browse files
committed
added batch capabilities by passing cluster_centers to kmeans
1 parent 514404c commit c212ac0

9 files changed

+2267
-671
lines changed

.ipynb_checkpoints/batch_processing-checkpoint.ipynb

+334
Large diffs are not rendered by default.

.ipynb_checkpoints/cpu_vs_gpu-checkpoint.ipynb

+364
Large diffs are not rendered by default.

.ipynb_checkpoints/example-checkpoint.ipynb

+343
Large diffs are not rendered by default.

batch_processing.ipynb

+334
Large diffs are not rendered by default.

cpu_vs_gpu.ipynb

+357-346
Large diffs are not rendered by default.

example.ipynb

+334-323
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import numpy as np
2+
import torch
3+
from tqdm import tqdm
4+
5+
# ToDo: Can't choose a cluster if two points are too close to each other, that's where the nan come from
6+
7+
8+
def initialize(X, num_clusters):
9+
"""
10+
initialize cluster centers
11+
:param X: (torch.tensor) matrix
12+
:param num_clusters: (int) number of clusters
13+
:return: (np.array) initial state
14+
"""
15+
num_samples = len(X)
16+
indices = np.random.choice(num_samples, num_clusters, replace=False)
17+
initial_state = X[indices]
18+
return initial_state
19+
20+
21+
def kmeans(
22+
X,
23+
num_clusters,
24+
distance='euclidean',
25+
cluster_centers = [],
26+
tol=1e-4,
27+
device=torch.device('cpu')
28+
):
29+
"""
30+
perform kmeans
31+
:param X: (torch.tensor) matrix
32+
:param num_clusters: (int) number of clusters
33+
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
34+
:param tol: (float) threshold [default: 0.0001]
35+
:param device: (torch.device) device [default: cpu]
36+
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers
37+
"""
38+
print(f'running k-means on {device}..')
39+
40+
if distance == 'euclidean':
41+
pairwise_distance_function = pairwise_distance
42+
elif distance == 'cosine':
43+
pairwise_distance_function = pairwise_cosine
44+
else:
45+
raise NotImplementedError
46+
47+
# convert to float
48+
X = X.float()
49+
50+
# transfer to device
51+
X = X.to(device)
52+
53+
# initialize
54+
if type(cluster_centers) == list: #ToDo: make this less annoyingly weird
55+
initial_state = initialize(X, num_clusters)
56+
else:
57+
print('resuming')
58+
# find data point closest to the initial cluster center
59+
initial_state = cluster_centers
60+
dis = pairwise_distance_function(X, initial_state)
61+
choice_points = torch.argmin(dis, dim=0)
62+
initial_state = X[choice_points]
63+
initial_state = initial_state.to(device)
64+
65+
iteration = 0
66+
tqdm_meter = tqdm(desc='[running kmeans]')
67+
while True:
68+
69+
dis = pairwise_distance_function(X, initial_state)
70+
71+
choice_cluster = torch.argmin(dis, dim=1)
72+
73+
initial_state_pre = initial_state.clone()
74+
75+
for index in range(num_clusters):
76+
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
77+
78+
selected = torch.index_select(X, 0, selected)
79+
80+
initial_state[index] = selected.mean(dim=0)
81+
82+
center_shift = torch.sum(
83+
torch.sqrt(
84+
torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
85+
))
86+
87+
# increment iteration
88+
iteration = iteration + 1
89+
90+
# update tqdm meter
91+
tqdm_meter.set_postfix(
92+
iteration=f'{iteration}',
93+
center_shift=f'{center_shift ** 2:0.6f}',
94+
tol=f'{tol:0.6f}'
95+
)
96+
tqdm_meter.update()
97+
if center_shift ** 2 < tol:
98+
break
99+
100+
return choice_cluster.cpu(), initial_state.cpu()
101+
102+
103+
def kmeans_predict(
104+
X,
105+
cluster_centers,
106+
distance='euclidean',
107+
device=torch.device('cpu')
108+
):
109+
"""
110+
predict using cluster centers
111+
:param X: (torch.tensor) matrix
112+
:param cluster_centers: (torch.tensor) cluster centers
113+
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
114+
:param device: (torch.device) device [default: 'cpu']
115+
:return: (torch.tensor) cluster ids
116+
"""
117+
print(f'predicting on {device}..')
118+
119+
if distance == 'euclidean':
120+
pairwise_distance_function = pairwise_distance
121+
elif distance == 'cosine':
122+
pairwise_distance_function = pairwise_cosine
123+
else:
124+
raise NotImplementedError
125+
126+
# convert to float
127+
X = X.float()
128+
129+
# transfer to device
130+
X = X.to(device)
131+
132+
dis = pairwise_distance_function(X, cluster_centers)
133+
choice_cluster = torch.argmin(dis, dim=1)
134+
135+
return choice_cluster.cpu()
136+
137+
138+
def pairwise_distance(data1, data2, device=torch.device('cpu')):
139+
# transfer to device
140+
data1, data2 = data1.to(device), data2.to(device)
141+
142+
# N*1*M
143+
A = data1.unsqueeze(dim=1)
144+
145+
# 1*N*M
146+
B = data2.unsqueeze(dim=0)
147+
148+
dis = (A - B) ** 2.0
149+
# return N*N matrix for pairwise distance
150+
dis = dis.sum(dim=-1).squeeze()
151+
return dis
152+
153+
154+
def pairwise_cosine(data1, data2, device=torch.device('cpu')):
155+
# transfer to device
156+
data1, data2 = data1.to(device), data2.to(device)
157+
158+
# N*1*M
159+
A = data1.unsqueeze(dim=1)
160+
161+
# 1*N*M
162+
B = data2.unsqueeze(dim=0)
163+
164+
# normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
165+
A_normalized = A / A.norm(dim=-1, keepdim=True)
166+
B_normalized = B / B.norm(dim=-1, keepdim=True)
167+
168+
cosine = A_normalized * B_normalized
169+
170+
# return N*N matrix for pairwise distance
171+
cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
172+
return cosine_dis
173+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
__version__ = "0.3"
5+
6+
7+
def main():
8+
print("TODO")
9+
10+
11+
if __name__ == "__main__":
12+
main()

kmeans_pytorch/__init__.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch
33
from tqdm import tqdm
44

5+
# ToDo: Can't choose a cluster if two points are too close to each other, that's where the nan come from
6+
57

68
def initialize(X, num_clusters):
79
"""
@@ -20,6 +22,7 @@ def kmeans(
2022
X,
2123
num_clusters,
2224
distance='euclidean',
25+
cluster_centers = [],
2326
tol=1e-4,
2427
device=torch.device('cpu')
2528
):
@@ -48,11 +51,21 @@ def kmeans(
4851
X = X.to(device)
4952

5053
# initialize
51-
initial_state = initialize(X, num_clusters)
52-
54+
if type(cluster_centers) == list: #ToDo: make this less annoyingly weird
55+
initial_state = initialize(X, num_clusters)
56+
else:
57+
print('resuming')
58+
# find data point closest to the initial cluster center
59+
initial_state = cluster_centers
60+
dis = pairwise_distance_function(X, initial_state)
61+
choice_points = torch.argmin(dis, dim=0)
62+
initial_state = X[choice_points]
63+
initial_state = initial_state.to(device)
64+
5365
iteration = 0
5466
tqdm_meter = tqdm(desc='[running kmeans]')
5567
while True:
68+
5669
dis = pairwise_distance_function(X, initial_state)
5770

5871
choice_cluster = torch.argmin(dis, dim=1)
@@ -63,6 +76,7 @@ def kmeans(
6376
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
6477

6578
selected = torch.index_select(X, 0, selected)
79+
6680
initial_state[index] = selected.mean(dim=0)
6781

6882
center_shift = torch.sum(

0 commit comments

Comments
 (0)