Skip to content

Commit 679ee1a

Browse files
committed
add commane-line option --out-file
1 parent f4682e8 commit 679ee1a

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

examples/MNIST/create_mnist_netcdf.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,9 @@ def load_data(self):
4848
return (x_train, y_train),(x_test, y_test)
4949

5050

51-
def to_nc(train_samples, train_labels, test_samples, test_labels, out_file_path='mnist_images.nc'):
52-
if os.path.exists(out_file_path):
53-
os.remove(out_file_path)
51+
def to_nc(train_samples, train_labels, test_samples, test_labels, out_file):
5452

55-
train_labels = list(train_labels)
56-
test_labels = list(test_labels)
57-
58-
with pnetcdf.File(out_file_path, mode = "w", format = "NC_64BIT_DATA") as fnc:
53+
with pnetcdf.File(out_file, mode = "w", format = "NC_64BIT_DATA") as fnc:
5954

6055
# add MNIST dataset URL as a global attribute
6156
fnc.url = "https://yann.lecun.com/exdb/mnist/"
@@ -116,6 +111,8 @@ def to_nc(train_samples, train_labels, test_samples, test_labels, out_file_path=
116111
default = "t10k-images-idx3-ubyte")
117112
parser.add_argument("--test-label-file", nargs=1, type=str, help="(Optional) input file name of testing labels",\
118113
default = "t10k-labels-idx1-ubyte")
114+
parser.add_argument("--out-file", nargs=1, type=str, help="(Optional) output NetCDF file name",\
115+
default = "mnist_images.nc")
119116
args = parser.parse_args()
120117

121118
verbose = True if args.verbose else False
@@ -152,6 +149,6 @@ def to_nc(train_samples, train_labels, test_samples, test_labels, out_file_path=
152149
# create mini MNIST file in NetCDF format
153150
#
154151
to_nc(train_data[0:n_train], train_label[0:n_train],
155-
test_data[0:n_test], test_label[0:n_test], "mnist_images.nc")
152+
test_data[0:n_test], test_label[0:n_test], args.out_file)
156153

157154

0 commit comments

Comments
 (0)