Skip to content

Commit 2dd6707

Browse files
authored
Support creating shards for Text files (#2390)
* Create shards for a csv file * Reader to partition csv files * Create csv reader * rename csv to text * Polish elasticdl job service * Polish elasticdl job service * Move the thread to check the timeout task into task manager * Delete unused imports * Fix conflicts * Pre-commit * Set flake8 * Fix by comments * delete the method to read records * Implement read_records * Fix shards to list
1 parent b9e443e commit 2dd6707

File tree

6 files changed

+100
-108
lines changed

6 files changed

+100
-108
lines changed

.flake8

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[flake8]
2+
ignore = E203, E266, W503
3+
max-line-length = 79
4+

elasticdl/python/data/reader/csv_reader.py

-75
This file was deleted.

elasticdl/python/data/reader/data_reader_factory.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from elasticdl.python.common.constants import MaxComputeConfig, ReaderType
1717
from elasticdl.python.data.odps_io import is_odps_configured
18-
from elasticdl.python.data.reader.csv_reader import CSVDataReader
1918
from elasticdl.python.data.reader.odps_reader import ODPSDataReader
2019
from elasticdl.python.data.reader.recordio_reader import RecordIODataReader
20+
from elasticdl.python.data.reader.text_reader import TextDataReader
2121

2222

2323
def create_data_reader(data_origin, records_per_task=None, **kwargs):
@@ -45,11 +45,17 @@ def create_data_reader(data_origin, records_per_task=None, **kwargs):
4545
**kwargs,
4646
)
4747
elif data_origin and data_origin.endswith(".csv"):
48-
return CSVDataReader(data_dir=data_origin, **kwargs)
48+
return TextDataReader(
49+
filename=data_origin,
50+
records_per_task=records_per_task,
51+
**kwargs,
52+
)
4953
else:
5054
return RecordIODataReader(data_dir=data_origin)
5155
elif reader_type == ReaderType.CSV_READER:
52-
return CSVDataReader(data_dir=data_origin, **kwargs)
56+
return TextDataReader(
57+
filename=data_origin, records_per_task=records_per_task, **kwargs
58+
)
5359
elif reader_type == ReaderType.ODPS_READER:
5460
if not is_odps_configured:
5561
raise ValueError(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2020 The ElasticDL Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import csv
15+
import linecache
16+
17+
import tensorflow as tf
18+
19+
from elasticdl.python.data.reader.data_reader import (
20+
AbstractDataReader,
21+
Metadata,
22+
)
23+
24+
25+
class TextDataReader(AbstractDataReader):
26+
"""This reader is used to create shards for a file and
27+
read records from the shard.
28+
"""
29+
30+
def __init__(self, filename, records_per_task, **kwargs):
31+
"""
32+
Args:
33+
kwargs should contains "filename" and "records_per_task".
34+
"""
35+
AbstractDataReader.__init__(self, **kwargs)
36+
self._kwargs = kwargs
37+
self._filename = filename
38+
self._records_per_task = records_per_task
39+
40+
def read_records(self, task):
41+
records = linecache.getlines(task.shard.name)[
42+
task.shard.start : task.shard.end
43+
]
44+
return records
45+
46+
def create_shards(self):
47+
size = self.get_size()
48+
shards = []
49+
num_shards = size // self._records_per_task
50+
start_ind = 0
51+
for shard_id in range(num_shards):
52+
shards.append((self._filename, start_ind, self._records_per_task,))
53+
start_ind += self._records_per_task
54+
# Create a shard with the last records
55+
num_records_left = size % self._records_per_task
56+
if num_records_left != 0:
57+
shards.append((self._filename, start_ind, num_records_left,))
58+
return shards
59+
60+
def get_size(self):
61+
with open(self._filename) as file:
62+
reader = csv.reader(file)
63+
line_num = len(list(reader))
64+
return line_num
65+
66+
@property
67+
def records_output_types(self):
68+
return tf.string
69+
70+
@property
71+
def metadata(self):
72+
return Metadata(column_names=None)

elasticdl/python/tests/data_reader_test.py

+12-28
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
from elasticdl.python.common.constants import MaxComputeConfig
2727
from elasticdl.python.common.model_utils import load_module
2828
from elasticdl.python.data.odps_io import is_odps_configured
29-
from elasticdl.python.data.reader.csv_reader import CSVDataReader
3029
from elasticdl.python.data.reader.data_reader import Metadata
3130
from elasticdl.python.data.reader.data_reader_factory import create_data_reader
3231
from elasticdl.python.data.reader.odps_reader import ODPSDataReader
3332
from elasticdl.python.data.reader.recordio_reader import RecordIODataReader
33+
from elasticdl.python.data.reader.text_reader import TextDataReader
3434
from elasticdl.python.master.task_manager import _Task
3535
from elasticdl.python.tests.test_utils import (
3636
IRIS_TABLE_COLUMN_NAMES,
@@ -73,7 +73,7 @@ def test_recordio_data_reader(self):
7373
self.assertEqual(len(v.numpy()), 1)
7474

7575

76-
class CSVDataReaderTest(unittest.TestCase):
76+
class TextDataReaderTest(unittest.TestCase):
7777
def test_csv_data_reader(self):
7878
with tempfile.TemporaryDirectory() as temp_dir_name:
7979
num_records = 128
@@ -87,33 +87,17 @@ def test_csv_data_reader(self):
8787
iris_file_name = create_iris_csv_file(
8888
size=num_records, columns=columns, temp_dir=temp_dir_name
8989
)
90-
csv_data_reader = CSVDataReader(columns=columns, sep=",")
91-
task = _Task(
92-
iris_file_name, 0, num_records, elasticdl_pb2.TRAINING
90+
csv_data_reader = TextDataReader(
91+
filename=iris_file_name, records_per_task=20
9392
)
94-
95-
def _gen():
96-
for record in csv_data_reader.read_records(task):
97-
yield record
98-
99-
def _feed(dataset, mode, metadata):
100-
def _parse_data(record):
101-
features = tf.strings.to_number(record[0:-1], tf.float32)
102-
label = tf.strings.to_number(record[-1], tf.float32)
103-
return features, label
104-
105-
dataset = dataset.map(_parse_data)
106-
dataset = dataset.batch(10)
107-
return dataset
108-
109-
dataset = tf.data.Dataset.from_generator(
110-
_gen, csv_data_reader.records_output_types
111-
)
112-
dataset = _feed(dataset, None, None)
113-
for features, labels in dataset:
114-
self.assertEqual(features.shape.as_list(), [10, 4])
115-
self.assertEqual(labels.shape.as_list(), [10])
116-
break
93+
shards = csv_data_reader.create_shards()
94+
self.assertEqual(len(shards), 7)
95+
task = _Task(iris_file_name, 0, 20, elasticdl_pb2.TRAINING)
96+
record_count = 0
97+
for record in csv_data_reader.read_records(task):
98+
record_count += 1
99+
self.assertEqual(csv_data_reader.get_size(), num_records)
100+
self.assertEqual(record_count, 20)
117101

118102

119103
@unittest.skipIf(

elasticdl/python/tests/test_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def create_recordio_file(size, dataset_name, shape, temp_dir=None):
270270
return temp_file.name
271271

272272

273-
def create_iris_csv_file(size, columns, temp_dir=None):
273+
def create_iris_csv_file(size, columns, with_heads=False, temp_dir=None):
274274
"""Creates a temporary CSV file.
275275
276276
Args:
@@ -291,7 +291,8 @@ def create_iris_csv_file(size, columns, temp_dir=None):
291291
csv_file_name = temp_file.name + ".csv"
292292
with open(csv_file_name, "w", newline="") as csv_file:
293293
csv_writer = csv.writer(csv_file)
294-
csv_writer.writerow(columns)
294+
if with_heads:
295+
csv_writer.writerow(columns)
295296
csv_writer.writerows(value_data)
296297

297298
return csv_file_name

0 commit comments

Comments
 (0)