import sys
from logging import getLogger
from pathlib import Path
import numpy as np
import pandas as pd
from streamz import Stream
from pore_c.datasources import Fastq
from pore_c.io import BatchedFastqWriter, FastqWriter
from pore_c.utils import DataFrameProgress, mean_qscore
logger = getLogger(__name__)
[docs]def read_length_stats(lengths, percentiles=[25, 50, 75]):
if len(lengths) == 0:
return {}
sorted_lengths = np.sort(lengths.values)
cumsum = sorted_lengths.cumsum()
total_bases = cumsum[-1]
n50_idx = np.argmax(cumsum > (total_bases * 0.5))
n50 = sorted_lengths[n50_idx]
qdict = dict(zip(["Q%d" % p for p in percentiles], map(float, np.percentile(sorted_lengths, percentiles))))
return {
**{
"num_sequences": len(lengths),
"total_bases": int(total_bases),
"mean": float(total_bases / len(lengths)),
"min": int(sorted_lengths[0]),
"max": int(sorted_lengths[-1]),
"N50": int(n50),
**qdict,
}
}
[docs]def prepare_fastq(
input_fastq: Path,
pass_fastq: Path = None,
fail_fastq: Path = None,
read_metadata: Path = None,
summary: Path = None,
min_read_length: int = 50,
max_read_length: int = 5000000,
min_qscore: int = 0,
max_qscore: int = 266,
chunksize: int = 10000,
):
fastq_stream = Stream()
filtered_stream = fastq_stream.map(filter_records, min_read_length, max_read_length, min_qscore, max_qscore)
pass_writer = BatchedFastqWriter(pass_fastq)
fail_writer = FastqWriter(fail_fastq)
read_prog = ReadFilterProgress()
df_sink = (
filtered_stream.pluck("metadata").accumulate(read_prog, returns_state=True, start=read_prog).sink_to_list()
)
pass_sink = filtered_stream.pluck("pass").sink(pass_writer) # noqa: F841
fail_sink = filtered_stream.pluck("fail").sink(fail_writer) # noqa: F841
# split reads into chunks for processing
for chunk_idx, records in enumerate(Fastq(input_fastq).read_chunked(chunksize)):
fastq_stream.emit(records)
metadata_df = pd.concat(df_sink, ignore_index=True)
metadata_df.to_parquet(read_metadata, index=False)
pass_rate = metadata_df["pass_filter"].mean()
if pass_rate == 0:
raise ValueError("No reads passed filter")
summary_stats = {
"all": read_length_stats(metadata_df["read_length"]),
"pass": read_length_stats(metadata_df.query("pass_filter == True")["read_length"]),
"fail": read_length_stats(metadata_df.query("pass_filter == False")["read_length"]),
}
pass_writer.close()
fail_writer.close()
read_prog.close()
sys.stderr.write("\n")
logger.debug("Finished processing reads")
assert (pass_writer._counter + fail_writer._counter) == summary_stats["all"]["num_sequences"]
df = pd.DataFrame([v for v in summary_stats.values() if v], index=[k for k, v in summary_stats.items() if v])
df.index.name = "read_subset"
logger.info(f"Finished processing {input_fastq}:\n{str(df)}\n")
df.to_csv(summary)
return summary_stats
[docs]def filter_records(list_of_records, min_read_length, max_read_length, min_qscore, max_qscore):
df = (
pd.DataFrame(
[(_.name, len(_.sequence), mean_qscore(_.get_quality_array())) for _ in list_of_records],
columns=["read_id", "read_length", "qscore"],
)
.astype({"read_length": np.uint32, "qscore": np.float32})
.eval(
"pass_filter = (@min_read_length <= read_length < @max_read_length) & (@min_qscore <= qscore < @max_qscore)"
)
)
seq_strings = {True: [], False: []}
for seq, is_pass in zip(map(str, list_of_records), df["pass_filter"].values):
seq_strings[is_pass].append(seq)
return {"metadata": df, "pass": seq_strings[True], "fail": seq_strings[False]}
[docs]class ReadFilterProgress(DataFrameProgress):
def __init__(self, **kwds):
kwds["desc"] = "Reads processed"
kwds["unit"] = " reads"
super().__init__(**kwds)
[docs] def update_data(self, read_df):
count_cols = ["read_length", "pass_filter"]
counts = read_df.loc[:, count_cols].sum(axis=0)
counts["reads"] = len(read_df)
counts = counts.astype(int).to_frame().T
if self._data is None:
self._data = counts
else:
self._data += counts
[docs] @staticmethod
def summarize(df):
summary = (
df.eval("Gb = read_length * 1e-9").eval("perc_pass = 100 * (pass_filter / reads)")
# .loc[:, ["reads", "Gb", "num_contacts", "contacts_per_Gb", "percent_cis"]]
)
return summary
[docs] def update_progress_bar(self, num_aligns):
self._bar.update(num_aligns)
self._bar.set_postfix(self.summarize(self._data).iloc[0, :].to_dict())