Skip to content

Commit 9182537

Browse files
committed
reformat: typehint and black
1 parent c996fb8 commit 9182537

24 files changed

+386
-284
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ __pycache__/
1111
# logs
1212
wandb/
1313

14-
14+
rotated_ship_data.py
1515

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
author="Sri Datta Budaraju",
77
author_email="b.sridatta@gmail.com",
88
packages=setuptools.find_packages(),
9-
python_requires='>=3.6'
10-
)
9+
python_requires=">=3.6",
10+
)

src/callbacks/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from src.callbacks.base import CallbackList, Callback
2-
from src.callbacks.model_checkpoint import ModelCheckpoint
3-
from src.callbacks.logging import Logging
1+
from .base import CallbackList, Callback
2+
from .model_checkpoint import ModelCheckpoint
3+
from .logging import Logging
44

55

66
__all__ = [
7-
'CallbackList',
8-
'ModelCheckpoint',
9-
'Logging',
10-
]
7+
"CallbackList",
8+
"ModelCheckpoint",
9+
"Logging",
10+
]

src/callbacks/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
'''
1+
"""
2+
23
Callback inspritations from PyTorch Lightning - https://github.com/PyTorchLightning/PyTorch-Lightning
34
and https://github.com/devforfu/pytorch_playground/blob/master/loop.ipynb
4-
'''
5+
"""
56

67
import abc
78

@@ -73,7 +74,6 @@ def on_test_end(self, **kwargs):
7374

7475

7576
class CallbackList(Callback):
76-
7777
def __init__(self, callbacks):
7878
self.callbacks = callbacks
7979

src/callbacks/logging.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,66 @@
1-
from src.callbacks.base import Callback
1+
from .base import Callback
22
import torch
33

44

55
class Logging(Callback):
66
"""Logging and printing metrics"""
77

88
def setup(self, opt, model, **kwargs):
9-
print(
10-
f'[INFO]: Start training procedure using device: {opt.device}')
9+
print(f"[INFO]: Start training procedure using device: {opt.device}")
1110
# log gradients and parameters of the model during training
1211
if opt.use_wandb:
13-
opt.logger.watch(model, log='all')
12+
opt.logger.watch(model, log="all")
1413

15-
def on_train_batch_end(self, opt, batch_idx, batch, dataloader, output, l_ship, l_bbox, **kwargs):
16-
batch_len = len(batch['input'])
14+
def on_train_batch_end(
15+
self, opt, batch_idx, batch, dataloader, output, l_ship, l_bbox, **kwargs
16+
):
17+
batch_len = len(batch["input"])
1718
dataset_len = len(dataloader.dataset)
1819
n_batches = len(dataloader)
1920

2021
# print to console
21-
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}\tL_ship: {:.4f}\tL_bbox: {:.4f}'.format(
22-
opt.epoch, batch_idx * batch_len,
23-
dataset_len, 100. * batch_idx / n_batches,
24-
output, l_ship, l_bbox),
25-
end='\n')
22+
print(
23+
"Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}\tL_ship: {:.4f}\tL_bbox: {:.4f}".format(
24+
opt.epoch,
25+
batch_idx * batch_len,
26+
dataset_len,
27+
100.0 * batch_idx / n_batches,
28+
output,
29+
l_ship,
30+
l_bbox,
31+
),
32+
end="\n",
33+
)
2634

2735
# log to wandb
2836
if opt.use_wandb:
29-
opt.logger.log({"train_loss": output,
30-
"l_ship": l_ship,
31-
"l_bbox": l_bbox})
37+
opt.logger.log({"train_loss": output, "l_ship": l_ship, "l_bbox": l_bbox})
3238

3339
def on_validation_end(self, opt, output, metrics, l_ship, l_bbox, **kwargs):
3440
# print and log metrics and loss after validation epoch
35-
print("Valiation - Loss: {:.4f}\tL_ship: {:.4f}\tL_bbox: {:.4f}".format(
36-
output, l_ship, l_bbox), end="\t")
41+
print(
42+
"Valiation - Loss: {:.4f}\tL_ship: {:.4f}\tL_bbox: {:.4f}".format(
43+
output, l_ship, l_bbox
44+
),
45+
end="\t",
46+
)
3747

3848
for k in metrics.keys():
3949
print(f"{k}: {metrics[k]}", end="\t")
4050
if opt.use_wandb:
4151
opt.logger.log(metrics, commit=False)
42-
opt.logger.log({"val_loss": output,
43-
"epoch": opt.epoch,
44-
"val_l_ship": l_ship,
45-
"val_l_bbox": l_bbox})
52+
opt.logger.log(
53+
{
54+
"val_loss": output,
55+
"epoch": opt.epoch,
56+
"val_l_ship": l_ship,
57+
"val_l_bbox": l_bbox,
58+
}
59+
)
4660
print("")
4761

4862
def on_epoch_end(self, opt, optimizer, **kwargs):
49-
lr = optimizer.param_groups[0]['lr']
63+
lr = optimizer.param_groups[0]["lr"]
5064
if opt.use_wandb:
5165
opt.logger.log({f"LR": lr})
5266
print("lr @ ", lr)

src/callbacks/model_checkpoint.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from src.callbacks.base import Callback
2-
import torch
31
import os
42

3+
import torch
4+
5+
from callbacks.base import Callback
6+
57

68
class ModelCheckpoint(Callback):
79
def __init__(self):
@@ -10,25 +12,27 @@ def __init__(self):
1012
def setup(self, opt, model, optimizer, **kwargs):
1113
# Save model code to wandb
1214
if opt.use_wandb:
13-
opt.logger.save(
14-
f"{os.path.dirname(os.path.abspath(__file__))}/models/*")
15+
opt.logger.save(f"{os.path.dirname(os.path.abspath(__file__))}/models/*")
1516

1617
# Resume training
1718
if opt.resume_run not in "None":
1819
state = torch.load(
19-
f'{opt.save_dir}/{opt.resume_run}.pt', map_location=opt.device)
20+
f"{opt.save_dir}/{opt.resume_run}.pt", map_location=opt.device
21+
)
2022
print(
21-
f'[INFO] Loaded Checkpoint {opt.resume_run}: @ epoch {state["epoch"]}')
22-
model.load_state_dict(state['model_state_dict'])
23+
f'[INFO] Loaded Checkpoint {opt.resume_run}: @ epoch {state["epoch"]}'
24+
)
25+
model.load_state_dict(state["model_state_dict"])
2326

24-
# Optimizers
27+
# Optimizers
2528
optimizer_state_dic = torch.load(
26-
f'{opt.save_dir}/{opt.resume_run}_optimizer.pt', map_location=opt.device)
29+
f"{opt.save_dir}/{opt.resume_run}_optimizer.pt", map_location=opt.device
30+
)
2731
optimizer.load_state_dict(optimizer_state_dic)
2832

2933
def on_epoch_end(self, opt, val_loss, model, optimizer, epoch, **kwargs):
3034
# track val loss and save model when it decreases
31-
if val_loss < self.val_loss_min and opt.device != 'cpu':
35+
if val_loss < self.val_loss_min and opt.device != "cpu":
3236
self.val_loss_min = val_loss
3337

3438
try:
@@ -37,28 +41,23 @@ def on_epoch_end(self, opt, val_loss, model, optimizer, epoch, **kwargs):
3741
state_dict = model.state_dict()
3842

3943
state = {
40-
'epoch': epoch,
41-
'val_loss': val_loss,
42-
'model_state_dict': state_dict,
44+
"epoch": epoch,
45+
"val_loss": val_loss,
46+
"model_state_dict": state_dict,
4347
}
4448

4549
# model
46-
torch.save(
47-
state, f'{opt.save_dir}/{opt.run_name}.pt')
50+
torch.save(state, f"{opt.save_dir}/{opt.run_name}.pt")
4851
if opt.use_wandb:
49-
opt.logger.save(
50-
f'{opt.save_dir}/{opt.run_name}.pt')
51-
print(
52-
f'[INFO] Saved pt: {opt.save_dir}/{opt.run_name}.pt')
52+
opt.logger.save(f"{opt.save_dir}/{opt.run_name}.pt")
53+
print(f"[INFO] Saved pt: {opt.save_dir}/{opt.run_name}.pt")
5354

5455
del state
5556

5657
# Optimizer
5758
torch.save(
58-
optimizer.state_dict(),
59-
f'{opt.save_dir}/{opt.run_name}_optimizer.pt')
59+
optimizer.state_dict(), f"{opt.save_dir}/{opt.run_name}_optimizer.pt"
60+
)
6061
if opt.use_wandb:
61-
opt.logger.save(
62-
f'{opt.save_dir}/{opt.run_name}_optimizer.pt')
63-
print(
64-
f'[INFO] Saved pt: {opt.save_dir}/{opt.run_name}_optimizer.pt')
62+
opt.logger.save(f"{opt.save_dir}/{opt.run_name}_optimizer.pt")
63+
print(f"[INFO] Saved pt: {opt.save_dir}/{opt.run_name}_optimizer.pt")

src/checkpoints/runX.pt

0 Bytes
Binary file not shown.

src/checkpoints/runX_optimizer.pt

0 Bytes
Binary file not shown.

src/dataloader.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,56 @@
1+
from argparse import Namespace
2+
13
from torch.utils.data import DataLoader
4+
25
from src.dataset import Ships
36

47

5-
def train_dataloader(opt):
8+
def train_dataloader(opt: Namespace) -> DataLoader:
69
print("[INFO]: Train dataloader called")
710
dataset = Ships(n_samples=opt.train_len)
811
sampler = None
912
shuffle = True
10-
loader = DataLoader(dataset=dataset,
11-
batch_size=opt.batch_size,
12-
num_workers=opt.num_workers,
13-
pin_memory=opt.pin_memory,
14-
sampler=sampler,
15-
shuffle=shuffle)
13+
loader = DataLoader(
14+
dataset=dataset,
15+
batch_size=opt.batch_size,
16+
num_workers=opt.num_workers,
17+
pin_memory=opt.pin_memory,
18+
sampler=sampler,
19+
shuffle=shuffle,
20+
)
1621
print("samples - ", len(dataset))
1722
return loader
1823

1924

20-
def val_dataloader(opt):
25+
def val_dataloader(opt: Namespace) -> DataLoader:
2126
print("[INFO]: Validation dataloader called")
2227
dataset = Ships(n_samples=opt.val_len)
2328
sampler = None
2429
shuffle = True
25-
loader = DataLoader(dataset=dataset,
26-
batch_size=opt.batch_size,
27-
num_workers=opt.num_workers,
28-
pin_memory=opt.pin_memory,
29-
sampler=sampler,
30-
shuffle=shuffle)
30+
loader = DataLoader(
31+
dataset=dataset,
32+
batch_size=opt.batch_size,
33+
num_workers=opt.num_workers,
34+
pin_memory=opt.pin_memory,
35+
sampler=sampler,
36+
shuffle=shuffle,
37+
)
3138
print("samples - ", len(dataset))
3239
return loader
3340

3441

35-
def test_dataloader(opt):
42+
def test_dataloader(opt: Namespace) -> DataLoader:
3643
print("[INFO]: Test dataloader called")
3744
dataset = Ships(n_samples=opt.test_len)
3845
sampler = None
3946
shuffle = True
40-
loader = DataLoader(dataset=dataset,
41-
batch_size=opt.batch_size,
42-
num_workers=opt.num_workers,
43-
pin_memory=opt.pin_memory,
44-
sampler=sampler,
45-
shuffle=shuffle)
47+
loader = DataLoader(
48+
dataset=dataset,
49+
batch_size=opt.batch_size,
50+
num_workers=opt.num_workers,
51+
pin_memory=opt.pin_memory,
52+
sampler=sampler,
53+
shuffle=shuffle,
54+
)
4655
print("samples - ", len(dataset))
4756
return loader

src/dataset.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
from torch.utils.data import Dataset
2-
from src.rotated_ship_data import make_data
3-
import torch
1+
from typing import Dict
2+
43
import numpy as np
4+
import torch
5+
from torch.functional import Tensor
6+
from torch.utils.data import Dataset
57
from tqdm import tqdm
68

9+
from src.rotated_ship_data import make_data
10+
711

812
class Ships(Dataset):
913
"""ship datasets with has ship labels
@@ -15,7 +19,7 @@ class Ships(Dataset):
1519
sample {Tenosr} -- p_ship, x, y, yaw, h, w
1620
"""
1721

18-
def __init__(self, n_samples=1000, pre_load=False):
22+
def __init__(self, n_samples: int = 1000, pre_load: bool = False):
1923
self.n_samples = n_samples
2024
self.pre_load = pre_load
2125
if pre_load:
@@ -32,7 +36,7 @@ def __init__(self, n_samples=1000, pre_load=False):
3236
def __len__(self):
3337
return self.n_samples
3438

35-
def __getitem__(self, idx):
39+
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
3640
if self.pre_load:
3741
inp = self.inps[idx]
3842
target = self.targets[idx]
@@ -52,10 +56,11 @@ def __getitem__(self, idx):
5256

5357
return sample
5458

59+
5560
# Used for simple experiment
5661

5762

58-
def make_batch(batch_size):
63+
def make_batch(batch_size: int):
5964
"""Used only when pre_load = True
6065
6166
Arguments:

0 commit comments

Comments
 (0)