import os
import shutil
import itertools
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Callable, Optional, Tuple
import numpy as np
from ..stats import resampling as rsp
from ..stats import statistics as stt
__all__ = ["Handler"]
[docs]
class Handler(ABC):
"""
Abstract class for handlers.
"""
save_path: str #: Save directory.
ext: str #: Save files extension.
getter: Callable[
[List[str], List[str], Dict[str, Tuple[float, float]]],
Tuple[np.ndarray, np.ndarray],
] #: Function providing samples for further computation.
stats: Dict[str, Callable] #: Dictionnary of available statistics.
resamplings: Dict[str, rsp.Resampling] #: Dictionnary of available resamplings.
def __init__(self):
self.stats = {
"mi": stt.MI(),
"condh": stt.Condh(),
"corr": stt.Corr(),
"gaussinfo": stt.GaussInfo(),
"gaussinforeparam": stt.GaussInfoReparam(),
}
self.resamplings = {
"bootstrapping": rsp.Bootstrapping(),
"subsampling": rsp.Subsampling(),
}
self.save_path = None
self.getter = None
self.additional_stats = None
self.additional_resamplings = None
# Setter
[docs]
def set_path(self, save_path: Optional[str] = None) -> None:
"""
Defines a new path to a save directory. Must be called at least once before calling other functions such as `store`.
Parameters
----------
save_path : Optional[str], optional
New save path. Default None.
"""
self.save_path = save_path
[docs]
def set_getter(
self,
getter: Callable[
[List[str], List[str], Dict[str, Tuple[float, float]]],
Tuple[np.ndarray, np.ndarray]
],
) -> None:
"""
Defines the function (getter) that provides samples for statistical relationship calculations. In most cases, this will correspond to the `get` method of the `StandardGetter`, but users can define their own implementation.
Parameters
----------
getter : Callable[ [List[str], List[str], Dict[str, Tuple[float, float]]], Tuple[np.ndarray, np.ndarray] ]
Function providing samples for further computation.
"""
assert getter is not None
self.getter = getter
[docs]
def set_additional_stats(
self,
additional_stats: Dict[str, stt.Statistic] = {},
) -> None:
"""
Add new resamplings (instances of Statistic) to estimate the informativity of variables. Each statistic has a user-defined name. This name will then be reused, for example in the `store` function and its variants.
Parameters
----------
additional_stats : Dict[str, Statistic], optional
Additional statistics to be used. Default {}.
"""
assert all([isinstance(el, stt.Statistic) for el in additional_stats.values()])
self.stats.update(additional_stats)
[docs]
def set_additional_resamplings(
self,
additional_resamplings: Dict[str, rsp.Resampling] = {},
) -> None:
"""
Add new resamplings (instances of Resampling) to estimate the variance of some estimators. Each resampling has a user-defined name. This name will then be reused, for example in the `store` function and its variants.
Parameters
----------
additional_resamplings : Dict[str, Resampling], optional
Additional resamplings to be used. Default: {}.
"""
assert all(
[isinstance(el, rsp.Resampling) for el in additional_resamplings.values()]
)
self.stats.update(additional_resamplings)
# Saves
[docs]
@abstractmethod
def get_filename(self, *args, **kwargs) -> str:
"""
Builds a save filename from data names.
Returns
-------
str
Filename.
"""
pass
[docs]
@abstractmethod
def parse_filename(self, filename: str) -> Any:
"""
Identifies data names from save filename.
Parameters
----------
filename : str
Save filename.
Returns
-------
Any
Data names.
"""
pass
[docs]
def get_existing_saves(self) -> List[str]:
"""
Returns the filenames (basenames) of any existing saves at `self.save_path`.
Any file ending with "cls.ext" is considered a valid save.
Returns
-------
List[str]
Existing saves.
"""
if not os.path.exists(self.save_path):
return []
return [f for f in os.listdir(self.save_path) if f.endswith(self.ext)]
# Creation/removal
[docs]
def create(self):
"""
Creates `self.save_path` directory if not exists.
"""
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
[docs]
@abstractmethod
def remove(self, *args, **kwargs) -> None:
"""
Removes `self.save_path` directory if exists.
"""
if os.path.exists(self.save_path):
shutil.rmtree(self.save_path)
else:
raise FileNotFoundError(f"Save directory {self.save_path} does not exist")
[docs]
@abstractmethod
def delete_stats(self, *args, **kwargs) -> None:
"""
Removes saved results for a given statistics.
"""
pass
# Writing access
[docs]
def update(
self,
*args,
**kwargs
) -> None:
"""
Calls `self.store` method with `overwrite=False`.
"""
self.store(*args, **kwargs, overwrite=False)
[docs]
def overwrite(
self,
*args,
**kwargs
) -> None:
"""
Calls `self.store` method with `overwrite=True`.
"""
self.store(*args, **kwargs, overwrite=True)
[docs]
@abstractmethod
def store(
self,
overwrite: bool = False,
**kwargs
) -> None:
"""
Compute and save values. The behavior depends on the values of the `overwrite` argument.
Parameters
----------
overwrite : bool, optional
If True, overwrite the current computed value, if exists. Default: False.
"""
pass
# Reading access
[docs]
@abstractmethod
def read(
self,
**kwargs
) -> Dict[str, Any]:
"""
Accesses saved values.
Returns
-------
Dict[str, Any]
Read entries.
"""
pass
# Display
[docs]
def overview(self):
"""
Describes the handler and existing backups.
"""
print(str(self))
print("Save path:", self.save_path)
files = self.get_existing_saves()
if len(files) == 0:
print("No existing saves.")
else:
print("Existing saves:")
for filename in files:
print(filename)
@abstractmethod
def __str__(self):
pass
# Helpers
@staticmethod
def _check_dict_type(
d: Dict, key_type: Any, value_type: Any
) -> Optional[Tuple[Any, Any]]:
"""
Returns the first entry that does not match the types given as arguments or None if all entries are valid.
Parameters
----------
d : Dict
Dictionnary to test.
key_type : Any
Expected key type.
value_type : Any
Expected value type.
Returns
-------
Optional[Tuple[Any, Any]]
First entry that does not match the types given as arguments. None if all entries are valid.
"""
assert isinstance(d, Dict)
for k, v in d.items():
if not isinstance(k, key_type) or not isinstance(v, value_type):
return k, v
return None
[docs]
@staticmethod
def drop_duplicates(ls: List[List[str]]):
ls = ls.copy()
ls.sort()
return list(k for k,_ in itertools.groupby(ls))