import argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
import collections
import functools
from itertools import tee
import logging
import os
from timeit import default_timer as now
from uuid import uuid4
import h5py
import numpy as np
import pysam
import gzip
from fast5_research.fast5 import Fast5, iterate_fast5
from fast5_research.fast5_bulk import BulkFast5
from fast5_research.util import _sanitize_data_for_writing, readtsv, group_vector
[docs]def triplewise(iterable):
a, b, c = tee(iterable, 3)
next(b)
next(c)
next(c)
yield from zip(a, b, c)
[docs]def time_cast(time, sample_rate):
"""
Convert a float time to sample index, or return time unmodified
"""
if isinstance(time, float):
return int(time * sample_rate)
else:
return time
[docs]def build_read_index():
logging.basicConfig(
format='[%(asctime)s - %(name)s] %(message)s',
datefmt='%H:%M:%S', level=logging.INFO
)
logger = logging.getLogger('Index Reads')
parser = argparse.ArgumentParser(description='Build index of reads within .fast5s. Output to stdout.')
parser.add_argument('input', help='.fast5 directory')
parser.add_argument('--recursive', action='store_true',
help='Search recursively under `input` for source files.')
parser.add_argument('--workers', type=int, default=8,
help='Number of worker processes.')
args = parser.parse_args()
src_files = list(iterate_fast5(args.input, paths=True, recursive=args.recursive))
logger.info("Found {} files.".format(len(src_files)))
with ProcessPoolExecutor(args.workers) as executor:
n_reads = 0
for i, (src, read_ids) in enumerate(
zip(src_files, executor.map(reads_in_multi, src_files, chunksize=10))):
n_reads += len(read_ids)
for read in read_ids:
print('\t'.join((read, os.path.abspath(src))))
if i % 10 == 0:
logger.info("Indexed {}/{} files. {} reads".format(i, len(src_files), n_reads))
[docs]def filter_file_from_bam():
logging.basicConfig(
format='[%(asctime)s - %(name)s] %(message)s',
datefmt='%H:%M:%S', level=logging.INFO
)
logger = logging.getLogger('Filter')
parser = argparse.ArgumentParser(
description='Create filter file from BAM and sequencing summary')
parser.add_argument('--seperator',
dest="SEP",
default='\t',
help="Seperator in sequencing summary files")
parser.add_argument('--id-col',
dest="READID_COL",
default='read_id',
help="Column name for read_id in sequencing summary files")
parser.add_argument('--fname-col',
dest="FNAME_COL",
default='filename',
help="Column name for fast5 filename in sequencing summary files")
parser.add_argument('-r', '--region',
dest="REGION",
default=None,
help="Print reads only from this region")
parser.add_argument('--workers', type=int, default=4,
help='Number of worker processes.')
parser.add_argument('-p', '--primary-only',
dest="PRIMARY",
action='store_true',
help="Ignore secondary and supplementary alignments")
parser.add_argument('BAM', help='Path to BAM file')
parser.add_argument("SUMMARY",
type=str,
nargs='+',
help="Sequencing summary files")
args = parser.parse_args()
region = args.REGION
primary_only = args.PRIMARY
bam_in = args.BAM
summary_files = args.SUMMARY
threads = args.workers
readid_col = args.READID_COL
fast5_col = args.FNAME_COL
sep = args.SEP
if not region:
logger.info("No region specified. Extracting all reads from BAM file")
else:
logger.info("Extracting read ids from {}".format(region))
read_ids = {}
with pysam.AlignmentFile(bam_in, "rb", threads=threads) as infile:
for read in infile.fetch(region=region):
if read.is_unmapped or (primary_only and (read.is_secondary or read.is_supplementary)):
continue
read_ids[read.query_name] = None
n = len(read_ids)
logger.info("Reads found in BAM file: {}".format(n))
if n == 0:
return
# Print header
print("read_id", "filename", sep='\t')
n_print = 0
for summary_file in summary_files:
logging.info("Opening: {}".format(summary_file))
with gzip.open(summary_file) as fh:
header = fh.readline().decode().strip()
header_cols = header.split(sep)
readid_idx = header_cols.index(readid_col)
path_idx = header_cols.index(fast5_col)
for line in fh:
line = line.decode().strip()
if not line:
continue
cols = line.split(sep)
readid = cols[readid_idx]
f5_path = cols[path_idx]
if readid not in read_ids:
continue
if read_ids[readid]:
logging.error("Two entries found for {} ({} and {})".format(readid, read_ids[readid], f5_path))
continue
n_print += 1
read_ids[readid] = f5_path
print(readid, read_ids[readid], sep='\t')
logging.info("Filename found for {} reads ({}%)".format(n_print, round(n_print * 100.0 / n)))
[docs]def filter_multi_reads():
logging.basicConfig(
format='[%(asctime)s - %(name)s] %(message)s',
datefmt='%H:%M:%S', level=logging.INFO
)
logger = logging.getLogger('Filter')
parser = argparse.ArgumentParser(
description='Extract reads from multi-read .fast5 files.')
parser.add_argument('input',
help='Path to input multi-read .fast5 files (or list of files).')
parser.add_argument('output',
help='Output folder.')
parser.add_argument('filter',
help='A .tsv file with column `read_id` defining required reads. '
'If a `filename` column is present, this will be used as the '
'location of the read.')
parser.add_argument('--tsv_field', default='read_id',
help='Field name from `filter` file to obtain read IDs.')
parser.add_argument('--prefix', default="",
help='Read file prefix.')
parser.add_argument('--recursive', action='store_true',
help='Search recursively under `input` for source files.')
parser.add_argument('--workers', type=int, default=4,
help='Number of worker processes.')
out_format = parser.add_mutually_exclusive_group()
out_format.add_argument('--multi', action='store_true', default=True,
help='Output multi-read files.')
out_format.add_argument('--single', action='store_false', dest='multi',
help='Output single-read files.')
#parser.add_argument('--limit', type=int, default=None, help='Limit reads per channel.')
args = parser.parse_args()
if not args.multi:
raise NotImplementedError('Extraction of reads to single read files is on the TODO list.')
if not os.path.exists(args.output):
os.makedirs(args.output)
else:
raise IOError('The output directory must not exist.')
# grab list of source files
logger.info("Searching for input files.")
try:
src_files = list(set(readtsv(args.input)['filename']))
except Exception as e:
logger.info('Failed to read `input` as filelist, assuming path to search. {}'.format(e))
src_files = list(iterate_fast5(args.input, paths=True, recursive=args.recursive))
n_files = len(src_files)
logger.info("Found {} source files.".format(n_files))
logger.info("Reading filter file.")
read_table = readtsv(args.filter, fields=[args.tsv_field])
logger.info("Found {} reads in filter.".format(len(read_table)))
try:
# try to build index from the filter file with 'filename' column
if 'filename' not in read_table.dtype.names:
raise ValueError("'filename' column not present in filter.")
logger.info("Attempting to build read index from input filter.")
src_path_files = {
os.path.basename(x):x for x in src_files
}
if len(src_path_files) != len(src_files):
raise ValueError('Found non-uniquely named source files')
read_index = dict()
for fname, indices in group_vector(read_table['filename']).items():
fpath = src_path_files[os.path.basename(fname)]
read_index[fpath] = read_table[args.tsv_field][indices]
logger.info("Successfully build read index from input filter.")
except Exception as e:
logger.info("Failed to build read index from summary: {}".format(e))
read_index = None
required_reads = set(read_table[args.tsv_field])
logger.info("Finding reads within {} source files.".format(n_files))
index_worker = functools.partial(reads_in_multi, filt=required_reads)
read_index = dict()
n_reads = 0
with ProcessPoolExecutor(args.workers) as executor:
i = 0
for src_file, read_ids in zip(src_files, executor.map(index_worker, src_files, chunksize=10)):
i += 1
n_reads += len(read_ids)
read_index[src_file] = read_ids
if i % 10 == 0:
logger.info("Indexed {}/{} files. {}/{} reads".format(i, n_files, n_reads, len(required_reads)))
n_reads = sum(len(x) for x in read_index.values())
# We don't go via creating Read objects, copying the data verbatim
# likely quicker and nothing should need the verification that the APIs
# provide (garbage in, garbage out).
logger.info("Extracting {} reads.".format(n_reads))
if args.prefix != '':
args.prefix = '{}_'.format(args.prefix)
with ProcessPoolExecutor(args.workers) as executor:
reads_per_process = np.ceil(n_reads / args.workers)
proc_n_reads = 0
proc_reads = dict()
job = 0
futures = list()
for src in read_index.keys():
proc_reads[src] = read_index[src]
proc_n_reads += len(proc_reads[src])
if proc_n_reads > reads_per_process:
proc_prefix = "{}{}_".format(args.prefix, job)
futures.append(executor.submit(_subset_reads_to_file, proc_reads, args.output, proc_prefix, worker_id=job))
job += 1
proc_n_reads = 0
proc_reads = dict()
if proc_n_reads > 0: # processing remaining reads
proc_prefix = "{}{}_".format(args.prefix, job)
futures.append(executor.submit(_subset_reads_to_file, proc_reads, args.output, proc_prefix, worker_id=job))
for fut in as_completed(futures):
try:
reads_written, prefix = fut.result()
logger.info("Written {} reads to {}.".format(reads_written, prefix))
except Exception as e:
logger.warning("Error: {}".format(e))
logger.info("Done.")
def _subset_reads_to_file(read_index, output, prefix, worker_id=0):
logger = logging.getLogger('Worker-{}'.format(worker_id))
n_reads = sum(len(x) for x in read_index.values())
reads_written = 0
t0 = now()
with MultiWriter(output, None, prefix=prefix) as writer:
for src_file, read_ids in read_index.items():
reads_written += len(read_ids)
t1 = now()
if t1 - t0 > 30: # log update every 30 seconds
logger.info("Written {}/{} reads ({:.0f}% done)".format(
reads_written, n_reads, 100 * reads_written / n_reads
))
t0 = t1
with h5py.File(src_file, 'r') as src_fh:
for read_id in read_ids:
try:
read_grp = src_fh["read_{}".format(read_id)]
except:
logger.warning("Did not find {} in {}.".format(read_id, src_fh.filename))
else:
writer.write_read(read_grp)
return reads_written, prefix
[docs]def reads_in_multi(src, filt=None):
"""Get list of read IDs contained within a multi-read file.
:param src: source file.
:param filt: perform filtering by given set.
:returns: set of read UUIDs (as string and recorded in hdf group name).
"""
logger = logging.getLogger(os.path.splitext(os.path.basename(src))[0])
logger.debug("Finding reads.")
prefix = 'read_'
with h5py.File(src, 'r') as fh:
read_ids = set(grp[len(prefix):] for grp in fh if grp.startswith(prefix))
logger.debug("Found {} reads.".format(len(read_ids)))
if filt is not None:
read_ids = read_ids.intersection(filt)
logger.debug("Filtered to {} reads.".format(len(read_ids)))
return read_ids
[docs]class Read(object):
# Just a sketch to help interchange of format
def __init__(self, read_id, read_number, tracking_id, channel_id, context_tags, raw):
self.read_id = read_id
self.read_number = read_number
self.tracking_id = tracking_id
self.channel_id = channel_id
self.context_tags = context_tags
self.raw = raw
# ensure typing and required fields
self.channel_id = Fast5.convert_channel_id(self.channel_id)
self.tracking_id = Fast5.convert_tracking_id(self.tracking_id)
[docs]class ReadWriter(object):
def __init__(self, out_path, by_id, prefix=""):
self.out_path = out_path
self.by_id = by_id
if prefix != "":
prefix = "{}_".format(prefix)
self.prefix = prefix
[docs] def write_read(self):
raise NotImplementedError()
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
pass
[docs]class SingleWriter(ReadWriter):
[docs] def write_read(self, read):
if self.by_id:
filename = '{}.fast5'.format(read.read_id['read_id'])
else:
filename = '{}read_ch{}_file{}.fast5'.format(
self.prefix, read.channel_id['channel_number'], read.read_number
)
filename = os.path.join(self.out_path, filename)
with Fast5.New(filename, 'a', tracking_id=read.tracking_id, context_tags=read.context_tags, channel_id=read.channel_id) as h:
h.set_raw(read.raw, meta=read.read_id, read_number=read.read_number)
MULTI_READ_FILE_VERSION = "2.0"
[docs]class MultiWriter(ReadWriter):
def __init__(self, out_path, by_id, prefix="", reads_per_file=4000):
super(MultiWriter, self).__init__(out_path, by_id, prefix=prefix)
self.reads_per_file = reads_per_file
self.current_reads = 0 # reads in open file, used to signal new file condition
self.file_counter = 0
self.current_file = None
self.closed = False
def __exit__(self, exception_type, exception_value, traceback):
self.close()
[docs] def close(self):
if isinstance(self.current_file, h5py.File):
self.current_file.close()
[docs] def write_read(self, read):
"""Write a read.
:param read: either a `Read` object or an hdf group handle from a
source multi-read file.
"""
if self.closed:
raise RuntimeError('Cannot write after closed.')
if self.current_reads == 0:
# start a new file
self.close()
filename = '{}mreads_file{}.fast5'.format(
self.prefix, self.file_counter
)
filename = os.path.join(self.out_path, filename)
self.current_file = h5py.File(filename, 'w')
self.current_file.attrs[_sanitize_data_for_writing('file_version')] = _sanitize_data_for_writing("2.0")
self.file_counter += 1
# write data
if isinstance(read, Read):
self._write_read(read)
elif isinstance(read, h5py.Group):
self._copy_read_group(read)
else:
raise TypeError("Cannot write type {} to output file.")
self.current_reads += 1
# update
if self.current_reads == self.reads_per_file:
self.current_reads = 0
def _write_read(self, read):
if read.raw.dtype != np.int16:
raise TypeError('Raw data must be of type int16.')
read_group = '/read_{}'.format(read.read_id['read_id'])
Fast5._add_attrs_to_fh(self.current_file, {'run_id': read.tracking_id['run_id']}, read_group, convert=str)
# add all attributes
for grp_name in ('tracking_id', 'context_tags'):
# spec has all of these as str
data = getattr(read, grp_name)
Fast5._add_attrs_to_fh(self.current_file, data, '{}/{}'.format(read_group, grp_name), convert=str)
Fast5._add_attrs_to_fh(self.current_file, read.channel_id, '{}/channel_id'.format(read_group))
# add the data (and some more attrs)
data_path = '{}/Raw'.format(read_group)
read_id = Fast5._convert_meta_times(read.read_id, read.channel_id['sampling_rate'])
read_id = Fast5.convert_raw_meta(read_id)
Fast5._add_attrs_to_fh(self.current_file, read_id, data_path)
signal_path = '{}/Signal'.format(data_path)
self.current_file.create_dataset(
signal_path, data=read.raw, compression='gzip', shuffle=True, dtype='i2')
def _copy_read_group(self, read):
self.current_file.copy(read, read.name)