import argparse
from concurrent.futures import ProcessPoolExecutor
import functools
import logging
import multiprocessing
import os
from intervaltree import IntervalTree, Interval
import numpy as np
import pysam
from pomoxis.util import parse_regions, Region
from pomoxis.coverage_from_bam import coverage_summary_of_region
from pomoxis.stats_from_bam import stats_from_aligned_read
[docs]def main():
logging.basicConfig(format='[%(asctime)s - %(name)s] %(message)s', datefmt='%H:%M:%S', level=logging.INFO)
parser = argparse.ArgumentParser(
prog='subsample_bam',
description='Subsample a bam to uniform or proportional depth',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('bam',
help='input bam file.')
parser.add_argument('depth', nargs='+', type=int,
help='Target depth.')
parser.add_argument('-o', '--output_prefix', default='sub_sampled',
help='Output prefix')
parser.add_argument('-r', '--regions', nargs='+',
help='Only process given regions.')
parser.add_argument('-p', '--profile', type=int, default=1000,
help='Stride in genomic coordinates for depth profile.')
parser.add_argument('-O', '--orientation', choices=['fwd', 'rev'],
help='Sample only forward or reverse reads.')
parser.add_argument('-t', '--threads', type=int, default=-1,
help='Number of threads to use.')
parser.add_argument('-q', '--quality', type=float,
help='Filter reads by mean qscore.')
parser.add_argument('-a', '--accuracy', type=float,
help='Filter reads by accuracy.')
parser.add_argument('-c', '--coverage', type=float,
help='Filter reads by coverage (what fraction of the read aligns).')
parser.add_argument('-l', '--length', type=int, default=None,
help='Filter reads by read length.')
eparser = parser.add_mutually_exclusive_group()
eparser.add_argument('--any_fail', action='store_true',
help='Exit with an error if any region has insufficient coverage.')
eparser.add_argument('--all_fail', action='store_true',
help='Exit with an error if all regions have insufficient coverage.')
uparser = parser.add_argument_group('Uniform sampling options')
uparser.add_argument('-x', '--patience', default=5, type=int,
help='Maximum iterations with no change in median coverage before aborting.')
uparser.add_argument('-s', '--stride', type=int, default=1000,
help='Stride in genomic coordinates when searching for new reads. Smaller can lead to more compact pileup.')
pparser = parser.add_argument_group('Proportional sampling options')
pparser.add_argument('-P', '--proportional', default=False, action='store_true',
help='Activate proportional sampling, thus keeping depth variations of the pileup.')
pparser.add_argument('-S', '--seed', default=None, type=int,
help='Random seed for proportional downsampling of reads.')
args = parser.parse_args()
if args.threads == -1:
args.threads = multiprocessing.cpu_count()
with pysam.AlignmentFile(args.bam) as bam:
ref_lengths = dict(zip(bam.references, bam.lengths))
if args.regions is not None:
regions = parse_regions(args.regions, ref_lengths=ref_lengths)
else:
regions = [Region(ref_name=r, start=0, end=ref_lengths[r]) for r in bam.references]
if args.proportional:
worker = functools.partial(subsample_region_proportionally, args=args)
else:
worker = functools.partial(subsample_region_uniformly, args=args)
enough_depth = []
with ProcessPoolExecutor(max_workers=args.threads) as executor:
for res in executor.map(worker, regions):
enough_depth.append(res)
if args.any_fail and not all(enough_depth):
raise RuntimeError('Insufficient read coverage for one or more requested regions.')
if args.all_fail and all([not x for x in enough_depth]):
raise RuntimeError('Insufficient read coverage for all requested regions.')
[docs]def subsample_region_proportionally(region, args):
if args.quality is not None or args.coverage is not None or args.accuracy is not None:
raise NotImplemented('Read filtering is not currently supported for proportion subsampling')
logger = logging.getLogger(region.ref_name)
coverage_summary = coverage_summary_of_region(region, args.bam, args.stride)
col = 'depth_{}'.format(args.orientation) if args.orientation is not None else 'depth'
median_coverage = coverage_summary[col].T['50%']
with pysam.AlignmentFile(args.bam) as bam:
def _read_iter():
for r in bam.fetch(region.ref_name, region.start, region.end):
if not filter_read(r, bam, args, logger):
yield r
reads = _read_iter()
# query names cannot be longer than 251
dtype=[('name', 'U251'), ('start', int),('end', int)]
read_data = np.fromiter(
((r.query_name, r.reference_start, r.reference_end) for r in reads),
dtype=dtype
)
targets = sorted(args.depth)
found_enough_depth = True
coverage = np.zeros(region.end - region.start, dtype=np.uint16)
if args.seed is not None:
np.random.seed(args.seed)
for target in targets:
if target > median_coverage:
msg = 'Target depth {} exceeds median coverage {}, skipping this depth and higher depths.'
logger.info(msg.format(target, median_coverage))
found_enough_depth = False
break
fraction = target / median_coverage
n_reads = int(round(fraction * len(read_data), 0))
target_reads = np.random.choice(read_data, n_reads, replace=False)
prefix = '{}_{}X'.format(args.output_prefix, target)
_write_bam(args.bam, prefix, region, target_reads['name'])
coverage.fill(0.0) # reset coverage for each target depth
for read in target_reads:
coverage[read['start'] - region.start:read['end'] - region.start] += 1
_write_coverage(prefix, region, coverage, args.profile)
logger.info('Processed {}X: {} reads ({:.2f} %).'.format(target, n_reads, 100 * fraction))
return found_enough_depth
[docs]def filter_read(r, bam, args, logger):
"""Decide whether a read should be filtered out, returning a bool"""
# primary alignments
if (r.is_secondary or r.is_supplementary):
return True
# filter orientation
if (r.is_reverse and args.orientation == 'fwd') or \
(not r.is_reverse and args.orientation == 'rev'):
return True
# filter quality
if args.quality is not None:
mean_q = np.mean(r.query_qualities)
if mean_q < args.quality:
logger.debug("Filtering {} by quality ({:.2f}).".format(r.query_name, mean_q))
return True
# filter accuracy or alignment coverage
if args.accuracy is not None or args.coverage is not None or args.length is not None:
stats = stats_from_aligned_read(r, bam.references, bam.lengths)
if args.accuracy is not None and stats['acc'] < args.accuracy:
logger.info("Filtering {} by accuracy ({:.2f}).".format(r.query_name, stats['acc']))
return True
if args.coverage is not None and stats['coverage'] < args.coverage:
logger.info("Filtering {} by coverage ({:.2f}).".format(r.query_name, stats['coverage']))
return True
if args.length is not None and stats['read_length'] < args.length:
logger.info("Filtering {} by length ({:.2f}).".format(r.query_name, stats['length']))
return True
# don't filter
return False
def _nearest_overlapping_point(src, point):
"""Find the interval with the closest start point to a given point.
:param src: IntervalTree instance.
:param point: query point.
:returns: Interval instance of interval with closest start.
"""
items = src.at(point)
if len(items) == 0:
return None
items = sorted(items, key=lambda x: x.end - x.begin, reverse=True)
items.sort(key=lambda x: abs(x.begin - point))
return items[0]
def _write_bam(bam, prefix, region, sequences):
# filtered bam
sequences = set(sequences)
taken = set()
output = '{}_{}.{}'.format(prefix, region.ref_name, os.path.basename(bam))
src_bam = pysam.AlignmentFile(bam, "rb")
out_bam = pysam.AlignmentFile(output, "wb", template=src_bam)
for read in src_bam.fetch(region.ref_name, region.start, region.end):
if read.query_name in sequences and read.query_name not in taken:
out_bam.write(read)
taken.add(read.query_name)
src_bam.close()
out_bam.close()
pysam.index(output)
def _write_coverage(prefix, region, coverage, profile):
# depth profile
output = '{}_{}.depth'.format(prefix, region.ref_name)
end = profile * (len(coverage) // profile)
cov_blocks = coverage[0:end].reshape(-1, profile)
depth_profile = np.mean(cov_blocks, axis=1, dtype=np.uint32)
start = region.start + profile // 2
positions = (start + profile * x for x in range(len(depth_profile)))
with open(output, 'w') as fh:
fh.write("position\tdepth\n")
for pos, depth in zip(positions, depth_profile):
fh.write("{}\t{}\n".format(pos, depth))
if __name__ == '__main__':
main()