@@ -48,14 +48,9 @@ def load_data(self):
48
48
return (x_train , y_train ),(x_test , y_test )
49
49
50
50
51
- def to_nc (train_samples , train_labels , test_samples , test_labels , out_file_path = 'mnist_images.nc' ):
52
- if os .path .exists (out_file_path ):
53
- os .remove (out_file_path )
51
+ def to_nc (train_samples , train_labels , test_samples , test_labels , out_file ):
54
52
55
- train_labels = list (train_labels )
56
- test_labels = list (test_labels )
57
-
58
- with pnetcdf .File (out_file_path , mode = "w" , format = "NC_64BIT_DATA" ) as fnc :
53
+ with pnetcdf .File (out_file , mode = "w" , format = "NC_64BIT_DATA" ) as fnc :
59
54
60
55
# add MNIST dataset URL as a global attribute
61
56
fnc .url = "https://yann.lecun.com/exdb/mnist/"
@@ -116,6 +111,8 @@ def to_nc(train_samples, train_labels, test_samples, test_labels, out_file_path=
116
111
default = "t10k-images-idx3-ubyte" )
117
112
parser .add_argument ("--test-label-file" , nargs = 1 , type = str , help = "(Optional) input file name of testing labels" ,\
118
113
default = "t10k-labels-idx1-ubyte" )
114
+ parser .add_argument ("--out-file" , nargs = 1 , type = str , help = "(Optional) output NetCDF file name" ,\
115
+ default = "mnist_images.nc" )
119
116
args = parser .parse_args ()
120
117
121
118
verbose = True if args .verbose else False
@@ -152,6 +149,6 @@ def to_nc(train_samples, train_labels, test_samples, test_labels, out_file_path=
152
149
# create mini MNIST file in NetCDF format
153
150
#
154
151
to_nc (train_data [0 :n_train ], train_label [0 :n_train ],
155
- test_data [0 :n_test ], test_label [0 :n_test ], "mnist_images.nc" )
152
+ test_data [0 :n_test ], test_label [0 :n_test ], args . out_file )
156
153
157
154
0 commit comments