Skip to content

Commit e29c9d1

Browse files
authored
A pytorch example to read original images with the custom dataloader (#2415)
* A pytorch example to read original images with the custom dataloader * Pre-commit * Add a example to submit the training job * Use print * Merge develop * Pre-commit * fix the argument name * Fix by comments * PyTorch example * PyTorch example * pre-commit * Add evaluation * Calculate the number of batch per epoch * Polish the docstring * polish the docstring * Set the epoch * Pre-commit * Add annotations * Rename the example
1 parent d5687b9 commit e29c9d1

File tree

3 files changed

+245
-119
lines changed

3 files changed

+245
-119
lines changed

.isort.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[settings]
22
multi_line_output=3
33
line_length=79
4-
known_third_party = PIL,deepctr,docker,google,grpc,horovod,jinja2,kubernetes,matplotlib,numpy,odps,pandas,recordio,requests,setuptools,six,sklearn,tensorflow,torch,yaml
4+
known_third_party = PIL,cv2,deepctr,docker,google,grpc,horovod,jinja2,kubernetes,matplotlib,numpy,odps,pandas,recordio,requests,setuptools,six,sklearn,tensorflow,torch,torchvision,yaml
55
include_trailing_comma=True

model_zoo/mnist/mnist_pytorch.py

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright 2020 The ElasticDL Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
"""
15+
Download the mnist dataset from
16+
https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz
17+
and then untar it into ${data_store_dir}. Using minikube, we can use the
18+
following command to submit a training job with the script.
19+
20+
elasticdl train \
21+
--image_name=elasticdl:pt_mnist_allreduce \
22+
--job_command="python -m model_zoo.mnist.mnist_pytorch_custom_dataset \
23+
--training_data=/local_data/mnist_png/training \
24+
--validation_data=/local_data/mnist_png/testing" \
25+
--num_minibatches_per_task=2 \
26+
--num_workers=2 \
27+
--worker_pod_priority=0.5 \
28+
--master_resource_request="cpu=0.2,memory=1024Mi" \
29+
--master_resource_limit="cpu=1,memory=2048Mi" \
30+
--worker_resource_request="cpu=0.3,memory=1024Mi" \
31+
--worker_resource_limit="cpu=1,memory=2048Mi" \
32+
--envs="USE_TORCH=true,HOROVOD_GLOO_TIMEOUT_SECONDS=60,PYTHONUNBUFFERED=0" \
33+
--job_name=test-mnist-allreduce \
34+
--image_pull_policy=Never \
35+
--volume="host_path=${data_store_dir},mount_path=/local_data" \
36+
--custom_training_loop=true \
37+
--distribution_strategy=AllreduceStrategy \
38+
"""
39+
40+
import argparse
41+
import sys
42+
43+
import cv2
44+
import numpy as np
45+
import torch
46+
import torch.nn as nn
47+
import torch.nn.functional as F
48+
import torch.optim as optim
49+
import torchvision
50+
from torch.optim.lr_scheduler import StepLR
51+
from torch.utils.data import DataLoader, Dataset
52+
53+
from elasticai_api.pytorch.controller import create_elastic_controller
54+
from elasticai_api.pytorch.optimizer import DistributedOptimizer
55+
56+
57+
class ElasticDataset(Dataset):
58+
def __init__(self, images, data_shard_service=None):
59+
"""The dataset supports elastic training.
60+
61+
Args:
62+
images: A list with tuples like (image_path, label_index).
63+
For example, we can use `torchvision.datasets.ImageFolder`
64+
to get the list.
65+
data_shard_service: If we want to use elastic training, we
66+
need to use the `data_shard_service` of the elastic controller
67+
in elasticai_api.
68+
"""
69+
self.data_shard_service = data_shard_service
70+
self._images = images
71+
72+
def __len__(self):
73+
if self.data_shard_service:
74+
# Set the maxsize because the size of dataset is not fixed
75+
# when using dynamic sharding
76+
return sys.maxsize
77+
else:
78+
return len(self._images)
79+
80+
def __getitem__(self, index):
81+
if self.data_shard_service:
82+
index = self.data_shard_service.fetch_record_index()
83+
return self.read_image(index)
84+
else:
85+
return self.read_image(index)
86+
87+
def read_image(self, index):
88+
image_path, label = self._images[index]
89+
image = cv2.imread(image_path)
90+
image = np.array(image / 255.0, np.float32)
91+
image = image.reshape(3, 28, 28)
92+
return image, label
93+
94+
95+
class Net(nn.Module):
96+
def __init__(self):
97+
super(Net, self).__init__()
98+
self.conv1 = nn.Conv2d(3, 32, 3, 1)
99+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
100+
self.dropout1 = nn.Dropout(0.25)
101+
self.dropout2 = nn.Dropout(0.5)
102+
self.fc1 = nn.Linear(9216, 128)
103+
self.fc2 = nn.Linear(128, 10)
104+
105+
def forward(self, x):
106+
x = self.conv1(x)
107+
x = F.relu(x)
108+
x = self.conv2(x)
109+
x = F.relu(x)
110+
x = F.max_pool2d(x, 2)
111+
x = self.dropout1(x)
112+
x = torch.flatten(x, 1)
113+
x = self.fc1(x)
114+
x = F.relu(x)
115+
x = self.dropout2(x)
116+
x = self.fc2(x)
117+
output = F.log_softmax(x, dim=1)
118+
return output
119+
120+
121+
def train(args):
122+
""" The function to run the training loop.
123+
Args:
124+
dataset: The dataset is provided by ElasticDL for the elastic training.
125+
Now, the dataset if tf.data.Dataset and we need to convert
126+
the data in dataset to torch.tensor. Later, ElasticDL will
127+
pass a torch.utils.data.DataLoader.
128+
elastic_controller: The controller for elastic training.
129+
"""
130+
use_cuda = not args.no_cuda and torch.cuda.is_available()
131+
device = torch.device("cuda" if use_cuda else "cpu")
132+
train_data = torchvision.datasets.ImageFolder(args.training_data)
133+
test_data = torchvision.datasets.ImageFolder(args.validation_data)
134+
batch_num_per_epoch = int(len(train_data.imgs) / args.batch_size)
135+
136+
allreduce_controller = create_elastic_controller(
137+
batch_size=args.batch_size,
138+
dataset_size=len(train_data.imgs),
139+
num_epochs=args.num_epochs,
140+
shuffle=True,
141+
)
142+
train_dataset = ElasticDataset(
143+
train_data.imgs, allreduce_controller.data_shard_service
144+
)
145+
train_loader = DataLoader(
146+
dataset=train_dataset, batch_size=args.batch_size, num_workers=2
147+
)
148+
149+
test_dataset = ElasticDataset(test_data.imgs)
150+
test_loader = DataLoader(
151+
dataset=test_dataset, batch_size=args.batch_size, num_workers=2
152+
)
153+
154+
model = Net()
155+
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
156+
optimizer = DistributedOptimizer(optimizer, fixed_global_batch_size=True)
157+
scheduler = StepLR(optimizer, step_size=1, gamma=0.5)
158+
159+
# Set the model and optimizer to broadcast.
160+
allreduce_controller.set_broadcast_model(model)
161+
allreduce_controller.set_broadcast_optimizer(optimizer)
162+
epoch = 0
163+
# Use the elastic function to wrap the training function with a batch.
164+
elastic_train_one_batch = allreduce_controller.elastic_run(train_one_batch)
165+
with allreduce_controller.scope():
166+
for batch_idx, (data, target) in enumerate(train_loader):
167+
model.train()
168+
target = target.type(torch.LongTensor)
169+
data, target = data.to(device), target.to(device)
170+
loss = elastic_train_one_batch(model, optimizer, data, target)
171+
print("loss = {}, step = {}".format(loss, batch_idx))
172+
new_epoch = int(
173+
allreduce_controller.global_completed_batch_num
174+
/ batch_num_per_epoch
175+
)
176+
if new_epoch > epoch:
177+
epoch = new_epoch
178+
# Set epoch of the scheduler
179+
scheduler.last_epoch = epoch - 1
180+
scheduler.step()
181+
test(model, device, test_loader)
182+
183+
184+
def train_one_batch(model, optimizer, data, target):
185+
optimizer.zero_grad()
186+
output = model(data)
187+
loss = F.nll_loss(output, target)
188+
loss.backward()
189+
optimizer.step()
190+
return loss
191+
192+
193+
def test(model, device, test_loader):
194+
model.eval()
195+
test_loss = 0
196+
correct = 0
197+
with torch.no_grad():
198+
for data, target in test_loader:
199+
data, target = data.to(device), target.to(device)
200+
output = model(data)
201+
test_loss += F.nll_loss(
202+
output, target, reduction="sum"
203+
).item() # sum up batch loss
204+
pred = output.argmax(
205+
dim=1, keepdim=True
206+
) # get the index of the max log-probability
207+
correct += pred.eq(target.view_as(pred)).sum().item()
208+
209+
test_loss /= len(test_loader.dataset)
210+
211+
print(
212+
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
213+
test_loss,
214+
correct,
215+
len(test_loader.dataset),
216+
100.0 * correct / len(test_loader.dataset),
217+
)
218+
)
219+
220+
221+
def arg_parser():
222+
parser = argparse.ArgumentParser(description="Process training parameters")
223+
parser.add_argument("--batch_size", type=int, default=64, required=False)
224+
parser.add_argument("--num_epochs", type=int, default=1, required=False)
225+
parser.add_argument(
226+
"--learning_rate", type=float, default=0.1, required=False
227+
)
228+
parser.add_argument(
229+
"--no-cuda",
230+
action="store_true",
231+
default=False,
232+
help="disable CUDA training",
233+
)
234+
parser.add_argument("--training_data", type=str, required=True)
235+
parser.add_argument(
236+
"--validation_data", type=str, default="", required=False
237+
)
238+
return parser
239+
240+
241+
if __name__ == "__main__":
242+
parser = arg_parser()
243+
args = parser.parse_args()
244+
train(args)

model_zoo/mnist/mnist_train_torch.py

-118
This file was deleted.

0 commit comments

Comments
 (0)