Skip to content

Commit 84cd318

Browse files
committed
update plot synthetic
1 parent 6656fa0 commit 84cd318

File tree

6 files changed

+53
-110
lines changed

6 files changed

+53
-110
lines changed

main_mnist.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

main_plot.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

plot_femnist.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/usr/bin/env python
2+
import h5py
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import argparse
6+
import importlib
7+
import random
8+
import os
9+
from flearn.servers.serveravg import FedAvg
10+
from flearn.servers.serverfedl import FEDL
11+
from flearn.trainmodel.models import *
12+
from utils.plot_utils import *
13+
import torch
14+
torch.manual_seed(0)
15+

plot_mnist.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/usr/bin/env python
2+
import h5py
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import argparse
6+
import importlib
7+
import random
8+
import os
9+
from flearn.servers.serveravg import FedAvg
10+
from flearn.servers.serverfedl import FEDL
11+
from flearn.trainmodel.models import *
12+
from utils.plot_utils import *
13+
import torch
14+
torch.manual_seed(0)
15+
16+
def main(dataset, algorithm, model, batch_size, learning_rate, hyper_learning_rate, L, num_glob_iters,
17+
local_epochs, optimizer, clients_per_round, times):

plot_synthetic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python
2+
import h5py
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
from utils.plot_utils import *
6+
import torch
7+
torch.manual_seed(0)
8+
9+
algorithms_list = ["FEDL","FEDL","FEDL","FEDL","FEDL","FEDL","FEDL","FEDL","FEDL","FEDL","FEDL","FEDL"]
10+
rho = [1.4, 1.4, 1.4, 1.4, 2 ,2 , 2, 2, 5, 5, 5, 5]
11+
lamb_value = [0, 0, 0, 0, 0, 0, 0, 0 , 0, 0, 0 ,0]
12+
learning_rate = [0.04,0.04,0.04,0.04, 0.04,0.04,0.04,0.04, 0.04,0.04,0.04,0.04]
13+
hyper_learning_rate = [0.01,0.03,0.05,0.07, 0.01,0.03,0.05,0.07, 0.01,0.03,0.05,0.07]
14+
local_ep = [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
15+
batch_size = [0,0,0,0 ,0,0,0,0, 0,0,0,0]
16+
DATA_SET = "Linear_synthetic"
17+
number_users = 100
18+
19+
plot_summary_linear(num_users=number_users, loc_ep1=local_ep, Numb_Glob_Iters=200, lamb=lamb_value, learning_rate=learning_rate, hyper_learning_rate = hyper_learning_rate, algorithms_list=algorithms_list, batch_size=batch_size, rho = rho, dataset=DATA_SET)

utils/plot_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ def get_training_data_value(num_users=100, loc_ep1=5, Numb_Glob_Iters=10, lamb=[
2828
string_learning_rate = string_learning_rate + "_" +str(hyper_learning_rate[i])
2929
algorithms_list[i] = algorithms_list[i] + \
3030
"_" + string_learning_rate + "_" + str(num_users) + \
31-
"u" + "_" + str(batch_size[i]) + "b"
31+
"u" + "_" + str(batch_size[i]) + "b" + "_" + str(loc_ep1[i])
3232
if(rho[i] > 0):
3333
algorithms_list[i] += "_" + str(rho[i])+"p"
3434

3535
train_acc[i, :], train_loss[i, :], glob_acc[i, :] = np.array(
36-
simple_read_data(str(loc_ep1[i]) + "_avg", dataset + "_" + algorithms_list[i]))[:, :Numb_Glob_Iters]
36+
simple_read_data("avg", dataset + "_" + algorithms_list[i]))[:, :Numb_Glob_Iters]
3737
algs_lbl[i] = algs_lbl[i]
3838
return glob_acc, train_acc, train_loss
3939

0 commit comments

Comments
 (0)