Skip to content

Commit 9d8c90e

Browse files
committed
Code release
1 parent 8325917 commit 9d8c90e

16 files changed

+2065
-0
lines changed

all_figures.sh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
python figure_conditions.py &&
2+
python figure_comparison_table.py &&
3+
python figure_privacy_quantities.py &&
4+
python figure_projection_ablation.py &&
5+
python figure_numusers_ablation.py #&&
6+
#python nemenyi/nemenyi.py figures/comparison_private.csv figures/figure_nemenyi_private.tex --h &&
7+
#python nemenyi/nemenyi.py figures/comparison_non_private.csv figures/figure_nemenyi_non_private.tex --h &&
8+
#python nemenyi/nemenyi.py figures/comparison_weak_private.csv figures/figure_nemenyi_weak_private.tex --h

ensembler.py

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import torch
2+
import numpy as np
3+
4+
from utils import calculate_score
5+
6+
class Ensembler:
7+
8+
def __init__(self, projector, A_t, num_devices, channel_snr_db, participation_probability, client_output, task) -> None:
9+
self.projector = projector
10+
self.num_devices = num_devices
11+
self.participation_probability = participation_probability
12+
self.A_t = A_t
13+
self.channel_snr_db = channel_snr_db
14+
15+
self.channel_snr = 10 ** (0.1 * self.channel_snr_db)
16+
self.client_output = client_output
17+
18+
self.task = task
19+
20+
self.Pavg = 1.0
21+
22+
def forward(self, method, *args):
23+
24+
return getattr(self, f"forward_{method}")(*args)
25+
26+
def find_mu_fp(self, val_beliefs, weights):
27+
28+
res = []
29+
for device_idx in range(self.num_devices):
30+
r = self.client_model(val_beliefs[device_idx], weights[device_idx])
31+
r = self.projector.project_only(r)
32+
r = (r ** 2).sum(dim=1).mean(dim=0)
33+
res.append(r)
34+
35+
res = torch.stack(res, dim=0).mean()
36+
37+
return res
38+
39+
def get_gamma(self, num_participating_clients, mu_fp):
40+
41+
mu_h = 1
42+
var_client = self.projector.get_sigma_client(num_participating_clients, self.num_devices) ** 2
43+
44+
gamma = torch.sqrt(self.Pavg / (mu_h * (mu_fp + self.projector.num_dims * var_client)))
45+
46+
return gamma
47+
48+
def forward_oac(self, beliefs, val_beliefs, y_val_true):
49+
participating_devices = self.sample_participating_devices()
50+
weights = self.find_weights(val_beliefs, y_val_true)
51+
#mu_fp = self.find_mu_fp(val_beliefs, weights)
52+
num_participating_devices = len(participating_devices)
53+
54+
#gamma = self.get_gamma(num_participating_devices, mu_fp)
55+
56+
res = []
57+
for device_idx in participating_devices:
58+
r = self.client_model(beliefs[device_idx], weights[device_idx])
59+
60+
r = self.projector.forward(r, num_participating_devices=num_participating_devices, num_devices=self.num_devices)
61+
62+
r = self.A_t * r / num_participating_devices
63+
64+
r = r / num_participating_devices
65+
66+
res.append(r)
67+
68+
received_signal = self.air_sum(res)
69+
70+
y_test_pred = self.server_model(received_signal)
71+
72+
return y_test_pred
73+
74+
def forward_orthogonal(self, beliefs, val_beliefs, y_val_true):
75+
participating_devices = self.sample_participating_devices()
76+
weights = self.find_weights(val_beliefs, y_val_true)
77+
num_participating_devices = len(participating_devices)
78+
num_classes = beliefs[0].shape[1]
79+
80+
res = []
81+
for device_idx in participating_devices:
82+
r = self.client_model(beliefs[device_idx], weights[device_idx])
83+
84+
r = self.projector.forward(r, num_participating_devices=1, num_devices=1)
85+
86+
r = self.A_t * r
87+
88+
res.append(r)
89+
90+
num_dims = res[0].shape[1]
91+
received_signal = torch.cat(res, dim=1)
92+
93+
final_signal = torch.zeros_like(res[0])
94+
95+
for i in range(num_participating_devices):
96+
cur_signal = received_signal[:, i*num_dims:(i+1)*num_dims]
97+
98+
final_signal += self.add_channel_noise(cur_signal, self.channel_snr)
99+
100+
final_signal = final_signal / num_participating_devices
101+
102+
y_test_pred = self.server_model(final_signal)
103+
104+
return y_test_pred
105+
106+
def find_best_device(self, val_beliefs, y_val_true):
107+
108+
cur_best_valscore = -np.inf
109+
best_device_idx = None
110+
for device_idx in range(self.num_devices):
111+
y_val_pred = val_beliefs[device_idx].argmax(dim=1)
112+
valscore = calculate_score(y_val_true, y_val_pred)
113+
114+
if valscore > cur_best_valscore:
115+
cur_best_valscore = valscore
116+
best_device_idx = device_idx
117+
118+
return best_device_idx
119+
120+
def find_weights(self, val_beliefs, y_val_true):
121+
122+
correct_preds = torch.empty(self.num_devices, val_beliefs[0].shape[1], dtype=torch.int)
123+
num_data = y_val_true.shape[0]
124+
y_val_true = torch.nn.functional.one_hot(y_val_true, val_beliefs[0].shape[1])
125+
126+
for device_idx in range(self.num_devices):
127+
y_val_pred = torch.nn.functional.one_hot(val_beliefs[device_idx].argmax(dim=1), val_beliefs[device_idx].shape[1])
128+
true_indices = (y_val_true == y_val_pred)
129+
130+
correct_preds[device_idx, :] = true_indices.sum(dim=0)
131+
132+
weights = correct_preds / num_data
133+
134+
135+
return weights
136+
137+
def forward_bestmodel(self, beliefs, val_beliefs, y_val_true):
138+
139+
device_idx = self.find_best_device(val_beliefs, y_val_true)
140+
141+
r = self.client_model(beliefs[device_idx])
142+
143+
r = self.projector.forward(r, num_participating_devices=1, num_devices=1)
144+
145+
r = self.A_t * r
146+
147+
r = r
148+
149+
received_signal = self.add_channel_noise(r, self.channel_snr) # air_sum(client_beliefs, channel_snr)
150+
151+
y_test_pred = self.server_model(received_signal)
152+
153+
return y_test_pred
154+
155+
def sample_participating_devices(self):
156+
157+
participating_devices = []
158+
for device_idx in range(self.num_devices):
159+
rnd = np.random.uniform(0, 1)
160+
if rnd < self.participation_probability:
161+
participating_devices.append(device_idx)
162+
163+
if len(participating_devices) == 0:
164+
participating_devices.append(np.random.choice(list(range(self.num_devices))))
165+
166+
return participating_devices
167+
168+
def server_model(self, signal):
169+
signal = signal / self.A_t
170+
171+
signal = self.projector.invert(signal)
172+
173+
if self.task == "multiclass":
174+
signal = torch.nn.functional.one_hot(signal.argmax(dim=1), signal.shape[1]) #(signal > 0.5).int()
175+
elif self.task == "multilabel":
176+
signal = (signal > 0.5).int()
177+
else:
178+
raise NotImplementedError
179+
180+
return signal
181+
182+
def client_model(self, beliefs, client_weights=None):
183+
num_classes = beliefs.shape[1]
184+
185+
if self.client_output == "label":
186+
beliefs = torch.nn.functional.one_hot(beliefs.argmax(dim=1), num_classes)
187+
elif self.client_output =="belief":
188+
beliefs = torch.nn.functional.softmax(beliefs, dim=1)
189+
elif self.client_output == "weighted_belief":
190+
beliefs = client_weights * torch.nn.functional.softmax(beliefs, dim=1)
191+
beliefs = beliefs / beliefs.sum(dim=1, keepdim=True)
192+
else:
193+
raise NotImplementedError
194+
195+
return beliefs.float()
196+
197+
def air_sum(self, signals):
198+
199+
max_sigma_channel = -1
200+
for signal in signals:
201+
sigma = self.calculate_sigma_channel(signal, self.channel_snr)
202+
max_sigma_channel = max(max_sigma_channel, sigma)
203+
204+
signal = torch.sum(torch.stack(signals, dim=0), dim=0)
205+
206+
signal = self.add_channel_noise_with_std(signal, max_sigma_channel)
207+
208+
return signal
209+
210+
def calculate_sigma_channel(self, signal, channel_snr):
211+
212+
return torch.sqrt( torch.mean((signal ** 2)) / channel_snr )
213+
214+
def add_channel_noise_with_std(self, signal, std):
215+
res = signal + torch.normal(0, std, size=signal.shape)
216+
217+
return res
218+
219+
def add_channel_noise(self, signal, channel_snr):
220+
sigma_channel = torch.sqrt( torch.mean((signal ** 2)) / channel_snr )
221+
222+
res = signal + torch.normal(0, sigma_channel, signal.shape)
223+
224+
return res

0 commit comments

Comments
 (0)