diff --git a/doc/source/api.rst b/doc/source/api.rst index 05d2e8091..153ac2d79 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -77,6 +77,14 @@ Testing Axis.iscompatible Axis.equals +Save +---- + +.. autosummary:: + :toctree: _generated/ + + Axis.to_hdf + .. _api-group: Group @@ -104,6 +112,7 @@ IGroup IGroup.startingwith IGroup.endingwith IGroup.matching + IGroup.to_hdf LGroup ------ @@ -127,6 +136,7 @@ LGroup LGroup.startingwith LGroup.endingwith LGroup.matching + LGroup.to_hdf .. _api-set: diff --git a/doc/source/changes/version_0_29.rst.inc b/doc/source/changes/version_0_29.rst.inc index 16391fc38..d6da65550 100644 --- a/doc/source/changes/version_0_29.rst.inc +++ b/doc/source/changes/version_0_29.rst.inc @@ -23,7 +23,25 @@ New features Miscellaneous improvements -------------------------- -* improved something. +* saving or loading a session from a file now includes `Axis` and `Group` objects in addition to arrays + (closes :issue:`578`): + + Create a session containing axes, groups and arrays + + >>> a, b = Axis("a=a0..a2"), Axis("b=b0..b2") + >>> a01 = a['a0,a1'] >> 'a01' + >>> arr1, arr2 = ndtest((a, b)), ndtest(a) + >>> s = Session([('a', a), ('b', b), ('a01', a01), ('arr1', arr1), ('arr2', arr2)]) + + Saving a session will save axes, groups and arrays + + >>> s.save('session.h5') + + Loading a session will load axes, groups and arrays + + >>> s2 = s.load('session.h5') + >>> s2 + Session(arr1, arr2, a, b, a01) Fixes diff --git a/larray/core/array.py b/larray/core/array.py index 70f58d7e9..a8024afd8 100644 --- a/larray/core/array.py +++ b/larray/core/array.py @@ -64,11 +64,11 @@ from larray.core.abstractbases import ABCLArray from larray.core.expr import ExprNode from larray.core.group import (Group, IGroup, LGroup, remove_nested_groups, _to_key, _to_keys, - _range_to_slice, _translate_sheet_name, _translate_key_hdf) + _range_to_slice, _translate_sheet_name, _translate_group_key_hdf) from larray.core.axis import Axis, AxisReference, AxisCollection, X, _make_axis from larray.util.misc import (table2str, size2str, basestring, izip, rproduct, ReprString, duplicates, float_error_handler_factory, _isnoneslice, light_product, unique_list, common_type, - renamed_to, deprecate_kwarg) + renamed_to, deprecate_kwarg, LHDFStore) nan = np.nan @@ -5997,7 +5997,7 @@ def to_csv(self, filepath, sep=',', na_rep='', wide=True, value_name='value', dr series = self.to_series(value_name, dropna is not None) series.to_csv(filepath, sep=sep, na_rep=na_rep, header=True, **kwargs) - def to_hdf(self, filepath, key, *args, **kwargs): + def to_hdf(self, filepath, key): """ Writes array to a HDF file. @@ -6009,17 +6009,31 @@ def to_hdf(self, filepath, key, *args, **kwargs): filepath : str Path where the hdf file has to be written. key : str or Group - Name of the array within the HDF file. - *args - **kargs + Key (path) of the array within the HDF file (see Notes below). + + Notes + ----- + Objects stored in a HDF file can be grouped together in `HDF groups`. + If an object 'my_obj' is stored in a HDF group 'my_group', + the key associated with this object is then 'my_group/my_obj'. + Be aware that a HDF group can have subgroups. Examples -------- >>> a = ndtest((2, 3)) - >>> a.to_hdf('test.h5', 'a') # doctest: +SKIP + + Save an array + + >>> a.to_hdf('test.h5', 'a') # doctest: +SKIP + + Save an array in a specific HDF group + + >>> a.to_hdf('test.h5', 'arrays/a') # doctest: +SKIP """ - key = _translate_key_hdf(key) - self.to_frame().to_hdf(filepath, key, *args, **kwargs) + key = _translate_group_key_hdf(key) + with LHDFStore(filepath) as store: + store.put(key, self.to_frame()) + store.get_storer(key).attrs.type = 'Array' @deprecate_kwarg('sheet_name', 'sheet') def to_excel(self, filepath=None, sheet=None, position='A1', overwrite_file=False, clear_sheet=False, @@ -6085,7 +6099,7 @@ def to_excel(self, filepath=None, sheet=None, position='A1', overwrite_file=Fals engine = 'xlwings' if xw is not None else None if engine == 'xlwings': - from larray.inout.excel import open_excel + from larray.inout.xw_excel import open_excel close = False new_workbook = False @@ -7022,7 +7036,7 @@ def aslarray(a): elif hasattr(a, '__larray__'): return a.__larray__() elif isinstance(a, pd.DataFrame): - from larray.inout.array import from_frame + from larray.inout.pandas import from_frame return from_frame(a) else: return LArray(a) diff --git a/larray/core/axis.py b/larray/core/axis.py index 00587fd7f..709e26b8e 100644 --- a/larray/core/axis.py +++ b/larray/core/axis.py @@ -7,14 +7,15 @@ from itertools import product import numpy as np +import pandas as pd from larray.core.abstractbases import ABCAxis, ABCAxisReference, ABCLArray from larray.core.expr import ExprNode from larray.core.group import (Group, LGroup, IGroup, IGroupMaker, _to_tick, _to_ticks, _to_key, _seq_summary, - _contain_group_ticks, _seq_group_to_name) + _contain_group_ticks, _seq_group_to_name, _translate_group_key_hdf) from larray.util.oset import * from larray.util.misc import (basestring, PY2, unicode, long, duplicates, array_lookup2, ReprString, index_by_id, - renamed_to, common_type) + renamed_to, common_type, LHDFStore) __all__ = ['Axis', 'AxisCollection', 'X', 'x'] @@ -1194,6 +1195,57 @@ def align(self, other, join='outer'): other = Axis(other) return other + def to_hdf(self, filepath, key=None): + """ + Writes axis to a HDF file. + + A HDF file can contain multiple axes. + The 'key' parameter is a unique identifier for the axis. + + Parameters + ---------- + filepath : str + Path where the hdf file has to be written. + key : str or Group, optional + Key (path) of the axis within the HDF file (see Notes below). + If None, the name of the axis is used. + Defaults to None. + + Notes + ----- + Objects stored in a HDF file can be grouped together in `HDF groups`. + If an object 'my_obj' is stored in a HDF group 'my_group', + the key associated with this object is then 'my_group/my_obj'. + Be aware that a HDF group can have subgroups. + + Examples + -------- + >>> a = Axis("a=a0..a2") + + Save axis + + >>> # by default, the key is the name of the axis + >>> a.to_hdf('test.h5') # doctest: +SKIP + + Save axis with a specific key + + >>> a.to_hdf('test.h5', 'a') # doctest: +SKIP + + Save axis in a specific HDF group + + >>> a.to_hdf('test.h5', 'axes/a') # doctest: +SKIP + """ + if key is None: + if self.name is None: + raise ValueError("Argument key must be provided explicitly in case of anonymous axis") + key = self.name + key = _translate_group_key_hdf(key) + s = pd.Series(data=self.labels, name=self.name) + with LHDFStore(filepath) as store: + store.put(key, s) + store.get_storer(key).attrs.type = 'Axis' + store.get_storer(key).attrs.wildcard = self.iswildcard + def _make_axis(obj): if isinstance(obj, Axis): diff --git a/larray/core/group.py b/larray/core/group.py index fb48aeaac..0f59d138b 100644 --- a/larray/core/group.py +++ b/larray/core/group.py @@ -11,7 +11,8 @@ from larray.core.abstractbases import ABCAxis, ABCAxisReference, ABCLArray from larray.util.oset import * -from larray.util.misc import basestring, PY2, unique, find_closing_chr, _parse_bound, _seq_summary, renamed_to +from larray.util.misc import (basestring, PY2, unique, find_closing_chr, _parse_bound, _seq_summary, + renamed_to, LHDFStore) __all__ = ['Group', 'LGroup', 'LSet', 'IGroup', 'union'] @@ -652,7 +653,7 @@ def _translate_sheet_name(sheet_name): _key_hdf_pattern = re.compile('[\\\/]') -def _translate_key_hdf(key): +def _translate_group_key_hdf(key): if isinstance(key, Group): key = _key_hdf_pattern.sub('_', str(_to_tick(key))) return key @@ -1275,6 +1276,80 @@ def containing(self, substring): substring = substring.eval() return LGroup([v for v in self.eval() if substring in v], axis=self.axis) + def to_hdf(self, filepath, key=None, axis_key=None): + """ + Writes group to a HDF file. + + A HDF file can contain multiple groups. + The 'key' parameter is a unique identifier for the group. + The 'axis_key' parameter is the unique identifier for the associated axis. + The associated axis will be saved if not already present in the HDF file. + + Parameters + ---------- + filepath : str + Path where the hdf file has to be written. + key : str or Group, optional + Key (path) of the group within the HDF file (see Notes below). + If None, the name of the group is used. + Defaults to None. + axis_key : str, optional + Key (path) of the associated axis in the HDF file (see Notes below). + If None, the name of the axis associated with the group is used. + Defaults to None. + + Notes + ----- + Objects stored in a HDF file can be grouped together in `HDF groups`. + If an object 'my_obj' is stored in a HDF group 'my_group', + the key associated with this object is then 'my_group/my_obj'. + Be aware that a HDF group can have subgroups. + + Examples + -------- + >>> from larray import Axis + >>> a = Axis("a=a0..a2") + >>> a.to_hdf('test.h5') + >>> a01 = a['a0,a1'] >> 'a01' + + Save group + + >>> # by default, the key is the name of the group + >>> # and axis_key the name of the associated axis + >>> a01.to_hdf('test.h5') # doctest: +SKIP + + Save group with a specific key + + >>> a01.to_hdf('test.h5', 'a_01') # doctest: +SKIP + + Save group in a specific HDF group + + >>> a.to_hdf('test.h5', 'groups/a01') # doctest: +SKIP + + The associated axis is saved with the group if not already present in the HDF file + + >>> b = Axis("b=b0..b2") + >>> b01 = b['b0,b1'] >> 'b01' + >>> # save both the group 'b01' and the associated axis 'b' + >>> b01.to_hdf('test.h5') # doctest: +SKIP + """ + if key is None: + if self.name is None: + raise ValueError("Argument key must be provided explicitly in case of anonymous group") + key = self.name + key = _translate_group_key_hdf(key) + if axis_key is None: + if self.axis.name is None: + raise ValueError("Argument axis_key must be provided explicitly if the associated axis is anonymous") + axis_key = self.axis.name + s = pd.Series(data=self.eval(), name=self.name) + with LHDFStore(filepath) as store: + store.put(key, s) + store.get_storer(key).attrs.type = 'Group' + if axis_key not in store: + self.axis.to_hdf(store, key=axis_key) + store.get_storer(key).attrs.axis_key = axis_key + # this makes range(LGroup(int)) possible def __index__(self): return self.eval().__index__() diff --git a/larray/core/session.py b/larray/core/session.py index 23fde4913..f2d937175 100644 --- a/larray/core/session.py +++ b/larray/core/session.py @@ -8,6 +8,7 @@ import numpy as np +from larray.core.group import Group from larray.core.axis import Axis from larray.core.array import LArray, get_axes, ndtest, zeros, zeros_like, sequence, aslarray from larray.util.misc import float_error_handler_factory, is_interactive_interpreter, renamed_to, inverseop @@ -17,30 +18,36 @@ # XXX: inherit from OrderedDict or LArray? class Session(object): """ - Groups several array objects together. + Groups several objects together. Parameters ---------- - args : str or dict of str, array or iterable of tuples (str, array) - Name of file to load or dictionary containing couples (name, array). - kwargs : dict of str, array - List of arrays to add written as 'name'=array, ... + *args : str or dict of {str: object} or iterable of tuples (str, object) + Path to the file containing the session to load or + list/tuple/dictionary containing couples (name, object). + **kwargs : dict of {str: object} + Objects to add written as name=object, ... Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) + >>> # axes + >>> a, b = Axis("a=a0..a2"), Axis("b=b0..b2") + >>> # groups + >>> a01 = a['a0,a1'] >> 'a01' + >>> # arrays + >>> arr1, arr2 = ndtest((a, b)), ndtest(a) - create a Session by passing a list of pairs (name, array) + create a Session by passing a list of pairs (name, object) - >>> s = Session([('arr1', arr1), ('arr2', arr2), ('arr3', arr3)]) + >>> s = Session([('a', a), ('b', b), ('a01', a01), ('arr1', arr1), ('arr2', arr2)]) create a Session using keyword arguments (but you lose order on Python < 3.6) - >>> s = Session(arr1=arr1, arr2=arr2, arr3=arr3) + >>> s = Session(a=a, b=b, a01=a01, arr1=arr1, arr2=arr2) create a Session by passing a dictionary (but you lose order on Python < 3.6) - >>> s = Session({'arr1': arr1, 'arr2': arr2, 'arr3': arr3}) + >>> s = Session({'a': a, 'b': b, 'a01': a01, 'arr1': arr1, 'arr2': arr2}) load Session from file @@ -72,10 +79,10 @@ def add(self, *args, **kwargs): Parameters ---------- - args : array - List of objects to add. Objects must have an attribute 'name'. - kwargs : dict of str, array - List of objects to add written as 'name'=array, ... + *args : list of object + Objects to add. Objects must have an attribute 'name'. + **kwargs : dict of {str: object} + Objects to add written as 'name'=array, ... Examples -------- @@ -113,30 +120,36 @@ def __getitem__(self, key): def get(self, key, default=None): """ - Returns the array object corresponding to the key. - If the key doesn't correspond to any array object, a default one can be returned. + Returns the object corresponding to the key. + If the key doesn't correspond to any object, a default one can be returned. Parameters ---------- key : str - Name of the array. - default : array, optional - Returned array if the key doesn't correspond to any array of the current session. + Name of the object. + default : object, optional + Returned object if the key doesn't correspond to any object of the current session. Returns ------- - LArray - Array corresponding to the given key or a default one if not found. + object + Object corresponding to the given key or a default one if not found. Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) - >>> s = Session([('arr1', arr1), ('arr2', arr2), ('arr3', arr3)]) + >>> # axes + >>> a, b = Axis("a=a0..a2"), Axis("b=b0..b2") + >>> # groups + >>> a01 = a['a0,a1'] >> 'a01' + >>> # arrays + >>> arr1, arr2 = ndtest((a, b)), ndtest(a) + >>> s = Session([('a', a), ('b', b), ('a01', a01), ('arr1', arr1), ('arr2', arr2)]) >>> arr = s.get('arr1') >>> arr - a\\b b0 b1 - a0 0 1 - a1 2 3 + a\\b b0 b1 b2 + a0 0 1 2 + a1 3 4 5 + a2 6 7 8 >>> arr = s.get('arr4', zeros('a=a0,a1;b=b0,b1', dtype=int)) >>> arr a\\b b0 b1 @@ -180,7 +193,7 @@ def __setstate__(self, d): def load(self, fname, names=None, engine='auto', display=False, **kwargs): """ - Loads array objects from a file, or several .csv files. + Loads LArray, Axis and Group objects from a file, or several .csv files. WARNING: never load a file using the pickle engine (.pkl or .pickle) from an untrusted source, as it can lead to arbitrary code execution. @@ -191,7 +204,8 @@ def load(self, fname, names=None, engine='auto', display=False, **kwargs): This can be either the path to a single file, a path to a directory containing .csv files or a pattern representing several .csv files. names : list of str, optional - List of arrays to load. If `fname` is None, list of paths to CSV files. + List of objects to load. + If `fname` is None, list of paths to CSV files. Defaults to all valid objects present in the file/directory. engine : {'auto', 'pandas_csv', 'pandas_hdf', 'pandas_excel', 'xlwings_excel', 'pickle'}, optional Load using `engine`. Defaults to 'auto' (use default engine for the format guessed from the file extension). @@ -200,19 +214,24 @@ def load(self, fname, names=None, engine='auto', display=False, **kwargs): Examples -------- - In one module - - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) # doctest: +SKIP - >>> s = Session([('arr1', arr1), ('arr2', arr2), ('arr3', arr3)]) # doctest: +SKIP - >>> s.save('input.h5') # doctest: +SKIP - - In another module - - >>> s = Session() # doctest: +SKIP - >>> s.load('input.h5', ['arr1', 'arr2', 'arr3']) # doctest: +SKIP - >>> arr1, arr2, arr3 = s['arr1', 'arr2', 'arr3'] # doctest: +SKIP + In one module: + + >>> # axes + >>> a, b = Axis("a=a0..a2"), Axis("b=b0..b2") # doctest: +SKIP + >>> # groups + >>> a01 = a['a0,a1'] >> 'a01' # doctest: +SKIP + >>> # arrays + >>> arr1, arr2 = ndtest((a, b)), ndtest(a) # doctest: +SKIP + >>> s = Session([('a', a), ('b', b), ('a01', a01), ('arr1', arr1), ('arr2', arr2)]) # doctest: +SKIP + >>> s.save('input.h5') # doctest: +SKIP + + In another module: load only some objects + + >>> s = Session() # doctest: +SKIP + >>> s.load('input.h5', ['a', 'b', 'arr1', 'arr2']) # doctest: +SKIP + >>> a, b, arr1, arr2 = s['a', 'b', 'arr1', 'arr2'] # doctest: +SKIP >>> # only if you know the order of arrays stored in session - >>> arr1, arr2, arr3 = s.values() # doctest: +SKIP + >>> a, b, a01, arr1, arr2 = s.values() # doctest: +SKIP Using .csv files (assuming the same session as above) @@ -220,7 +239,7 @@ def load(self, fname, names=None, engine='auto', display=False, **kwargs): >>> s = Session() # doctest: +SKIP >>> # load all .csv files starting with "output" in the data directory >>> s.load('data') # doctest: +SKIP - >>> # or equivalently in this case + >>> # or only arrays (i.e. all CSV files starting with 'arr') >>> s.load('data/arr*.csv') # doctest: +SKIP """ if display: @@ -236,20 +255,22 @@ def load(self, fname, names=None, engine='auto', display=False, **kwargs): engine = ext_default_engine[ext] handler_cls = handler_classes[engine] handler = handler_cls(fname) - arrays = handler.read_arrays(names, display=display, **kwargs) - for k, v in arrays.items(): + objects = handler.read_items(names, display=display, **kwargs) + for k, v in objects.items(): self[k] = v def save(self, fname, names=None, engine='auto', overwrite=True, display=False, **kwargs): """ - Dumps all array objects from the current session to a file. + Dumps LArray, Axis and Group objects from the current session to a file. Parameters ---------- fname : str - Path for the dump. + Path of the file for the dump. + If objects are saved in CSV files, the path corresponds to a directory. names : list of str or None, optional - List of names of objects to dump. If `fname` is None, list of paths to CSV files. + List of names of LArray/Axis/Group objects to dump. + If `fname` is None, list of paths to CSV files. Defaults to all objects present in the Session. engine : {'auto', 'pandas_csv', 'pandas_hdf', 'pandas_excel', 'xlwings_excel', 'pickle'}, optional Dump using `engine`. Defaults to 'auto' (use default engine for the format guessed from the file extension). @@ -261,16 +282,21 @@ def save(self, fname, names=None, engine='auto', overwrite=True, display=False, Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) # doctest: +SKIP - >>> s = Session([('arr1', arr1), ('arr2', arr2), ('arr3', arr3)]) # doctest: +SKIP + >>> # axes + >>> a, b = Axis("a=a0..a2"), Axis("b=b0..b2") # doctest: +SKIP + >>> # groups + >>> a01 = a['a0,a1'] >> 'a01' # doctest: +SKIP + >>> # arrays + >>> arr1, arr2 = ndtest((a, b)), ndtest(a) # doctest: +SKIP + >>> s = Session([('a', a), ('b', b), ('a01', a01), ('arr1', arr1), ('arr2', arr2)]) # doctest: +SKIP - Save all arrays + Save all objects - >>> s.save('output.h5') # doctest: +SKIP + >>> s.save('output.h5') # doctest: +SKIP - Save only some arrays + Save only some objects - >>> s.save('output.h5', ['arr1', 'arr3']) # doctest: +SKIP + >>> s.save('output.h5', ['a', 'b', 'arr1']) # doctest: +SKIP Update file @@ -285,11 +311,14 @@ def save(self, fname, names=None, engine='auto', overwrite=True, display=False, engine = ext_default_engine[ext] handler_cls = handler_classes[engine] handler = handler_cls(fname, overwrite) - items = self.filter(kind=LArray).items() + if engine != 'pandas_hdf': + items = self.filter(kind=LArray).items() + else: + items = self.items() if names is not None: names_set = set(names) items = [(k, v) for k, v in items if k in names_set] - handler.dump_arrays(items, display=display, **kwargs) + handler.dump_items(items, display=display, **kwargs) def to_globals(self, names=None, depth=0, warn=True, inplace=False): """ @@ -394,14 +423,15 @@ def to_pickle(self, fname, names=None, overwrite=True, display=False, **kwargs): def to_hdf(self, fname, names=None, overwrite=True, display=False, **kwargs): """ - Dumps all array objects from the current session to an HDF file. + Dumps LArray, Axis and Group objects from the current session to an HDF file. Parameters ---------- fname : str - Path for the dump. + Path of the file for the dump. names : list of str or None, optional - List of names of objects to dump. Defaults to all objects present in the Session. + Names of LArray/Axis/Group objects to dump. + Defaults to all objects present in the Session. overwrite: bool, optional Whether or not to overwrite an existing file, if any. If False, file is updated. Defaults to True. @@ -410,16 +440,21 @@ def to_hdf(self, fname, names=None, overwrite=True, display=False, **kwargs): Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) # doctest: +SKIP - >>> s = Session([('arr1', arr1), ('arr2', arr2), ('arr3', arr3)]) # doctest: +SKIP + >>> # axes + >>> a, b = Axis("a=a0..a2"), Axis("b=b0..b2") # doctest: +SKIP + >>> # groups + >>> a01 = a['a0,a1'] >> 'a01' # doctest: +SKIP + >>> # arrays + >>> arr1, arr2 = ndtest((a, b)), ndtest(a) # doctest: +SKIP + >>> s = Session([('a', a), ('b', b), ('a01', a01), ('arr1', arr1), ('arr2', arr2)]) # doctest: +SKIP Save all arrays >>> s.to_hdf('output.h5') # doctest: +SKIP - Save only some arrays + Save only some objects - >>> s.to_hdf('output.h5', ['arr1', 'arr3']) # doctest: +SKIP + >>> s.to_hdf('output.h5', ['a', 'b', 'arr1']) # doctest: +SKIP """ self.save(fname, names, ext_default_engine['hdf'], overwrite, display, **kwargs) @@ -489,14 +524,14 @@ def to_csv(self, fname, names=None, display=False, **kwargs): def filter(self, pattern=None, kind=None): """ - Returns a new session with array objects which match some criteria. + Returns a new session with objects which match some criteria. Parameters ---------- pattern : str, optional Only keep arrays whose key match `pattern`. - kind : type, optional - Only keep arrays which are instances of type `kind`. + kind : (tuple of) type, optional + Only keep objects which are instances of type(s) `kind`. Returns ------- @@ -506,18 +541,21 @@ def filter(self, pattern=None, kind=None): Examples -------- >>> axis = Axis('a=a0..a2') - >>> test1, test2, zero1 = ndtest((2, 2)), ndtest(4), zeros((3, 2)) - >>> s = Session([('test1', test1), ('test2', test2), ('zero1', zero1), ('axis', axis)]) + >>> group = axis['a0,a1'] >> 'a01' + >>> test1, zero1 = ndtest((2, 2)), zeros((3, 2)) + >>> s = Session([('test1', test1), ('zero1', zero1), ('axis', axis), ('group', group)]) Filter using a pattern argument >>> s.filter(pattern='test').names - ['test1', 'test2'] + ['test1'] Filter using kind argument >>> s.filter(kind=Axis).names ['axis'] + >>> s.filter(kind=(Axis, Group)).names + ['axis', 'group'] """ items = self._objects.items() if pattern is not None: @@ -529,7 +567,7 @@ def filter(self, pattern=None, kind=None): @property def names(self): """ - Returns the list of names of the array objects in the session. + Returns the list of names of the objects in the session. The list is sorted alphabetically and does not follow the internal order. Returns @@ -542,15 +580,17 @@ def names(self): Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) - >>> s = Session([('arr2', arr2), ('arr1', arr1), ('arr3', arr3)]) + >>> axis1 = Axis("a=a0..a2") + >>> group1 = axis1['a0,a1'] >> 'a01' + >>> arr1, arr2 = ndtest((2, 2)), ndtest(4) + >>> s = Session([('arr2', arr2), ('arr1', arr1), ('group1', group1), ('axis1', axis1)]) >>> # print array's names in the alphabetical order >>> s.names - ['arr1', 'arr2', 'arr3'] + ['arr1', 'arr2', 'axis1', 'group1'] >>> # keys() follows the internal order >>> list(s.keys()) - ['arr2', 'arr1', 'arr3'] + ['arr2', 'arr1', 'group1', 'axis1'] """ return sorted(self._objects.keys()) @@ -574,15 +614,17 @@ def keys(self): Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) - >>> s = Session([('arr2', arr2), ('arr1', arr1), ('arr3', arr3)]) + >>> axis1 = Axis("a=a0..a2") + >>> group1 = axis1['a0,a1'] >> 'a01' + >>> arr1, arr2 = ndtest((2, 2)), ndtest(4) + >>> s = Session([('arr2', arr2), ('arr1', arr1), ('group1', group1), ('axis1', axis1)]) >>> # similar to names by follows the internal order >>> list(s.keys()) - ['arr2', 'arr1', 'arr3'] + ['arr2', 'arr1', 'group1', 'axis1'] - >>> # gives the names of arrays in alphabetical order + >>> # gives the names of objects in alphabetical order >>> s.names - ['arr1', 'arr2', 'arr3'] + ['arr1', 'arr2', 'axis1', 'group1'] """ return self._objects.keys() @@ -596,16 +638,20 @@ def values(self): Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) - >>> s = Session([('arr2', arr2), ('arr1', arr1), ('arr3', arr3)]) - >>> # assuming you know the order of arrays stored in the session - >>> arr2, arr1, arr3 = s.values() + >>> axis1 = Axis("a=a0..a2") + >>> group1 = axis1['a0,a1'] >> 'a01' + >>> arr1, arr2 = ndtest((2, 2)), ndtest(4) + >>> s = Session([('arr2', arr2), ('arr1', arr1), ('group1', group1), ('axis1', axis1)]) + >>> # assuming you know the order of objects stored in the session + >>> arr2, arr1, group1, axis1 = s.values() >>> # otherwise, prefer the following syntax - >>> arr1, arr2, arr3 = s['arr1', 'arr2', 'arr3'] + >>> arr1, arr2, axis1, group1 = s['arr1', 'arr2', 'axis1', 'group1'] >>> arr1 a\\b b0 b1 a0 0 1 a1 2 3 + >>> axis1 + Axis(['a0', 'a1', 'a2'], 'a') """ return self._objects.values() @@ -619,12 +665,14 @@ def items(self): Examples -------- - >>> arr1, arr2, arr3 = ndtest((2, 2)), ndtest(4), ndtest((3, 2)) + >>> axis1 = Axis("a=a0..a2") + >>> group1 = axis1['a0,a1'] >> 'a01' + >>> arr1, arr2 = ndtest((2, 2)), ndtest(4) >>> # make the test pass on both Windows and Linux - >>> arr1, arr2, arr3 = arr1.astype(np.int64), arr2.astype(np.int64), arr3.astype(np.int64) - >>> s = Session([('arr2', arr2), ('arr1', arr1), ('arr3', arr3)]) + >>> arr1, arr2 = arr1.astype(np.int64), arr2.astype(np.int64) + >>> s = Session([('arr2', arr2), ('arr1', arr1), ('group1', group1), ('axis1', axis1)]) >>> for k, v in s.items(): - ... print("{}: {}".format(k, v.info)) + ... print("{}: {}".format(k, v.info if isinstance(v, LArray) else repr(v))) arr2: 4 a [4]: 'a0' 'a1' 'a2' 'a3' dtype: int64 @@ -634,11 +682,8 @@ def items(self): b [2]: 'b0' 'b1' dtype: int64 memory used: 32 bytes - arr3: 3 x 2 - a [3]: 'a0' 'a1' 'a2' - b [2]: 'b0' 'b1' - dtype: int64 - memory used: 48 bytes + group1: a['a0', 'a1'] >> 'a01' + axis1: Axis(['a0', 'a1', 'a2'], 'a') """ return self._objects.items() @@ -971,7 +1016,7 @@ def apply(self, func, *args, **kwargs): def summary(self, template=None): """ - Returns a summary of the content of the session. + Returns a summary of the content of the session (arrays only). Parameters ---------- @@ -982,7 +1027,7 @@ def summary(self, template=None): Returns ------- str - Short representation of the content of the session. + Short representation of the content of the session (arrays only). . Examples -------- @@ -1006,7 +1051,8 @@ def summary(self, template=None): template = "{name}: {axes_names}\n {title}\n" templ_kwargs = [{'name': k, 'axes_names': ', '.join(v.axes.display_names), - 'title': v.title} for k, v in self.items()] + 'title': v.title} + for k, v in self.items() if isinstance(v, LArray)] return '\n'.join(template.format(**kwargs) for kwargs in templ_kwargs) diff --git a/larray/inout/__init__.py b/larray/inout/__init__.py index ad1239606..317fdd9e8 100644 --- a/larray/inout/__init__.py +++ b/larray/inout/__init__.py @@ -1,5 +1,9 @@ from __future__ import absolute_import, division, print_function +from larray.inout.pandas import * +from larray.inout.csv import * +from larray.inout.misc import * from larray.inout.excel import * -from larray.inout.array import * -from larray.inout.session import * +from larray.inout.hdf import * +from larray.inout.sas import * +from larray.inout.xw_excel import * diff --git a/larray/inout/array.py b/larray/inout/array.py deleted file mode 100644 index 6f5b12dca..000000000 --- a/larray/inout/array.py +++ /dev/null @@ -1,770 +0,0 @@ -from __future__ import absolute_import, print_function - -import os -import csv -import numpy as np -import pandas as pd -import warnings -from itertools import product - -from larray.core.axis import Axis -from larray.core.array import LArray, ndtest -from larray.core.group import _translate_sheet_name, _translate_key_hdf -from larray.util.misc import (basestring, skip_comment_cells, strip_rows, csv_open, StringIO, decode, unique, - deprecate_kwarg) - -try: - import xlwings as xw -except ImportError: - xw = None - -__all__ = ['from_frame', 'read_csv', 'read_tsv', 'read_eurostat', 'read_hdf', 'read_excel', 'read_sas', - 'from_lists', 'from_string'] - - -def parse(s): - """ - Used to parse the "folded" axis ticks (usually periods). - """ - # parameters can be strings or numbers - if isinstance(s, basestring): - s = s.strip() - low = s.lower() - if low == 'true': - return True - elif low == 'false': - return False - elif s.isdigit(): - return int(s) - else: - try: - return float(s) - except ValueError: - return s - else: - return s - - -def df_labels(df, sort=True): - """ - Returns unique labels for each dimension. - """ - idx = df.index - if isinstance(idx, pd.core.index.MultiIndex): - if sort: - return list(idx.levels) - else: - return [list(unique(idx.get_level_values(l))) for l in idx.names] - else: - assert isinstance(idx, pd.core.index.Index) - # use .values if needed - return [idx] - - -def cartesian_product_df(df, sort_rows=False, sort_columns=False, **kwargs): - labels = df_labels(df, sort=sort_rows) - if sort_rows: - new_index = pd.MultiIndex.from_product(labels) - else: - new_index = pd.MultiIndex.from_tuples(list(product(*labels))) - columns = sorted(df.columns) if sort_columns else list(df.columns) - # the prodlen test is meant to avoid the more expensive array_equal test - prodlen = np.prod([len(axis_labels) for axis_labels in labels]) - if prodlen == len(df) and columns == list(df.columns) and np.array_equal(df.index.values, new_index.values): - return df, labels - return df.reindex(new_index, columns, **kwargs), labels - - -def from_series(s, sort_rows=False): - """ - Converts Pandas Series into 1D LArray. - - Parameters - ---------- - s : Pandas Series - Input Pandas Series. - sort_rows : bool, optional - Whether or not to sort the rows alphabetically. Defaults to False. - - Returns - ------- - LArray - """ - name = s.name if s.name is not None else s.index.name - if name is not None: - name = str(name) - if sort_rows: - s = s.sort_index() - return LArray(s.values, Axis(s.index.values, name)) - - -def from_frame(df, sort_rows=False, sort_columns=False, parse_header=False, unfold_last_axis_name=False, **kwargs): - """ - Converts Pandas DataFrame into LArray. - - Parameters - ---------- - df : pandas.DataFrame - Input dataframe. By default, name and labels of the last axis are defined by the name and labels of the - columns Index of the dataframe unless argument unfold_last_axis_name is set to True. - sort_rows : bool, optional - Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. - sort_columns : bool, optional - Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). - Defaults to False. - parse_header : bool, optional - Whether or not to parse columns labels. Pandas treats column labels as strings. - If True, column labels are converted into int, float or boolean when possible. Defaults to False. - unfold_last_axis_name : bool, optional - Whether or not to extract the names of the last two axes by splitting the name of the last index column of the - dataframe using ``\\``. Defaults to False. - - Returns - ------- - LArray - - See Also - -------- - LArray.to_frame - - Examples - -------- - >>> df = ndtest((2, 2, 2)).to_frame() - >>> df # doctest: +NORMALIZE_WHITESPACE - c c0 c1 - a b - a0 b0 0 1 - b1 2 3 - a1 b0 4 5 - b1 6 7 - >>> from_frame(df) - a b\\c c0 c1 - a0 b0 0 1 - a0 b1 2 3 - a1 b0 4 5 - a1 b1 6 7 - - Names of the last two axes written as ``before_last_axis_name\\last_axis_name`` - - >>> df = ndtest((2, 2, 2)).to_frame(fold_last_axis_name=True) - >>> df # doctest: +NORMALIZE_WHITESPACE - c0 c1 - a b\\c - a0 b0 0 1 - b1 2 3 - a1 b0 4 5 - b1 6 7 - >>> from_frame(df, unfold_last_axis_name=True) - a b\\c c0 c1 - a0 b0 0 1 - a0 b1 2 3 - a1 b0 4 5 - a1 b1 6 7 - """ - axes_names = [decode(name, 'utf8') for name in df.index.names] - - # handle 2 or more dimensions with the last axis name given using \ - if unfold_last_axis_name: - if isinstance(axes_names[-1], basestring) and '\\' in axes_names[-1]: - last_axes = [name.strip() for name in axes_names[-1].split('\\')] - axes_names = axes_names[:-1] + last_axes - else: - axes_names += [None] - else: - axes_names += [df.columns.name] - - df, axes_labels = cartesian_product_df(df, sort_rows=sort_rows, sort_columns=sort_columns, **kwargs) - - # Pandas treats column labels as column names (strings) so we need to convert them to values - last_axis_labels = [parse(cell) for cell in df.columns.values] if parse_header else list(df.columns.values) - axes_labels.append(last_axis_labels) - axes_names = [str(name) if name is not None else name - for name in axes_names] - - axes = [Axis(labels, name) for labels, name in zip(axes_labels, axes_names)] - data = df.values.reshape([len(axis) for axis in axes]) - return LArray(data, axes) - - -def df_aslarray(df, sort_rows=False, sort_columns=False, raw=False, parse_header=True, wide=True, **kwargs): - """ - Prepare Pandas DataFrame and then convert it into LArray. - - Parameters - ---------- - df : Pandas DataFrame - Input dataframe. - sort_rows : bool, optional - Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. - sort_columns : bool, optional - Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). - Defaults to False. - raw : bool, optional - Whether or not to consider the input dataframe as a raw dataframe, i.e. read without index at all. - If True, build the first N-1 axes of the output array from the first N-1 dataframe columns. Defaults to False. - parse_header : bool, optional - Whether or not to parse columns labels. Pandas treats column labels as strings. - If True, column labels are converted into int, float or boolean when possible. Defaults to True. - wide : bool, optional - Whether or not to assume the array is stored in "wide" format. - If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. - Defaults to True. - - Returns - ------- - LArray - """ - # we could inline df_aslarray into the functions that use it, so that the original (non-cartesian) df is freed from - # memory at this point, but it would be much uglier and would not lower the peak memory usage which happens during - # cartesian_product_df.reindex - - # raw = True: the dataframe was read without index at all (ie 2D dataframe), - # irrespective of the actual data dimensionality - if raw: - columns = df.columns.values.tolist() - if wide: - try: - # take the first column which contains '\' - pos_last = next(i for i, v in enumerate(columns) if isinstance(v, basestring) and '\\' in v) - except StopIteration: - # we assume first column will not contain data - pos_last = 0 - - # This is required to handle int column names (otherwise we can simply use column positions in set_index). - # This is NOT the same as df.columns[list(range(...))] ! - index_columns = [df.columns[i] for i in range(pos_last + 1)] - df.set_index(index_columns, inplace=True) - else: - index_columns = [df.columns[i] for i in range(len(df.columns) - 1)] - df.set_index(index_columns, inplace=True) - series = df[df.columns[-1]] - if isinstance(series.index, pd.core.index.MultiIndex): - fill_value = kwargs.get('fill_value', np.nan) - # TODO: use argument sort=False when it will be available - # (see https://github.com/pandas-dev/pandas/issues/15105) - df = series.unstack(level=-1, fill_value=fill_value) - # pandas (un)stack and pivot(_table) methods return a Dataframe/Series with sorted index and columns - labels = df_labels(series, sort=False) - index = pd.MultiIndex.from_tuples(list(product(*labels[:-1])), names=series.index.names[:-1]) - columns = labels[-1] - df = df.reindex(index=index, columns=columns, fill_value=fill_value) - else: - series.name = series.index.name - if sort_rows: - raise ValueError('sort_rows=True is not valid for 1D arrays. Please use sort_columns instead.') - return from_series(series, sort_rows=sort_columns) - - # handle 1D - if len(df) == 1 and (pd.isnull(df.index.values[0]) or - (isinstance(df.index.values[0], basestring) and df.index.values[0].strip() == '')): - if parse_header: - df.columns = pd.Index([parse(cell) for cell in df.columns.values], name=df.columns.name) - series = df.iloc[0] - series.name = df.index.name - if sort_rows: - raise ValueError('sort_rows=True is not valid for 1D arrays. Please use sort_columns instead.') - return from_series(series, sort_rows=sort_columns) - else: - axes_names = [decode(name, 'utf8') for name in df.index.names] - unfold_last_axis_name = isinstance(axes_names[-1], basestring) and '\\' in axes_names[-1] - return from_frame(df, sort_rows=sort_rows, sort_columns=sort_columns, parse_header=parse_header, - unfold_last_axis_name=unfold_last_axis_name, **kwargs) - - -def _get_index_col(nb_axes=None, index_col=None, wide=True): - if not wide: - if nb_axes is not None or index_col is not None: - raise ValueError("`nb_axes` or `index_col` argument cannot be used when `wide` argument is False") - - if nb_axes is not None and index_col is not None: - raise ValueError("cannot specify both `nb_axes` and `index_col`") - elif nb_axes is not None: - index_col = list(range(nb_axes - 1)) - elif isinstance(index_col, int): - index_col = [index_col] - - return index_col - - -@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) -def read_csv(filepath_or_buffer, nb_axes=None, index_col=None, sep=',', headersep=None, fill_value=np.nan, - na=np.nan, sort_rows=False, sort_columns=False, wide=True, dialect='larray', **kwargs): - """ - Reads csv file and returns an array with the contents. - - Notes - ----- - csv file format: - arr,ages,sex,nat\time,1991,1992,1993 - A1,BI,H,BE,1,0,0 - A1,BI,H,FO,2,0,0 - A1,BI,F,BE,0,0,1 - A1,BI,F,FO,0,0,0 - A1,A0,H,BE,0,0,0 - - Parameters - ---------- - filepath_or_buffer : str or any file-like object - Path where the csv file has to be read or a file handle. - nb_axes : int, optional - Number of axes of output array. The first `nb_axes` - 1 columns and the header of the CSV file will be used - to set the axes of the output array. If not specified, the number of axes is given by the position of the - column header including the character `\` plus one. If no column header includes the character `\`, the array - is assumed to have one axis. Defaults to None. - index_col : list, optional - Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). - sep : str, optional - Separator. - headersep : str or None, optional - Separator for headers. - fill_value : scalar or LArray, optional - Value used to fill cells corresponding to label combinations which are not present in the input. - Defaults to NaN. - sort_rows : bool, optional - Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. - sort_columns : bool, optional - Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). - Defaults to False. - wide : bool, optional - Whether or not to assume the array is stored in "wide" format. - If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. - Defaults to True. - dialect : 'classic' | 'larray' | 'liam2', optional - Name of dialect. Defaults to 'larray'. - **kwargs - - Returns - ------- - LArray - - Examples - -------- - >>> tmpdir = getfixture('tmpdir') - >>> fname = os.path.join(tmpdir.strpath, 'test.csv') - >>> a = ndtest('nat=BE,FO;sex=M,F') - >>> a - nat\\sex M F - BE 0 1 - FO 2 3 - >>> a.to_csv(fname) - >>> with open(fname) as f: - ... print(f.read().strip()) - nat\\sex,M,F - BE,0,1 - FO,2,3 - >>> read_csv(fname) - nat\\sex M F - BE 0 1 - FO 2 3 - - Sort columns - - >>> read_csv(fname, sort_columns=True) - nat\\sex F M - BE 1 0 - FO 3 2 - - Read array saved in "narrow" format (wide=False) - - >>> a.to_csv(fname, wide=False) - >>> with open(fname) as f: - ... print(f.read().strip()) - nat,sex,value - BE,M,0 - BE,F,1 - FO,M,2 - FO,F,3 - >>> read_csv(fname, wide=False) - nat\\sex M F - BE 0 1 - FO 2 3 - - Specify the number of axes of the output array (useful when the name of the last axis is implicit) - - >>> a.to_csv(fname, dialect='classic') - >>> with open(fname) as f: - ... print(f.read().strip()) - nat,M,F - BE,0,1 - FO,2,3 - >>> read_csv(fname, nb_axes=2) - nat\\{1} M F - BE 0 1 - FO 2 3 - """ - if not np.isnan(na): - fill_value = na - warnings.warn("read_csv `na` argument has been renamed to `fill_value`. Please use that instead.", - FutureWarning, stacklevel=2) - - if dialect == 'liam2': - # read axes names. This needs to be done separately instead of reading the whole file with Pandas then - # manipulating the dataframe because the header line must be ignored for the column types to be inferred - # correctly. Note that to read one line, this is faster than using Pandas reader. - with csv_open(filepath_or_buffer) as f: - reader = csv.reader(f, delimiter=sep) - line_stream = skip_comment_cells(strip_rows(reader)) - axes_names = next(line_stream) - - if nb_axes is not None or index_col is not None: - raise ValueError("nb_axes and index_col are not compatible with dialect='liam2'") - if len(axes_names) > 1: - nb_axes = len(axes_names) - # use the second data line for column headers (excludes comments and blank lines before counting) - kwargs['header'] = 1 - kwargs['comment'] = '#' - - index_col = _get_index_col(nb_axes, index_col, wide) - - if headersep is not None: - if index_col is None: - index_col = [0] - - df = pd.read_csv(filepath_or_buffer, index_col=index_col, sep=sep, **kwargs) - if dialect == 'liam2': - if len(df) == 1: - df.set_index([[np.nan]], inplace=True) - if len(axes_names) > 1: - df.index.names = axes_names[:-1] - df.columns.name = axes_names[-1] - raw = False - else: - raw = index_col is None - - if headersep is not None: - combined_axes_names = df.index.name - df.index = df.index.str.split(headersep, expand=True) - df.index.names = combined_axes_names.split(headersep) - raw = False - - return df_aslarray(df, sort_rows=sort_rows, sort_columns=sort_columns, fill_value=fill_value, raw=raw, wide=wide) - - -def read_tsv(filepath_or_buffer, **kwargs): - return read_csv(filepath_or_buffer, sep='\t', **kwargs) - - -def read_eurostat(filepath_or_buffer, **kwargs): - """Reads EUROSTAT TSV (tab-separated) file into an array. - - EUROSTAT TSV files are special because they use tabs as data separators but comas to separate headers. - - Parameters - ---------- - filepath_or_buffer : str or any file-like object - Path where the tsv file has to be read or a file handle. - kwargs - Arbitrary keyword arguments are passed through to read_csv. - - Returns - ------- - LArray - """ - return read_csv(filepath_or_buffer, sep='\t', headersep=',', **kwargs) - - -def read_hdf(filepath_or_buffer, key, fill_value=np.nan, na=np.nan, sort_rows=False, sort_columns=False, **kwargs): - """Reads an array named key from a HDF5 file in filepath (path+name) - - Parameters - ---------- - filepath_or_buffer : str or pandas.HDFStore - Path and name where the HDF5 file is stored or a HDFStore object. - key : str or Group - Name of the array. - fill_value : scalar or LArray, optional - Value used to fill cells corresponding to label combinations which are not present in the input. - Defaults to NaN. - sort_rows : bool, optional - Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. - sort_columns : bool, optional - Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). - Defaults to False. - - Returns - ------- - LArray - """ - if not np.isnan(na): - fill_value = na - warnings.warn("read_hdf `na` argument has been renamed to `fill_value`. Please use that instead.", - FutureWarning, stacklevel=2) - - key = _translate_key_hdf(key) - df = pd.read_hdf(filepath_or_buffer, key, **kwargs) - return df_aslarray(df, sort_rows=sort_rows, sort_columns=sort_columns, fill_value=fill_value, parse_header=False) - - -@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) -@deprecate_kwarg('sheetname', 'sheet') -def read_excel(filepath, sheet=0, nb_axes=None, index_col=None, fill_value=np.nan, na=np.nan, - sort_rows=False, sort_columns=False, wide=True, engine=None, **kwargs): - """ - Reads excel file from sheet name and returns an LArray with the contents - - Parameters - ---------- - filepath : str - Path where the Excel file has to be read or use -1 to refer to the currently active workbook. - sheet : str, Group or int, optional - Name or index of the Excel sheet containing the array to be read. - By default the array is read from the first sheet. - nb_axes : int, optional - Number of axes of output array. The first `nb_axes` - 1 columns and the header of the Excel sheet will be used - to set the axes of the output array. If not specified, the number of axes is given by the position of the - column header including the character `\` plus one. If no column header includes the character `\`, the array - is assumed to have one axis. Defaults to None. - index_col : list, optional - Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). - fill_value : scalar or LArray, optional - Value used to fill cells corresponding to label combinations which are not present in the input. - Defaults to NaN. - sort_rows : bool, optional - Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. - sort_columns : bool, optional - Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). - Defaults to False. - wide : bool, optional - Whether or not to assume the array is stored in "wide" format. - If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. - Defaults to True. - engine : {'xlrd', 'xlwings'}, optional - Engine to use to read the Excel file. If None (default), it will use 'xlwings' by default if the module is - installed and relies on Pandas default reader otherwise. - **kwargs - """ - if not np.isnan(na): - fill_value = na - warnings.warn("read_excel `na` argument has been renamed to `fill_value`. Please use that instead.", - FutureWarning, stacklevel=2) - - sheet = _translate_sheet_name(sheet) - - if engine is None: - engine = 'xlwings' if xw is not None else None - - index_col = _get_index_col(nb_axes, index_col, wide) - - if engine == 'xlwings': - if kwargs: - raise TypeError("'{}' is an invalid keyword argument for this function when using the xlwings backend" - .format(list(kwargs.keys())[0])) - from larray.inout.excel import open_excel - with open_excel(filepath) as wb: - return wb[sheet].load(index_col=index_col, fill_value=fill_value, sort_rows=sort_rows, - sort_columns=sort_columns, wide=wide) - else: - df = pd.read_excel(filepath, sheet, index_col=index_col, engine=engine, **kwargs) - return df_aslarray(df, sort_rows=sort_rows, sort_columns=sort_columns, raw=index_col is None, - fill_value=fill_value, wide=wide) - - -@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) -def read_sas(filepath, nb_axes=None, index_col=None, fill_value=np.nan, na=np.nan, sort_rows=False, sort_columns=False, - **kwargs): - """ - Reads sas file and returns an LArray with the contents - nb_axes: number of axes of the output array - or - index_col: Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]) - """ - if not np.isnan(na): - fill_value = na - warnings.warn("read_sas `na` argument has been renamed to `fill_value`. Please use that instead.", - FutureWarning, stacklevel=2) - - if nb_axes is not None and index_col is not None: - raise ValueError("cannot specify both nb_axes and index_col") - elif nb_axes is not None: - index_col = list(range(nb_axes - 1)) - elif isinstance(index_col, int): - index_col = [index_col] - - df = pd.read_sas(filepath, index=index_col, **kwargs) - return df_aslarray(df, sort_rows=sort_rows, sort_columns=sort_columns, fill_value=fill_value) - - -@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) -def from_lists(data, nb_axes=None, index_col=None, fill_value=np.nan, sort_rows=False, sort_columns=False, wide=True): - """ - initialize array from a list of lists (lines) - - Parameters - ---------- - data : sequence (tuple, list, ...) - Input data. All data is supposed to already have the correct type (e.g. strings are not parsed). - nb_axes : int, optional - Number of axes of output array. The first `nb_axes` - 1 columns and the header will be used - to set the axes of the output array. If not specified, the number of axes is given by the position of the - column header including the character `\` plus one. If no column header includes the character `\`, the array - is assumed to have one axis. Defaults to None. - index_col : list, optional - Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). - fill_value : scalar or LArray, optional - Value used to fill cells corresponding to label combinations which are not present in the input. - Defaults to NaN. - sort_rows : bool, optional - Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. - sort_columns : bool, optional - Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). - Defaults to False. - wide : bool, optional - Whether or not to assume the array is stored in "wide" format. - If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. - Defaults to True. - - Returns - ------- - LArray - - Examples - -------- - >>> from_lists([['sex', 'M', 'F'], - ... ['', 0, 1]]) - sex M F - 0 1 - >>> from_lists([['sex\\\\year', 1991, 1992, 1993], - ... [ 'M', 0, 1, 2], - ... [ 'F', 3, 4, 5]]) - sex\\year 1991 1992 1993 - M 0 1 2 - F 3 4 5 - - Read array with missing values + `fill_value` argument - - >>> from_lists([['sex', 'nat\\\\year', 1991, 1992, 1993], - ... [ 'M', 'BE', 1, 0, 0], - ... [ 'M', 'FO', 2, 0, 0], - ... [ 'F', 'BE', 0, 0, 1]]) - sex nat\\year 1991 1992 1993 - M BE 1.0 0.0 0.0 - M FO 2.0 0.0 0.0 - F BE 0.0 0.0 1.0 - F FO nan nan nan - - >>> from_lists([['sex', 'nat\\\\year', 1991, 1992, 1993], - ... [ 'M', 'BE', 1, 0, 0], - ... [ 'M', 'FO', 2, 0, 0], - ... [ 'F', 'BE', 0, 0, 1]], fill_value=42) - sex nat\\year 1991 1992 1993 - M BE 1 0 0 - M FO 2 0 0 - F BE 0 0 1 - F FO 42 42 42 - - Specify the number of axes of the array to be read - - >>> from_lists([['sex', 'nat', 1991, 1992, 1993], - ... [ 'M', 'BE', 1, 0, 0], - ... [ 'M', 'FO', 2, 0, 0], - ... [ 'F', 'BE', 0, 0, 1]], nb_axes=3) - sex nat\\{2} 1991 1992 1993 - M BE 1.0 0.0 0.0 - M FO 2.0 0.0 0.0 - F BE 0.0 0.0 1.0 - F FO nan nan nan - - Read array saved in "narrow" format (wide=False) - - >>> from_lists([['sex', 'nat', 'year', 'value'], - ... [ 'M', 'BE', 1991, 1 ], - ... [ 'M', 'BE', 1992, 0 ], - ... [ 'M', 'BE', 1993, 0 ], - ... [ 'M', 'FO', 1991, 2 ], - ... [ 'M', 'FO', 1992, 0 ], - ... [ 'M', 'FO', 1993, 0 ], - ... [ 'F', 'BE', 1991, 0 ], - ... [ 'F', 'BE', 1992, 0 ], - ... [ 'F', 'BE', 1993, 1 ]], wide=False) - sex nat\\year 1991 1992 1993 - M BE 1.0 0.0 0.0 - M FO 2.0 0.0 0.0 - F BE 0.0 0.0 1.0 - F FO nan nan nan - """ - index_col = _get_index_col(nb_axes, index_col, wide) - - df = pd.DataFrame(data[1:], columns=data[0]) - if index_col is not None: - df.set_index([df.columns[c] for c in index_col], inplace=True) - - return df_aslarray(df, raw=index_col is None, parse_header=False, sort_rows=sort_rows, sort_columns=sort_columns, - fill_value=fill_value, wide=wide) - - -@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) -def from_string(s, nb_axes=None, index_col=None, sep=' ', wide=True, **kwargs): - """Create an array from a multi-line string. - - Parameters - ---------- - s : str - input string. - nb_axes : int, optional - Number of axes of output array. The first `nb_axes` - 1 columns and the header will be used - to set the axes of the output array. If not specified, the number of axes is given by the position of the - column header including the character `\` plus one. If no column header includes the character `\`, the array - is assumed to have one axis. Defaults to None. - index_col : list, optional - Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). - sep : str - delimiter used to split each line into cells. - wide : bool, optional - Whether or not to assume the array is stored in "wide" format. - If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. - Defaults to True. - \**kwargs - See arguments of Pandas read_csv function. - - Returns - ------- - LArray - - Examples - -------- - >>> # to create a 1D array using the default separator ' ', a tabulation character \t must be added in front - >>> # of the data line - >>> from_string("sex M F\\n\\t 0 1") - sex M F - 0 1 - >>> from_string("nat\\\\sex M F\\nBE 0 1\\nFO 2 3") - nat\sex M F - BE 0 1 - FO 2 3 - >>> from_string("period a b\\n2010 0 1\\n2011 2 3") - period\{1} a b - 2010 0 1 - 2011 2 3 - - Each label is stripped of leading and trailing whitespace, so this is valid too: - - >>> from_string('''nat\\\\sex M F - ... BE 0 1 - ... FO 2 3''') - nat\sex M F - BE 0 1 - FO 2 3 - >>> from_string('''age nat\\\\sex M F - ... 0 BE 0 1 - ... 0 FO 2 3 - ... 1 BE 4 5 - ... 1 FO 6 7''') - age nat\sex M F - 0 BE 0 1 - 0 FO 2 3 - 1 BE 4 5 - 1 FO 6 7 - - Empty lines at the beginning or end are ignored, so one can also format the string like this: - - >>> from_string(''' - ... nat\\\\sex M F - ... BE 0 1 - ... FO 2 3 - ... ''') - nat\sex M F - BE 0 1 - FO 2 3 - """ - return read_csv(StringIO(s), nb_axes=nb_axes, index_col=index_col, sep=sep, skipinitialspace=True, - wide=wide, **kwargs) diff --git a/larray/inout/common.py b/larray/inout/common.py new file mode 100644 index 000000000..7433c42a4 --- /dev/null +++ b/larray/inout/common.py @@ -0,0 +1,158 @@ +from __future__ import absolute_import, print_function + +import os +from collections import OrderedDict + +from larray.core.axis import Axis +from larray.core.group import Group +from larray.core.array import LArray + + +def _get_index_col(nb_axes=None, index_col=None, wide=True): + if not wide: + if nb_axes is not None or index_col is not None: + raise ValueError("`nb_axes` or `index_col` argument cannot be used when `wide` argument is False") + + if nb_axes is not None and index_col is not None: + raise ValueError("cannot specify both `nb_axes` and `index_col`") + elif nb_axes is not None: + index_col = list(range(nb_axes - 1)) + elif isinstance(index_col, int): + index_col = [index_col] + + return index_col + + +_allowed_types = (LArray, Axis, Group) + + +class FileHandler(object): + """ + Abstract class defining the methods for "file handler" subclasses. + + Parameters + ---------- + fname : str + Filename. + + Attributes + ---------- + fname : str + Filename. + """ + def __init__(self, fname, overwrite_file=False): + self.fname = fname + self.original_file_name = None + self.overwrite_file = overwrite_file + + def _open_for_read(self): + raise NotImplementedError() + + def _open_for_write(self): + raise NotImplementedError() + + def list(self): + """ + Returns the list of objects' names. + """ + raise NotImplementedError() + + def _read_item(self, key, *args, **kwargs): + raise NotImplementedError() + + def _dump(self, key, value, *args, **kwargs): + raise NotImplementedError() + + def save(self): + """ + Saves items in file. + """ + pass + + def close(self): + """ + Closes file. + """ + raise NotImplementedError() + + def _get_original_file_name(self): + if self.overwrite_file and os.path.isfile(self.fname): + self.original_file_name = self.fname + self.fname = '{}~{}'.format(*os.path.splitext(self.fname)) + + def _update_original_file(self): + if self.original_file_name is not None and os.path.isfile(self.fname): + os.remove(self.original_file_name) + os.rename(self.fname, self.original_file_name) + + def read_items(self, keys, *args, **kwargs): + """ + Reads file content (HDF, Excel, CSV, ...) and returns a dictionary containing loaded objects. + + Parameters + ---------- + keys : list of str + List of objects' names. + *args : any + Any other argument is passed through to the underlying read function. + display : bool, optional + Whether or not the function should display a message when starting and ending to load each object. + Defaults to False. + ignore_exceptions : bool, optional + Whether or not an exception should stop the function or be ignored. Defaults to False. + **kwargs : any + Any other keyword argument is passed through to the underlying read function. + + Returns + ------- + OrderedDict(str, LArray/Axis/Group) + Dictionary containing the loaded object. + """ + display = kwargs.pop('display', False) + ignore_exceptions = kwargs.pop('ignore_exceptions', False) + self._open_for_read() + res = OrderedDict() + if keys is None: + keys = self.list() + for key in keys: + if display: + print("loading", key, "...", end=' ') + try: + key, item = self._read_item(key, *args, **kwargs) + res[key] = item + except Exception: + if not ignore_exceptions: + raise + if display: + print("done") + self.close() + return res + + def dump_items(self, key_values, *args, **kwargs): + """ + Dumps objects corresponding to keys in file in HDF, Excel, CSV, ... format + + Parameters + ---------- + key_values : list of (str, LArray/Axis/Group) pairs + Name and data of objects to dump. + kwargs : + * display: whether or not to display when the dump of each object is started/done. + """ + display = kwargs.pop('display', False) + self._get_original_file_name() + self._open_for_write() + key_values = [(k, v) for k, v in key_values if isinstance(v, _allowed_types)] + for key, value in key_values: + if isinstance(value, LArray) and value.ndim == 0: + if display: + print('Cannot dump {}. Dumping 0D arrays is currently not supported.'.format(key)) + continue + if display: + print("dumping", key, "...", end=' ') + self._dump(key, value, *args, **kwargs) + if display: + print("done") + self.save() + self.close() + self._update_original_file() \ No newline at end of file diff --git a/larray/inout/csv.py b/larray/inout/csv.py new file mode 100644 index 000000000..f1bece602 --- /dev/null +++ b/larray/inout/csv.py @@ -0,0 +1,248 @@ +from __future__ import absolute_import, print_function + +import os +import csv +import warnings +from glob import glob + +import pandas as pd +import numpy as np + +from larray.core.array import LArray, ndtest +from larray.util.misc import skip_comment_cells, strip_rows, csv_open, deprecate_kwarg +from larray.inout.common import _get_index_col, FileHandler +from larray.inout.pandas import df_aslarray + + +__all__ = ['read_csv', 'read_tsv', 'read_eurostat'] + + +@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) +def read_csv(filepath_or_buffer, nb_axes=None, index_col=None, sep=',', headersep=None, fill_value=np.nan, + na=np.nan, sort_rows=False, sort_columns=False, wide=True, dialect='larray', **kwargs): + """ + Reads csv file and returns an array with the contents. + + Notes + ----- + csv file format: + arr,ages,sex,nat\time,1991,1992,1993 + A1,BI,H,BE,1,0,0 + A1,BI,H,FO,2,0,0 + A1,BI,F,BE,0,0,1 + A1,BI,F,FO,0,0,0 + A1,A0,H,BE,0,0,0 + + Parameters + ---------- + filepath_or_buffer : str or any file-like object + Path where the csv file has to be read or a file handle. + nb_axes : int, optional + Number of axes of output array. The first `nb_axes` - 1 columns and the header of the CSV file will be used + to set the axes of the output array. If not specified, the number of axes is given by the position of the + column header including the character `\` plus one. If no column header includes the character `\`, the array + is assumed to have one axis. Defaults to None. + index_col : list, optional + Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). + sep : str, optional + Separator. + headersep : str or None, optional + Separator for headers. + fill_value : scalar or LArray, optional + Value used to fill cells corresponding to label combinations which are not present in the input. + Defaults to NaN. + sort_rows : bool, optional + Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. + sort_columns : bool, optional + Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). + Defaults to False. + wide : bool, optional + Whether or not to assume the array is stored in "wide" format. + If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. + Defaults to True. + dialect : 'classic' | 'larray' | 'liam2', optional + Name of dialect. Defaults to 'larray'. + **kwargs + + Returns + ------- + LArray + + Examples + -------- + >>> tmpdir = getfixture('tmpdir') + >>> fname = os.path.join(tmpdir.strpath, 'test.csv') + >>> a = ndtest('nat=BE,FO;sex=M,F') + >>> a + nat\\sex M F + BE 0 1 + FO 2 3 + >>> a.to_csv(fname) + >>> with open(fname) as f: + ... print(f.read().strip()) + nat\\sex,M,F + BE,0,1 + FO,2,3 + >>> read_csv(fname) + nat\\sex M F + BE 0 1 + FO 2 3 + + Sort columns + + >>> read_csv(fname, sort_columns=True) + nat\\sex F M + BE 1 0 + FO 3 2 + + Read array saved in "narrow" format (wide=False) + + >>> a.to_csv(fname, wide=False) + >>> with open(fname) as f: + ... print(f.read().strip()) + nat,sex,value + BE,M,0 + BE,F,1 + FO,M,2 + FO,F,3 + >>> read_csv(fname, wide=False) + nat\\sex M F + BE 0 1 + FO 2 3 + + Specify the number of axes of the output array (useful when the name of the last axis is implicit) + + >>> a.to_csv(fname, dialect='classic') + >>> with open(fname) as f: + ... print(f.read().strip()) + nat,M,F + BE,0,1 + FO,2,3 + >>> read_csv(fname, nb_axes=2) + nat\\{1} M F + BE 0 1 + FO 2 3 + """ + if not np.isnan(na): + fill_value = na + warnings.warn("read_csv `na` argument has been renamed to `fill_value`. Please use that instead.", + FutureWarning, stacklevel=2) + + if dialect == 'liam2': + # read axes names. This needs to be done separately instead of reading the whole file with Pandas then + # manipulating the dataframe because the header line must be ignored for the column types to be inferred + # correctly. Note that to read one line, this is faster than using Pandas reader. + with csv_open(filepath_or_buffer) as f: + reader = csv.reader(f, delimiter=sep) + line_stream = skip_comment_cells(strip_rows(reader)) + axes_names = next(line_stream) + + if nb_axes is not None or index_col is not None: + raise ValueError("nb_axes and index_col are not compatible with dialect='liam2'") + if len(axes_names) > 1: + nb_axes = len(axes_names) + # use the second data line for column headers (excludes comments and blank lines before counting) + kwargs['header'] = 1 + kwargs['comment'] = '#' + + index_col = _get_index_col(nb_axes, index_col, wide) + + if headersep is not None: + if index_col is None: + index_col = [0] + + df = pd.read_csv(filepath_or_buffer, index_col=index_col, sep=sep, **kwargs) + if dialect == 'liam2': + if len(df) == 1: + df.set_index([[np.nan]], inplace=True) + if len(axes_names) > 1: + df.index.names = axes_names[:-1] + df.columns.name = axes_names[-1] + raw = False + else: + raw = index_col is None + + if headersep is not None: + combined_axes_names = df.index.name + df.index = df.index.str.split(headersep, expand=True) + df.index.names = combined_axes_names.split(headersep) + raw = False + + return df_aslarray(df, sort_rows=sort_rows, sort_columns=sort_columns, fill_value=fill_value, raw=raw, wide=wide) + + +def read_tsv(filepath_or_buffer, **kwargs): + return read_csv(filepath_or_buffer, sep='\t', **kwargs) + + +def read_eurostat(filepath_or_buffer, **kwargs): + """Reads EUROSTAT TSV (tab-separated) file into an array. + + EUROSTAT TSV files are special because they use tabs as data separators but comas to separate headers. + + Parameters + ---------- + filepath_or_buffer : str or any file-like object + Path where the tsv file has to be read or a file handle. + kwargs + Arbitrary keyword arguments are passed through to read_csv. + + Returns + ------- + LArray + """ + return read_csv(filepath_or_buffer, sep='\t', headersep=',', **kwargs) + + +class PandasCSVHandler(FileHandler): + def __init__(self, fname, overwrite_file=False): + super(PandasCSVHandler, self).__init__(fname, overwrite_file) + if fname is None: + self.pattern = None + self.directory = None + elif '.csv' in fname or '*' in fname or '?' in fname: + self.pattern = fname + self.directory = os.path.dirname(fname) + else: + # assume fname is a directory. + # Not testing for os.path.isdir(fname) here because when writing, the directory might not exist. + self.pattern = os.path.join(fname, '*.csv') + self.directory = fname + + def _get_original_file_name(self): + pass + + def _open_for_read(self): + if self.directory and not os.path.isdir(self.directory): + raise ValueError("Directory '{}' does not exist".format(self.directory)) + + def _open_for_write(self): + if self.directory is not None: + try: + os.makedirs(self.directory) + except OSError: + if not os.path.isdir(self.directory): + raise ValueError("Path {} must represent a directory".format(self.directory)) + + def list(self): + fnames = glob(self.pattern) if self.pattern is not None else [] + # drop directory + fnames = [os.path.basename(fname) for fname in fnames] + # strip extension from files + # XXX: unsure we should use sorted here + return sorted([os.path.splitext(fname)[0] for fname in fnames]) + + def _to_filepath(self, key): + if self.directory is not None: + return os.path.join(self.directory, '{}.csv'.format(key)) + else: + return key + + def _read_item(self, key, *args, **kwargs): + return key, read_csv(self._to_filepath(key), *args, **kwargs) + + def _dump(self, key, value, *args, **kwargs): + value.to_csv(self._to_filepath(key), *args, **kwargs) + + def close(self): + pass \ No newline at end of file diff --git a/larray/inout/excel.py b/larray/inout/excel.py index 5d4df038c..8381075ec 100644 --- a/larray/inout/excel.py +++ b/larray/inout/excel.py @@ -1,714 +1,138 @@ -# -*- coding: utf8 -*- from __future__ import absolute_import, print_function -__all__ = ['open_excel', 'Workbook'] - - -import os -import atexit +import warnings import numpy as np +import pandas as pd try: import xlwings as xw except ImportError: xw = None from larray.core.group import _translate_sheet_name -from larray.core.axis import Axis -from larray.core.array import LArray, ndtest -from larray.inout.array import df_aslarray, from_lists -from larray.util.misc import PY2 - -string_types = (str,) - - -if xw is not None: - from xlwings.conversion.pandas_conv import PandasDataFrameConverter - - global_app = None - - def is_app_alive(app): - try: - app.books - return True - except Exception: - return False - - - def kill_global_app(): - global global_app - - if global_app is not None: - if is_app_alive(global_app): - try: - global_app.kill() - except Exception: - pass - del global_app - global_app = None - - - class LArrayConverter(PandasDataFrameConverter): - writes_types = LArray - - @classmethod - def read_value(cls, value, options): - df = PandasDataFrameConverter.read_value(value, options) - return df_aslarray(df) - - @classmethod - def write_value(cls, value, options): - df = value.to_frame(fold_last_axis_name=True) - return PandasDataFrameConverter.write_value(df, options) - - LArrayConverter.register(LArray) - - # TODO: replace overwrite_file by mode='r'|'w'|'a' the day xlwings will support a read-only mode - class Workbook(object): - def __init__(self, filepath=None, overwrite_file=False, visible=None, silent=None, app=None): - global global_app - - xw_wkb = None - self.delayed_filepath = None - self.filepath = None - self.new_workbook = False - self.active_workbook = filepath == -1 - - if filepath is None: - self.new_workbook = True - - if isinstance(filepath, str): - basename, ext = os.path.splitext(filepath) - if ext: - # XXX: we might want to be more precise than .xl* because I am unsure writing .xls - # (or anything other than .xlsx and .xlsm) would work - if not ext.startswith('.xl'): - raise ValueError("'%s' is not a supported file extension" % ext) - if not os.path.isfile(filepath) and not overwrite_file: - raise ValueError("File {} does not exist. Please give the path to an existing file or set " - "overwrite_file argument to True".format(filepath)) - if os.path.isfile(filepath) and overwrite_file: - self.filepath = filepath - # we create a temporary file to work on. In case of crash, the original is not destroyed. - # the temporary file is renamed as the original file at close. - filepath = basename + '~' + ext - if not os.path.isfile(filepath): - self.new_workbook = True - else: - # try to target an open but unsaved workbook. We cannot use the same code path as for other options - # because we do not know which Excel instance has that book - xw_wkb = xw.Book(filepath) - app = xw_wkb.app - - # active workbook use active app by default - if self.active_workbook and app not in {None, "active"}: - raise ValueError("to connect to the active workbook, one must use the 'active' Excel instance " - "(app='active' or app=None)") - - # unless explicitly set, app is set to visible for brand new or active book. - # For unsaved_book it is left intact. - if visible is None: - if filepath is None or self.active_workbook: - visible = True - elif xw_wkb is None: - # filepath is not None but we don't target an unsaved book - visible = False - - if app is None: - if self.active_workbook: - app = "active" - elif visible: - app = "new" - else: - app = "global" - - load_addins = False - if app == "new": - app = xw.App(visible=visible, add_book=False) - load_addins = True - elif app == "active": - app = xw.apps.active - elif app == "global": - if global_app is None: - atexit.register(kill_global_app) - if global_app is None or not is_app_alive(global_app): - global_app = xw.App(visible=visible, add_book=False) - load_addins = True - app = global_app - assert isinstance(app, xw.App) - - if visible: - app.visible = visible - - if silent is None: - silent = not visible - - # activate XLA(M) addins - # (for some reasons, add-ins are not activated when an Excel Workbook is opened from Python) - if load_addins: - for ia in range(1, app.api.Addins.Count + 1): - addin_path = app.api.Addins(ia).FullName - if not '.xll' in addin_path.lower(): - app.api.Workbooks.Open(addin_path) - - update_links_backup = app.api.AskToUpdateLinks - display_alerts_backup = app.display_alerts - if silent: - # try to update links silently instead of asking: "Update", "Don't Update", "Help" - app.api.AskToUpdateLinks = False - - # in case some links cannot be updated, continue instead of asking: "Continue" or "Edit Links..." - app.display_alerts = False - - if filepath is None: - # creates a new/blank Book - xw_wkb = app.books.add() - elif self.active_workbook: - xw_wkb = app.books.active - elif xw_wkb is None: - # file already exists (and is a file) - if os.path.isfile(filepath): - xw_wkb = app.books.open(filepath) - else: - # let us remember the path - self.delayed_filepath = filepath - xw_wkb = app.books.add() - - if silent: - app.api.AskToUpdateLinks = update_links_backup - app.display_alerts = display_alerts_backup - - self.xw_wkb = xw_wkb - - def __contains__(self, key): - if isinstance(key, int): - length = len(self) - return -length <= key < length - else: - # I would like to use: "return key in wb.sheets" but as of xlwings 0.10 wb.sheets.__contains__ does not - # work for sheet names (it works with Sheet objects I think) - return key in self.sheet_names() - - def _ipython_key_completions_(self): - return list(self.sheet_names()) - - def __getitem__(self, key): - key = _translate_sheet_name(key) - if key in self: - return Sheet(self, key) - else: - raise KeyError('Workbook has no sheet named {}'.format(key)) - - def __setitem__(self, key, value): - key = _translate_sheet_name(key) - if self.new_workbook: - self.xw_wkb.sheets[0].name = key - self.new_workbook = False - key_in_self = key in self - if isinstance(value, Sheet): - if value.xw_sheet.book.app != self.xw_wkb.app: - raise ValueError("cannot copy a sheet from one instance of Excel to another") - - # xlwings index is 1-based - # TODO: implement Workbook.index(key) - target_idx = self[key].xw_sheet.index - 1 if key_in_self else -1 - target_sheet = self[target_idx].xw_sheet - # add new sheet after target sheet. The new sheet will be named something like "value.name (1)" but I - # do not think there is anything we can do about this, except rename it afterwards because Copy has no - # name argument. See https://msdn.microsoft.com/en-us/library/office/ff837784.aspx - value.xw_sheet.api.Copy(None, target_sheet.api) - if key_in_self: - target_sheet.delete() - # rename the new sheet - self[target_idx].name = key - return - if key_in_self: - sheet = self[key] - sheet.clear() - else: - xw_sheet = self.xw_wkb.sheets.add(key, after=self[-1].xw_sheet) - sheet = Sheet(None, None, xw_sheet=xw_sheet) - sheet["A1"] = value - - def __delitem__(self, key): - self[key].delete() - - def sheet_names(self): - return [s.name for s in self] - - def save(self, path=None): - # saved_path = self.xw_wkb.api.Path - # was_saved = saved_path != '' - if path is None and self.delayed_filepath is not None: - path = self.delayed_filepath - self.xw_wkb.save(path=path) - - def close(self): - # Close the workbook in Excel. - # This will not quit the Excel instance, even if this was the last workbook of that Excel instance. - if self.filepath is not None and os.path.isfile(self.xw_wkb.fullname): - tmp_file = self.xw_wkb.fullname - self.xw_wkb.close() - os.remove(self.filepath) - os.rename(tmp_file, self.filepath) - else: - self.xw_wkb.close() - - def __iter__(self): - return iter([Sheet(None, None, xw_sheet) - for xw_sheet in self.xw_wkb.sheets]) - - def __len__(self): - return len(self.xw_wkb.sheets) - - def __dir__(self): - return list(set(dir(self.__class__)) | set(dir(self.xw_wkb))) - - def __getattr__(self, key): - return getattr(self.xw_wkb, key) - - def __enter__(self): - return self - - def __exit__(self, type_, value, traceback): - if not self.active_workbook: - self.close() - - def __repr__(self): - cls = self.__class__ - return '<{}.{} [{}]>'.format(cls.__module__, cls.__name__, self.name) - - - def _fill_slice(s, length): - """ - replaces a slice None bounds by actual bounds. - - Parameters - ---------- - s : slice - slice to replace - length : int - length of sequence - - Returns - ------- - slice - """ - return slice(s.start if s.start is not None else 0, s.stop if s.stop is not None else length, s.step) - - - def _concrete_key(key, obj, ndim=2): - """Expand key to ndim and replace None in slices start/stop bounds by 0 or obj.shape[corresponding_dim] - respectively. - - Parameters - ---------- - key : scalar, slice or tuple - input key - obj : object - any object with a 'shape' attribute. - ndim : int - number of dimensions to expand to. We could use len(obj.shape) instead but we avoid it to not trigger - obj.shape, which can be expensive in the case of a sheet with blank cells after the data. - """ - if not isinstance(key, tuple): - key = (key,) - - if len(key) < ndim: - key = key + (slice(None),) * (ndim - len(key)) - - # only compute shape if necessary because it can be expensive in some cases - if any(isinstance(k, slice) and k.stop is None for k in key): - shape = obj.shape - else: - shape = (None, None) - - # We use _fill_slice instead of slice(*k.indices(length)) because the later also clips bounds which exceed - # the length and we do NOT want to do that in this case (see issue #273). - return [_fill_slice(k, length) if isinstance(k, slice) else k - for k, length in zip(key, shape)] - - - class Sheet(object): - def __init__(self, workbook, key, xw_sheet=None): - if xw_sheet is None: - xw_sheet = workbook.xw_wkb.sheets[key] - object.__setattr__(self, 'xw_sheet', xw_sheet) - - # TODO: we can probably scrap this for xlwings 0.9+. We need to have - # a unit test for this though. - def __getitem__(self, key): - if isinstance(key, string_types): - return Range(self, key) - - row, col = _concrete_key(key, self) - if isinstance(row, slice) or isinstance(col, slice): - row1, row2 = (row.start, row.stop) if isinstance(row, slice) else (row, row + 1) - col1, col2 = (col.start, col.stop) if isinstance(col, slice) else (col, col + 1) - return Range(self, (row1 + 1, col1 + 1), (row2, col2)) - else: - return Range(self, (row + 1, col + 1)) - - def __setitem__(self, key, value): - if isinstance(value, LArray): - value = value.dump(header=False) - self[key].xw_range.value = value - - @property - def shape(self): - """ - shape of sheet including top-left empty rows/columns but excluding bottom-right ones. - """ - from xlwings.constants import Direction as xldir - - sheet = self.xw_sheet.api - used = sheet.UsedRange - first_row = used.Row - first_col = used.Column - last_row = first_row + used.Rows.Count - 1 - last_col = first_col + used.Columns.Count - 1 - last_cell = sheet.Cells(last_row, last_col) - - # fast path for sheets with a non blank bottom-right value - if last_cell.Value is not None: - return last_row, last_col - - last_row_used = last_cell.End(xldir.xlToLeft).Value is not None - last_col_used = last_cell.End(xldir.xlUp).Value is not None - - # fast path for sheets where last row and last col are not entirely blank - if last_row_used and last_col_used: - return last_row, last_col - else: - LEFT, UP = xldir.xlToLeft, xldir.xlUp - - def line_length(row, col, direction): - last_cell = sheet.Cells(row, col) - if last_cell.Value is not None: - return col if direction is LEFT else row - first_cell = last_cell.End(direction) - pos = first_cell.Column if direction is LEFT else first_cell.Row - return pos - 1 if first_cell.Value is None else pos - - if last_row < last_col: - if last_row_used or last_row == 1: - max_row = last_row - else: - for max_row in range(last_row - 1, first_row - 1, -1): - if line_length(max_row, last_col, LEFT) > 0: - break - if last_col_used or last_col == 1: - max_col = last_col - else: - max_col = max(line_length(row, last_col, LEFT) for row in range(first_row, max_row + 1)) - else: - if last_col_used or last_col == 1: - max_col = last_col - else: - for max_col in range(last_col - 1, first_col - 1, -1): - if line_length(last_row, max_col, UP) > 0: - break - if last_row_used or last_row == 1: - max_row = last_row - else: - max_row = max(line_length(last_row, col, UP) for col in range(first_col, max_col + 1)) - return max_row, max_col - - @property - def ndim(self): - return 2 - - def __array__(self, dtype=None): - return np.asarray(self[:], dtype=dtype) - - def __dir__(self): - return list(set(dir(self.__class__)) | set(dir(self.xw_sheet))) - - def __getattr__(self, key): - return getattr(self.xw_sheet, key) - - def __setattr__(self, key, value): - setattr(self.xw_sheet, key, value) - - def load(self, header=True, convert_float=True, nb_index=None, index_col=None, fill_value=np.nan, - sort_rows=False, sort_columns=False, wide=True): - return self[:].load(header=header, convert_float=convert_float, nb_index=nb_index, index_col=index_col, - fill_value=fill_value, sort_rows=sort_rows, sort_columns=sort_columns, wide=wide) - - # TODO: generalize to more than 2 dimensions or scrap it - def array(self, data, row_labels=None, column_labels=None, names=None): - """ - - Parameters - ---------- - data : str - range for data - row_labels : str, optional - range for row labels - column_labels : str, optional - range for column labels - names : list of str, optional - - Returns - ------- - LArray - """ - if row_labels is not None: - row_labels = np.asarray(self[row_labels]) - if column_labels is not None: - column_labels = np.asarray(self[column_labels]) - if names is not None: - labels = (row_labels, column_labels) - axes = [Axis(axis_labels, name) for axis_labels, name in zip(labels, names)] - else: - axes = (row_labels, column_labels) - # _converted_value is used implicitly via Range.__array__ - return LArray(np.asarray(self[data]), axes) - - def __repr__(self): - cls = self.__class__ - xw_sheet = self.xw_sheet - return '<{}.{} [{}]{}>'.format(cls.__module__, cls.__name__, xw_sheet.book.name, xw_sheet.name) - - - class Range(object): - def __init__(self, sheet, *args): - xw_range = sheet.xw_sheet.range(*args) - - object.__setattr__(self, 'sheet', sheet) - object.__setattr__(self, 'xw_range', xw_range) - - def _range_key_to_sheet_key(self, key): - # string keys does not make sense in this case - assert not isinstance(key, string_types) - row_offset = self.xw_range.row1 - 1 - col_offset = self.xw_range.col1 - 1 - row, col = _concrete_key(key, self.xw_range) - row = slice(row.start + row_offset, row.stop + row_offset) if isinstance(row, slice) else row + row_offset - col = slice(col.start + col_offset, col.stop + col_offset) if isinstance(col, slice) else col + col_offset - return row, col - - # TODO: we can probably scrap this for xlwings 0.9+. We need to have - # a unit test for this though. - def __getitem__(self, key): - return self.sheet[self._range_key_to_sheet_key(key)] - - def __setitem__(self, key, value): - self.sheet[self._range_key_to_sheet_key(key)] = value - - def _converted_value(self, convert_float=True): - list_data = self.xw_range.value - - # As of version 0.7.2 of xlwings, there is no built-in converter for - # this. The builtin .options(numbers=int) converter converts all - # values to int, whether that would loose information or not, but - # this is not what we want. - if convert_float: - # Excel 'numbers' are always floats - def convert(value): - if isinstance(value, float): - int_val = int(value) - if int_val == value: - return int_val - return value - elif isinstance(value, list): - return [convert(v) for v in value] - else: - return value - return convert(list_data) - return list_data - - def __float__(self): - # no need to use _converted_value because we will convert back to a float anyway - return float(self.xw_range.value) - - def __int__(self): - # no need to use _converted_value because we will convert to an int anyway - return int(self.xw_range.value) - - def __index__(self): - v = self._converted_value() - if hasattr(v, '__index__'): - return v.__index__() - else: - raise TypeError("only integer scalars can be converted to a scalar index") - - def __array__(self, dtype=None): - return np.array(self._converted_value(), dtype=dtype) - - def __larray__(self): - return LArray(self._converted_value()) - - def __dir__(self): - return list(set(dir(self.__class__)) | set(dir(self.xw_range))) - - def __getattr__(self, key): - if hasattr(LArray, key): - return getattr(self.__larray__(), key) - else: - return getattr(self.xw_range, key) - - def __setattr__(self, key, value): - setattr(self.xw_range, key, value) - - # TODO: implement all binops - # def __mul__(self, other): - # return self.__larray__() * other - - def __str__(self): - return str(self.__larray__()) - __repr__ = __str__ - - def load(self, header=True, convert_float=True, nb_index=None, index_col=None, fill_value=np.nan, - sort_rows=False, sort_columns=False, wide=True): - if not self.ndim: - return LArray([]) - - list_data = self._converted_value(convert_float=convert_float) - - if header: - return from_lists(list_data, nb_index=nb_index, index_col=index_col, fill_value=fill_value, - sort_rows=sort_rows, sort_columns=sort_columns, wide=wide) - else: - return LArray(list_data) - - # XXX: deprecate this function? - def open_excel(filepath=None, overwrite_file=False, visible=None, silent=None, app=None): - return Workbook(filepath, overwrite_file, visible, silent, app) -else: - class Workbook(object): - def __init__(self, filepath=None, overwrite_file=False, visible=None, silent=None, app=None): - raise Exception("Workbook class cannot be instanciated because xlwings is not installed") - - def sheet_names(self): - raise Exception() - - def save(self, path=None): - raise Exception() - - def close(self): - raise Exception() - - def open_excel(filepath=None, overwrite_file=False, visible=None, silent=None, app=None): - raise Exception("open_excel() is not available because xlwings is not installed") - - -# We define Workbook and open_excel documentation here since Readthedocs runs on Linux -if not PY2: - Workbook.__doc__ = """ -Excel Workbook. - -See Also --------- -open_excel -""" - - Workbook.sheet_names.__doc__ = """ -Returns the names of the Excel sheets. - -Examples --------- ->>> arr, arr2, arr3 = ndtest((3, 3)), ndtest((2, 2)), ndtest(4) ->>> with open_excel('excel_file.xlsx', overwrite_file=True) as wb: # doctest: +SKIP -... wb['arr'] = arr.dump() -... wb['arr2'] = arr2.dump() -... wb['arr3'] = arr3.dump() -... wb.save() -... -... wb.sheet_names() -['arr', 'arr2', 'arr3'] -""" - - Workbook.save.__doc__ = """ -Saves the Workbook. - -If a path is being provided, this works like SaveAs() in Excel. -If no path is specified and if the file hasn’t been saved previously, -it’s being saved in the current working directory with the current filename. -Existing files are overwritten without prompting. - -Parameters ----------- -path : str, optional - Full path to the workbook. Defaults to None. - -Examples --------- ->>> arr, arr2, arr3 = ndtest((3, 3)), ndtest((2, 2)), ndtest(4) ->>> with open_excel('excel_file.xlsx', overwrite_file=True) as wb: # doctest: +SKIP -... wb['arr'] = arr.dump() -... wb['arr2'] = arr2.dump() -... wb['arr3'] = arr3.dump() -... wb.save() -""" - - Workbook.close.__doc__ = """ -Close the workbook in Excel. - -Need to be called if the workbook has been opened without the `with` statement. - -Examples --------- ->>> arr, arr2, arr3 = ndtest((3, 3)), ndtest((2, 2)), ndtest(4) # doctest: +SKIP ->>> wb = open_excel('excel_file.xlsx', overwrite_file=True) # doctest: +SKIP ->>> wb['arr'] = arr.dump() # doctest: +SKIP ->>> wb['arr2'] = arr2.dump() # doctest: +SKIP ->>> wb['arr3'] = arr3.dump() # doctest: +SKIP ->>> wb.save() # doctest: +SKIP ->>> wb.close() # doctest: +SKIP -""" - -open_excel.__doc__ = """ -Open an Excel workbook - -Parameters ----------- -filepath : None, int or str, optional - path to the Excel file. The file must exist if overwrite_file is False. Use None for a new blank workbook, - -1 for the currently active workbook. Defaults to None. -overwrite_file : bool, optional - whether or not to overwrite an existing file, if any. Defaults to False. -visible : None or bool, optional - whether or not Excel should be visible. Defaults to False for files, True for new/active workbooks and to None - ("unchanged") for existing unsaved workbooks. -silent : None or bool, optional - whether or not to show dialog boxes for updating links or when some links cannot be updated. - Defaults to False if visible, True otherwise. -app : None, "new", "active", "global" or xlwings.App, optional - use "new" for opening a new Excel instance, "active" for the last active instance (including ones opened by the - user) and "global" to (re)use the same instance for all workbooks of a program. None is equivalent to "active" if - filepath is -1, "new" if visible is True and "global" otherwise. Defaults to None. - - The "global" instance is a specific Excel instance for all input from/output to Excel from within a single Python - program (and should not interact with instances manually opened by the user or another program). - -Returns -------- -Excel workbook. - -Examples --------- ->>> arr = ndtest((3, 3)) ->>> arr -a\\b b0 b1 b2 - a0 0 1 2 - a1 3 4 5 - a2 6 7 8 - -create a new Excel file and save an array - ->>> # to create a new Excel file, argument overwrite_file must be set to True ->>> with open_excel('excel_file.xlsx', overwrite_file=True) as wb: # doctest: +SKIP -... wb['arr'] = arr.dump() -... wb.save() - -read array from an Excel file - ->>> with open_excel('excel_file.xlsx') as wb: # doctest: +SKIP -... arr2 = wb['arr'].load() ->>> arr2 # doctest: +SKIP -a\\b b0 b1 b2 - a0 0 1 2 - a1 3 4 5 - a2 6 7 8 -""" +from larray.util.misc import deprecate_kwarg +from larray.inout.common import _get_index_col, FileHandler +from larray.inout.pandas import df_aslarray +from larray.inout.xw_excel import open_excel + +__all__ = ['read_excel'] + + +# TODO : add examples +@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) +@deprecate_kwarg('sheetname', 'sheet') +def read_excel(filepath, sheet=0, nb_axes=None, index_col=None, fill_value=np.nan, na=np.nan, + sort_rows=False, sort_columns=False, wide=True, engine=None, **kwargs): + """ + Reads excel file from sheet name and returns an LArray with the contents + + Parameters + ---------- + filepath : str + Path where the Excel file has to be read or use -1 to refer to the currently active workbook. + sheet : str, Group or int, optional + Name or index of the Excel sheet containing the array to be read. + By default the array is read from the first sheet. + nb_axes : int, optional + Number of axes of output array. The first `nb_axes` - 1 columns and the header of the Excel sheet will be used + to set the axes of the output array. If not specified, the number of axes is given by the position of the + column header including the character `\` plus one. If no column header includes the character `\`, the array + is assumed to have one axis. Defaults to None. + index_col : list, optional + Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). + fill_value : scalar or LArray, optional + Value used to fill cells corresponding to label combinations which are not present in the input. + Defaults to NaN. + sort_rows : bool, optional + Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. + sort_columns : bool, optional + Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). + Defaults to False. + wide : bool, optional + Whether or not to assume the array is stored in "wide" format. + If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. + Defaults to True. + engine : {'xlrd', 'xlwings'}, optional + Engine to use to read the Excel file. If None (default), it will use 'xlwings' by default if the module is + installed and relies on Pandas default reader otherwise. + **kwargs + """ + if not np.isnan(na): + fill_value = na + warnings.warn("read_excel `na` argument has been renamed to `fill_value`. Please use that instead.", + FutureWarning, stacklevel=2) + + sheet = _translate_sheet_name(sheet) + + if engine is None: + engine = 'xlwings' if xw is not None else None + + index_col = _get_index_col(nb_axes, index_col, wide) + + if engine == 'xlwings': + if kwargs: + raise TypeError("'{}' is an invalid keyword argument for this function when using the xlwings backend" + .format(list(kwargs.keys())[0])) + from larray.inout.xw_excel import open_excel + with open_excel(filepath) as wb: + return wb[sheet].load(index_col=index_col, fill_value=fill_value, sort_rows=sort_rows, + sort_columns=sort_columns, wide=wide) + else: + df = pd.read_excel(filepath, sheet, index_col=index_col, engine=engine, **kwargs) + return df_aslarray(df, sort_rows=sort_rows, sort_columns=sort_columns, raw=index_col is None, + fill_value=fill_value, wide=wide) + + +class PandasExcelHandler(FileHandler): + """ + Handler for Excel files using Pandas. + """ + def _open_for_read(self): + self.handle = pd.ExcelFile(self.fname) + + def _open_for_write(self): + self.handle = pd.ExcelWriter(self.fname) + + def list(self): + return self.handle.sheet_names + + def _read_item(self, key, *args, **kwargs): + df = self.handle.parse(key, *args, **kwargs) + return key, df_aslarray(df, raw=True) + + def _dump(self, key, value, *args, **kwargs): + kwargs['engine'] = 'xlsxwriter' + value.to_excel(self.handle, key, *args, **kwargs) + + def close(self): + self.handle.close() + + +class XLWingsHandler(FileHandler): + """ + Handler for Excel files using XLWings. + """ + def _get_original_file_name(self): + # for XLWingsHandler, no need to create a temporary file, the job is already done in the Workbook class + pass + + def _open_for_read(self): + self.handle = open_excel(self.fname) + + def _open_for_write(self): + self.handle = open_excel(self.fname, overwrite_file=self.overwrite_file) + + def list(self): + return self.handle.sheet_names() + + def _read_item(self, key, *args, **kwargs): + return key, self.handle[key].load(*args, **kwargs) + + def _dump(self, key, value, *args, **kwargs): + self.handle[key] = value.dump(*args, **kwargs) + + def save(self): + self.handle.save() + + def close(self): + self.handle.close() \ No newline at end of file diff --git a/larray/inout/hdf.py b/larray/inout/hdf.py new file mode 100644 index 000000000..8ed3ceb81 --- /dev/null +++ b/larray/inout/hdf.py @@ -0,0 +1,114 @@ +from __future__ import absolute_import, print_function + +import warnings + +import numpy as np +from pandas import HDFStore + +from larray.core.axis import Axis +from larray.core.group import Group, LGroup, _translate_group_key_hdf +from larray.core.array import LArray +from larray.util.misc import LHDFStore +from larray.inout.pandas import df_aslarray +from larray.inout.common import FileHandler + + +__all__ = ['read_hdf'] + + +# TODO : add examples +def read_hdf(filepath_or_buffer, key, fill_value=np.nan, na=np.nan, sort_rows=False, sort_columns=False, + name=None, **kwargs): + """Reads an array named key from a HDF5 file in filepath (path+name) + + Parameters + ---------- + filepath_or_buffer : str or pandas.HDFStore + Path and name where the HDF5 file is stored or a HDFStore object. + key : str or Group + Name of the array. + fill_value : scalar or LArray, optional + Value used to fill cells corresponding to label combinations which are not present in the input. + Defaults to NaN. + sort_rows : bool, optional + Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. + sort_columns : bool, optional + Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). + Defaults to False. + name : str, optional + Name of the axis or group to return. If None, name is set to passed key. + Defaults to None. + + Returns + ------- + LArray + """ + if not np.isnan(na): + fill_value = na + warnings.warn("read_hdf `na` argument has been renamed to `fill_value`. Please use that instead.", + FutureWarning, stacklevel=2) + + key = _translate_group_key_hdf(key) + res = None + with LHDFStore(filepath_or_buffer) as store: + pd_obj = store.get(key) + attrs = store.get_storer(key).attrs + # for backward compatibility but any object read from an hdf file should have an attribute 'type' + _type = attrs.type if 'type' in dir(attrs) else 'Array' + if _type == 'Array': + res = df_aslarray(pd_obj, sort_rows=sort_rows, sort_columns=sort_columns, fill_value=fill_value, parse_header=False) + elif _type == 'Axis': + if name is None: + name = str(pd_obj.name) + if name == 'None': + name = None + res = Axis(labels=pd_obj.values, name=name) + res._iswildcard = attrs['wildcard'] + elif _type == 'Group': + if name is None: + name = str(pd_obj.name) + if name == 'None': + name = None + axis = read_hdf(filepath_or_buffer, attrs['axis_key']) + res = LGroup(key=pd_obj.values, name=name, axis=axis) + return res + + +class PandasHDFHandler(FileHandler): + """ + Handler for HDF5 files using Pandas. + """ + def _open_for_read(self): + self.handle = HDFStore(self.fname, mode='r') + + def _open_for_write(self): + self.handle = HDFStore(self.fname) + + def list(self): + return [key.strip('/') for key in self.handle.keys()] + + def _read_item(self, key, *args, **kwargs): + if '__axes__' in key: + session_key = key.split('/')[-1] + kwargs['name'] = session_key + elif '__groups__' in key: + session_key = key.split('/')[-1] + kwargs['name'] = session_key + else: + session_key = key + key = '/' + key + return session_key, read_hdf(self.handle, key, *args, **kwargs) + + def _dump(self, key, value, *args, **kwargs): + if isinstance(value, Axis): + key = '__axes__/' + key + elif isinstance(value, Group): + key = '__groups__/' + key + # axis_key (see Group.to_hdf) + args = ('__axes__/' + value.axis.name,) + args + else: + key = '/' + key + value.to_hdf(self.handle, key, *args, **kwargs) + + def close(self): + self.handle.close() \ No newline at end of file diff --git a/larray/inout/misc.py b/larray/inout/misc.py new file mode 100644 index 000000000..6c725fb52 --- /dev/null +++ b/larray/inout/misc.py @@ -0,0 +1,198 @@ +from __future__ import absolute_import, print_function + +import numpy as np +from pandas import DataFrame + +from larray.util.misc import StringIO, deprecate_kwarg +from larray.inout.common import _get_index_col +from larray.inout.pandas import df_aslarray +from larray.inout.csv import read_csv + + +__all__ = ['from_lists', 'from_string'] + + +@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) +def from_lists(data, nb_axes=None, index_col=None, fill_value=np.nan, sort_rows=False, sort_columns=False, wide=True): + """ + initialize array from a list of lists (lines) + + Parameters + ---------- + data : sequence (tuple, list, ...) + Input data. All data is supposed to already have the correct type (e.g. strings are not parsed). + nb_axes : int, optional + Number of axes of output array. The first `nb_axes` - 1 columns and the header will be used + to set the axes of the output array. If not specified, the number of axes is given by the position of the + column header including the character `\` plus one. If no column header includes the character `\`, the array + is assumed to have one axis. Defaults to None. + index_col : list, optional + Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). + fill_value : scalar or LArray, optional + Value used to fill cells corresponding to label combinations which are not present in the input. + Defaults to NaN. + sort_rows : bool, optional + Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. + sort_columns : bool, optional + Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). + Defaults to False. + wide : bool, optional + Whether or not to assume the array is stored in "wide" format. + If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. + Defaults to True. + + Returns + ------- + LArray + + Examples + -------- + >>> from_lists([['sex', 'M', 'F'], + ... ['', 0, 1]]) + sex M F + 0 1 + >>> from_lists([['sex\\\\year', 1991, 1992, 1993], + ... [ 'M', 0, 1, 2], + ... [ 'F', 3, 4, 5]]) + sex\\year 1991 1992 1993 + M 0 1 2 + F 3 4 5 + + Read array with missing values + `fill_value` argument + + >>> from_lists([['sex', 'nat\\\\year', 1991, 1992, 1993], + ... [ 'M', 'BE', 1, 0, 0], + ... [ 'M', 'FO', 2, 0, 0], + ... [ 'F', 'BE', 0, 0, 1]]) + sex nat\\year 1991 1992 1993 + M BE 1.0 0.0 0.0 + M FO 2.0 0.0 0.0 + F BE 0.0 0.0 1.0 + F FO nan nan nan + + >>> from_lists([['sex', 'nat\\\\year', 1991, 1992, 1993], + ... [ 'M', 'BE', 1, 0, 0], + ... [ 'M', 'FO', 2, 0, 0], + ... [ 'F', 'BE', 0, 0, 1]], fill_value=42) + sex nat\\year 1991 1992 1993 + M BE 1 0 0 + M FO 2 0 0 + F BE 0 0 1 + F FO 42 42 42 + + Specify the number of axes of the array to be read + + >>> from_lists([['sex', 'nat', 1991, 1992, 1993], + ... [ 'M', 'BE', 1, 0, 0], + ... [ 'M', 'FO', 2, 0, 0], + ... [ 'F', 'BE', 0, 0, 1]], nb_axes=3) + sex nat\\{2} 1991 1992 1993 + M BE 1.0 0.0 0.0 + M FO 2.0 0.0 0.0 + F BE 0.0 0.0 1.0 + F FO nan nan nan + + Read array saved in "narrow" format (wide=False) + + >>> from_lists([['sex', 'nat', 'year', 'value'], + ... [ 'M', 'BE', 1991, 1 ], + ... [ 'M', 'BE', 1992, 0 ], + ... [ 'M', 'BE', 1993, 0 ], + ... [ 'M', 'FO', 1991, 2 ], + ... [ 'M', 'FO', 1992, 0 ], + ... [ 'M', 'FO', 1993, 0 ], + ... [ 'F', 'BE', 1991, 0 ], + ... [ 'F', 'BE', 1992, 0 ], + ... [ 'F', 'BE', 1993, 1 ]], wide=False) + sex nat\\year 1991 1992 1993 + M BE 1.0 0.0 0.0 + M FO 2.0 0.0 0.0 + F BE 0.0 0.0 1.0 + F FO nan nan nan + """ + index_col = _get_index_col(nb_axes, index_col, wide) + + df = DataFrame(data[1:], columns=data[0]) + if index_col is not None: + df.set_index([df.columns[c] for c in index_col], inplace=True) + + return df_aslarray(df, raw=index_col is None, parse_header=False, sort_rows=sort_rows, sort_columns=sort_columns, + fill_value=fill_value, wide=wide) + + +@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) +def from_string(s, nb_axes=None, index_col=None, sep=' ', wide=True, **kwargs): + """Create an array from a multi-line string. + + Parameters + ---------- + s : str + input string. + nb_axes : int, optional + Number of axes of output array. The first `nb_axes` - 1 columns and the header will be used + to set the axes of the output array. If not specified, the number of axes is given by the position of the + column header including the character `\` plus one. If no column header includes the character `\`, the array + is assumed to have one axis. Defaults to None. + index_col : list, optional + Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]). Defaults to None (see nb_axes above). + sep : str + delimiter used to split each line into cells. + wide : bool, optional + Whether or not to assume the array is stored in "wide" format. + If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. + Defaults to True. + \**kwargs + See arguments of Pandas read_csv function. + + Returns + ------- + LArray + + Examples + -------- + >>> # to create a 1D array using the default separator ' ', a tabulation character \t must be added in front + >>> # of the data line + >>> from_string("sex M F\\n\\t 0 1") + sex M F + 0 1 + >>> from_string("nat\\\\sex M F\\nBE 0 1\\nFO 2 3") + nat\sex M F + BE 0 1 + FO 2 3 + >>> from_string("period a b\\n2010 0 1\\n2011 2 3") + period\{1} a b + 2010 0 1 + 2011 2 3 + + Each label is stripped of leading and trailing whitespace, so this is valid too: + + >>> from_string('''nat\\\\sex M F + ... BE 0 1 + ... FO 2 3''') + nat\sex M F + BE 0 1 + FO 2 3 + >>> from_string('''age nat\\\\sex M F + ... 0 BE 0 1 + ... 0 FO 2 3 + ... 1 BE 4 5 + ... 1 FO 6 7''') + age nat\sex M F + 0 BE 0 1 + 0 FO 2 3 + 1 BE 4 5 + 1 FO 6 7 + + Empty lines at the beginning or end are ignored, so one can also format the string like this: + + >>> from_string(''' + ... nat\\\\sex M F + ... BE 0 1 + ... FO 2 3 + ... ''') + nat\sex M F + BE 0 1 + FO 2 3 + """ + return read_csv(StringIO(s), nb_axes=nb_axes, index_col=index_col, sep=sep, skipinitialspace=True, + wide=wide, **kwargs) \ No newline at end of file diff --git a/larray/inout/pandas.py b/larray/inout/pandas.py new file mode 100644 index 000000000..c9a8aa434 --- /dev/null +++ b/larray/inout/pandas.py @@ -0,0 +1,263 @@ +from __future__ import absolute_import, print_function + +from itertools import product + +import numpy as np +import pandas as pd + +from larray.core.axis import Axis +from larray.core.array import LArray +from larray.util.misc import basestring, decode, unique + + +__all__ = ['from_frame', 'from_series'] + + +def parse(s): + """ + Used to parse the "folded" axis ticks (usually periods). + """ + # parameters can be strings or numbers + if isinstance(s, basestring): + s = s.strip() + low = s.lower() + if low == 'true': + return True + elif low == 'false': + return False + elif s.isdigit(): + return int(s) + else: + try: + return float(s) + except ValueError: + return s + else: + return s + + +def df_labels(df, sort=True): + """ + Returns unique labels for each dimension. + """ + idx = df.index + if isinstance(idx, pd.core.index.MultiIndex): + if sort: + return list(idx.levels) + else: + return [list(unique(idx.get_level_values(l))) for l in idx.names] + else: + assert isinstance(idx, pd.core.index.Index) + # use .values if needed + return [idx] + + +def cartesian_product_df(df, sort_rows=False, sort_columns=False, **kwargs): + labels = df_labels(df, sort=sort_rows) + if sort_rows: + new_index = pd.MultiIndex.from_product(labels) + else: + new_index = pd.MultiIndex.from_tuples(list(product(*labels))) + columns = sorted(df.columns) if sort_columns else list(df.columns) + # the prodlen test is meant to avoid the more expensive array_equal test + prodlen = np.prod([len(axis_labels) for axis_labels in labels]) + if prodlen == len(df) and columns == list(df.columns) and np.array_equal(df.index.values, new_index.values): + return df, labels + return df.reindex(new_index, columns, **kwargs), labels + + +def from_series(s, sort_rows=False): + """ + Converts Pandas Series into 1D LArray. + + Parameters + ---------- + s : Pandas Series + Input Pandas Series. + sort_rows : bool, optional + Whether or not to sort the rows alphabetically. Defaults to False. + + Returns + ------- + LArray + """ + name = s.name if s.name is not None else s.index.name + if name is not None: + name = str(name) + if sort_rows: + s = s.sort_index() + return LArray(s.values, Axis(s.index.values, name)) + + +def from_frame(df, sort_rows=False, sort_columns=False, parse_header=False, unfold_last_axis_name=False, **kwargs): + """ + Converts Pandas DataFrame into LArray. + + Parameters + ---------- + df : pandas.DataFrame + Input dataframe. By default, name and labels of the last axis are defined by the name and labels of the + columns Index of the dataframe unless argument unfold_last_axis_name is set to True. + sort_rows : bool, optional + Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. + sort_columns : bool, optional + Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). + Defaults to False. + parse_header : bool, optional + Whether or not to parse columns labels. Pandas treats column labels as strings. + If True, column labels are converted into int, float or boolean when possible. Defaults to False. + unfold_last_axis_name : bool, optional + Whether or not to extract the names of the last two axes by splitting the name of the last index column of the + dataframe using ``\\``. Defaults to False. + + Returns + ------- + LArray + + See Also + -------- + LArray.to_frame + + Examples + -------- + >>> from larray import ndtest + >>> df = ndtest((2, 2, 2)).to_frame() + >>> df # doctest: +NORMALIZE_WHITESPACE + c c0 c1 + a b + a0 b0 0 1 + b1 2 3 + a1 b0 4 5 + b1 6 7 + >>> from_frame(df) + a b\\c c0 c1 + a0 b0 0 1 + a0 b1 2 3 + a1 b0 4 5 + a1 b1 6 7 + + Names of the last two axes written as ``before_last_axis_name\\last_axis_name`` + + >>> df = ndtest((2, 2, 2)).to_frame(fold_last_axis_name=True) + >>> df # doctest: +NORMALIZE_WHITESPACE + c0 c1 + a b\\c + a0 b0 0 1 + b1 2 3 + a1 b0 4 5 + b1 6 7 + >>> from_frame(df, unfold_last_axis_name=True) + a b\\c c0 c1 + a0 b0 0 1 + a0 b1 2 3 + a1 b0 4 5 + a1 b1 6 7 + """ + axes_names = [decode(name, 'utf8') for name in df.index.names] + + # handle 2 or more dimensions with the last axis name given using \ + if unfold_last_axis_name: + if isinstance(axes_names[-1], basestring) and '\\' in axes_names[-1]: + last_axes = [name.strip() for name in axes_names[-1].split('\\')] + axes_names = axes_names[:-1] + last_axes + else: + axes_names += [None] + else: + axes_names += [df.columns.name] + + df, axes_labels = cartesian_product_df(df, sort_rows=sort_rows, sort_columns=sort_columns, **kwargs) + + # Pandas treats column labels as column names (strings) so we need to convert them to values + last_axis_labels = [parse(cell) for cell in df.columns.values] if parse_header else list(df.columns.values) + axes_labels.append(last_axis_labels) + axes_names = [str(name) if name is not None else name + for name in axes_names] + + axes = [Axis(labels, name) for labels, name in zip(axes_labels, axes_names)] + data = df.values.reshape([len(axis) for axis in axes]) + return LArray(data, axes) + + +def df_aslarray(df, sort_rows=False, sort_columns=False, raw=False, parse_header=True, wide=True, **kwargs): + """ + Prepare Pandas DataFrame and then convert it into LArray. + + Parameters + ---------- + df : Pandas DataFrame + Input dataframe. + sort_rows : bool, optional + Whether or not to sort the rows alphabetically (sorting is more efficient than not sorting). Defaults to False. + sort_columns : bool, optional + Whether or not to sort the columns alphabetically (sorting is more efficient than not sorting). + Defaults to False. + raw : bool, optional + Whether or not to consider the input dataframe as a raw dataframe, i.e. read without index at all. + If True, build the first N-1 axes of the output array from the first N-1 dataframe columns. Defaults to False. + parse_header : bool, optional + Whether or not to parse columns labels. Pandas treats column labels as strings. + If True, column labels are converted into int, float or boolean when possible. Defaults to True. + wide : bool, optional + Whether or not to assume the array is stored in "wide" format. + If False, the array is assumed to be stored in "narrow" format: one column per axis plus one value column. + Defaults to True. + + Returns + ------- + LArray + """ + # we could inline df_aslarray into the functions that use it, so that the original (non-cartesian) df is freed from + # memory at this point, but it would be much uglier and would not lower the peak memory usage which happens during + # cartesian_product_df.reindex + + # raw = True: the dataframe was read without index at all (ie 2D dataframe), + # irrespective of the actual data dimensionality + if raw: + columns = df.columns.values.tolist() + if wide: + try: + # take the first column which contains '\' + pos_last = next(i for i, v in enumerate(columns) if isinstance(v, basestring) and '\\' in v) + except StopIteration: + # we assume first column will not contain data + pos_last = 0 + + # This is required to handle int column names (otherwise we can simply use column positions in set_index). + # This is NOT the same as df.columns[list(range(...))] ! + index_columns = [df.columns[i] for i in range(pos_last + 1)] + df.set_index(index_columns, inplace=True) + else: + index_columns = [df.columns[i] for i in range(len(df.columns) - 1)] + df.set_index(index_columns, inplace=True) + series = df[df.columns[-1]] + if isinstance(series.index, pd.core.index.MultiIndex): + fill_value = kwargs.get('fill_value', np.nan) + # TODO: use argument sort=False when it will be available + # (see https://github.com/pandas-dev/pandas/issues/15105) + df = series.unstack(level=-1, fill_value=fill_value) + # pandas (un)stack and pivot(_table) methods return a Dataframe/Series with sorted index and columns + labels = df_labels(series, sort=False) + index = pd.MultiIndex.from_tuples(list(product(*labels[:-1])), names=series.index.names[:-1]) + columns = labels[-1] + df = df.reindex(index=index, columns=columns, fill_value=fill_value) + else: + series.name = series.index.name + if sort_rows: + raise ValueError('sort_rows=True is not valid for 1D arrays. Please use sort_columns instead.') + return from_series(series, sort_rows=sort_columns) + + # handle 1D + if len(df) == 1 and (pd.isnull(df.index.values[0]) or + (isinstance(df.index.values[0], basestring) and df.index.values[0].strip() == '')): + if parse_header: + df.columns = pd.Index([parse(cell) for cell in df.columns.values], name=df.columns.name) + series = df.iloc[0] + series.name = df.index.name + if sort_rows: + raise ValueError('sort_rows=True is not valid for 1D arrays. Please use sort_columns instead.') + return from_series(series, sort_rows=sort_columns) + else: + axes_names = [decode(name, 'utf8') for name in df.index.names] + unfold_last_axis_name = isinstance(axes_names[-1], basestring) and '\\' in axes_names[-1] + return from_frame(df, sort_rows=sort_rows, sort_columns=sort_columns, parse_header=parse_header, + unfold_last_axis_name=unfold_last_axis_name, **kwargs) diff --git a/larray/inout/pickle.py b/larray/inout/pickle.py new file mode 100644 index 000000000..200e3fa26 --- /dev/null +++ b/larray/inout/pickle.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import, division, print_function + +from collections import OrderedDict + +from larray.util.misc import pickle +from larray.inout.common import FileHandler + + +class PickleHandler(FileHandler): + def _open_for_read(self): + with open(self.fname, 'rb') as f: + self.data = pickle.load(f) + + def _open_for_write(self): + self.data = OrderedDict() + + def list(self): + return self.data.keys() + + def _read_item(self, key): + return key, self.data[key] + + def _dump(self, key, value): + self.data[key] = value + + def close(self): + with open(self.fname, 'wb') as f: + pickle.dump(self.data, f) \ No newline at end of file diff --git a/larray/inout/sas.py b/larray/inout/sas.py new file mode 100644 index 000000000..308e7a5db --- /dev/null +++ b/larray/inout/sas.py @@ -0,0 +1,37 @@ +from __future__ import absolute_import, print_function + +import warnings + +import numpy as np +import pandas as pd + +from larray.util.misc import deprecate_kwarg +from larray.inout.pandas import df_aslarray + + +__all__ = ['read_sas'] + + +@deprecate_kwarg('nb_index', 'nb_axes', arg_converter=lambda x: x + 1) +def read_sas(filepath, nb_axes=None, index_col=None, fill_value=np.nan, na=np.nan, sort_rows=False, sort_columns=False, + **kwargs): + """ + Reads sas file and returns an LArray with the contents + nb_axes: number of axes of the output array + or + index_col: Positions of columns for the n-1 first axes (ex. [0, 1, 2, 3]) + """ + if not np.isnan(na): + fill_value = na + warnings.warn("read_sas `na` argument has been renamed to `fill_value`. Please use that instead.", + FutureWarning, stacklevel=2) + + if nb_axes is not None and index_col is not None: + raise ValueError("cannot specify both nb_axes and index_col") + elif nb_axes is not None: + index_col = list(range(nb_axes - 1)) + elif isinstance(index_col, int): + index_col = [index_col] + + df = pd.read_sas(filepath, index=index_col, **kwargs) + return df_aslarray(df, sort_rows=sort_rows, sort_columns=sort_columns, fill_value=fill_value) diff --git a/larray/inout/session.py b/larray/inout/session.py index d799af343..f3fee91ee 100644 --- a/larray/inout/session.py +++ b/larray/inout/session.py @@ -1,312 +1,15 @@ from __future__ import absolute_import, division, print_function -import os -from glob import glob -from collections import OrderedDict -from pandas import ExcelWriter, ExcelFile, HDFStore - -from larray.core.abstractbases import ABCLArray -from larray.util.misc import pickle -from larray.inout.excel import open_excel -from larray.inout.array import df_aslarray, read_csv, read_hdf - -try: - import xlwings as xw -except ImportError: - xw = None +from larray.inout.csv import PandasCSVHandler +from larray.inout.excel import PandasExcelHandler, XLWingsHandler +from larray.inout.hdf import PandasHDFHandler +from larray.inout.pickle import PickleHandler def check_pattern(k, pattern): return k.startswith(pattern) -class FileHandler(object): - """ - Abstract class defining the methods for "file handler" subclasses. - - Parameters - ---------- - fname : str - Filename. - - Attributes - ---------- - fname : str - Filename. - """ - def __init__(self, fname, overwrite_file=False): - self.fname = fname - self.original_file_name = None - self.overwrite_file = overwrite_file - - def _open_for_read(self): - raise NotImplementedError() - - def _open_for_write(self): - raise NotImplementedError() - - def list(self): - """ - Returns the list of arrays' names. - """ - raise NotImplementedError() - - def _read_array(self, key, *args, **kwargs): - raise NotImplementedError() - - def _dump(self, key, value, *args, **kwargs): - raise NotImplementedError() - - def save(self): - """ - Saves arrays in file. - """ - pass - - def close(self): - """ - Closes file. - """ - raise NotImplementedError() - - def _get_original_file_name(self): - if self.overwrite_file and os.path.isfile(self.fname): - self.original_file_name = self.fname - self.fname = '{}~{}'.format(*os.path.splitext(self.fname)) - - def _update_original_file(self): - if self.original_file_name is not None and os.path.isfile(self.fname): - os.remove(self.original_file_name) - os.rename(self.fname, self.original_file_name) - - def read_arrays(self, keys, *args, **kwargs): - """ - Reads file content (HDF, Excel, CSV, ...) and returns a dictionary containing loaded arrays. - - Parameters - ---------- - keys : list of str - List of arrays' names. - *args : any - Any other argument is passed through to the underlying read function. - display : bool, optional - Whether or not the function should display a message when starting and ending to load each array. - Defaults to False. - ignore_exceptions : bool, optional - Whether or not an exception should stop the function or be ignored. Defaults to False. - **kwargs : any - Any other keyword argument is passed through to the underlying read function. - - Returns - ------- - OrderedDict(str, LArray) - Dictionary containing the loaded arrays. - """ - display = kwargs.pop('display', False) - ignore_exceptions = kwargs.pop('ignore_exceptions', False) - self._open_for_read() - res = OrderedDict() - if keys is None: - keys = self.list() - for key in keys: - if display: - print("loading", key, "...", end=' ') - try: - res[key] = self._read_array(key, *args, **kwargs) - except Exception: - if not ignore_exceptions: - raise - if display: - print("done") - self.close() - return res - - def dump_arrays(self, key_values, *args, **kwargs): - """ - Dumps arrays corresponds to keys in file in HDF, Excel, CSV, ... format - - Parameters - ---------- - key_values : list of (str, LArray) pairs - Name and data of arrays to dump. - kwargs : - * display: whether or not to display when the dump of each array is started/done. - """ - display = kwargs.pop('display', False) - self._get_original_file_name() - self._open_for_write() - for key, value in key_values: - if isinstance(value, ABCLArray) and value.ndim == 0: - if display: - print('Cannot dump {}. Dumping 0D arrays is currently not supported.'.format(key)) - continue - if display: - print("dumping", key, "...", end=' ') - self._dump(key, value, *args, **kwargs) - if display: - print("done") - self.save() - self.close() - self._update_original_file() - - -class PandasHDFHandler(FileHandler): - """ - Handler for HDF5 files using Pandas. - """ - def _open_for_read(self): - self.handle = HDFStore(self.fname, mode='r') - - def _open_for_write(self): - self.handle = HDFStore(self.fname) - - def list(self): - return [key.strip('/') for key in self.handle.keys()] - - def _to_hdf_key(self, key): - return '/' + key - - def _read_array(self, key, *args, **kwargs): - return read_hdf(self.handle, self._to_hdf_key(key), *args, **kwargs) - - def _dump(self, key, value, *args, **kwargs): - value.to_hdf(self.handle, self._to_hdf_key(key), *args, **kwargs) - - def close(self): - self.handle.close() - - -class PandasExcelHandler(FileHandler): - """ - Handler for Excel files using Pandas. - """ - def _open_for_read(self): - self.handle = ExcelFile(self.fname) - - def _open_for_write(self): - self.handle = ExcelWriter(self.fname) - - def list(self): - return self.handle.sheet_names - - def _read_array(self, key, *args, **kwargs): - df = self.handle.parse(key, *args, **kwargs) - return df_aslarray(df, raw=True) - - def _dump(self, key, value, *args, **kwargs): - kwargs['engine'] = 'xlsxwriter' - value.to_excel(self.handle, key, *args, **kwargs) - - def close(self): - self.handle.close() - - -class XLWingsHandler(FileHandler): - """ - Handler for Excel files using XLWings. - """ - def _get_original_file_name(self): - # for XLWingsHandler, no need to create a temporary file, the job is already done in the Workbook class - pass - - def _open_for_read(self): - self.handle = open_excel(self.fname) - - def _open_for_write(self): - self.handle = open_excel(self.fname, overwrite_file=self.overwrite_file) - - def list(self): - return self.handle.sheet_names() - - def _read_array(self, key, *args, **kwargs): - return self.handle[key].load(*args, **kwargs) - - def _dump(self, key, value, *args, **kwargs): - self.handle[key] = value.dump(*args, **kwargs) - - def save(self): - self.handle.save() - - def close(self): - self.handle.close() - - -class PandasCSVHandler(FileHandler): - def __init__(self, fname, overwrite_file=False): - super(PandasCSVHandler, self).__init__(fname, overwrite_file) - if fname is None: - self.pattern = None - self.directory = None - elif '.csv' in fname or '*' in fname or '?' in fname: - self.pattern = fname - self.directory = os.path.dirname(fname) - else: - # assume fname is a directory. - # Not testing for os.path.isdir(fname) here because when writing, the directory might not exist. - self.pattern = os.path.join(fname, '*.csv') - self.directory = fname - - def _get_original_file_name(self): - pass - - def _open_for_read(self): - if self.directory and not os.path.isdir(self.directory): - raise ValueError("Directory '{}' does not exist".format(self.directory)) - - def _open_for_write(self): - if self.directory is not None: - try: - os.makedirs(self.directory) - except OSError: - if not os.path.isdir(self.directory): - raise ValueError("Path {} must represent a directory".format(self.directory)) - - def list(self): - fnames = glob(self.pattern) if self.pattern is not None else [] - # drop directory - fnames = [os.path.basename(fname) for fname in fnames] - # strip extension from files - # XXX: unsure we should use sorted here - return sorted([os.path.splitext(fname)[0] for fname in fnames]) - - def _to_filepath(self, key): - if self.directory is not None: - return os.path.join(self.directory, '{}.csv'.format(key)) - else: - return key - - def _read_array(self, key, *args, **kwargs): - return read_csv(self._to_filepath(key), *args, **kwargs) - - def _dump(self, key, value, *args, **kwargs): - value.to_csv(self._to_filepath(key), *args, **kwargs) - - def close(self): - pass - - -class PickleHandler(FileHandler): - def _open_for_read(self): - with open(self.fname, 'rb') as f: - self.data = pickle.load(f) - - def _open_for_write(self): - self.data = OrderedDict() - - def list(self): - return self.data.keys() - - def _read_array(self, key): - return self.data[key] - - def _dump(self, key, value): - self.data[key] = value - - def close(self): - with open(self.fname, 'wb') as f: - pickle.dump(self.data, f) - - handler_classes = { 'pickle': PickleHandler, 'pandas_csv': PandasCSVHandler, diff --git a/larray/inout/xw_excel.py b/larray/inout/xw_excel.py new file mode 100644 index 000000000..9214712db --- /dev/null +++ b/larray/inout/xw_excel.py @@ -0,0 +1,715 @@ +# -*- coding: utf8 -*- +from __future__ import absolute_import, print_function + +__all__ = ['open_excel', 'Workbook'] + + +import os +import atexit + +import numpy as np +try: + import xlwings as xw +except ImportError: + xw = None + +from larray.core.group import _translate_sheet_name +from larray.core.axis import Axis +from larray.core.array import LArray, ndtest +from larray.inout.pandas import df_aslarray +from larray.inout.misc import from_lists +from larray.util.misc import PY2 + +string_types = (str,) + + +if xw is not None: + from xlwings.conversion.pandas_conv import PandasDataFrameConverter + + global_app = None + + def is_app_alive(app): + try: + app.books + return True + except Exception: + return False + + + def kill_global_app(): + global global_app + + if global_app is not None: + if is_app_alive(global_app): + try: + global_app.kill() + except Exception: + pass + del global_app + global_app = None + + + class LArrayConverter(PandasDataFrameConverter): + writes_types = LArray + + @classmethod + def read_value(cls, value, options): + df = PandasDataFrameConverter.read_value(value, options) + return df_aslarray(df) + + @classmethod + def write_value(cls, value, options): + df = value.to_frame(fold_last_axis_name=True) + return PandasDataFrameConverter.write_value(df, options) + + LArrayConverter.register(LArray) + + # TODO: replace overwrite_file by mode='r'|'w'|'a' the day xlwings will support a read-only mode + class Workbook(object): + def __init__(self, filepath=None, overwrite_file=False, visible=None, silent=None, app=None): + global global_app + + xw_wkb = None + self.delayed_filepath = None + self.filepath = None + self.new_workbook = False + self.active_workbook = filepath == -1 + + if filepath is None: + self.new_workbook = True + + if isinstance(filepath, str): + basename, ext = os.path.splitext(filepath) + if ext: + # XXX: we might want to be more precise than .xl* because I am unsure writing .xls + # (or anything other than .xlsx and .xlsm) would work + if not ext.startswith('.xl'): + raise ValueError("'%s' is not a supported file extension" % ext) + if not os.path.isfile(filepath) and not overwrite_file: + raise ValueError("File {} does not exist. Please give the path to an existing file or set " + "overwrite_file argument to True".format(filepath)) + if os.path.isfile(filepath) and overwrite_file: + self.filepath = filepath + # we create a temporary file to work on. In case of crash, the original is not destroyed. + # the temporary file is renamed as the original file at close. + filepath = basename + '~' + ext + if not os.path.isfile(filepath): + self.new_workbook = True + else: + # try to target an open but unsaved workbook. We cannot use the same code path as for other options + # because we do not know which Excel instance has that book + xw_wkb = xw.Book(filepath) + app = xw_wkb.app + + # active workbook use active app by default + if self.active_workbook and app not in {None, "active"}: + raise ValueError("to connect to the active workbook, one must use the 'active' Excel instance " + "(app='active' or app=None)") + + # unless explicitly set, app is set to visible for brand new or active book. + # For unsaved_book it is left intact. + if visible is None: + if filepath is None or self.active_workbook: + visible = True + elif xw_wkb is None: + # filepath is not None but we don't target an unsaved book + visible = False + + if app is None: + if self.active_workbook: + app = "active" + elif visible: + app = "new" + else: + app = "global" + + load_addins = False + if app == "new": + app = xw.App(visible=visible, add_book=False) + load_addins = True + elif app == "active": + app = xw.apps.active + elif app == "global": + if global_app is None: + atexit.register(kill_global_app) + if global_app is None or not is_app_alive(global_app): + global_app = xw.App(visible=visible, add_book=False) + load_addins = True + app = global_app + assert isinstance(app, xw.App) + + if visible: + app.visible = visible + + if silent is None: + silent = not visible + + # activate XLA(M) addins + # (for some reasons, add-ins are not activated when an Excel Workbook is opened from Python) + if load_addins: + for ia in range(1, app.api.Addins.Count + 1): + addin_path = app.api.Addins(ia).FullName + if not '.xll' in addin_path.lower(): + app.api.Workbooks.Open(addin_path) + + update_links_backup = app.api.AskToUpdateLinks + display_alerts_backup = app.display_alerts + if silent: + # try to update links silently instead of asking: "Update", "Don't Update", "Help" + app.api.AskToUpdateLinks = False + + # in case some links cannot be updated, continue instead of asking: "Continue" or "Edit Links..." + app.display_alerts = False + + if filepath is None: + # creates a new/blank Book + xw_wkb = app.books.add() + elif self.active_workbook: + xw_wkb = app.books.active + elif xw_wkb is None: + # file already exists (and is a file) + if os.path.isfile(filepath): + xw_wkb = app.books.open(filepath) + else: + # let us remember the path + self.delayed_filepath = filepath + xw_wkb = app.books.add() + + if silent: + app.api.AskToUpdateLinks = update_links_backup + app.display_alerts = display_alerts_backup + + self.xw_wkb = xw_wkb + + def __contains__(self, key): + if isinstance(key, int): + length = len(self) + return -length <= key < length + else: + # I would like to use: "return key in wb.sheets" but as of xlwings 0.10 wb.sheets.__contains__ does not + # work for sheet names (it works with Sheet objects I think) + return key in self.sheet_names() + + def _ipython_key_completions_(self): + return list(self.sheet_names()) + + def __getitem__(self, key): + key = _translate_sheet_name(key) + if key in self: + return Sheet(self, key) + else: + raise KeyError('Workbook has no sheet named {}'.format(key)) + + def __setitem__(self, key, value): + key = _translate_sheet_name(key) + if self.new_workbook: + self.xw_wkb.sheets[0].name = key + self.new_workbook = False + key_in_self = key in self + if isinstance(value, Sheet): + if value.xw_sheet.book.app != self.xw_wkb.app: + raise ValueError("cannot copy a sheet from one instance of Excel to another") + + # xlwings index is 1-based + # TODO: implement Workbook.index(key) + target_idx = self[key].xw_sheet.index - 1 if key_in_self else -1 + target_sheet = self[target_idx].xw_sheet + # add new sheet after target sheet. The new sheet will be named something like "value.name (1)" but I + # do not think there is anything we can do about this, except rename it afterwards because Copy has no + # name argument. See https://msdn.microsoft.com/en-us/library/office/ff837784.aspx + value.xw_sheet.api.Copy(None, target_sheet.api) + if key_in_self: + target_sheet.delete() + # rename the new sheet + self[target_idx].name = key + return + if key_in_self: + sheet = self[key] + sheet.clear() + else: + xw_sheet = self.xw_wkb.sheets.add(key, after=self[-1].xw_sheet) + sheet = Sheet(None, None, xw_sheet=xw_sheet) + sheet["A1"] = value + + def __delitem__(self, key): + self[key].delete() + + def sheet_names(self): + return [s.name for s in self] + + def save(self, path=None): + # saved_path = self.xw_wkb.api.Path + # was_saved = saved_path != '' + if path is None and self.delayed_filepath is not None: + path = self.delayed_filepath + self.xw_wkb.save(path=path) + + def close(self): + # Close the workbook in Excel. + # This will not quit the Excel instance, even if this was the last workbook of that Excel instance. + if self.filepath is not None and os.path.isfile(self.xw_wkb.fullname): + tmp_file = self.xw_wkb.fullname + self.xw_wkb.close() + os.remove(self.filepath) + os.rename(tmp_file, self.filepath) + else: + self.xw_wkb.close() + + def __iter__(self): + return iter([Sheet(None, None, xw_sheet) + for xw_sheet in self.xw_wkb.sheets]) + + def __len__(self): + return len(self.xw_wkb.sheets) + + def __dir__(self): + return list(set(dir(self.__class__)) | set(dir(self.xw_wkb))) + + def __getattr__(self, key): + return getattr(self.xw_wkb, key) + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + if not self.active_workbook: + self.close() + + def __repr__(self): + cls = self.__class__ + return '<{}.{} [{}]>'.format(cls.__module__, cls.__name__, self.name) + + + def _fill_slice(s, length): + """ + replaces a slice None bounds by actual bounds. + + Parameters + ---------- + s : slice + slice to replace + length : int + length of sequence + + Returns + ------- + slice + """ + return slice(s.start if s.start is not None else 0, s.stop if s.stop is not None else length, s.step) + + + def _concrete_key(key, obj, ndim=2): + """Expand key to ndim and replace None in slices start/stop bounds by 0 or obj.shape[corresponding_dim] + respectively. + + Parameters + ---------- + key : scalar, slice or tuple + input key + obj : object + any object with a 'shape' attribute. + ndim : int + number of dimensions to expand to. We could use len(obj.shape) instead but we avoid it to not trigger + obj.shape, which can be expensive in the case of a sheet with blank cells after the data. + """ + if not isinstance(key, tuple): + key = (key,) + + if len(key) < ndim: + key = key + (slice(None),) * (ndim - len(key)) + + # only compute shape if necessary because it can be expensive in some cases + if any(isinstance(k, slice) and k.stop is None for k in key): + shape = obj.shape + else: + shape = (None, None) + + # We use _fill_slice instead of slice(*k.indices(length)) because the later also clips bounds which exceed + # the length and we do NOT want to do that in this case (see issue #273). + return [_fill_slice(k, length) if isinstance(k, slice) else k + for k, length in zip(key, shape)] + + + class Sheet(object): + def __init__(self, workbook, key, xw_sheet=None): + if xw_sheet is None: + xw_sheet = workbook.xw_wkb.sheets[key] + object.__setattr__(self, 'xw_sheet', xw_sheet) + + # TODO: we can probably scrap this for xlwings 0.9+. We need to have + # a unit test for this though. + def __getitem__(self, key): + if isinstance(key, string_types): + return Range(self, key) + + row, col = _concrete_key(key, self) + if isinstance(row, slice) or isinstance(col, slice): + row1, row2 = (row.start, row.stop) if isinstance(row, slice) else (row, row + 1) + col1, col2 = (col.start, col.stop) if isinstance(col, slice) else (col, col + 1) + return Range(self, (row1 + 1, col1 + 1), (row2, col2)) + else: + return Range(self, (row + 1, col + 1)) + + def __setitem__(self, key, value): + if isinstance(value, LArray): + value = value.dump(header=False) + self[key].xw_range.value = value + + @property + def shape(self): + """ + shape of sheet including top-left empty rows/columns but excluding bottom-right ones. + """ + from xlwings.constants import Direction as xldir + + sheet = self.xw_sheet.api + used = sheet.UsedRange + first_row = used.Row + first_col = used.Column + last_row = first_row + used.Rows.Count - 1 + last_col = first_col + used.Columns.Count - 1 + last_cell = sheet.Cells(last_row, last_col) + + # fast path for sheets with a non blank bottom-right value + if last_cell.Value is not None: + return last_row, last_col + + last_row_used = last_cell.End(xldir.xlToLeft).Value is not None + last_col_used = last_cell.End(xldir.xlUp).Value is not None + + # fast path for sheets where last row and last col are not entirely blank + if last_row_used and last_col_used: + return last_row, last_col + else: + LEFT, UP = xldir.xlToLeft, xldir.xlUp + + def line_length(row, col, direction): + last_cell = sheet.Cells(row, col) + if last_cell.Value is not None: + return col if direction is LEFT else row + first_cell = last_cell.End(direction) + pos = first_cell.Column if direction is LEFT else first_cell.Row + return pos - 1 if first_cell.Value is None else pos + + if last_row < last_col: + if last_row_used or last_row == 1: + max_row = last_row + else: + for max_row in range(last_row - 1, first_row - 1, -1): + if line_length(max_row, last_col, LEFT) > 0: + break + if last_col_used or last_col == 1: + max_col = last_col + else: + max_col = max(line_length(row, last_col, LEFT) for row in range(first_row, max_row + 1)) + else: + if last_col_used or last_col == 1: + max_col = last_col + else: + for max_col in range(last_col - 1, first_col - 1, -1): + if line_length(last_row, max_col, UP) > 0: + break + if last_row_used or last_row == 1: + max_row = last_row + else: + max_row = max(line_length(last_row, col, UP) for col in range(first_col, max_col + 1)) + return max_row, max_col + + @property + def ndim(self): + return 2 + + def __array__(self, dtype=None): + return np.asarray(self[:], dtype=dtype) + + def __dir__(self): + return list(set(dir(self.__class__)) | set(dir(self.xw_sheet))) + + def __getattr__(self, key): + return getattr(self.xw_sheet, key) + + def __setattr__(self, key, value): + setattr(self.xw_sheet, key, value) + + def load(self, header=True, convert_float=True, nb_index=None, index_col=None, fill_value=np.nan, + sort_rows=False, sort_columns=False, wide=True): + return self[:].load(header=header, convert_float=convert_float, nb_index=nb_index, index_col=index_col, + fill_value=fill_value, sort_rows=sort_rows, sort_columns=sort_columns, wide=wide) + + # TODO: generalize to more than 2 dimensions or scrap it + def array(self, data, row_labels=None, column_labels=None, names=None): + """ + + Parameters + ---------- + data : str + range for data + row_labels : str, optional + range for row labels + column_labels : str, optional + range for column labels + names : list of str, optional + + Returns + ------- + LArray + """ + if row_labels is not None: + row_labels = np.asarray(self[row_labels]) + if column_labels is not None: + column_labels = np.asarray(self[column_labels]) + if names is not None: + labels = (row_labels, column_labels) + axes = [Axis(axis_labels, name) for axis_labels, name in zip(labels, names)] + else: + axes = (row_labels, column_labels) + # _converted_value is used implicitly via Range.__array__ + return LArray(np.asarray(self[data]), axes) + + def __repr__(self): + cls = self.__class__ + xw_sheet = self.xw_sheet + return '<{}.{} [{}]{}>'.format(cls.__module__, cls.__name__, xw_sheet.book.name, xw_sheet.name) + + + class Range(object): + def __init__(self, sheet, *args): + xw_range = sheet.xw_sheet.range(*args) + + object.__setattr__(self, 'sheet', sheet) + object.__setattr__(self, 'xw_range', xw_range) + + def _range_key_to_sheet_key(self, key): + # string keys does not make sense in this case + assert not isinstance(key, string_types) + row_offset = self.xw_range.row1 - 1 + col_offset = self.xw_range.col1 - 1 + row, col = _concrete_key(key, self.xw_range) + row = slice(row.start + row_offset, row.stop + row_offset) if isinstance(row, slice) else row + row_offset + col = slice(col.start + col_offset, col.stop + col_offset) if isinstance(col, slice) else col + col_offset + return row, col + + # TODO: we can probably scrap this for xlwings 0.9+. We need to have + # a unit test for this though. + def __getitem__(self, key): + return self.sheet[self._range_key_to_sheet_key(key)] + + def __setitem__(self, key, value): + self.sheet[self._range_key_to_sheet_key(key)] = value + + def _converted_value(self, convert_float=True): + list_data = self.xw_range.value + + # As of version 0.7.2 of xlwings, there is no built-in converter for + # this. The builtin .options(numbers=int) converter converts all + # values to int, whether that would loose information or not, but + # this is not what we want. + if convert_float: + # Excel 'numbers' are always floats + def convert(value): + if isinstance(value, float): + int_val = int(value) + if int_val == value: + return int_val + return value + elif isinstance(value, list): + return [convert(v) for v in value] + else: + return value + return convert(list_data) + return list_data + + def __float__(self): + # no need to use _converted_value because we will convert back to a float anyway + return float(self.xw_range.value) + + def __int__(self): + # no need to use _converted_value because we will convert to an int anyway + return int(self.xw_range.value) + + def __index__(self): + v = self._converted_value() + if hasattr(v, '__index__'): + return v.__index__() + else: + raise TypeError("only integer scalars can be converted to a scalar index") + + def __array__(self, dtype=None): + return np.array(self._converted_value(), dtype=dtype) + + def __larray__(self): + return LArray(self._converted_value()) + + def __dir__(self): + return list(set(dir(self.__class__)) | set(dir(self.xw_range))) + + def __getattr__(self, key): + if hasattr(LArray, key): + return getattr(self.__larray__(), key) + else: + return getattr(self.xw_range, key) + + def __setattr__(self, key, value): + setattr(self.xw_range, key, value) + + # TODO: implement all binops + # def __mul__(self, other): + # return self.__larray__() * other + + def __str__(self): + return str(self.__larray__()) + __repr__ = __str__ + + def load(self, header=True, convert_float=True, nb_index=None, index_col=None, fill_value=np.nan, + sort_rows=False, sort_columns=False, wide=True): + if not self.ndim: + return LArray([]) + + list_data = self._converted_value(convert_float=convert_float) + + if header: + return from_lists(list_data, nb_index=nb_index, index_col=index_col, fill_value=fill_value, + sort_rows=sort_rows, sort_columns=sort_columns, wide=wide) + else: + return LArray(list_data) + + # XXX: deprecate this function? + def open_excel(filepath=None, overwrite_file=False, visible=None, silent=None, app=None): + return Workbook(filepath, overwrite_file, visible, silent, app) +else: + class Workbook(object): + def __init__(self, filepath=None, overwrite_file=False, visible=None, silent=None, app=None): + raise Exception("Workbook class cannot be instanciated because xlwings is not installed") + + def sheet_names(self): + raise Exception() + + def save(self, path=None): + raise Exception() + + def close(self): + raise Exception() + + def open_excel(filepath=None, overwrite_file=False, visible=None, silent=None, app=None): + raise Exception("open_excel() is not available because xlwings is not installed") + + +# We define Workbook and open_excel documentation here since Readthedocs runs on Linux +if not PY2: + Workbook.__doc__ = """ +Excel Workbook. + +See Also +-------- +open_excel +""" + + Workbook.sheet_names.__doc__ = """ +Returns the names of the Excel sheets. + +Examples +-------- +>>> arr, arr2, arr3 = ndtest((3, 3)), ndtest((2, 2)), ndtest(4) +>>> with open_excel('excel_file.xlsx', overwrite_file=True) as wb: # doctest: +SKIP +... wb['arr'] = arr.dump() +... wb['arr2'] = arr2.dump() +... wb['arr3'] = arr3.dump() +... wb.save() +... +... wb.sheet_names() +['arr', 'arr2', 'arr3'] +""" + + Workbook.save.__doc__ = """ +Saves the Workbook. + +If a path is being provided, this works like SaveAs() in Excel. +If no path is specified and if the file hasn’t been saved previously, +it’s being saved in the current working directory with the current filename. +Existing files are overwritten without prompting. + +Parameters +---------- +path : str, optional + Full path to the workbook. Defaults to None. + +Examples +-------- +>>> arr, arr2, arr3 = ndtest((3, 3)), ndtest((2, 2)), ndtest(4) +>>> with open_excel('excel_file.xlsx', overwrite_file=True) as wb: # doctest: +SKIP +... wb['arr'] = arr.dump() +... wb['arr2'] = arr2.dump() +... wb['arr3'] = arr3.dump() +... wb.save() +""" + + Workbook.close.__doc__ = """ +Close the workbook in Excel. + +Need to be called if the workbook has been opened without the `with` statement. + +Examples +-------- +>>> arr, arr2, arr3 = ndtest((3, 3)), ndtest((2, 2)), ndtest(4) # doctest: +SKIP +>>> wb = open_excel('excel_file.xlsx', overwrite_file=True) # doctest: +SKIP +>>> wb['arr'] = arr.dump() # doctest: +SKIP +>>> wb['arr2'] = arr2.dump() # doctest: +SKIP +>>> wb['arr3'] = arr3.dump() # doctest: +SKIP +>>> wb.save() # doctest: +SKIP +>>> wb.close() # doctest: +SKIP +""" + +open_excel.__doc__ = """ +Open an Excel workbook + +Parameters +---------- +filepath : None, int or str, optional + path to the Excel file. The file must exist if overwrite_file is False. Use None for a new blank workbook, + -1 for the currently active workbook. Defaults to None. +overwrite_file : bool, optional + whether or not to overwrite an existing file, if any. Defaults to False. +visible : None or bool, optional + whether or not Excel should be visible. Defaults to False for files, True for new/active workbooks and to None + ("unchanged") for existing unsaved workbooks. +silent : None or bool, optional + whether or not to show dialog boxes for updating links or when some links cannot be updated. + Defaults to False if visible, True otherwise. +app : None, "new", "active", "global" or xlwings.App, optional + use "new" for opening a new Excel instance, "active" for the last active instance (including ones opened by the + user) and "global" to (re)use the same instance for all workbooks of a program. None is equivalent to "active" if + filepath is -1, "new" if visible is True and "global" otherwise. Defaults to None. + + The "global" instance is a specific Excel instance for all input from/output to Excel from within a single Python + program (and should not interact with instances manually opened by the user or another program). + +Returns +------- +Excel workbook. + +Examples +-------- +>>> arr = ndtest((3, 3)) +>>> arr +a\\b b0 b1 b2 + a0 0 1 2 + a1 3 4 5 + a2 6 7 8 + +create a new Excel file and save an array + +>>> # to create a new Excel file, argument overwrite_file must be set to True +>>> with open_excel('excel_file.xlsx', overwrite_file=True) as wb: # doctest: +SKIP +... wb['arr'] = arr.dump() +... wb.save() + +read array from an Excel file + +>>> with open_excel('excel_file.xlsx') as wb: # doctest: +SKIP +... arr2 = wb['arr'].load() +>>> arr2 # doctest: +SKIP +a\\b b0 b1 b2 + a0 0 1 2 + a1 3 4 5 + a2 6 7 8 +""" diff --git a/larray/tests/common.py b/larray/tests/common.py index bb55900d4..01998a0c7 100644 --- a/larray/tests/common.py +++ b/larray/tests/common.py @@ -98,3 +98,7 @@ def nan_equal(a, b): assert_nparray_equiv = assert_nparray_equal_factory(equal) assert_nparray_nan_equiv = assert_nparray_equal_factory(nan_equal) + + +def assert_axis_eq(axis1, axis2): + assert axis1.equals(axis2) diff --git a/larray/tests/test_array.py b/larray/tests/test_array.py index 8f1e0f2b1..c685b8a71 100644 --- a/larray/tests/test_array.py +++ b/larray/tests/test_array.py @@ -17,7 +17,7 @@ from larray import (LArray, Axis, LGroup, union, zeros, zeros_like, ndtest, ones, eye, diag, stack, clip, exp, where, X, mean, isnan, round, read_hdf, read_csv, read_eurostat, read_excel, from_lists, from_string, open_excel, from_frame, sequence, nan_equal) -from larray.inout.array import from_series +from larray.inout.pandas import from_series from larray.core.axis import _to_ticks, _to_key from larray.util.misc import StringIO diff --git a/larray/tests/test_axis.py b/larray/tests/test_axis.py index da7aefb6c..eb6a77aa9 100644 --- a/larray/tests/test_axis.py +++ b/larray/tests/test_axis.py @@ -1,880 +1,439 @@ from __future__ import absolute_import, division, print_function - -from unittest import TestCase - import pytest +import os.path import numpy as np from larray.tests.common import assert_array_equal -from larray import Axis, AxisCollection, LGroup, IGroup - - -class TestAxis(TestCase): - def setUp(self): - pass - - def tearDown(self): - pass - - def test_init(self): - sex_tuple = ('M', 'F') - sex_list = ['M', 'F'] - sex_array = np.array(sex_list) - - # wildcard axis - axis = Axis(10, 'axis') - assert len(axis) == 10 - assert list(axis.labels) == list(range(10)) - # tuple of strings - assert_array_equal(Axis(sex_tuple, 'sex').labels, sex_array) - # list of strings - assert_array_equal(Axis(sex_list, 'sex').labels, sex_array) - # array of strings - assert_array_equal(Axis(sex_array, 'sex').labels, sex_array) - # single string - assert_array_equal(Axis('sex=M,F').labels, sex_array) - # list of ints - assert_array_equal(Axis(range(116), 'age').labels, np.arange(116)) - # range-string - axis = Axis('0..115', 'age') - assert_array_equal(axis.labels, np.arange(116)) - # int-like labels with 0 padding - assert_array_equal(Axis('01..12', 'zero_padding').labels, [str(i).zfill(2) for i in range(1, 13)]) - assert_array_equal(Axis('01,02,03,10,11,12', 'zero_padding').labels, ['01', '02', '03', '10', '11', '12']) - - # another axis group - group = axis[:10] - group_axis = Axis(group) - assert_array_equal(group_axis.labels, np.arange(11)) - assert_array_equal(group_axis.name, 'age') - # another axis as labels argument - other = Axis('other=0..10') - axis = Axis(other, 'age') - assert_array_equal(axis.labels, other.labels) - assert_array_equal(axis.name, 'age') - - def test_equals(self): - self.assertTrue(Axis('sex=M,F').equals(Axis('sex=M,F'))) - self.assertTrue(Axis('sex=M,F').equals(Axis(['M', 'F'], 'sex'))) - self.assertFalse(Axis('sex=M,W').equals(Axis('sex=M,F'))) - self.assertFalse(Axis('sex1=M,F').equals(Axis('sex2=M,F'))) - self.assertFalse(Axis('sex1=M,W').equals(Axis('sex2=M,F'))) - - def test_getitem(self): - age = Axis('age=0..10') - # a tuple - a159 = age[1, 5, 9] - self.assertEqual(a159.key, [1, 5, 9]) - self.assertIs(a159.name, None) - self.assertIs(a159.axis, age) - - # a normal list - a159 = age[[1, 5, 9]] - self.assertEqual(a159.key, [1, 5, 9]) - self.assertIs(a159.name, None) - self.assertIs(a159.axis, age) - - # a string list - a159 = age['1,5,9'] - self.assertEqual(a159.key, [1, 5, 9]) - self.assertIs(a159.name, None) - self.assertIs(a159.axis, age) - - # a normal slice - a10to20 = age[5:9] - self.assertEqual(a10to20.key, slice(5, 9)) - self.assertIs(a10to20.axis, age) - - # a string slice - a10to20 = age['5:9'] - self.assertEqual(a10to20.key, slice(5, 9)) - self.assertIs(a10to20.axis, age) - - # with name - group = age[[1, 5, 9]] >> 'test' - self.assertEqual(group.key, [1, 5, 9]) - self.assertEqual(group.name, 'test') - self.assertIs(group.axis, age) - - # all - group = age[:] >> 'all' - self.assertEqual(group.key, slice(None)) - self.assertIs(group.axis, age) - - # an axis - age2 = Axis('age=0..5') - group = age[age2] - assert list(group.key) == list(age2.labels) - - def test_translate(self): - # an axis with labels having the object dtype - a = Axis(np.array(["a0", "a1"], dtype=object), 'a') - - self.assertEqual(a.index('a1'), 1) - self.assertEqual(a.index('a1 >> A1'), 1) - - def test_getitem_lgroup_keys(self): - def group_equal(g1, g2): - return (g1.key == g2.key and g1.name == g2.name and - g1.axis is g2.axis) - - age = Axis(range(100), 'age') - ages = [1, 5, 9] - - val_only = LGroup(ages) - self.assertTrue(group_equal(age[val_only], LGroup(ages, axis=age))) - self.assertTrue(group_equal(age[val_only] >> 'a_name', LGroup(ages, 'a_name', axis=age))) - - val_name = LGroup(ages, 'val_name') - self.assertTrue(group_equal(age[val_name], LGroup(ages, 'val_name', age))) - self.assertTrue(group_equal(age[val_name] >> 'a_name', LGroup(ages, 'a_name', age))) - - val_axis = LGroup(ages, axis=age) - self.assertTrue(group_equal(age[val_axis], LGroup(ages, axis=age))) - self.assertTrue(group_equal(age[val_axis] >> 'a_name', LGroup(ages, 'a_name', axis=age))) - - val_axis_name = LGroup(ages, 'val_axis_name', age) - self.assertTrue(group_equal(age[val_axis_name], LGroup(ages, 'val_axis_name', age))) - self.assertTrue(group_equal(age[val_axis_name] >> 'a_name', LGroup(ages, 'a_name', age))) - - def test_getitem_group_keys(self): - a = Axis('a=a0..a2') - alt_a = Axis('a=a1..a3') - - # a) key is a single LGroup - # ------------------------- - - # a.1) containing a scalar - key = a['a1'] - # use it on the same axis - g = a[key] - self.assertEqual(g.key, 'a1') - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertEqual(g.key, 'a1') - self.assertIs(g.axis, alt_a) - - # a.2) containing a slice - key = a['a1':'a2'] - # use it on the same axis - g = a[key] - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, alt_a) - - # a.3) containing a list - key = a[['a1', 'a2']] - # use it on the same axis - g = a[key] - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, alt_a) - - # b) key is a single IGroup - # ------------------------- - - # b.1) containing a scalar - key = a.i[1] - # use it on the same axis - g = a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, 'a1') - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, 'a1') - self.assertIs(g.axis, alt_a) - - # b.2) containing a slice - key = a.i[1:3] - # use it on the same axis - g = a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, alt_a) - - # b.3) containing a list - key = a.i[[1, 2]] - # use it on the same axis - g = a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(list(g.key), ['a1', 'a2']) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(list(g.key), ['a1', 'a2']) - self.assertIs(g.axis, alt_a) - - # c) key is a slice - # ----------------- - - # c.1) with LGroup bounds - lg_a1 = a['a1'] - lg_a2 = a['a2'] - # use it on the same axis - g = a[lg_a1:lg_a2] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[lg_a1:lg_a2] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, alt_a) - - # c.2) with IGroup bounds - pg_a1 = a.i[1] - pg_a2 = a.i[2] - # use it on the same axis - g = a[pg_a1:pg_a2] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[pg_a1:pg_a2] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, slice('a1', 'a2')) - self.assertIs(g.axis, alt_a) - - # d) key is a list of scalar groups => create a single LGroup - # --------------------------------- - - # d.1) with LGroup - key = [a['a1'], a['a2']] - # use it on the same axis - g = a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, alt_a) - - # d.2) with IGroup - key = [a.i[1], a.i[2]] - # use it on the same axis - g = a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, alt_a) - - # e) key is a list of non-scalar groups => retarget multiple groups to axis - # ------------------------------------- - - # e.1) with LGroup - key = [a['a1', 'a2'], a['a2', 'a1']] - # use it on the same axis => nothing happens - g = a[key] - self.assertIsInstance(g, list) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(g[0].key, ['a1', 'a2']) - self.assertEqual(g[1].key, ['a2', 'a1']) - self.assertIs(g[0].axis, a) - self.assertIs(g[1].axis, a) - # use it on a different axis => change axis - g = alt_a[key] - self.assertIsInstance(g, list) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(g[0].key, ['a1', 'a2']) - self.assertEqual(g[1].key, ['a2', 'a1']) - self.assertIs(g[0].axis, alt_a) - self.assertIs(g[1].axis, alt_a) - - # e.2) with IGroup - key = (a.i[1, 2], a.i[2, 1]) - # use it on the same axis => change to LGroup - g = a[key] - self.assertIsInstance(g, tuple) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(list(g[0].key), ['a1', 'a2']) - self.assertEqual(list(g[1].key), ['a2', 'a1']) - self.assertIs(g[0].axis, a) - self.assertIs(g[1].axis, a) - # use it on a different axis => retarget to axis - g = alt_a[key] - self.assertIsInstance(g, tuple) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(list(g[0].key), ['a1', 'a2']) - self.assertEqual(list(g[1].key), ['a2', 'a1']) - self.assertIs(g[0].axis, alt_a) - self.assertIs(g[1].axis, alt_a) - - # f) key is a tuple of scalar groups => create a single LGroup - # ---------------------------------- - - # f.1) with LGroups - key = (a['a1'], a['a2']) - # use it on the same axis - g = a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, alt_a) - - # f.2) with IGroup - key = (a.i[1], a.i[2]) - # use it on the same axis - g = a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, LGroup) - self.assertEqual(g.key, ['a1', 'a2']) - self.assertIs(g.axis, alt_a) - - # g) key is a tuple of non-scalar groups => retarget multiple groups to axis - # -------------------------------------- - - # g.1) with LGroups - key = (a['a1', 'a2'], a['a2', 'a1']) - # use it on the same axis - g = a[key] - self.assertIsInstance(g, tuple) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(g[0].key, ['a1', 'a2']) - self.assertEqual(g[1].key, ['a2', 'a1']) - self.assertIs(g[0].axis, a) - self.assertIs(g[1].axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, tuple) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(g[0].key, ['a1', 'a2']) - self.assertEqual(g[1].key, ['a2', 'a1']) - self.assertIs(g[0].axis, alt_a) - self.assertIs(g[1].axis, alt_a) - - # g.2) with IGroup - key = (a.i[1, 2], a.i[2, 1]) - # use it on the same axis - g = a[key] - self.assertIsInstance(g, tuple) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(list(g[0].key), ['a1', 'a2']) - self.assertEqual(list(g[1].key), ['a2', 'a1']) - self.assertIs(g[0].axis, a) - self.assertIs(g[1].axis, a) - # use it on a different axis - g = alt_a[key] - self.assertIsInstance(g, tuple) - self.assertIsInstance(g[0], LGroup) - self.assertIsInstance(g[1], LGroup) - self.assertEqual(list(g[0].key), ['a1', 'a2']) - self.assertEqual(list(g[1].key), ['a2', 'a1']) - self.assertIs(g[0].axis, alt_a) - self.assertIs(g[1].axis, alt_a) - - def test_init_from_group(self): - code = Axis('code=C01..C03') - code_group = code[:'C02'] - subset_axis = Axis(code_group, 'code_subset') - assert_array_equal(subset_axis.labels, ['C01', 'C02']) - - def test_matching(self): - sutcode = Axis(['A23', 'A2301', 'A25', 'A2501'], 'sutcode') - self.assertEqual(sutcode.matching('^...$'), LGroup(['A23', 'A25'])) - self.assertEqual(sutcode.startingwith('A23'), LGroup(['A23', 'A2301'])) - self.assertEqual(sutcode.endingwith('01'), LGroup(['A2301', 'A2501'])) - - def test_iter(self): - sex = Axis('sex=M,F') - self.assertEqual(list(sex), [IGroup(0, axis=sex), IGroup(1, axis=sex)]) - - def test_positional(self): - age = Axis('age=0..115') - - # these are NOT equivalent (not translated until used in an LArray - # self.assertEqual(age.i[:17], age[':17']) - key = age.i[:-1] - self.assertEqual(key.key, slice(None, -1)) - self.assertIs(key.axis, age) - - def test_contains(self): - # normal Axis - age = Axis('age=0..10') - - age2 = age[2] - age2bis = age[(2,)] - age2ter = age[[2]] - age2qua = '2,' - - age20 = LGroup('20') - age20bis = LGroup('20,') - age20ter = LGroup(['20']) - age20qua = '20,' - - # TODO: move assert to another test - # self.assertEqual(age2bis, age2ter) - - age247 = age['2,4,7'] - age247bis = age[['2', '4', '7']] - age359 = age[['3', '5', '9']] - age468 = age['4,6,8'] >> 'even' - - self.assertTrue(5 in age) - self.assertFalse('5' in age) - - self.assertTrue(age2 in age) - # only single ticks are "contained" in the axis, not "collections" - self.assertFalse(age2bis in age) - self.assertFalse(age2ter in age) - self.assertFalse(age2qua in age) - - self.assertFalse(age20 in age) - self.assertFalse(age20bis in age) - self.assertFalse(age20ter in age) - self.assertFalse(age20qua in age) - self.assertFalse(['3', '5', '9'] in age) - self.assertFalse('3,5,9' in age) - self.assertFalse('3:9' in age) - self.assertFalse(age247 in age) - self.assertFalse(age247bis in age) - self.assertFalse(age359 in age) - self.assertFalse(age468 in age) - - # aggregated Axis - # FIXME: _to_tick(age2) == 2, but then np.asarray([2, '2,4,7', ...]) returns np.array(['2', '2,4,7']) - # instead of returning an object array - agg = Axis((age2, age247, age359, age468, '2,6', ['3', '5', '7'], ('6', '7', '9')), "agg") - # fails because of above FIXME - # self.assertTrue(age2 in agg) - self.assertFalse(age2bis in agg) - self.assertFalse(age2ter in agg) - self.assertFalse(age2qua in age) - - self.assertTrue(age247 in agg) - self.assertTrue(age247bis in agg) - self.assertTrue('2,4,7' in agg) - self.assertTrue(['2', '4', '7'] in agg) - - self.assertTrue(age359 in agg) - self.assertTrue('3,5,9' in agg) - self.assertTrue(['3', '5', '9'] in agg) - - self.assertTrue(age468 in agg) - # no longer the case - # self.assertTrue('4,6,8' in agg) - # self.assertTrue(['4', '6', '8'] in agg) - self.assertTrue('even' in agg) - - self.assertTrue('2,6' in agg) - self.assertTrue(['2', '6'] in agg) - self.assertTrue(age['2,6'] in agg) - self.assertTrue(age[['2', '6']] in agg) - - self.assertTrue('3,5,7' in agg) - self.assertTrue(['3', '5', '7'] in agg) - self.assertTrue(age['3,5,7'] in agg) - self.assertTrue(age[['3', '5', '7']] in agg) - - self.assertTrue('6,7,9' in agg) - self.assertTrue(['6', '7', '9'] in agg) - self.assertTrue(age['6,7,9'] in agg) - self.assertTrue(age[['6', '7', '9']] in agg) - - self.assertFalse(5 in agg) - self.assertFalse('5' in agg) - self.assertFalse(age20 in agg) - self.assertFalse(age20bis in agg) - self.assertFalse(age20ter in agg) - self.assertFalse(age20qua in agg) - self.assertFalse('2,7' in agg) - self.assertFalse(['2', '7'] in agg) - self.assertFalse(age['2,7'] in agg) - self.assertFalse(age[['2', '7']] in agg) - - -class TestAxisCollection(TestCase): - def setUp(self): - self.lipro = Axis('lipro=P01..P04') - self.sex = Axis('sex=M,F') - self.sex2 = Axis('sex=F,M') - self.age = Axis('age=0..7') - self.geo = Axis('geo=A11,A12,A13') - self.value = Axis('value=0..10') - self.collection = AxisCollection((self.lipro, self.sex, self.age)) - - def test_init_from_group(self): - lipro_subset = self.lipro[:'P03'] - col2 = AxisCollection((lipro_subset, self.sex)) - self.assertEqual(col2.names, ['lipro', 'sex']) - assert_array_equal(col2.lipro.labels, ['P01', 'P02', 'P03']) - assert_array_equal(col2.sex.labels, ['M', 'F']) - - def test_init_from_string(self): - col = AxisCollection('age=10;sex=M,F;year=2000..2017') - assert col.names == ['age', 'sex', 'year'] - assert list(col.age.labels) == [10] - assert list(col.sex.labels) == ['M', 'F'] - assert list(col.year.labels) == [y for y in range(2000, 2018)] - - def test_eq(self): - col = self.collection - self.assertEqual(col, col) - self.assertEqual(col, AxisCollection((self.lipro, self.sex, self.age))) - self.assertEqual(col, (self.lipro, self.sex, self.age)) - self.assertNotEqual(col, (self.lipro, self.age, self.sex)) - - def test_getitem_name(self): - col = self.collection - self.assert_axis_eq(col['lipro'], self.lipro) - self.assert_axis_eq(col['sex'], self.sex) - self.assert_axis_eq(col['age'], self.age) - - def test_getitem_int(self): - col = self.collection - self.assert_axis_eq(col[0], self.lipro) - self.assert_axis_eq(col[-3], self.lipro) - self.assert_axis_eq(col[1], self.sex) - self.assert_axis_eq(col[-2], self.sex) - self.assert_axis_eq(col[2], self.age) - self.assert_axis_eq(col[-1], self.age) - - def test_getitem_slice(self): - col = self.collection[:2] - self.assertEqual(len(col), 2) - self.assert_axis_eq(col[0], self.lipro) - self.assert_axis_eq(col[1], self.sex) - - def test_setitem_name(self): - col = self.collection[:] - # replace an axis with one with another name - col['lipro'] = self.geo - self.assertEqual(len(col), 3) - self.assertEqual(col, [self.geo, self.sex, self.age]) - # replace an axis with one with the same name - col['sex'] = self.sex2 - self.assertEqual(col, [self.geo, self.sex2, self.age]) - col['geo'] = self.lipro - self.assertEqual(col, [self.lipro, self.sex2, self.age]) - col['age'] = self.geo - self.assertEqual(col, [self.lipro, self.sex2, self.geo]) - col['sex'] = self.sex - col['geo'] = self.age - self.assertEqual(col, self.collection) - - def test_setitem_name_axis_def(self): - col = self.collection[:] - # replace an axis with one with another name - col['lipro'] = 'geo=A11,A12,A13' - self.assertEqual(len(col), 3) - self.assertEqual(col, [self.geo, self.sex, self.age]) - # replace an axis with one with the same name - col['sex'] = 'sex=F,M' - self.assertEqual(col, [self.geo, self.sex2, self.age]) - col['geo'] = 'lipro=P01..P04' - self.assertEqual(col, [self.lipro, self.sex2, self.age]) - col['age'] = 'geo=A11,A12,A13' - self.assertEqual(col, [self.lipro, self.sex2, self.geo]) - col['sex'] = 'sex=M,F' - col['geo'] = 'age=0..7' - self.assertEqual(col, self.collection) - - def test_setitem_int(self): - col = self.collection[:] - col[1] = self.geo - self.assertEqual(len(col), 3) - self.assertEqual(col, [self.lipro, self.geo, self.age]) - col[2] = self.sex - self.assertEqual(col, [self.lipro, self.geo, self.sex]) - col[-1] = self.age - self.assertEqual(col, [self.lipro, self.geo, self.age]) - - def test_setitem_list_replace(self): - col = self.collection[:] - col[['lipro', 'age']] = [self.geo, self.lipro] - self.assertEqual(col, [self.geo, self.sex, self.lipro]) - - def test_setitem_slice_replace(self): - col = self.collection[:] - # replace by list - col[1:] = [self.geo, self.sex] - self.assertEqual(col, [self.lipro, self.geo, self.sex]) - # replace by collection - col[1:] = self.collection[1:] - self.assertEqual(col, self.collection) - - def test_setitem_slice_insert(self): - col = self.collection[:] - col[1:1] = [self.geo] - self.assertEqual(col, [self.lipro, self.geo, self.sex, self.age]) - - def test_setitem_slice_delete(self): - col = self.collection[:] - col[1:2] = [] - self.assertEqual(col, [self.lipro, self.age]) - col[0:1] = [] - self.assertEqual(col, [self.age]) - - def assert_axis_eq(self, axis1, axis2): - self.assertTrue(axis1.equals(axis2)) - - def test_delitem(self): - col = self.collection[:] - self.assertEqual(len(col), 3) - del col[0] - self.assertEqual(len(col), 2) - self.assert_axis_eq(col[0], self.sex) - self.assert_axis_eq(col[1], self.age) - del col['age'] - self.assertEqual(len(col), 1) - self.assert_axis_eq(col[0], self.sex) - del col[self.sex] - self.assertEqual(len(col), 0) - - def test_delitem_slice(self): - col = self.collection[:] - self.assertEqual(len(col), 3) - del col[0:2] - self.assertEqual(len(col), 1) - self.assertEqual(col, [self.age]) - del col[:] - self.assertEqual(len(col), 0) - - def test_pop(self): - col = self.collection[:] - lipro, sex, age = col - self.assertEqual(len(col), 3) - self.assertIs(col.pop(), age) - self.assertEqual(len(col), 2) - self.assertIs(col[0], lipro) - self.assertIs(col[1], sex) - self.assertIs(col.pop(), sex) - self.assertEqual(len(col), 1) - self.assertIs(col[0], lipro) - self.assertIs(col.pop(), lipro) - self.assertEqual(len(col), 0) - - def test_replace(self): - col = self.collection[:] - newcol = col.replace('sex', self.geo) - # original collection is not modified - self.assertEqual(col, self.collection) - self.assertEqual(len(newcol), 3) - self.assertEqual(newcol.names, ['lipro', 'geo', 'age']) - self.assertEqual(newcol.shape, (4, 3, 8)) - newcol = newcol.replace(self.geo, self.sex) - self.assertEqual(len(newcol), 3) - self.assertEqual(newcol.names, ['lipro', 'sex', 'age']) - self.assertEqual(newcol.shape, (4, 2, 8)) - - # from now on, reuse original collection - newcol = col.replace(self.sex, 3) - self.assertEqual(len(newcol), 3) - self.assertEqual(newcol.names, ['lipro', None, 'age']) - self.assertEqual(newcol.shape, (4, 3, 8)) - - newcol = col.replace(self.sex, ['a', 'b', 'c']) - self.assertEqual(len(newcol), 3) - self.assertEqual(newcol.names, ['lipro', None, 'age']) - self.assertEqual(newcol.shape, (4, 3, 8)) - - newcol = col.replace(self.sex, "letters=a,b,c") - self.assertEqual(len(newcol), 3) - self.assertEqual(newcol.names, ['lipro', 'letters', 'age']) - self.assertEqual(newcol.shape, (4, 3, 8)) - - def test_contains(self): - col = self.collection - self.assertTrue('lipro' in col) - self.assertFalse('nonexisting' in col) - - self.assertTrue(0 in col) - self.assertTrue(1 in col) - self.assertTrue(2 in col) - self.assertTrue(-1 in col) - self.assertTrue(-2 in col) - self.assertTrue(-3 in col) - self.assertFalse(3 in col) - - # objects actually in col - self.assertTrue(self.lipro in col) - self.assertTrue(self.sex in col) - self.assertTrue(self.age in col) - # other axis with the same name - self.assertTrue(self.sex2 in col) - self.assertFalse(self.geo in col) - self.assertFalse(self.value in col) - - # test anonymous axes - anon = Axis([0, 1]) - col.append(anon) - self.assertTrue(anon in col) - # different object, same values - anon2 = anon.copy() - self.assertTrue(anon2 in col) - # different values - anon3 = Axis([0, 2]) - self.assertFalse(anon3 in col) - - def test_index(self): - col = self.collection - self.assertEqual(col.index('lipro'), 0) - with self.assertRaises(ValueError): - col.index('nonexisting') - self.assertEqual(col.index(0), 0) - self.assertEqual(col.index(1), 1) - self.assertEqual(col.index(2), 2) - self.assertEqual(col.index(-1), -1) - self.assertEqual(col.index(-2), -2) - self.assertEqual(col.index(-3), -3) - with self.assertRaises(ValueError): - col.index(3) - - # objects actually in col - self.assertEqual(col.index(self.lipro), 0) - self.assertEqual(col.index(self.sex), 1) - self.assertEqual(col.index(self.age), 2) - # other axis with the same name - self.assertEqual(col.index(self.sex2), 1) - # non existing - with self.assertRaises(ValueError): - col.index(self.geo) - with self.assertRaises(ValueError): - col.index(self.value) - - # test anonymous axes - anon = Axis([0, 1]) - col.append(anon) - self.assertEqual(col.index(anon), 3) - # different object, same values - anon2 = anon.copy() - self.assertEqual(col.index(anon2), 3) - # different values - anon3 = Axis([0, 2]) - with self.assertRaises(ValueError): - col.index(anon3) - - def test_get(self): - col = self.collection - self.assert_axis_eq(col.get('lipro'), self.lipro) - self.assertIsNone(col.get('nonexisting')) - self.assertIs(col.get('nonexisting', self.value), self.value) - - def test_keys(self): - self.assertEqual(self.collection.keys(), ['lipro', 'sex', 'age']) - - def test_getattr(self): - col = self.collection - self.assert_axis_eq(col.lipro, self.lipro) - self.assert_axis_eq(col.sex, self.sex) - self.assert_axis_eq(col.age, self.age) - - def test_append(self): - col = self.collection - geo = Axis('geo=A11,A12,A13') - col.append(geo) - self.assertEqual(col, [self.lipro, self.sex, self.age, geo]) - - def test_extend(self): - col = self.collection - col.extend([self.geo, self.value]) - self.assertEqual(col, - [self.lipro, self.sex, self.age, self.geo, self.value]) - - def test_insert(self): - col = self.collection - col.insert(1, self.geo) - self.assertEqual(col, [self.lipro, self.geo, self.sex, self.age]) - - def test_add(self): - col = self.collection.copy() - lipro, sex, age = self.lipro, self.sex, self.age - geo, value = self.geo, self.value - - # 1) list - # a) no dupe - new = col + [self.geo, value] - self.assertEqual(new, [lipro, sex, age, geo, value]) - # check the original has not been modified - self.assertEqual(col, self.collection) - - # b) with compatible dupe - # the "new" age axis is ignored (because it is compatible) - new = col + [Axis('geo=A11,A12,A13'), Axis('age=0..7')] - self.assertEqual(new, [lipro, sex, age, geo]) - - # c) with incompatible dupe - # XXX: the "new" age axis is ignored. We might want to ignore it if it - # is the same but raise an exception if it is different - with self.assertRaises(ValueError): - col + [Axis('geo=A11,A12,A13'), Axis('age=0..6')] - - # 2) other AxisCollection - new = col + AxisCollection([geo, value]) - self.assertEqual(new, [lipro, sex, age, geo, value]) - - def test_combine(self): - col = self.collection.copy() - lipro, sex, age = self.lipro, self.sex, self.age - res = col.combine_axes((lipro, sex)) - self.assertEqual(res.names, ['lipro_sex', 'age']) - self.assertEqual(res.size, col.size) - self.assertEqual(res.shape, (4 * 2, 8)) - print(res.info) - assert_array_equal(res.lipro_sex.labels[0], 'P01_M') - res = col.combine_axes((lipro, age)) - self.assertEqual(res.names, ['lipro_age', 'sex']) - self.assertEqual(res.size, col.size) - self.assertEqual(res.shape, (4 * 8, 2)) - assert_array_equal(res.lipro_age.labels[0], 'P01_0') - res = col.combine_axes((sex, age)) - self.assertEqual(res.names, ['lipro', 'sex_age']) - self.assertEqual(res.size, col.size) - self.assertEqual(res.shape, (4, 2 * 8)) - assert_array_equal(res.sex_age.labels[0], 'M_0') - - def test_info(self): - expected = """\ -4 x 2 x 8 - lipro [4]: 'P01' 'P02' 'P03' 'P04' - sex [2]: 'M' 'F' - age [8]: 0 1 2 ... 5 6 7""" - self.assertEqual(self.collection.info, expected) - - def test_str(self): - self.assertEqual(str(self.collection), "{lipro, sex, age}") - - def test_repr(self): - self.assertEqual(repr(self.collection), """AxisCollection([ - Axis(['P01', 'P02', 'P03', 'P04'], 'lipro'), - Axis(['M', 'F'], 'sex'), - Axis([0, 1, 2, 3, 4, 5, 6, 7], 'age') -])""") +from larray import Axis, LGroup, IGroup, read_hdf + + +def test_init(): + sex_tuple = ('M', 'F') + sex_list = ['M', 'F'] + sex_array = np.array(sex_list) + axis = Axis(10, 'axis') + assert len(axis) == 10 + assert list(axis.labels) == list(range(10)) + assert_array_equal(Axis(sex_tuple, 'sex').labels, sex_array) + assert_array_equal(Axis(sex_list, 'sex').labels, sex_array) + assert_array_equal(Axis(sex_array, 'sex').labels, sex_array) + assert_array_equal(Axis('sex=M,F').labels, sex_array) + assert_array_equal(Axis(range(116), 'age').labels, np.arange(116)) + axis = Axis('0..115', 'age') + assert_array_equal(axis.labels, np.arange(116)) + assert_array_equal(Axis('01..12', 'zero_padding').labels, [str(i).zfill(2) for i in range(1, 13)]) + assert_array_equal(Axis('01,02,03,10,11,12', 'zero_padding').labels, ['01', '02', '03', '10', '11', '12']) + group = axis[:10] + group_axis = Axis(group) + assert_array_equal(group_axis.labels, np.arange(11)) + assert_array_equal(group_axis.name, 'age') + other = Axis('other=0..10') + axis = Axis(other, 'age') + assert_array_equal(axis.labels, other.labels) + assert_array_equal(axis.name, 'age') + +def test_equals(): + assert Axis('sex=M,F').equals(Axis('sex=M,F')) + assert Axis('sex=M,F').equals(Axis(['M', 'F'], 'sex')) + assert not Axis('sex=M,W').equals(Axis('sex=M,F')) + assert not Axis('sex1=M,F').equals(Axis('sex2=M,F')) + assert not Axis('sex1=M,W').equals(Axis('sex2=M,F')) + +def test_getitem(): + age = Axis('age=0..10') + a159 = age[1, 5, 9] + assert a159.key == [1, 5, 9] + assert a159.name is None + assert a159.axis is age + a159 = age[[1, 5, 9]] + assert a159.key == [1, 5, 9] + assert a159.name is None + assert a159.axis is age + a159 = age['1,5,9'] + assert a159.key == [1, 5, 9] + assert a159.name is None + assert a159.axis is age + a10to20 = age[5:9] + assert a10to20.key == slice(5, 9) + assert a10to20.axis is age + a10to20 = age['5:9'] + assert a10to20.key == slice(5, 9) + assert a10to20.axis is age + group = age[[1, 5, 9]] >> 'test' + assert group.key == [1, 5, 9] + assert group.name == 'test' + assert group.axis is age + group = age[:] >> 'all' + assert group.key == slice(None) + assert group.axis is age + age2 = Axis('age=0..5') + group = age[age2] + assert list(group.key) == list(age2.labels) + +def test_translate(): + # an axis with labels having the object dtype + a = Axis(np.array(["a0", "a1"], dtype=object), 'a') + assert a.index('a1') == 1 + assert a.index('a1 >> A1') == 1 + +def test_getitem_lgroup_keys(): + def group_equal(g1, g2): + return (g1.key == g2.key and g1.name == g2.name and g1.axis is g2.axis) + + age = Axis(range(100), 'age') + ages=[1,5,9] + val_only=LGroup(ages) + assert group_equal(age[val_only],LGroup(ages,axis=age)) + assert group_equal(age[val_only]>>'a_name',LGroup(ages,'a_name',axis=age)) + val_name=LGroup(ages,'val_name') + assert group_equal(age[val_name],LGroup(ages,'val_name',age)) + assert group_equal(age[val_name]>>'a_name',LGroup(ages,'a_name',age)) + val_axis=LGroup(ages,axis=age) + assert group_equal(age[val_axis],LGroup(ages,axis=age)) + assert group_equal(age[val_axis]>>'a_name',LGroup(ages,'a_name',axis=age)) + val_axis_name=LGroup(ages,'val_axis_name',age) + assert group_equal(age[val_axis_name],LGroup(ages,'val_axis_name',age)) + assert group_equal(age[val_axis_name]>>'a_name',LGroup(ages,'a_name',age)) + +def test_getitem_group_keys(): + a = Axis('a=a0..a2') + alt_a = Axis('a=a1..a3') + key = a['a1'] + g = a[key] + assert g.key == 'a1' + assert g.axis is a + g = alt_a[key] + assert g.key == 'a1' + assert g.axis is alt_a + key = a['a1':'a2'] + g = a[key] + assert g.key == slice('a1', 'a2') + assert g.axis is a + g = alt_a[key] + assert g.key == slice('a1', 'a2') + assert g.axis is alt_a + key = a[['a1', 'a2']] + g = a[key] + assert g.key == ['a1', 'a2'] + assert g.axis is a + g = alt_a[key] + assert g.key == ['a1', 'a2'] + assert g.axis is alt_a + key = a.i[1] + g = a[key] + assert isinstance(g, LGroup) + assert g.key == 'a1' + assert g.axis is a + g = alt_a[key] + assert isinstance(g, LGroup) + assert g.key == 'a1' + assert g.axis is alt_a + key = a.i[1:3] + g = a[key] + assert isinstance(g, LGroup) + assert g.key == slice('a1', 'a2') + assert g.axis is a + g = alt_a[key] + assert isinstance(g, LGroup) + assert g.key == slice('a1', 'a2') + assert g.axis is alt_a + key = a.i[[1, 2]] + g = a[key] + assert isinstance(g, LGroup) + assert list(g.key) == ['a1', 'a2'] + assert g.axis is a + g = alt_a[key] + assert isinstance(g, LGroup) + assert list(g.key) == ['a1', 'a2'] + assert g.axis is alt_a + lg_a1 = a['a1'] + lg_a2 = a['a2'] + g = a[lg_a1:lg_a2] + assert isinstance(g, LGroup) + assert g.key == slice('a1', 'a2') + assert g.axis is a + g = alt_a[lg_a1:lg_a2] + assert isinstance(g, LGroup) + assert g.key == slice('a1', 'a2') + assert g.axis is alt_a + pg_a1 = a.i[1] + pg_a2 = a.i[2] + g = a[pg_a1:pg_a2] + assert isinstance(g, LGroup) + assert g.key == slice('a1', 'a2') + assert g.axis is a + g = alt_a[pg_a1:pg_a2] + assert isinstance(g, LGroup) + assert g.key == slice('a1', 'a2') + assert g.axis is alt_a + key = [a['a1'], a['a2']] + g = a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is a + g = alt_a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is alt_a + key = [a.i[1], a.i[2]] + g = a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is a + g = alt_a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is alt_a + key = [a['a1', 'a2'], a['a2', 'a1']] + g = a[key] + assert isinstance(g, list) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert g[0].key == ['a1', 'a2'] + assert g[1].key == ['a2', 'a1'] + assert g[0].axis is a + assert g[1].axis is a + g = alt_a[key] + assert isinstance(g, list) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert g[0].key == ['a1', 'a2'] + assert g[1].key == ['a2', 'a1'] + assert g[0].axis is alt_a + assert g[1].axis is alt_a + key = (a.i[1, 2], a.i[2, 1]) + g = a[key] + assert isinstance(g, tuple) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert list(g[0].key) == ['a1', 'a2'] + assert list(g[1].key) == ['a2', 'a1'] + assert g[0].axis is a + assert g[1].axis is a + g = alt_a[key] + assert isinstance(g, tuple) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert list(g[0].key) == ['a1', 'a2'] + assert list(g[1].key) == ['a2', 'a1'] + assert g[0].axis is alt_a + assert g[1].axis is alt_a + key = (a['a1'], a['a2']) + g = a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is a + g = alt_a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is alt_a + key = (a.i[1], a.i[2]) + g = a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is a + g = alt_a[key] + assert isinstance(g, LGroup) + assert g.key == ['a1', 'a2'] + assert g.axis is alt_a + key = (a['a1', 'a2'], a['a2', 'a1']) + g = a[key] + assert isinstance(g, tuple) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert g[0].key == ['a1', 'a2'] + assert g[1].key == ['a2', 'a1'] + assert g[0].axis is a + assert g[1].axis is a + g = alt_a[key] + assert isinstance(g, tuple) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert g[0].key == ['a1', 'a2'] + assert g[1].key == ['a2', 'a1'] + assert g[0].axis is alt_a + assert g[1].axis is alt_a + key = (a.i[1, 2], a.i[2, 1]) + g = a[key] + assert isinstance(g, tuple) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert list(g[0].key) == ['a1', 'a2'] + assert list(g[1].key) == ['a2', 'a1'] + assert g[0].axis is a + assert g[1].axis is a + g = alt_a[key] + assert isinstance(g, tuple) + assert isinstance(g[0], LGroup) + assert isinstance(g[1], LGroup) + assert list(g[0].key) == ['a1', 'a2'] + assert list(g[1].key) == ['a2', 'a1'] + assert g[0].axis is alt_a + assert g[1].axis is alt_a + +def test_init_from_group(): + code = Axis('code=C01..C03') + code_group = code[:'C02'] + subset_axis = Axis(code_group, 'code_subset') + assert_array_equal(subset_axis.labels, ['C01', 'C02']) + +def test_matching(): + sutcode = Axis(['A23', 'A2301', 'A25', 'A2501'], 'sutcode') + assert sutcode.matching('^...$') == LGroup(['A23', 'A25']) + assert sutcode.startingwith('A23') == LGroup(['A23', 'A2301']) + assert sutcode.endingwith('01') == LGroup(['A2301', 'A2501']) + +def test_iter(): + sex = Axis('sex=M,F') + assert list(sex) == [IGroup(0, axis=sex), IGroup(1, axis=sex)] + +def test_positional(): + age = Axis('age=0..115') + key = age.i[:-1] + assert key.key == slice(None, -1) + assert key.axis is age + +def test_contains(): + # normal Axis + age = Axis('age=0..10') + age2 = age[2] + age2bis = age[(2,)] + age2ter = age[[2]] + age2qua = '2,' + age20 = LGroup('20') + age20bis = LGroup('20,') + age20ter = LGroup(['20']) + age20qua = '20,' + age247 = age['2,4,7'] + age247bis = age[['2', '4', '7']] + age359 = age[['3', '5', '9']] + age468 = age['4,6,8'] >> 'even' + assert 5 in age + assert '5' not in age + assert age2 in age + assert age2bis not in age + assert age2ter not in age + assert age2qua not in age + assert age20 not in age + assert age20bis not in age + assert age20ter not in age + assert age20qua not in age + assert ['3', '5', '9'] not in age + assert '3,5,9' not in age + assert '3:9' not in age + assert age247 not in age + assert age247bis not in age + assert age359 not in age + assert age468 not in age + agg = Axis((age2, age247, age359, age468, '2,6', ['3', '5', '7'], ('6', '7', '9')), "agg") + assert age2bis not in agg + assert age2ter not in agg + assert age2qua not in age + assert age247 in agg + assert age247bis in agg + assert '2,4,7' in agg + assert ['2', '4', '7'] in agg + assert age359 in agg + assert '3,5,9' in agg + assert ['3', '5', '9'] in agg + assert age468 in agg + assert 'even' in agg + assert '2,6' in agg + assert ['2', '6'] in agg + assert age['2,6'] in agg + assert age[['2', '6']] in agg + assert '3,5,7' in agg + assert ['3', '5', '7'] in agg + assert age['3,5,7'] in agg + assert age[['3', '5', '7']] in agg + assert '6,7,9' in agg + assert ['6', '7', '9'] in agg + assert age['6,7,9'] in agg + assert age[['6', '7', '9']] in agg + assert 5 not in agg + assert '5' not in agg + assert age20 not in agg + assert age20bis not in agg + assert age20ter not in agg + assert age20qua not in agg + assert '2,7' not in agg + assert ['2', '7'] not in agg + assert age['2,7'] not in agg + assert age[['2', '7']] not in agg + +def test_h5_io(tmpdir): + age = Axis('age=0..10') + lipro = Axis('lipro=P01..P05') + anonymous = Axis(range(3)) + wildcard = Axis(3, 'wildcard') + fpath = os.path.join(str(tmpdir), 'axes.h5') + + # ---- default behavior ---- + # int axis + age.to_hdf(fpath) + age2 = read_hdf(fpath, key=age.name) + assert age.equals(age2) + # string axis + lipro.to_hdf(fpath) + lipro2 = read_hdf(fpath, key=lipro.name) + assert lipro.equals(lipro2) + # anonymous axis + with pytest.raises(ValueError, message="Argument key must be provided explicitly in case of anonymous axis"): + anonymous.to_hdf(fpath) + # wildcard axis + wildcard.to_hdf(fpath) + wildcard2 = read_hdf(fpath, key=wildcard.name) + assert wildcard2.iswildcard + assert wildcard.equals(wildcard2) + + # ---- specific key ---- + # int axis + key = 'axis_age' + age.to_hdf(fpath, key) + age2 = read_hdf(fpath, key=key) + assert age.equals(age2) + # string axis + key = 'axis_lipro' + lipro.to_hdf(fpath, key) + lipro2 = read_hdf(fpath, key=key) + assert lipro.equals(lipro2) + # anonymous axis + key = 'axis_anonymous' + anonymous.to_hdf(fpath, key) + anonymous2 = read_hdf(fpath, key=key) + assert anonymous2.name is None + assert_array_equal(anonymous.labels, anonymous2.labels) + # wildcard axis + key = 'axis_wildcard' + wildcard.to_hdf(fpath, key) + wildcard2 = read_hdf(fpath, key=key) + assert wildcard2.iswildcard + assert wildcard.equals(wildcard2) + + # ---- specific hdf group + key ---- + hdf_group = 'my_axes' + # int axis + key = hdf_group + '/axis_age' + age.to_hdf(fpath, key) + age2 = read_hdf(fpath, key=key) + assert age.equals(age2) + # string axis + key = hdf_group + '/axis_lipro' + lipro.to_hdf(fpath, key) + lipro2 = read_hdf(fpath, key=key) + assert lipro.equals(lipro2) + # anonymous axis + key = hdf_group + '/axis_anonymous' + anonymous.to_hdf(fpath, key) + anonymous2 = read_hdf(fpath, key=key) + assert anonymous2.name is None + assert_array_equal(anonymous.labels, anonymous2.labels) + # wildcard axis + key = hdf_group + '/axis_wildcard' + wildcard.to_hdf(fpath, key) + wildcard2 = read_hdf(fpath, key=key) + assert wildcard2.iswildcard + assert wildcard.equals(wildcard2) if __name__ == "__main__": diff --git a/larray/tests/test_axiscollection.py b/larray/tests/test_axiscollection.py new file mode 100644 index 000000000..daec2a261 --- /dev/null +++ b/larray/tests/test_axiscollection.py @@ -0,0 +1,311 @@ +from __future__ import absolute_import, division, print_function +import pytest + +from larray.tests.common import assert_array_equal, assert_axis_eq +from larray import Axis, AxisCollection + + +lipro = Axis('lipro=P01..P04') +sex = Axis('sex=M,F') +sex2 = Axis('sex=F,M') +age = Axis('age=0..7') +geo = Axis('geo=A11,A12,A13') +value = Axis('value=0..10') + +@pytest.fixture +def col(): + return AxisCollection((lipro, sex, age)) + + +def test_init_from_group(): + lipro_subset = lipro[:'P03'] + col2 = AxisCollection((lipro_subset, sex)) + assert col2.names == ['lipro', 'sex'] + assert_array_equal(col2.lipro.labels, ['P01', 'P02', 'P03']) + assert_array_equal(col2.sex.labels, ['M', 'F']) + +def test_init_from_string(): + col = AxisCollection('age=10;sex=M,F;year=2000..2017') + assert col.names == ['age', 'sex', 'year'] + assert list(col.age.labels) == [10] + assert list(col.sex.labels) == ['M', 'F'] + assert list(col.year.labels) == [y for y in range(2000, 2018)] + +def test_eq(col): + assert col == col + assert col == AxisCollection((lipro, sex, age)) + assert col == (lipro, sex, age) + assert col != (lipro, age, sex) + +def test_getitem_name(col): + assert_axis_eq(col['lipro'], lipro) + assert_axis_eq(col['sex'], sex) + assert_axis_eq(col['age'], age) + +def test_getitem_int(col): + assert_axis_eq(col[0], lipro) + assert_axis_eq(col[-3], lipro) + assert_axis_eq(col[1], sex) + assert_axis_eq(col[-2], sex) + assert_axis_eq(col[2], age) + assert_axis_eq(col[-1], age) + +def test_getitem_slice(col): + col = col[:2] + assert len(col) == 2 + assert_axis_eq(col[0], lipro) + assert_axis_eq(col[1], sex) + +def test_setitem_name(col): + col2 = col[:] + col2['lipro'] = geo + assert len(col2) == 3 + assert col2 == [geo, sex, age] + col2['sex'] = sex2 + assert col2 == [geo, sex2, age] + col2['geo'] = lipro + assert col2 == [lipro, sex2, age] + col2['age'] = geo + assert col2 == [lipro, sex2, geo] + col2['sex'] = sex + col2['geo'] = age + assert col2 == col + +def test_setitem_name_axis_def(col): + col2 = col[:] + col2['lipro'] = 'geo=A11,A12,A13' + assert len(col2) == 3 + assert col2 == [geo, sex, age] + col2['sex'] = 'sex=F,M' + assert col2 == [geo, sex2, age] + col2['geo'] = 'lipro=P01..P04' + assert col2 == [lipro, sex2, age] + col2['age'] = 'geo=A11,A12,A13' + assert col2 == [lipro, sex2, geo] + col2['sex'] = 'sex=M,F' + col2['geo'] = 'age=0..7' + assert col2 == col + +def test_setitem_int(col): + col[1] = geo + assert len(col) == 3 + assert col == [lipro, geo, age] + col[2] = sex + assert col == [lipro, geo, sex] + col[-1] = age + assert col == [lipro, geo, age] + +def test_setitem_list_replace(col): + col[['lipro', 'age']] = [geo, lipro] + assert col == [geo, sex, lipro] + +def test_setitem_slice_replace(col): + col2 = col[:] + col2[1:] = [geo, sex] + assert col2 == [lipro, geo, sex] + col2[1:] = col[1:] + assert col2 == col + +def test_setitem_slice_insert(col): + col[1:1] = [geo] + assert col == [lipro, geo, sex, age] + +def test_setitem_slice_delete(col): + col[1:2] = [] + assert col == [lipro, age] + col[0:1] = [] + assert col == [age] + +def test_delitem(col): + assert len(col) == 3 + del col[0] + assert len(col) == 2 + assert_axis_eq(col[0], sex) + assert_axis_eq(col[1], age) + del col['age'] + assert len(col) == 1 + assert_axis_eq(col[0], sex) + del col[sex] + assert len(col) == 0 + +def test_delitem_slice(col): + assert len(col) == 3 + del col[0:2] + assert len(col) == 1 + assert col == [age] + del col[:] + assert len(col) == 0 + +def test_pop(col): + lipro, sex, age = col + assert len(col) == 3 + assert col.pop() is age + assert len(col) == 2 + assert col[0] is lipro + assert col[1] is sex + assert col.pop() is sex + assert len(col) == 1 + assert col[0] is lipro + assert col.pop() is lipro + assert len(col) == 0 + +def test_replace(col): + col2 = col[:] + newcol = col2.replace('sex', geo) + assert col2 == col + assert len(newcol) == 3 + assert newcol.names == ['lipro', 'geo', 'age'] + assert newcol.shape == (4, 3, 8) + newcol = newcol.replace(geo, sex) + assert len(newcol) == 3 + assert newcol.names == ['lipro', 'sex', 'age'] + assert newcol.shape == (4, 2, 8) + newcol = col2.replace(sex, 3) + assert len(newcol) == 3 + assert newcol.names == ['lipro', None, 'age'] + assert newcol.shape == (4, 3, 8) + newcol = col2.replace(sex, ['a', 'b', 'c']) + assert len(newcol) == 3 + assert newcol.names == ['lipro', None, 'age'] + assert newcol.shape == (4, 3, 8) + newcol = col2.replace(sex, "letters=a,b,c") + assert len(newcol) == 3 + assert newcol.names == ['lipro', 'letters', 'age'] + assert newcol.shape == (4, 3, 8) + +def test_contains(col): + assert 'lipro' in col + assert not ('nonexisting' in col) + assert 0 in col + assert 1 in col + assert 2 in col + assert -1 in col + assert -2 in col + assert -3 in col + assert not (3 in col) + assert lipro in col + assert sex in col + assert age in col + assert sex2 in col + assert not (geo in col) + assert not (value in col) + anon = Axis([0, 1]) + col.append(anon) + assert anon in col + anon2 = anon.copy() + assert anon2 in col + anon3 = Axis([0, 2]) + assert not (anon3 in col) + +def test_index(col): + assert col.index('lipro') == 0 + with pytest.raises(ValueError): + col.index('nonexisting') + assert col.index(0)== 0 + assert col.index(1) == 1 + assert col.index(2) == 2 + assert col.index(-1) == -1 + assert col.index(-2) == -2 + assert col.index(-3) == -3 + with pytest.raises(ValueError): + col.index(3) + + # objects actually in col + assert col.index(lipro) == 0 + assert col.index(sex) == 1 + assert col.index(age) == 2 + assert col.index(sex2) == 1 + with pytest.raises(ValueError): + col.index(geo) + with pytest.raises(ValueError): + col.index(value) + + # test anonymous axes + anon = Axis([0, 1]) + col.append(anon) + assert col.index(anon) == 3 + anon2 = anon.copy() + assert col.index(anon2) == 3 + anon3 = Axis([0,2]) + with pytest.raises(ValueError): + col.index(anon3) + +def test_get(col): + assert_axis_eq(col.get('lipro'), lipro) + assert col.get('nonexisting') is None + assert col.get('nonexisting', value) is value + +def test_keys(col): + assert col.keys() == ['lipro', 'sex', 'age'] + +def test_getattr(col): + assert_axis_eq(col.lipro, lipro) + assert_axis_eq(col.sex, sex) + assert_axis_eq(col.age, age) + +def test_append(col): + geo = Axis('geo=A11,A12,A13') + col.append(geo) + assert col == [lipro, sex, age, geo] + +def test_extend(col): + col.extend([geo, value]) + assert col == [lipro, sex, age, geo, value] + +def test_insert(col): + col.insert(1, geo) + assert col == [lipro, geo, sex, age] + +def test_add(col): + col2 = col.copy() + new = col2 + [geo, value] + assert new == [lipro, sex, age, geo, value] + assert col2 == col + new = col2 + [Axis('geo=A11,A12,A13'), Axis('age=0..7')] + assert new == [lipro, sex, age, geo] + with pytest.raises(ValueError): + col2 + [Axis('geo=A11,A12,A13'), Axis('age=0..6')] + + # 2) other AxisCollection + new = col2 + AxisCollection([geo, value]) + assert new == [lipro,sex,age,geo,value] + +def test_combine(col): + res = col.combine_axes((lipro, sex)) + assert res.names == ['lipro_sex', 'age'] + assert res.size == col.size + assert res.shape == (4 * 2, 8) + print(res.info) + assert_array_equal(res.lipro_sex.labels[0], 'P01_M') + res = col.combine_axes((lipro, age)) + assert res.names == ['lipro_age', 'sex'] + assert res.size == col.size + assert res.shape == (4 * 8, 2) + assert_array_equal(res.lipro_age.labels[0], 'P01_0') + res = col.combine_axes((sex, age)) + assert res.names == ['lipro', 'sex_age'] + assert res.size == col.size + assert res.shape == (4, 2 * 8) + assert_array_equal(res.sex_age.labels[0], 'M_0') + +def test_info(col): + expected = """\ +4 x 2 x 8 + lipro [4]: 'P01' 'P02' 'P03' 'P04' + sex [2]: 'M' 'F' + age [8]: 0 1 2 ... 5 6 7""" + assert col.info == expected + +def test_str(col): + assert str(col) == "{lipro, sex, age}" + +def test_repr(col): + assert repr(col) == """AxisCollection([ + Axis(['P01', 'P02', 'P03', 'P04'], 'lipro'), + Axis(['M', 'F'], 'sex'), + Axis([0, 1, 2, 3, 4, 5, 6, 7], 'age') +])""" + + +if __name__ == "__main__": + pytest.main() \ No newline at end of file diff --git a/larray/tests/test_excel.py b/larray/tests/test_excel.py index 0f04eeafa..eb6be5fe7 100644 --- a/larray/tests/test_excel.py +++ b/larray/tests/test_excel.py @@ -11,7 +11,7 @@ xw = None from larray import ndtest, larray_equal, open_excel, aslarray, Axis -from larray.inout import excel +from larray.inout import xw_excel @pytest.mark.skipif(xw is None, reason="xlwings is not available") @@ -26,7 +26,7 @@ def test_open_excel(self): wb1.sheet_names() wb2 = open_excel(visible=False) app2 = wb2.app - assert app1 == app2 == excel.global_app + assert app1 == app2 == xw_excel.global_app # this effectively close all workbooks but leaves the instance intact (this is probably due to us keeping a # reference to it). app1.quit() @@ -40,18 +40,18 @@ def test_open_excel(self): def test_repr(self): with open_excel(visible=False) as wb: - assert re.match('', repr(wb)) + assert re.match('', repr(wb)) def test_getitem(self): with open_excel(visible=False) as wb: sheet = wb[0] - assert isinstance(sheet, excel.Sheet) + assert isinstance(sheet, xw_excel.Sheet) # this might not be true on non-English locale assert sheet.name == 'Sheet1' # this might not work on non-English locale sheet = wb['Sheet1'] - assert isinstance(sheet, excel.Sheet) + assert isinstance(sheet, xw_excel.Sheet) assert sheet.name == 'Sheet1' with pytest.raises(KeyError) as e_info: @@ -166,7 +166,7 @@ def test_array_method(self): def test_repr(self): with open_excel(visible=False) as wb: sheet = wb[0] - assert re.match('', repr(sheet)) + assert re.match('', repr(sheet)) @pytest.mark.skipif(xw is None, reason="xlwings is not available") diff --git a/larray/tests/test_group.py b/larray/tests/test_group.py index bbf6f2f14..c01cd0e76 100644 --- a/larray/tests/test_group.py +++ b/larray/tests/test_group.py @@ -1,294 +1,417 @@ from __future__ import absolute_import, division, print_function - -from unittest import TestCase - import pytest +import os.path import numpy as np from larray.tests.common import assert_array_equal -from larray import Axis, LGroup, LSet, ndtest - - -class TestLGroup(TestCase): - def setUp(self): - self.age = Axis('age=0..10') - self.lipro = Axis('lipro=P01..P05') - self.anonymous = Axis(range(3)) - - self.slice_both_named_wh_named_axis = LGroup('1:5', "full", self.age) - self.slice_both_named = LGroup('1:5', "named") - self.slice_both = LGroup('1:5') - self.slice_start = LGroup('1:') - self.slice_stop = LGroup(':5') - self.slice_none_no_axis = LGroup(':') - self.slice_none_wh_named_axis = LGroup(':', axis=self.lipro) - self.slice_none_wh_anonymous_axis = LGroup(':', axis=self.anonymous) - - self.single_value = LGroup('P03') - self.list = LGroup('P01,P03,P04') - self.list_named = LGroup('P01,P03,P04', "P134") - - def test_init(self): - self.assertEqual(self.slice_both_named_wh_named_axis.name, "full") - self.assertEqual(self.slice_both_named_wh_named_axis.key, slice(1, 5, None)) - self.assertIs(self.slice_both_named_wh_named_axis.axis, self.age) - - self.assertEqual(self.slice_both_named.name, "named") - self.assertEqual(self.slice_both_named.key, slice(1, 5, None)) - - self.assertEqual(self.slice_both.key, slice(1, 5, None)) - self.assertEqual(self.slice_start.key, slice(1, None, None)) - self.assertEqual(self.slice_stop.key, slice(None, 5, None)) - self.assertEqual(self.slice_none_no_axis.key, slice(None, None, None)) - self.assertIs(self.slice_none_wh_named_axis.axis, self.lipro) - self.assertIs(self.slice_none_wh_anonymous_axis.axis, self.anonymous) - - self.assertEqual(self.single_value.key, 'P03') - self.assertEqual(self.list.key, ['P01', 'P03', 'P04']) - - # passing an axis as name - group = LGroup('1:5', self.age, self.age) - assert group.name == self.age.name - group = self.age['1:5'] >> self.age - assert group.name == self.age.name - # passing an unnamed group as name - group2 = LGroup('1', axis=self.age) - group = LGroup('1', group2, axis=self.age) - assert group.name == '1' - group = self.age['1'] >> group2 - assert group.name == '1' - # passing a named group as name - group2 = LGroup('1:5', 'age', self.age) - group = LGroup('1:5', group2, axis=self.age) - assert group.name == group2.name - group = self.age['1:5'] >> group2 - assert group.name == group2.name - # additional test - axis = Axis('axis=a,a0..a3,b,b0..b3,c,c0..c3') - for code in axis.matching('^.$'): - group = axis.startingwith(code) >> code - assert group == axis.startingwith(code) >> str(code) - - def test_eq(self): - # with axis vs no axis do not compare equal - # self.assertEqual(self.slice_both, self.slice_both_named_wh_named_axis) - self.assertEqual(self.slice_both, self.slice_both_named) - - res = self.slice_both_named_wh_named_axis == self.age[1:5] - self.assertIsInstance(res, np.ndarray) - self.assertEqual(res.shape, (5,)) - self.assertTrue(res.all()) - - self.assertEqual(self.slice_both, LGroup(slice(1, 5))) - self.assertEqual(self.slice_start, LGroup(slice(1, None))) - self.assertEqual(self.slice_stop, LGroup(slice(5))) - self.assertEqual(self.slice_none_no_axis, LGroup(slice(None))) - - self.assertEqual(self.single_value, LGroup('P03')) - self.assertEqual(self.list, LGroup(['P01', 'P03', 'P04'])) - self.assertEqual(self.list_named, LGroup(['P01', 'P03', 'P04'])) - - # test with raw objects - self.assertEqual(self.slice_both, slice(1, 5)) - self.assertEqual(self.slice_start, slice(1, None)) - self.assertEqual(self.slice_stop, slice(5)) - self.assertEqual(self.slice_none_no_axis, slice(None)) - - self.assertEqual(self.single_value, 'P03') - self.assertEqual(self.list, ['P01', 'P03', 'P04']) - self.assertEqual(self.list_named, ['P01', 'P03', 'P04']) - - def test_getitem(self): - axis = Axis("a=a0,a1") - assert axis['a0'][0] == 'a' - assert axis['a0'][1] == '0' - assert axis['a0':'a1'][1] == 'a1' - assert axis[:][1] == 'a1' - assert list(axis[:][0:2]) == ['a0', 'a1'] - assert list((axis[:][[1, 0]])) == ['a1', 'a0'] - assert axis[['a0', 'a1', 'a0']][2] == 'a0' - assert axis[('a0', 'a1', 'a0')][2] == 'a0' - assert axis[ndtest("a=a0,a1,a0")][2] == 2 - - def test_sorted(self): - self.assertEqual(sorted(LGroup(['c', 'd', 'a', 'b'])), - [LGroup('a'), LGroup('b'), LGroup('c'), LGroup('d')]) - - def test_asarray(self): - assert_array_equal(np.asarray(self.slice_both_named_wh_named_axis), np.array([1, 2, 3, 4, 5])) - assert_array_equal(np.asarray(self.slice_none_wh_named_axis), np.array(['P01', 'P02', 'P03', 'P04', 'P05'])) - - def test_hash(self): - # this test is a lot less important than what it used to, because we cannot have Group ticks on an axis anymore - d = {self.slice_both: 1, - self.single_value: 2, - self.list_named: 3} - # target a LGroup with an equivalent LGroup - self.assertEqual(d.get(self.slice_both), 1) - self.assertEqual(d.get(self.single_value), 2) - self.assertEqual(d.get(self.list), 3) - self.assertEqual(d.get(self.list_named), 3) - - def test_repr(self): - self.assertEqual(repr(self.slice_both_named_wh_named_axis), "age[1:5] >> 'full'") - self.assertEqual(repr(self.slice_both_named), "LGroup(slice(1, 5, None)) >> 'named'") - self.assertEqual(repr(self.slice_both), "LGroup(slice(1, 5, None))") - self.assertEqual(repr(self.list), "LGroup(['P01', 'P03', 'P04'])") - self.assertEqual(repr(self.slice_none_no_axis), "LGroup(slice(None, None, None))") - self.assertEqual(repr(self.slice_none_wh_named_axis), "lipro[:]") - self.assertEqual(repr(self.slice_none_wh_anonymous_axis), - "LGroup(slice(None, None, None), axis=Axis([0, 1, 2], None))") - - def test_to_int(self): - a = Axis(['42'], 'a') - self.assertEqual(int(a['42']), 42) - - def test_to_float(self): - a = Axis(['42'], 'a') - self.assertEqual(float(a['42']), 42.0) - - -class TestLSet(TestCase): - def test_or(self): - # without axis - self.assertEqual(LSet(['a', 'b']) | LSet(['c', 'd']), - LSet(['a', 'b', 'c', 'd'])) - self.assertEqual(LSet(['a', 'b', 'c']) | LSet(['c', 'd']), - LSet(['a', 'b', 'c', 'd'])) - # with axis (pure) - alpha = Axis('alpha=a,b,c,d') - res = alpha['a', 'b'].set() | alpha['c', 'd'].set() - self.assertIs(res.axis, alpha) - self.assertEqual(res, alpha['a', 'b', 'c', 'd'].set()) - self.assertEqual(alpha['a', 'b', 'c'].set() | alpha['c', 'd'].set(), - alpha['a', 'b', 'c', 'd'].set()) - - # with axis (mixed) - alpha = Axis('alpha=a,b,c,d') - res = alpha['a', 'b'].set() | alpha['c', 'd'] - self.assertIs(res.axis, alpha) - self.assertEqual(res, alpha['a', 'b', 'c', 'd'].set()) - self.assertEqual(alpha['a', 'b', 'c'].set() | alpha['c', 'd'], - alpha['a', 'b', 'c', 'd'].set()) - - # with axis & name - alpha = Axis('alpha=a,b,c,d') - res = alpha['a', 'b'].set().named('ab') | alpha['c', 'd'].set().named('cd') - self.assertIs(res.axis, alpha) - self.assertEqual(res.name, 'ab | cd') - self.assertEqual(res, alpha['a', 'b', 'c', 'd'].set()) - self.assertEqual(alpha['a', 'b', 'c'].set() | alpha['c', 'd'], - alpha['a', 'b', 'c', 'd'].set()) - - # numeric axis - num = Axis(range(10), 'num') - # single int - self.assertEqual(num[1, 5, 3].set() | 4, num[1, 5, 3, 4].set()) - self.assertEqual(num[1, 5, 3].set() | num[4], num[1, 5, 3, 4].set()) - self.assertEqual(num[4].set() | num[1, 5, 3], num[4, 1, 5, 3].set()) - # slices - self.assertEqual(num[:2].set() | num[8:], num[0, 1, 2, 8, 9].set()) - self.assertEqual(num[:2].set() | num[5], num[0, 1, 2, 5].set()) - - def test_and(self): - # without axis - self.assertEqual(LSet(['a', 'b', 'c']) & LSet(['c', 'd']), LSet(['c'])) - # with axis & name - alpha = Axis('alpha=a,b,c,d') - res = alpha['a', 'b', 'c'].named('abc').set() & alpha['c', 'd'].named('cd') - self.assertIs(res.axis, alpha) - self.assertEqual(res.name, 'abc & cd') - self.assertEqual(res, alpha[['c']].set()) - - def test_sub(self): - self.assertEqual(LSet(['a', 'b', 'c']) - LSet(['c', 'd']), LSet(['a', 'b'])) - self.assertEqual(LSet(['a', 'b', 'c']) - ['c', 'd'], LSet(['a', 'b'])) - self.assertEqual(LSet(['a', 'b', 'c']) - 'b', LSet(['a', 'c'])) - self.assertEqual(LSet([1, 2, 3]) - 4, LSet([1, 2, 3])) - self.assertEqual(LSet([1, 2, 3]) - 2, LSet([1, 3])) - - -class TestIGroup(TestCase): - def _assert_array_equal_is_true_array(self, a, b): - res = a == b - self.assertIsInstance(res, np.ndarray) - self.assertEqual(res.shape, np.asarray(b).shape) - self.assertTrue(res.all()) - - def setUp(self): - self.code_axis = Axis('code=a0..a4') - - self.slice_both_named = self.code_axis.i[1:4] >> 'a123' - self.slice_both = self.code_axis.i[1:4] - self.slice_start = self.code_axis.i[1:] - self.slice_stop = self.code_axis.i[:4] - self.slice_none = self.code_axis.i[:] - - self.first_value = self.code_axis.i[0] - self.last_value = self.code_axis.i[-1] - self.list = self.code_axis.i[[0, 1, -2, -1]] - self.tuple = self.code_axis.i[0, 1, -2, -1] - - def test_asarray(self): - assert_array_equal(np.asarray(self.slice_both), np.array(['a1', 'a2', 'a3'])) - - def test_eq(self): - self._assert_array_equal_is_true_array(self.slice_both, ['a1', 'a2', 'a3']) - self._assert_array_equal_is_true_array(self.slice_both_named, ['a1', 'a2', 'a3']) - self._assert_array_equal_is_true_array(self.slice_both, self.slice_both_named) - self._assert_array_equal_is_true_array(self.slice_both_named, self.slice_both) - self._assert_array_equal_is_true_array(self.slice_start, ['a1', 'a2', 'a3', 'a4']) - self._assert_array_equal_is_true_array(self.slice_stop, ['a0', 'a1', 'a2', 'a3']) - self._assert_array_equal_is_true_array(self.slice_none, ['a0', 'a1', 'a2', 'a3', 'a4']) - - self.assertEqual(self.first_value, 'a0') - self.assertEqual(self.last_value, 'a4') - - self._assert_array_equal_is_true_array(self.list, ['a0', 'a1', 'a3', 'a4']) - self._assert_array_equal_is_true_array(self.tuple, ['a0', 'a1', 'a3', 'a4']) - - def test_getitem(self): - axis = Axis("a=a0,a1") - assert axis.i[0][0] == 'a' - assert axis.i[0][1] == '0' - assert axis.i[0:1][1] == 'a1' - assert axis.i[:][1] == 'a1' - assert list(axis.i[:][0:2]) == ['a0', 'a1'] - assert list((axis.i[:][[1, 0]])) == ['a1', 'a0'] - assert axis.i[[0, 1, 0]][2] == 'a0' - assert axis.i[(0, 1, 0)][2] == 'a0' - - def test_getattr(self): - agg = Axis(['a1:a2', ':a2', 'a1:'], 'agg') - self.assertEqual(agg.i[0].split(':'), ['a1', 'a2']) - self.assertEqual(agg.i[1].split(':'), ['', 'a2']) - self.assertEqual(agg.i[2].split(':'), ['a1', '']) - - def test_dir(self): - agg = Axis(['a', 1], 'agg') - self.assertTrue('split' in dir(agg.i[0])) - self.assertTrue('strip' in dir(agg.i[0])) - self.assertTrue('strip' in dir(agg.i[0])) - - def test_repr(self): - self.assertEqual(repr(self.slice_both_named), "code.i[1:4] >> 'a123'") - self.assertEqual(repr(self.slice_both), "code.i[1:4]") - self.assertEqual(repr(self.slice_start), "code.i[1:]") - self.assertEqual(repr(self.slice_stop), "code.i[:4]") - self.assertEqual(repr(self.slice_none), "code.i[:]") - self.assertEqual(repr(self.first_value), "code.i[0]") - self.assertEqual(repr(self.last_value), "code.i[-1]") - self.assertEqual(repr(self.list), "code.i[0, 1, -2, -1]") - self.assertEqual(repr(self.tuple), "code.i[0, 1, -2, -1]") - - def test_to_int(self): - a = Axis(['42'], 'a') - self.assertEqual(int(a.i[0]), 42) - - def test_to_float(self): - a = Axis(['42'], 'a') - self.assertEqual(float(a.i[0]), 42.0) - +from larray import Axis, LGroup, LSet, ndtest, read_hdf + + +age = Axis('age=0..10') +lipro = Axis('lipro=P01..P05') +anonymous = Axis(range(3)) +age_wildcard = Axis(10, 'wildcard') + + +# ################## # +# LGroup # +# ################## # + +@pytest.fixture +def lgroups(): + class TestLGroup(): + def __init__(self): + self.slice_both_named_wh_named_axis = LGroup('1:5', "full", age) + self.slice_both_named = LGroup('1:5', "named") + self.slice_both = LGroup('1:5') + self.slice_start = LGroup('1:') + self.slice_stop = LGroup(':5') + self.slice_none_no_axis = LGroup(':') + self.slice_none_wh_named_axis = LGroup(':', axis=lipro) + self.slice_none_wh_anonymous_axis = LGroup(':', axis=anonymous) + self.single_value = LGroup('P03') + self.list = LGroup('P01,P03,P04') + self.list_named = LGroup('P01,P03,P04', "P134") + return TestLGroup() + + +def test_init_lgroup(lgroups): + assert lgroups.slice_both_named_wh_named_axis.name == "full" + assert lgroups.slice_both_named_wh_named_axis.key == slice(1, 5, None) + assert lgroups.slice_both_named_wh_named_axis.axis is age + assert lgroups.slice_both_named.name == "named" + assert lgroups.slice_both_named.key == slice(1, 5, None) + assert lgroups.slice_both.key == slice(1, 5, None) + assert lgroups.slice_start.key == slice(1, None, None) + assert lgroups.slice_stop.key == slice(None, 5, None) + assert lgroups.slice_none_no_axis.key == slice(None, None, None) + assert lgroups.slice_none_wh_named_axis.axis is lipro + assert lgroups.slice_none_wh_anonymous_axis.axis is anonymous + assert lgroups.single_value.key == 'P03' + assert lgroups.list.key == ['P01', 'P03', 'P04'] + group = LGroup('1:5', age, age) + assert group.name == age.name + group = age['1:5'] >> age + assert group.name == age.name + group2 = LGroup('1', axis=age) + group = LGroup('1', group2, axis=age) + assert group.name == '1' + group = age['1'] >> group2 + assert group.name == '1' + group2 = LGroup('1:5', 'age', age) + group = LGroup('1:5', group2, axis=age) + assert group.name == group2.name + group = age['1:5'] >> group2 + assert group.name == group2.name + axis = Axis('axis=a,a0..a3,b,b0..b3,c,c0..c3') + for code in axis.matching('^.$'): + group = axis.startingwith(code) >> code + assert group == axis.startingwith(code) >> str(code) + +def test_eq_lgroup(lgroups): + # with axis vs no axis do not compare equal + # lgroups.slice_both == lgroups.slice_both_named_wh_named_axis + assert lgroups.slice_both == lgroups.slice_both_named + res = lgroups.slice_both_named_wh_named_axis == age[1:5] + assert isinstance(res, np.ndarray) + assert res.shape == (5,) + assert res.all() + assert lgroups.slice_both == LGroup(slice(1, 5)) + assert lgroups.slice_start == LGroup(slice(1, None)) + assert lgroups.slice_stop == LGroup(slice(5)) + assert lgroups.slice_none_no_axis == LGroup(slice(None)) + assert lgroups.single_value == LGroup('P03') + assert lgroups.list == LGroup(['P01', 'P03', 'P04']) + assert lgroups.list_named == LGroup(['P01', 'P03', 'P04']) + assert lgroups.slice_both == slice(1, 5) + assert lgroups.slice_start == slice(1, None) + assert lgroups.slice_stop == slice(5) + assert lgroups.slice_none_no_axis == slice(None) + assert lgroups.single_value == 'P03' + assert lgroups.list == ['P01', 'P03', 'P04'] + assert lgroups.list_named == ['P01', 'P03', 'P04'] + +def test_getitem_lgroup(): + axis = Axis("a=a0,a1") + assert axis['a0'][0] == 'a' + assert axis['a0'][1] == '0' + assert axis['a0':'a1'][1] == 'a1' + assert axis[:][1] == 'a1' + assert list(axis[:][0:2]) == ['a0', 'a1'] + assert list((axis[:][[1, 0]])) == ['a1', 'a0'] + assert axis[['a0', 'a1', 'a0']][2] == 'a0' + assert axis[('a0', 'a1', 'a0')][2] == 'a0' + assert axis[ndtest("a=a0,a1,a0")][2] == 2 + +def test_sorted_lgroup(): + assert sorted(LGroup(['c', 'd', 'a', 'b'])) == [LGroup('a'), LGroup('b'), LGroup('c'), LGroup('d')] + +def test_asarray_lgroup(lgroups): + assert_array_equal(np.asarray(lgroups.slice_both_named_wh_named_axis), np.array([1, 2, 3, 4, 5])) + assert_array_equal(np.asarray(lgroups.slice_none_wh_named_axis), np.array(['P01', 'P02', 'P03', 'P04', 'P05'])) + +def test_hash_lgroup(lgroups): + # this test is a lot less important than what it used to, because we cannot have Group ticks on an axis anymore + d = {lgroups.slice_both: 1, lgroups.single_value: 2, lgroups.list_named: 3} + assert d.get(lgroups.slice_both) == 1 + assert d.get(lgroups.single_value) == 2 + assert d.get(lgroups.list) == 3 + assert d.get(lgroups.list_named) == 3 + +def test_repr_lgroup(lgroups): + assert repr(lgroups.slice_both_named_wh_named_axis) == "age[1:5] >> 'full'" + assert repr(lgroups.slice_both_named) == "LGroup(slice(1, 5, None)) >> 'named'" + assert repr(lgroups.slice_both) == "LGroup(slice(1, 5, None))" + assert repr(lgroups.list) == "LGroup(['P01', 'P03', 'P04'])" + assert repr(lgroups.slice_none_no_axis) == "LGroup(slice(None, None, None))" + assert repr(lgroups.slice_none_wh_named_axis) == "lipro[:]" + assert repr(lgroups.slice_none_wh_anonymous_axis) == "LGroup(slice(None, None, None), axis=Axis([0, 1, 2], None))" + +def test_to_int_lgroup(): + a = Axis(['42'], 'a') + assert int(a['42']) == 42 + +def test_to_float_lgroup(): + a = Axis(['42'], 'a') + assert float(a['42']) == 42.0 + +def test_h5_io_lgroup(tmpdir): + fpath = os.path.join(str(tmpdir), 'lgroups.h5') + age.to_hdf(fpath) + + named = age[':5'] >> 'age_05' + named_axis_not_in_file = lipro['P01,P03,P05'] >> 'P_odd' + anonymous = age[':5'] + wildcard = age_wildcard[':5'] >> 'age_w_05' + + # ---- default behavior ---- + # named group + named.to_hdf(fpath) + named2 = read_hdf(fpath, key=named.name) + assert all(named == named2) + # anonymous group + with pytest.raises(ValueError, message="Argument key must be provided explicitly in case of anonymous axis"): + anonymous.to_hdf(fpath) + # wildcard group + wildcard.to_hdf(fpath) + wildcard2 = read_hdf(fpath, key=wildcard.name) + assert all(wildcard == wildcard2) + # associated axis not saved yet + named_axis_not_in_file.to_hdf(fpath) + named2 = read_hdf(fpath, key=named_axis_not_in_file.name) + assert all(named_axis_not_in_file == named2) + + # ---- specific hdf group + key ---- + hdf_group = 'my_groups' + # named group + key = hdf_group + '/named_group' + named.to_hdf(fpath, key) + named2 = read_hdf(fpath, key=key) + assert all(named == named2) + # anonymous group + key = hdf_group + '/anonymous_group' + anonymous.to_hdf(fpath, key) + anonymous2 = read_hdf(fpath, key=key) + assert anonymous2.name is None + assert all(anonymous == anonymous2) + # wildcard group + key = hdf_group + '/wildcard_group' + wildcard.to_hdf(fpath, key) + wildcard2 = read_hdf(fpath, key=key) + assert all(wildcard == wildcard2) + # associated axis not saved yet + key = hdf_group + '/named_group_axis_not_in_file' + named_axis_not_in_file.to_hdf(fpath, key=key) + named2 = read_hdf(fpath, key=key) + assert all(named_axis_not_in_file == named2) + + # ---- specific axis_key ---- + axis_key = 'axes/associated_axis_0' + # named group + named.to_hdf(fpath, axis_key=axis_key) + named2 = read_hdf(fpath, key=named.name) + assert all(named == named2) + # anonymous group + key = 'anonymous' + anonymous.to_hdf(fpath, key=key, axis_key=axis_key) + anonymous2 = read_hdf(fpath, key=key) + assert anonymous2.name is None + assert all(anonymous == anonymous2) + # wildcard group + wildcard.to_hdf(fpath, axis_key=axis_key) + wildcard2 = read_hdf(fpath, key=wildcard.name) + assert all(wildcard == wildcard2) + # associated axis not saved yet + axis_key = 'axes/associated_axis_1' + named_axis_not_in_file.to_hdf(fpath, axis_key=axis_key) + named2 = read_hdf(fpath, key=named_axis_not_in_file.name) + assert all(named_axis_not_in_file == named2) + + +# ################## # +# LSet # +# ################## # + +def test_or_lset(): + # without axis + assert LSet(['a', 'b']) | LSet(['c', 'd']) == LSet(['a', 'b', 'c', 'd']) + assert LSet(['a', 'b', 'c']) | LSet(['c', 'd']) == LSet(['a', 'b', 'c', 'd']) + alpha = Axis('alpha=a,b,c,d') + res = alpha['a', 'b'].set() | alpha['c', 'd'].set() + assert res.axis is alpha + assert res == alpha['a', 'b', 'c', 'd'].set() + assert alpha['a', 'b', 'c'].set() | alpha['c', 'd'].set() == alpha['a', 'b', 'c', 'd'].set() + alpha = Axis('alpha=a,b,c,d') + res = alpha['a', 'b'].set() | alpha['c', 'd'] + assert res.axis is alpha + assert res == alpha['a', 'b', 'c', 'd'].set() + assert alpha['a', 'b', 'c'].set() | alpha['c', 'd'] == alpha['a', 'b', 'c', 'd'].set() + alpha = Axis('alpha=a,b,c,d') + res = alpha['a', 'b'].set().named('ab') | alpha['c', 'd'].set().named('cd') + assert res.axis is alpha + assert res.name == 'ab | cd' + assert res == alpha['a', 'b', 'c', 'd'].set() + assert alpha['a', 'b', 'c'].set() | alpha['c', 'd'] == alpha['a', 'b', 'c', 'd'].set() + num = Axis(range(10), 'num') + assert num[1, 5, 3].set() | 4 == num[1, 5, 3, 4].set() + assert num[1, 5, 3].set() | num[4] == num[1, 5, 3, 4].set() + assert num[4].set() | num[1, 5, 3] == num[4, 1, 5, 3].set() + assert num[:2].set() | num[8:] == num[0, 1, 2, 8, 9].set() + assert num[:2].set() | num[5] == num[0, 1, 2, 5].set() + +def test_and_lset(): + # without axis + assert LSet(['a', 'b', 'c']) & LSet(['c', 'd']) == LSet(['c']) + alpha = Axis('alpha=a,b,c,d') + res = alpha['a', 'b', 'c'].named('abc').set() & alpha['c', 'd'].named('cd') + assert res.axis is alpha + assert res.name == 'abc & cd' + assert res == alpha[['c']].set() + +def test_sub_lset(): + assert LSet(['a', 'b', 'c']) - LSet(['c', 'd']) == LSet(['a', 'b']) + assert LSet(['a', 'b', 'c']) - ['c', 'd'] == LSet(['a', 'b']) + assert LSet(['a', 'b', 'c']) - 'b' == LSet(['a', 'c']) + assert LSet([1, 2, 3]) - 4 == LSet([1, 2, 3]) + assert LSet([1, 2, 3]) - 2 == LSet([1, 3]) + + +# ################## # +# IGroup # +# ################## # + +@pytest.fixture +def igroups(): + class TestIGroup(): + def __init__(self): + self.code_axis = Axis('code=a0..a4') + self.slice_both_named = self.code_axis.i[1:4] >> 'a123' + self.slice_both = self.code_axis.i[1:4] + self.slice_start = self.code_axis.i[1:] + self.slice_stop = self.code_axis.i[:4] + self.slice_none = self.code_axis.i[:] + self.first_value = self.code_axis.i[0] + self.last_value = self.code_axis.i[-1] + self.list = self.code_axis.i[[0, 1, -2, -1]] + self.tuple = self.code_axis.i[0, 1, -2, -1] + return TestIGroup() + +def _assert_array_equal_is_true_array(a, b): + res = a == b + assert isinstance(res, np.ndarray) + assert res.shape == np.asarray(b).shape + assert res.all() + + +def test_asarray_igroup(igroups): + assert_array_equal(np.asarray(igroups.slice_both), np.array(['a1', 'a2', 'a3'])) + +def test_eq_igroup(igroups): + _assert_array_equal_is_true_array(igroups.slice_both, ['a1', 'a2', 'a3']) + _assert_array_equal_is_true_array(igroups.slice_both_named, ['a1', 'a2', 'a3']) + _assert_array_equal_is_true_array(igroups.slice_both, igroups.slice_both_named) + _assert_array_equal_is_true_array(igroups.slice_both_named, igroups.slice_both) + _assert_array_equal_is_true_array(igroups.slice_start, ['a1', 'a2', 'a3', 'a4']) + _assert_array_equal_is_true_array(igroups.slice_stop, ['a0', 'a1', 'a2', 'a3']) + _assert_array_equal_is_true_array(igroups.slice_none, ['a0', 'a1', 'a2', 'a3', 'a4']) + assert igroups.first_value == 'a0' + assert igroups.last_value == 'a4' + _assert_array_equal_is_true_array(igroups.list, ['a0', 'a1', 'a3', 'a4']) + _assert_array_equal_is_true_array(igroups.tuple, ['a0', 'a1', 'a3', 'a4']) + +def test_getitem_igroup(): + axis = Axis("a=a0,a1") + assert axis.i[0][0] == 'a' + assert axis.i[0][1] == '0' + assert axis.i[0:1][1] == 'a1' + assert axis.i[:][1] == 'a1' + assert list(axis.i[:][0:2]) == ['a0', 'a1'] + assert list((axis.i[:][[1, 0]])) == ['a1', 'a0'] + assert axis.i[[0, 1, 0]][2] == 'a0' + assert axis.i[(0, 1, 0)][2] == 'a0' + +def test_getattr_igroup(): + agg = Axis(['a1:a2', ':a2', 'a1:'], 'agg') + assert agg.i[0].split(':') == ['a1', 'a2'] + assert agg.i[1].split(':') == ['', 'a2'] + assert agg.i[2].split(':') == ['a1', ''] + +def test_dir_igroup(): + agg = Axis(['a', 1], 'agg') + assert 'split' in dir(agg.i[0]) + assert 'strip' in dir(agg.i[0]) + assert 'strip' in dir(agg.i[0]) + +def test_repr_igroup(igroups): + assert repr(igroups.slice_both_named) == "code.i[1:4] >> 'a123'" + assert repr(igroups.slice_both) == "code.i[1:4]" + assert repr(igroups.slice_start) == "code.i[1:]" + assert repr(igroups.slice_stop) == "code.i[:4]" + assert repr(igroups.slice_none) == "code.i[:]" + assert repr(igroups.first_value) == "code.i[0]" + assert repr(igroups.last_value) == "code.i[-1]" + assert repr(igroups.list) == "code.i[0, 1, -2, -1]" + assert repr(igroups.tuple) == "code.i[0, 1, -2, -1]" + +def test_to_int_igroup(): + a = Axis(['42'], 'a') + assert int(a.i[0]) == 42 + +def test_to_float_igroup(): + a = Axis(['42'], 'a') + assert float(a.i[0]) == 42.0 + +def test_h5_io_igroup(tmpdir): + fpath = os.path.join(str(tmpdir), 'igroups.h5') + age.to_hdf(fpath) + + named = age.i[:6] >> 'age_05' + named_axis_not_in_file = lipro.i[1::2] >> 'P_odd' + anonymous = age.i[:6] + wildcard = age_wildcard.i[:6] >> 'age_w_05' + + # ---- default behavior ---- + # named group + named.to_hdf(fpath) + named2 = read_hdf(fpath, key=named.name) + assert all(named == named2) + # anonymous group + with pytest.raises(ValueError, message="Argument key must be provided explicitly in case of anonymous axis"): + anonymous.to_hdf(fpath) + # wildcard group + wildcard.to_hdf(fpath) + wildcard2 = read_hdf(fpath, key=wildcard.name) + assert all(wildcard == wildcard2) + # associated axis not saved yet + named_axis_not_in_file.to_hdf(fpath) + named2 = read_hdf(fpath, key=named_axis_not_in_file.name) + assert all(named_axis_not_in_file == named2) + + # ---- specific hdf group + key ---- + hdf_group = 'my_groups' + # named group + key = hdf_group + '/named_group' + named.to_hdf(fpath, key) + named2 = read_hdf(fpath, key=key) + assert all(named == named2) + # anonymous group + key = hdf_group + '/anonymous_group' + anonymous.to_hdf(fpath, key) + anonymous2 = read_hdf(fpath, key=key) + assert anonymous2.name is None + assert all(anonymous == anonymous2) + # wildcard group + key = hdf_group + '/wildcard_group' + wildcard.to_hdf(fpath, key) + wildcard2 = read_hdf(fpath, key=key) + assert all(wildcard == wildcard2) + # associated axis not saved yet + key = hdf_group + '/named_group_axis_not_in_file' + named_axis_not_in_file.to_hdf(fpath, key=key) + named2 = read_hdf(fpath, key=key) + assert all(named_axis_not_in_file == named2) + + # ---- specific axis_key ---- + axis_key = 'axes/associated_axis_0' + # named group + named.to_hdf(fpath, axis_key=axis_key) + named2 = read_hdf(fpath, key=named.name) + assert all(named == named2) + # anonymous group + key = 'anonymous' + anonymous.to_hdf(fpath, key=key, axis_key=axis_key) + anonymous2 = read_hdf(fpath, key=key) + assert anonymous2.name is None + assert all(anonymous == anonymous2) + # wildcard group + wildcard.to_hdf(fpath, axis_key=axis_key) + wildcard2 = read_hdf(fpath, key=wildcard.name) + assert all(wildcard == wildcard2) + # associated axis not saved yet + axis_key = 'axes/associated_axis_1' + named_axis_not_in_file.to_hdf(fpath, axis_key=axis_key) + named2 = read_hdf(fpath, key=named_axis_not_in_file.name) + assert all(named_axis_not_in_file == named2) if __name__ == "__main__": pytest.main() diff --git a/larray/tests/test_session.py b/larray/tests/test_session.py index ebeb62e14..29109486e 100644 --- a/larray/tests/test_session.py +++ b/larray/tests/test_session.py @@ -9,7 +9,7 @@ import pytest from larray.tests.common import assert_array_nan_equal, inputpath -from larray import (Session, Axis, LArray, isnan, larray_equal, zeros_like, ndtest, ones_like, +from larray import (Session, Axis, LArray, Group, isnan, zeros_like, ndtest, ones_like, local_arrays, global_arrays, arrays) from larray.util.misc import pickle @@ -27,14 +27,17 @@ def equal(o1, o2): else: return o1 == o2 + global_arr1 = ndtest((2, 2)) _global_arr2 = ndtest((3, 3)) class TestSession(TestCase): def setUp(self): - self.a = Axis([], 'a') - self.b = Axis([], 'b') + self.a = Axis('a=a0..a2') + self.a01 = self.a['a0,a1'] >> 'a01' + self.b = Axis('b=b0..b2') + self.b12 = self.b['b1,b2'] >> 'b12' self.c = 'c' self.d = {} self.e = ndtest([(2, 'a0'), (3, 'a1')]) @@ -42,9 +45,8 @@ def setUp(self): self.f = ndtest([(3, 'a0'), (2, 'a1')]) self.g = ndtest([(2, 'a0'), (4, 'a1')]) self.session = Session([ - ('b', self.b), ('a', self.a), - ('c', self.c), ('d', self.d), - ('e', self.e), ('g', self.g), ('f', self.f), + ('b', self.b), ('b12', self.b12), ('a', self.a), ('a01', self.a01), + ('c', self.c), ('d', self.d), ('e', self.e), ('g', self.g), ('f', self.f), ]) @pytest.fixture(autouse=True) @@ -52,20 +54,19 @@ def output_dir(self, tmpdir_factory): self.tmpdir = tmpdir_factory.mktemp('tmp_session').strpath def get_path(self, fname): - return os.path.join(self.tmpdir, fname) + return os.path.join(str(self.tmpdir), fname) def assertObjListEqual(self, got, expected): - self.assertEqual(len(got), len(expected)) + assert len(got) == len(expected) for e1, e2 in zip(got, expected): - self.assertTrue(equal(e1, e2), "{} != {}".format(e1, e2)) + assert equal(e1, e2), "{} != {}".format(e1, e2) def test_init(self): - s = Session(self.b, self.a, c=self.c, d=self.d, - e=self.e, f=self.f, g=self.g) - self.assertEqual(s.names, ['a', 'b', 'c', 'd', 'e', 'f', 'g']) + s = Session(self.b, self.b12, self.a, self.a01, c=self.c, d=self.d, e=self.e, f=self.f, g=self.g) + assert s.names == ['a', 'a01', 'b', 'b12', 'c', 'd', 'e', 'f', 'g'] s = Session(inputpath('test_session.h5')) - self.assertEqual(s.names, ['e', 'f', 'g']) + assert s.names == ['e', 'f', 'g'] # this needs xlwings installed # s = Session('test_session_ef.xlsx') @@ -77,18 +78,22 @@ def test_init(self): def test_getitem(self): s = self.session - self.assertIs(s['a'], self.a) - self.assertIs(s['b'], self.b) - self.assertEqual(s['c'], 'c') - self.assertEqual(s['d'], {}) + assert s['a'] is self.a + assert s['b'] is self.b + assert s['a01'] is self.a01 + assert s['b12'] is self.b12 + assert s['c'] == 'c' + assert s['d'] == {} def test_getitem_list(self): s = self.session - self.assertEqual(list(s[[]]), []) - self.assertEqual(list(s[['b', 'a']]), [self.b, self.a]) - self.assertEqual(list(s[['a', 'b']]), [self.a, self.b]) - self.assertEqual(list(s[['a', 'e', 'g']]), [self.a, self.e, self.g]) - self.assertEqual(list(s[['g', 'a', 'e']]), [self.g, self.a, self.e]) + assert list(s[[]]) == [] + assert list(s[['b', 'a']]) == [self.b, self.a] + assert list(s[['a', 'b']]) == [self.a, self.b] + assert list(s[['b12', 'a']]) == [self.b12, self.a] + assert list(s[['e', 'a01']]) == [self.e, self.a01] + assert list(s[['a', 'e', 'g']]) == [self.a, self.e, self.g] + assert list(s[['g', 'a', 'e']]) == [self.g, self.a, self.e] def test_getitem_larray(self): s1 = self.session.filter(kind=LArray) @@ -101,51 +106,58 @@ def test_getitem_larray(self): def test_setitem(self): s = self.session s['g'] = 'g' - self.assertEqual(s['g'], 'g') + assert s['g'] == 'g' def test_getattr(self): s = self.session - self.assertIs(s.a, self.a) - self.assertIs(s.b, self.b) - self.assertEqual(s.c, 'c') - self.assertEqual(s.d, {}) + assert s.a is self.a + assert s.b is self.b + assert s.a01 is self.a01 + assert s.b12 is self.b12 + assert s.c == 'c' + assert s.d == {} def test_setattr(self): s = self.session s.h = 'h' - self.assertEqual(s.h, 'h') + assert s.h == 'h' def test_add(self): s = self.session - h = Axis([], 'h') - s.add(h, i='i') - self.assertTrue(h.equals(s.h)) - self.assertEqual(s.i, 'i') + h = Axis('h=h0..h2') + h01 = h['h0,h1'] >> 'h01' + s.add(h, h01, i='i') + assert h.equals(s.h) + assert h01 == s.h01 + assert s.i == 'i' def test_iter(self): - expected = [self.b, self.a, self.c, self.d, self.e, self.g, self.f] + expected = [self.b, self.b12, self.a, self.a01, self.c, self.d, self.e, self.g, self.f] self.assertObjListEqual(self.session, expected) def test_filter(self): s = self.session s.ax = 'ax' - self.assertObjListEqual(s.filter(), [self.b, self.a, 'c', {}, + self.assertObjListEqual(s.filter(), [self.b, self.b12, self.a, self.a01, 'c', {}, self.e, self.g, self.f, 'ax']) - self.assertEqual(list(s.filter('a')), [self.a, 'ax']) - self.assertEqual(list(s.filter('a', dict)), []) - self.assertEqual(list(s.filter('a', str)), ['ax']) - self.assertEqual(list(s.filter('a', Axis)), [self.a]) - self.assertEqual(list(s.filter(kind=Axis)), [self.b, self.a]) + self.assertObjListEqual(s.filter('a'), [self.a, self.a01, 'ax']) + assert list(s.filter('a', dict)) == [] + assert list(s.filter('a', str)) == ['ax'] + assert list(s.filter('a', Axis)) == [self.a] + assert list(s.filter(kind=Axis)) == [self.b, self.a] + assert list(s.filter('a01', Group)) == [self.a01] + assert list(s.filter(kind=Group)) == [self.b12, self.a01] self.assertObjListEqual(s.filter(kind=LArray), [self.e, self.g, self.f]) - self.assertEqual(list(s.filter(kind=dict)), [{}]) + assert list(s.filter(kind=dict)) == [{}] + assert list(s.filter(kind=(Axis, Group))) == [self.b, self.b12, self.a, self.a01] def test_names(self): s = self.session - self.assertEqual(s.names, ['a', 'b', 'c', 'd', 'e', 'f', 'g']) + assert s.names == ['a', 'a01', 'b', 'b12', 'c', 'd', 'e', 'f', 'g'] # add them in the "wrong" order s.add(i='i') s.add(h='h') - self.assertEqual(s.names, ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']) + assert s.names == ['a', 'a01', 'b', 'b12', 'c', 'd', 'e', 'f', 'g', 'h', 'i'] def test_h5_io(self): fpath = self.get_path('test_session.h5') @@ -153,79 +165,93 @@ def test_h5_io(self): s = Session() s.load(fpath) - # HDF does *not* keep ordering (ie, keys are always sorted) - self.assertEqual(list(s.keys()), ['e', 'f', 'g']) - - # update an array (overwrite=False) - Session(e=self.e2).save(fpath, overwrite=False) + # HDF does *not* keep ordering (ie, keys are always sorted + + # read LArray objects, then Axis objects and finally Group objects) + assert list(s.keys()) == ['e', 'f', 'g', 'a', 'b', 'a01', 'b12'] + + # update a Group + an Axis + an array (overwrite=False) + a2 = Axis('a=0..2') + a2_01 = a2['0,1'] >> 'a01' + e2 = ndtest((a2, 'b=b0..b2')) + Session(a=a2, a01=a2_01, e=e2).save(fpath, overwrite=False) s.load(fpath) - self.assertEqual(list(s.keys()), ['e', 'f', 'g']) - assert_array_nan_equal(s['e'], self.e2) + assert list(s.keys()) == ['e', 'f', 'g', 'a', 'b', 'a01', 'b12'] + assert s['a'].equals(a2) + assert all(s['a01'] == a2_01) + assert_array_nan_equal(s['e'], e2) s = Session() s.load(fpath, ['e', 'f']) - self.assertEqual(list(s.keys()), ['e', 'f']) + assert list(s.keys()) == ['e', 'f'] def test_xlsx_pandas_io(self): + session = self.session.filter(kind=LArray) + fpath = self.get_path('test_session.xlsx') - self.session.save(fpath, engine='pandas_excel') + session.save(fpath, engine='pandas_excel') s = Session() s.load(fpath, engine='pandas_excel') - self.assertEqual(list(s.keys()), ['e', 'g', 'f']) + assert list(s.keys()) == ['e', 'g', 'f'] # update an array (overwrite=False) - Session(e=self.e2).save(fpath, engine='pandas_excel', overwrite=False) + e2 = ndtest(('a=0..2', 'b=b0..b2')) + Session(e=e2).save(fpath, engine='pandas_excel', overwrite=False) s.load(fpath, engine='pandas_excel') - self.assertEqual(list(s.keys()), ['e', 'g', 'f']) - assert_array_nan_equal(s['e'], self.e2) + assert list(s.keys()) == ['e', 'g', 'f'] + assert_array_nan_equal(s['e'], e2) fpath = self.get_path('test_session_ef.xlsx') self.session.save(fpath, ['e', 'f'], engine='pandas_excel') s = Session() s.load(fpath, engine='pandas_excel') - self.assertEqual(list(s.keys()), ['e', 'f']) + assert list(s.keys()) == ['e', 'f'] @pytest.mark.skipif(xw is None, reason="xlwings is not available") def test_xlsx_xlwings_io(self): + session = self.session.filter(kind=LArray) + fpath = self.get_path('test_session_xw.xlsx') # test save when Excel file does not exist - self.session.save(fpath, engine='xlwings_excel') + session.save(fpath, engine='xlwings_excel') s = Session() s.load(fpath, engine='xlwings_excel') # ordering is only kept if the file did not exist previously (otherwise the ordering is left intact) - self.assertEqual(list(s.keys()), ['e', 'g', 'f']) + assert list(s.keys()) == ['e', 'g', 'f'] # update an array (overwrite=False) - Session(e=self.e2).save(fpath, engine='xlwings_excel', overwrite=False) + e2 = ndtest(('a=0..2', 'b=b0..b2')) + Session(e=e2).save(fpath, engine='xlwings_excel', overwrite=False) s.load(fpath, engine='xlwings_excel') - self.assertEqual(list(s.keys()), ['e', 'g', 'f']) - assert_array_nan_equal(s['e'], self.e2) + assert list(s.keys()) == ['e', 'g', 'f'] + assert_array_nan_equal(s['e'], e2) fpath = self.get_path('test_session_ef_xw.xlsx') self.session.save(fpath, ['e', 'f'], engine='xlwings_excel') s = Session() s.load(fpath, engine='xlwings_excel') - self.assertEqual(list(s.keys()), ['e', 'f']) + assert list(s.keys()) == ['e', 'f'] def test_csv_io(self): try: + session = self.session.filter(kind=LArray) + fpath = self.get_path('test_session_csv') - self.session.to_csv(fpath) + session.to_csv(fpath) # test loading a directory s = Session() s.load(fpath, engine='pandas_csv') # CSV cannot keep ordering (so we always sort keys) - self.assertEqual(list(s.keys()), ['e', 'f', 'g']) + assert list(s.keys()) == ['e', 'f', 'g'] # test loading with a pattern pattern = os.path.join(fpath, '*.csv') s = Session(pattern) # s = Session() # s.load(pattern) - self.assertEqual(list(s.keys()), ['e', 'f', 'g']) + assert list(s.keys()) == ['e', 'f', 'g'] # create an invalid .csv file invalid_fpath = os.path.join(fpath, 'invalid.csv') @@ -239,23 +265,26 @@ def test_csv_io(self): # test loading a pattern, ignoring invalid/unsupported files s = Session() s.load(pattern, ignore_exceptions=True) - self.assertEqual(list(s.keys()), ['e', 'f', 'g']) + assert list(s.keys()) == ['e', 'f', 'g'] finally: shutil.rmtree(fpath) def test_pickle_io(self): + session = self.session.filter(kind=LArray) + fpath = self.get_path('test_session.pkl') - self.session.save(fpath) + session.save(fpath) s = Session() s.load(fpath, engine='pickle') - self.assertEqual(list(s.keys()), ['e', 'g', 'f']) + assert list(s.keys()) == ['e', 'g', 'f'] # update an array (overwrite=False) - Session(e=self.e2).save(fpath, overwrite=False) + e2 = ndtest(('a=0..2', 'b=b0..b2')) + Session(e=e2).save(fpath, overwrite=False) s.load(fpath, engine='pickle') - self.assertEqual(list(s.keys()), ['e', 'g', 'f']) - assert_array_nan_equal(s['e'], self.e2) + assert list(s.keys()) == ['e', 'g', 'f'] + assert_array_nan_equal(s['e'], e2) def test_to_globals(self): with pytest.warns(RuntimeWarning) as caught_warnings: @@ -410,14 +439,15 @@ def test_rdiv(self): assert_array_nan_equal(res['f'], self.f / self.f) def test_summary(self): + # only arrays sess = self.session.filter(kind=LArray) - self.assertEqual(sess.summary(), - "e: a0*, a1*\n \n\n" - "g: a0*, a1*\n \n\n" - "f: a0*, a1*\n \n") + assert sess.summary() == "e: a0*, a1*\n \n\ng: a0*, a1*\n \n\nf: a0*, a1*\n \n" + # all objects + sess = self.session + assert sess.summary() == "e: a0*, a1*\n \n\ng: a0*, a1*\n \n\nf: a0*, a1*\n \n" def test_pickle_roundtrip(self): - original = self.session + original = self.session.filter(kind=LArray) s = pickle.dumps(original) res = pickle.loads(s) assert res.equals(original) diff --git a/larray/util/misc.py b/larray/util/misc.py index edb2bcf8f..5d7c53897 100644 --- a/larray/util/misc.py +++ b/larray/util/misc.py @@ -25,6 +25,8 @@ except TypeError: pass +import pandas as pd + if sys.version_info[0] < 3: basestring = basestring bytes = str @@ -732,3 +734,23 @@ def common_type(arrays): return np.dtype(('U' if need_unicode else 'S', max_size)) else: return object + + +class LHDFStore(object): + """Context manager for pandas HDFStore""" + def __init__(self, filepath_or_buffer, **kwargs): + if isinstance(filepath_or_buffer, pd.HDFStore): + if not filepath_or_buffer.is_open: + raise IOError('The HDFStore must be open for reading.') + self.store = filepath_or_buffer + self.close_store = False + else: + self.store = pd.HDFStore(filepath_or_buffer, **kwargs) + self.close_store = True + + def __enter__(self): + return self.store + + def __exit__(self, type_, value, traceback): + if self.close_store: + self.store.close()