import re
from contextlib import AbstractContextManager
from logging import getLogger
from time import sleep
from typing import Optional
import numpy as np
from dask.distributed import Client, LocalCluster
from tqdm import tqdm
logger = getLogger(__name__)
PHRED_TO_PROB = np.power(10, (np.arange(256, dtype=float) / -10.0))
[docs]def mean_qscore(quals):
return -10 * np.log10(PHRED_TO_PROB[quals].mean())
[docs]class DaskExecEnv(AbstractContextManager):
def __init__(
self,
n_workers: int = 1,
processes: bool = True,
threads_per_worker: int = 1,
scheduler_port: int = 0,
dashboard_port: Optional[int] = None,
):
self._cluster_kwds = {
"processes": processes,
"n_workers": n_workers,
"scheduler_port": scheduler_port,
"dashboard_address": f"127.0.0.1:{dashboard_port}",
"threads_per_worker": threads_per_worker,
}
if dashboard_port is None:
self._cluster_kwds["dashboard_address"] = None
self._cluster, self._client = None, None
[docs] def scatter(self, data):
return self._client.scatter(data)
def __enter__(self):
self._cluster = LocalCluster(**self._cluster_kwds)
self._client = Client(self._cluster)
logger.debug(f"Cluster started: {self._cluster}")
logger.debug(f"Client started: {self._client}")
return self
def __exit__(self, *args):
if self._cluster:
max_tries = 10
backoff = 2
delay = 1
while max_tries > 1:
processing = self._client.processing()
still_running = [len(v) > 0 for k, v in processing.items()]
if any(still_running):
sleep(delay)
max_tries -= 1
delay = delay * backoff
else:
sleep(1)
break
self._client.close()
self._cluster.close()
[docs]class DataFrameProgress:
def __init__(self, save_to=None, **kwds):
self._bar = tqdm(**kwds)
self._data = None
self._save_to = save_to
def __call__(self, _, df):
self.update_data(df)
self.update_progress_bar(len(df))
return self, df
[docs] def update_data(self, df):
raise NotImplementedError
[docs] def update_progress_bar(self, num_rows):
self._bar.update(num_rows)
self._bar.set_postfix(self._data.to_dict())
[docs] def close(self):
# self._bar.flush()
self._bar.close()
if self._save_to:
self.save(self._save_to)
[docs] def save(self, path):
self._data.to_csv(path, index=False)
[docs]def kmg_bases_to_int(value: str) -> int:
try:
result = int(value)
except Exception as _: # noqa
result = None
if result is not None:
return result
value_re = re.compile(r"(\d+(?:\.\d+)?)([KkMmGg])([bB])?")
m = value_re.match(value.strip())
if not m:
raise ValueError(f"Invalid string: {value}")
value = float(m.group(1))
exponent = {"k": 1e3, "m": 1e6, "g": 1e9}[m.group(2).lower()]
return value * int(exponent)