import argparse
from collections import defaultdict, Counter, namedtuple
import concurrent.futures
from functools import partial
import itertools
import logging
from operator import attrgetter
import os
import pickle
import re
import unittest
import warnings
import matplotlib; matplotlib.use('Agg', force=True) # enforce non-interactive backend
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import pysam
from pomoxis.util import get_trimmed_pairs, intervaltrees_from_bed
AlignSeg = namedtuple('AlignSeg', ('rname', 'qname', 'pairs', 'rlen'))
Error = namedtuple('Error', ('rp', 'rname', 'qp', 'qname', 'ref', 'match', 'read', 'counts', 'klass', 'aggr_klass'))
Context = namedtuple('Context', ('p_i', 'qb', 'rb'))
_match_ = ' '
_indel_ = ':'
_sub_ = '|'
_sep_ = '\t'
_indel_sizes_=(5,10,100)
_error_groups_ = [
('HP sub swp', lambda x: 'HP sub swp' in x),
('HP sub ext', lambda x: 'HP sub ext' in x),
('HP sub trc', lambda x: 'HP sub trc' in x),
('sub HP split', lambda x: 'sub HP split' in x),
('ins HP split', lambda x: 'ins HP split' in x),
('sub HP join', lambda x: 'sub HP join' in x),
('del HP join', lambda x: 'del HP join' in x),
('multi ins >', lambda x: bool(re.search('multi .*ins >', x))),
('multi ins <', lambda x: bool(re.search('multi .*ins <', x))),
('multi del >', lambda x: bool(re.search('multi .*del >', x))),
('multi del <', lambda x: bool(re.search('multi .*del <', x))),
('HP del', lambda x: 'HP del' in x),
('HP ins', lambda x: 'HP ins' in x),
('HP sub', lambda x: 'HP sub' in x),
('fwd repeat ins', lambda x: 'fwd repeat ins' in x),
('rev repeat ins', lambda x: 'rev repeat ins' in x),
('fwd repeat del', lambda x: 'fwd repeat del' in x),
('rev repeat del', lambda x: 'rev repeat del' in x),
('ins', lambda x: 'ins' in x),
('del', lambda x: 'del' in x),
('sub', lambda x: 'sub' in x),
]
[docs]def get_errors(aln, tree=None):
"""Find positions of errors in an aligment.
:param aln: iterable of `AlignPos` objects.
:param bed_file: path to .bed file of regions to include in analysis.
:param tree: `intervaltree.IntervalTree` object of regions to analyse.
:returns: ( [(ri, qi, 'error_type', last_ri, last_qi)], aligned_ref_len)
ri, qi: ref and query positions
error_type: 'D', 'I' or 'S'
last_ri, last_qi: ref and query positions of the last match
aligned_ref_len: total aligned reference length (taking account of masking tree)
"""
err = []
last_qi = None
last_ri = None
pos = None
n_masked = 0
aligned_ref_len = 0
for (qi, qb, ri, rb) in aln:
if tree is not None:
pos = ri if ri is not None else pos
if not tree.overlaps(pos) or (ri is None and not tree.overlaps(pos + 1)):
# if ri is None, we are in an insertion, check if pos + 1 overlaps
# (ref position of ins is arbitrary)
# print('Skipping ref {}:{}'.format(read.reference_name, pos))
n_masked += 1
continue
if qi is None: # deletion
last_ri = ri # ri will not be None
err.append((ri, qi, 'D', (last_ri, last_qi)))
aligned_ref_len += 1
elif ri is None:
last_qi = qi # qi will not be None
err.append((ri, qi, 'I', (last_ri, last_qi)))
else:
last_qi = qi
last_ri = ri
if qb != rb:
err.append((ri, qi, 'S', (last_ri, last_qi)))
aligned_ref_len += 1
return err, aligned_ref_len, n_masked
[docs]def rle(it):
"""Calculate a run length encoding (rle), of an input vector.
:param it: iterable.
:returns: structured array with fields `start`, `length`, and `value`.
"""
val_dtype = np.array(it[0]).dtype
dtype = [('length', int), ('start', int), ('value', val_dtype)]
def _gen():
start = 0
for key, group in itertools.groupby(it):
length = sum(1 for x in group)
yield length, start, key
start += length
return np.fromiter(_gen(), dtype=dtype)
[docs]def get_run(i, runs):
"""Find run to which the i'th element belongs.
:param i: int, element index wihin input to `rle`.
:returns: int, element index within runs to which i belongs.
"""
ends = runs['start'] + runs['length']
start_i = min(np.searchsorted(runs['start'], i), len(runs) - 1)
end_i = min(np.searchsorted(ends, i), len(runs) - 1)
start = runs['start'][start_i]
end = ends[end_i] - 1
run_i = start_i if np.argmin(np.abs([start - i, end - i])) == 0 else end_i
return run_i
def _get_context_bounds(p, aln, search_by_q, offset):
"""Find start and en d of context.
In the simplest case this will be p-offset:p+offset, but we adjust
for boundaries and to ensure we don't start/end on an error or within a HP.
:param p: int, position (ref position, or query position)
:param aln: iterable of `AlignPos` objects.
:param search_by_q: bool, whether to search by query position (typically done if qi is None).
:returns: (int start index, int end index, int index of p within aln[start:end])
"""
if search_by_q:
pos = [x.qpos for x in aln]
else:
pos = [x.rpos for x in aln]
offset_tmp = offset
start_p = max(p - offset, pos[0])
end_p = min(p + offset, pos[-1])
s = pos.index(start_p)
e = pos.index(end_p)
# extend context until it does not start/end with an error or in a HP.
while True:
s_is_match = aln[s].qbase == aln[s].rbase
sq_not_hp = aln[s+1].qbase != aln[s].qbase
sr_not_hp = aln[s+1].rbase != aln[s].rbase
if (s_is_match and sq_not_hp and sr_not_hp) or s == 0:
break
else:
s -= 1
while True:
e_is_match = aln[e-1].qbase == aln[e-1].rbase
eq_not_hp = aln[e-1].qbase != aln[e-2].qbase
er_not_hp = aln[e-1].rbase != aln[e-2].rbase
if (e_is_match and eq_not_hp and er_not_hp) or e == len(aln):
break
else:
e += 1
p_i = pos.index(p)
return s, e, p_i - s
[docs]def are_adjacent(inds):
""""Check if all int indices in the interable are consecutive.
:param inds: iterable of ints
:returns: bool
"""
if len(inds) == 1:
adjacent = True
else:
adjacent = False
i = inds[0]
for n in inds[1:]:
if n - i != 1:
break
i = n
else:
adjacent = True
return adjacent
[docs]def is_in_hp(seq, p_i):
n_indels_fwd = len([x for x in itertools.takewhile(lambda x: x=='-', seq[p_i+1:])])
n_indels_rev = len([x for x in itertools.takewhile(lambda x: x=='-', seq[:p_i][::-1])])
if seq[p_i] == '-':
p_is_hp = seq[p_i - n_indels_rev - 1].upper() == seq[p_i + n_indels_fwd + 1].upper()
else:
p_is_hp = (seq[p_i].upper() == seq[p_i - 1 - n_indels_rev].upper() or
seq[p_i].upper() == seq[p_i + 1 + n_indels_fwd].upper())
return p_is_hp
[docs]def classify_hp_sub(p_i, adjacent, errors, match_line, rb_runs,
qb_runs, qp_is_hp, rp_is_hp):
hp_kls = None
# sub errors
# 'swap', e.g. r TTCCC -> q TTTCC or r CCTTCCC -> q CTTTTCC
# 'split', e.g. r TTTTTT -> q TTCCTT or r TTTTTT -> q TCTTCT
# 'join', e.g. r TTCCTT -> TTTTTT
# 'trunc', e.g. r TTTTTT -> q TTTTTC, r TTTTTT -> q CTTTTC,
# r TT -> q TC, r TT -> q GC
# 'ext', e.g. r TTTTTC -> q TTTTTT
rp_run_ind = get_run(p_i, rb_runs)
qp_run_ind = get_run(p_i, qb_runs)
rp_run = rb_runs[rp_run_ind]
qp_run = qb_runs[qp_run_ind]
sub_rb_inds = [get_run(i, rb_runs) for i in errors['sub']]
sub_qb_inds = [get_run(i, qb_runs) for i in errors['sub']]
# find query runs ref position run (if if ref is HP, get q over extent)
q_run_inds_in_r_run = set([get_run(i, qb_runs) for i in range(rp_run['start'], rp_run['start'] + rp_run['length'])])
q_runs_in_r_run = qb_runs[list(q_run_inds_in_r_run)]
r_run_inds_in_q_run = set([get_run(i, rb_runs) for i in range(qp_run['start'], qp_run['start'] + qp_run['length'])])
r_runs_in_q_run = rb_runs[list(r_run_inds_in_q_run)]
cols = ['value', 'length']
if qp_run[cols] == rp_run[cols]:
# this is a false alarm, though some HP in the context has changed,
# it was not as this position
return hp_kls
# if more than 1 query run is in ref run, we have a split / trunc
# swap is a special case of truncation in which another HP has gained a base
if len(q_run_inds_in_r_run) > 1:
# all subs should belong to same RHP, otherwise we have a mess.
if all([i == rp_run_ind for i in sub_rb_inds]):
hp_start_ind = rp_run['start']
hp_end_ind = hp_start_ind + rp_run['length'] - 1
match_runs = rle(match_line)
match_sub_inds = [get_run(i, match_runs) for i in errors['sub']]
# if subs are not flanking we have a split
if min(errors['sub']) > hp_start_ind and max(errors['sub']) < hp_end_ind:
hp_kls = 'sub HP split ({}{})'.format(rp_run['value'], rp_run['length'])
# if (possibly multiple adjacent) subs are all flanking the HP
elif set(match_sub_inds).issubset([get_run(hp_start_ind, match_runs),
get_run(hp_end_ind, match_runs)]):
# if query is a HP, this is a swap
if qp_run['length'] > 1:
hp_kls = 'HP sub swp ({}{}->{}{},{})'.format(rp_run['value'], rp_run['length'],
qp_run['value'], qp_run['length'] - len(errors['sub']), len(errors['sub']))
else:
hp_kls = 'flk HP sub trc ({}{}->{})'.format(rp_run['value'], rp_run['length'],
rp_run['length'] - len(errors['sub']))
else:
hp_kls = 'complex HP sub trc ({}{})'.format(rp_run['value'], rp_run['length'])
else:
hp_kls = 'messy HP sub trc'
# if more than 1 ref run is in query run, we have a join or extension
# swap is a special case of extension in which another HP has lost a base
elif len(r_run_inds_in_q_run) > 1:
# all subs should belong to same QHP, else we have a mess
if all([i == qp_run_ind for i in sub_qb_inds]):
hp_start_ind = qp_run['start']
hp_end_ind = hp_start_ind + qp_run['length'] - 1
match_runs = rle(match_line)
match_sub_inds = [get_run(i, match_runs) for i in errors['sub']]
# if subs are not flanking we have a join
if min(errors['sub']) > hp_start_ind and max(errors['sub']) < hp_end_ind:
hp_kls = 'sub HP join ({}{})'.format(qp_run['value'], qp_run['length'])
# if (possibly multiple adjacent) subs are all flanking the HP
elif set(match_sub_inds).issubset([get_run(hp_start_ind, match_runs),
get_run(hp_end_ind, match_runs)]):
# if ref is a HP, this is a swap
if rp_run['length'] > 1:
hp_kls = 'HP sub swp ({}{}->{}{},{})'.format(rp_run['value'], rp_run['length'],
qp_run['value'], qp_run['length'] - len(errors['sub']), len(errors['sub']))
else:
hp_kls = 'flk HP sub ext ({}{}->{})'.format(qp_run['value'], qp_run['length']- len(errors['sub']), qp_run['length'])
else:
hp_kls = 'complex HP sub ext ({}{})'.format(qp_run['value'], qp_run['length'])
else:
hp_kls = 'messy HP sub ext'
# if we just have 1 run of each, we might have a complete sub of the HP
# e.g. TT->CC
elif len(r_run_inds_in_q_run) == 1 and len(q_run_inds_in_r_run) == 1:
if rp_run['value'] != qp_run['value']:
hp_kls = 'HP sub swp ({}{}->{}{},{})'.format(rp_run['value'], rp_run['length'],
qp_run['value'], 0, qp_run['length'])
return hp_kls
[docs]def classify_hp_indel(p_i, key, errors, runs1, seq2):
"""Look for a specific kind of HP indel that splits or joins two HPs
:param p_i: int, index of error
:param key: key of error type within errors (should be 'ins' or 'del')
:param errors: dict of error positions
:runs1: np.ndarray, rle encoding of sequence1 (query if deletions
join two HPs, or ref if insertions split a HP
:seq2: iterable of str of sequence2 (ref if deletions join two HPs,
or query if insertions split a HP.
:returns: str classification or None
"""
# look for deletions in query HPs which indicate two ref HPs are joined.
# or look for insertions in ref HPs which split them
assert key in ['ins', 'del']
_type_ = {'ins': 'split', 'del': 'join'}
hp_kls = None
p_run_ind = get_run(p_i, runs1)
# if ref/query is not a HP, return
if runs1[p_run_ind - 1]['value'] != runs1[p_run_ind + 1]['value']:
logging.debug('Not a HP {} {}'.format(runs1[p_run_ind - 1]['value'],
runs1[p_run_ind + 1]['value']))
return hp_kls
# if we have multiple non-adjacent indels, we could have
# HP which should be split into >2 runs or HP joined from >2 HPs
hp_base = runs1[p_run_ind - 1]['value']
helper = lambda x: runs1[x]['value'] in ['-', hp_base]
run_inds_in_hp = [i for i in itertools.takewhile(helper, range(p_run_ind - 1, 0, -1))]
run_inds_in_hp += [i for i in itertools.takewhile(helper, range(p_run_ind, len(runs1)))]
runs1_in_hp = runs1[run_inds_in_hp]
runs1_in_hp = runs1_in_hp[np.where(runs1_in_hp['value'] == hp_base)]
hp_len = np.sum(runs1_in_hp['length'])
hp_start_ind = runs1_in_hp[0]['start']
hp_end_ind = runs1_in_hp[-1]['start'] + runs1_in_hp[-1]['length'] - 1
# check that the HPs in the query and ref actually differ (some insertions
# might preserve the HP length e.g. CG--GC -> CGGCGC, in which case the HP
# is not a split/join
runs2_in_hp = rle(seq2[hp_start_ind: hp_end_ind + 1])
runs2_in_hp = runs2_in_hp[np.where(runs2_in_hp['length'] > 1)]
if (len(runs2_in_hp) > 0 and
np.any(np.logical_and(runs2_in_hp['length'] == hp_len,
runs2_in_hp['value'] == hp_base))):
return hp_kls
# we have a split/join if any indels are within hp
if any([i > hp_start_ind and i < hp_end_ind for i in errors[key]]):
# if any indels are outside HP run, we have a mess
if min(errors[key]) < hp_start_ind or max(errors[key]) > hp_end_ind:
subtype = 'messy'
elif are_adjacent(errors[key]):
if len(errors[key]) == 1:
subtype = 'simple'
else:
subtype = 'multi'
else:
subtype = 'complex'
hp_kls = '{} {} HP {} ({}{})'.format(subtype, key, _type_[key], hp_base, hp_len)
return hp_kls
[docs]def get_match_line_and_err_index(context):
match_line = ''
errors = defaultdict(list)
k = ''
for i, (q, r) in enumerate(zip(context.qb, context.rb)):
if q == r:
match_line += _match_
elif q == '-' or r == '-': # indel
match_line += _indel_
k = 'del' if q == '-' else 'ins'
else: # sub
match_line += _sub_
k = 'sub'
if match_line[-1] != _match_:
errors[k].append(i)
if i == context.p_i: # this is the central error we are classifying
p_k = k
return match_line, errors, p_k
[docs]def preprocess_error(p, aln, search_by_q, offset=10):
"""
:param p: int, position (ref position, or query position)
:param aln: iterable of `AlignPos` objects.
:param search_by_q: bool, whether to search by query position (typically done if qi is None).
:returns: `Context` object
"""
s, e, p_i = _get_context_bounds(p, aln, search_by_q, offset)
sl = aln[s:e]
qi, qb, ri, rb = zip(*sl)
return Context(p_i, qb, rb)
[docs]def simple_klass(adjacent, n, err_type, indel_sizes):
if not adjacent:
descr = 'complex'
else:
descr = 'simple' if n == 1 else 'multi'
if 'ins' in err_type or 'del' in err_type:
size = _get_size(n, indel_sizes)
klass = "{} {} {}".format(descr, err_type, size)
else:
klass = "{} {}".format(descr, err_type)
return klass
[docs]def classify_error(context, indel_sizes=None):
"""Classify error within an alignment.
:param context: `Context` object
:indel_sizes: iterable of int, for binning indel sizes.
indels >= to indel_sizes[0] will not be considered as HP splitting/joining indels
:returns: (str reference_context,
str match_line,
str query_context,
dict counts of sub/ins/del within context)
"""
if indel_sizes is None:
indel_sizes = _indel_sizes_
p_i, qb, rb = context
match_line, errors, p_k = get_match_line_and_err_index(context)
# find homopolymers within context
qb_runs = rle([b.upper() for b in qb])
rb_runs = rle([b.upper() for b in rb])
qb_hps = qb_runs[np.where(qb_runs['length'] > 1)]
rb_hps = rb_runs[np.where(rb_runs['length'] > 1)]
cols = ['value', 'length']
hp_changed = len(qb_hps) != len(rb_hps) or np.any(qb_hps[cols] != rb_hps[cols])
# hp might also be changed if we have an ins in the middle of a ref hp
if not hp_changed:
hp_changed = any([(rb_runs[i]['value'] == '-' and
rb_runs[i - 1]['value'] == rb_runs[i + 1]['value'])
for i in range(1, len(rb_runs) - 1)])
# hp might also be changed if we have an del in the middle of a query hp
if not hp_changed:
hp_changed = any([(qb_runs[i]['value'] == '-' and
qb_runs[i - 1]['value'] == qb_runs[i + 1]['value'])
for i in range(1, len(qb_runs) - 1)])
hp_kls = None
# Try to class the error, make it up as we go along!
dels = len(errors['del'])
ins = len(errors['ins'])
indels = dels + ins
subs = len(errors['sub'])
# is this position part of a hp in query or ref?
qp_is_hp = is_in_hp(qb, p_i)
rp_is_hp = is_in_hp(rb, p_i)
hp_subtypes = {(False, False): '',
(True, False): 'RHP',
(False, True): 'QHP',
(True, True): 'HP',
}
hp_subtype = hp_subtypes[(rp_is_hp, qp_is_hp)]
err_type = '{} {}'.format(hp_subtype, p_k) if hp_subtype != '' else p_k
kls = 'unknown ({})'.format(err_type)
if (subs > 0 and indels > 0) or (ins > 0 and dels >0):
kls = 'complex mess ({})'.format(err_type)
elif indels == 0: # we have a sub
adjacent = are_adjacent(errors['sub'])
kls = simple_klass(adjacent, subs, err_type, indel_sizes)
if hp_changed:
hp_kls = classify_hp_sub(p_i, adjacent, errors, match_line, rb_runs,
qb_runs, qp_is_hp, rp_is_hp)
kls = hp_kls if hp_kls is not None else kls
elif dels == 0:
adjacent = are_adjacent(errors['ins'])
if hp_changed and ins < indel_sizes[0]:
hp_kls = classify_hp_indel(p_i, 'ins', errors, rb_runs, qb)
kls = hp_kls if hp_kls is not None else kls
if hp_kls is None:
kls = simple_klass(adjacent, ins, err_type, indel_sizes)
if adjacent:
i = errors['ins'][0]
b = qb[i]
q_hp_len = len([x for x in itertools.takewhile(lambda x: x==b, qb[i:])])
r_hp_len = len([x for x in itertools.takewhile(lambda x: x==b, rb[i+ins:])])
inserted = qb[i:i + ins]
# get ref_after insertions to check if we have a repeat
repeat_start = i + ins
repeat_end = repeat_start + ins
if repeat_end <= len(rb):
ref_after = rb[repeat_start: repeat_end]
else: # we don't have enough context after the insertion
ref_after = []
if (q_hp_len > 1 or r_hp_len > 1) and all([qb[j] == b and rb[j] == '-' for j in errors['ins']]):
kls = "HP ins ({}{}->{})".format(b, r_hp_len, q_hp_len)
elif inserted == ref_after:
kls = "fwd repeat ins len {}".format(ins)
elif inserted == ref_after[::-1]:
kls = "rev repeat ins len {}".format(ins)
elif ins == 0:
adjacent = are_adjacent(errors['del'])
if hp_changed and dels < indel_sizes[0]:
hp_kls = classify_hp_indel(p_i, 'del', errors, qb_runs, rb)
kls = hp_kls if hp_kls is not None else kls
if hp_kls is None:
kls = simple_klass(adjacent, dels, err_type, indel_sizes)
if adjacent:
i = errors['del'][0]
b = rb[i]
# we use aln_rb instead of rb, so we can find the length
# of HPs which extent beyond the context.
hp_len = len([x for x in itertools.takewhile(lambda x: x==b, rb[i:])])
deleted = rb[i:i + dels]
# get ref_after deletions to check if we have a repeat
repeat_start = i + dels
repeat_end = repeat_start + dels
if repeat_end <= len(rb):
ref_after = rb[repeat_start: repeat_end]
else: # we don't have enough context after the insertion
ref_after = []
if hp_len > 1:
if all([rb[j] == b for j in errors['del']]):
kls = "HP del ({}{}->{})".format(b, hp_len, hp_len - dels)
elif deleted == ref_after:
kls = "fwd repeat del len {}".format(dels)
elif deleted == ref_after[::-1]:
kls = "rev repeat del len {}".format(dels)
return ''.join(rb), match_line, ''.join(qb), errors, kls
def _get_size(n, sizes):
l_i = np.searchsorted(sizes, n)
if l_i == len(sizes):
size = "> {}".format(sizes[-1])
else:
size = "<= {}".format(sizes[l_i])
return size
def _process_read(bam, read_num, bed_file=None):
"""Load an alignment from bam and return result of `_process_seg`.
:param bam: str, bam file.
:param read_num: int, index of alignment to process.
:param bed_file: path to .bed file of regions to include in analysis.
:returns: result of `_process_seg`.
"""
trees = None
if bed_file is not None:
trees = intervaltrees_from_bed(bed_file)
with pysam.AlignmentFile(bam, 'rb') as bam_obj:
gen = (r for r in bam_obj)
for i in range(read_num + 1):
rec = next(gen)
if rec.is_unmapped or rec.is_supplementary or rec.is_secondary:
return
if bed_file is not None:
tree = trees[rec.reference_name]
if not tree.overlaps(rec.reference_start, rec.reference_end):
#sys.stderr.write('read {} does not overlap with any regions in bedfile\n'.format(rec.query_name))
return
else:
tree = None
seg = AlignSeg(rname=rec.reference_name, qname=rec.query_name,
pairs=list(get_trimmed_pairs(rec)), rlen=rec.reference_length
)
logging.debug('Loaded query {}'.format(seg.qname))
return _process_seg(seg, tree)
def _process_seg(seg, tree=None):
"""Classify and count errors within an `AlignSeg` object.
:param seg: `AlignSeg` object.
:param tree: `intervaltree.IntervalTree` object of regions to analyse.
:returns: (seg.rname, aligned_ref_len, error_count, errors, n_masked)
error_count: `Counter` of error classes
errors: list of `Error` objects
n_masked: number of reference positions excluded by tree.
"""
error_count = Counter()
errors = []
pos_and_errors, aligned_ref_len, n_masked = get_errors(seg.pairs, tree)
for ri, qi, error, approx_pos in pos_and_errors:
ref, match, read, counts, klass = classify_error(preprocess_error(
ri if ri is not None else qi, seg.pairs, search_by_q=(ri is None)
))
rp = ri
qp = qi
if rp is None:
rp = "~{}".format(approx_pos[0])
if qp is None:
qp = "~{}".format(approx_pos[1])
errors.append(Error(rp=rp, rname=seg.rname, qp=qp, qname=seg.qname, ref=ref,
match=match, read=read, counts=counts, klass=klass,
aggr_klass=get_aggr_klass(klass)))
error_count[klass] += 1
if tree is None:
assert seg.rlen == aligned_ref_len
logging.debug('Done processing {} aligned to {}'.format(seg.qname, seg.rname))
return seg.rname, aligned_ref_len, error_count, errors, n_masked
[docs]def qscore(d):
"""Calculate a qscore"""
with warnings.catch_warnings():
# some values might be zero
warnings.simplefilter("ignore")
q = -10 * np.log10(d)
return q
[docs]def analyze_counts(counts, total_ref_length):
df = pd.DataFrame({'klass': list(counts.keys()),
'count': list(counts.values())})
df['err_rate'] = df['count'] / total_ref_length
df['fraction_of_total'] = df['count'] / df['count'].sum()
df.sort_values('count', ascending=True, inplace=True)
df['remaining_count'] = df['count'].cumsum()
df['remaining_err_rate'] = df['err_rate'].cumsum()
df['err_rate_q'] = qscore(df['err_rate'])
df['remaining_err_rate_q'] = qscore(df['remaining_err_rate'])
df.sort_values('remaining_count', ascending=False, inplace=True)
return df
[docs]def plot_summary(df, outdir, prefix, ref_len):
"""Create a plot showing Q-scores as largest remaining
error klass is removed"""
fig, ax = plt.subplots()
fig.subplots_adjust(left=0.3)
y_pos = np.arange(len(df) + 1)
no_error_score = -10 * np.log10(1/ref_len)
ax.barh(y_pos, df['remaining_err_rate_q'].append(pd.Series(no_error_score)), align='center', color='green', ecolor='black')
ax.set_xlabel('Q(Accuracy)')
ax.set_ylabel('Error Class')
ax.set_ylim((y_pos[0]-0.5, y_pos[-1]+0.5))
ax.set_yticks(y_pos)
ax.set_yticklabels(['total error'] + list(df['klass']))
ax.invert_yaxis() # labels read top-to-bottom
xstart, xend = ax.get_xlim()
ystart, yend = ax.get_ylim()
ax.text(xend - 2.25, ystart - 0.25, '+')
ax.set_title('Q-score after removing error class')
fp = os.path.join(outdir, '{}_remaining_errors.png'.format(prefix))
fig.savefig(fp)
plt.close()
[docs]def get_aggr_counts(total_counts):
# get errors per type
aggregate_counts = Counter()
max_indel = max(_indel_sizes_)
for key, val in total_counts.items():
aggregate_counts[get_aggr_klass(key)] += val
return aggregate_counts
[docs]def get_aggr_klass(klass):
max_indel = max(_indel_sizes_)
for error_type, is_type in _error_groups_:
if is_type(klass):
if '>' in error_type or '<' in error_type:
aggr_klass = '{} {}'.format(error_type, max_indel)
else:
aggr_klass = error_type
break
return aggr_klass
[docs]def main():
logging.basicConfig(format='[%(asctime)s - %(name)s] %(message)s', datefmt='%H:%M:%S', level=logging.INFO)
parser = argparse.ArgumentParser(
prog='catalogue_errors',
description='Create a catalogue of all query errors in a bam.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('bam', help='Input alignments (aligned to ref).')
parser.add_argument('--bed', default=None, help='.bed file of reference regions to include.')
parser.add_argument('-t', '--threads', type=int, default=1, help='Number of threads for parallel execution.')
parser.add_argument('-o', '--outdir', default='error_catalogue', help='Output directory.')
args = parser.parse_args()
os.mkdir(args.outdir)
with pysam.AlignmentFile(args.bam, 'rb') as bam:
n_reads = bam.count()
total_ref_length = defaultdict(int)
total_n_ref_sites_masked = defaultdict(int)
error_count = defaultdict(Counter)
f = partial(_process_read, args.bam, bed_file=args.bed)
# record draft start position of each long multi indel
multi_errs = {}
# make an approximate position into int
helper = lambda x: int(x.replace('~','')) if isinstance(x, str) else x
db_fh = open(os.path.join(args.outdir, 'error_catalogue_db.txt'), 'w')
txt_fh = open(os.path.join(args.outdir, 'error_catalogue.txt'), 'w')
headers = [('ref_name', attrgetter('rname')),
('ref_pos', helper(attrgetter('rp'))),
('ref_context', attrgetter('ref')),
('query_name', attrgetter('qname')),
('query_pos', helper(attrgetter('qp'))),
('query_context', attrgetter('read')),
('class', attrgetter('klass')),
('aggr_class', attrgetter('aggr_klass')),
('n_ins', lambda e: len(e.counts['ins'])),
('n_del', lambda e: len(e.counts['del'])),
('n_sub', lambda e: len(e.counts['sub'])),
('context_len', lambda e: len(e.ref)),
]
db_fh.write(_sep_.join((h[0] for h in headers)) + '\n')
with concurrent.futures.ProcessPoolExecutor(max_workers=args.threads) as ex:
for returned in ex.map(f, range(n_reads)):
if returned is None:
continue
else:
ref_name, ref_length, counts, errors, n_masked = returned
error_count[ref_name].update(counts)
total_ref_length[ref_name] += ref_length
total_n_ref_sites_masked[ref_name] += n_masked
for e in errors:
db_fh.write(_sep_.join((str(h[1](e)) for h in headers)) + '\n')
txt_fh.write("Ref Pos: {}, {} Pos {}, {}, {}\n".format(e.rp, e.qname, e.qp, e.klass, e.aggr_klass))
txt_fh.write(e.ref + "\n")
txt_fh.write(e.match + "\n")
txt_fh.write(e.read + "\n")
txt_fh.write(".\n")
total_counts = Counter()
aggr_by_ref = {}
for ref_name, counts in error_count.items():
df = analyze_counts(counts, total_ref_length[ref_name])
fp = os.path.join(args.outdir, '{}_error_summary.txt'.format(ref_name))
df.to_csv(fp, sep=_sep_, index=False)
aggr_by_ref[ref_name] = get_aggr_counts(counts)
df = analyze_counts(aggr_by_ref[ref_name], total_ref_length[ref_name])
fp = os.path.join(args.outdir, '{}_aggr_error_summary.txt'.format(ref_name))
df.to_csv(fp, sep=_sep_, index=False)
plot_summary(df, args.outdir, '{}_aggr'.format(ref_name),
ref_len=total_ref_length[ref_name])
total_counts.update(counts)
df = analyze_counts(total_counts, sum(total_ref_length.values()))
fp = os.path.join(args.outdir, '{}_error_summary.txt'.format('total'))
df.to_csv(fp, sep=_sep_, index=False)
aggregate_counts = get_aggr_counts(total_counts)
df = analyze_counts(aggregate_counts, sum(total_ref_length.values()))
fp = os.path.join(args.outdir, '{}_aggr_error_summary.txt'.format('total'))
df.to_csv(fp, sep=_sep_, index=False)
plot_summary(df, args.outdir, '{}_aggr'.format('total'),
ref_len=sum(total_ref_length.values()))
# save counts to yaml for any further analysis
to_save = {'ref_lengths': total_ref_length,
'n_ref_sites_masked': total_n_ref_sites_masked,
'counts': {'by_ref': error_count,
'by_ref_aggr': aggr_by_ref,
'total': total_counts,
'total_aggr': aggregate_counts,
}
}
with open(os.path.join(args.outdir, 'counts.pkl'), 'wb') as fh:
pickle.dump(to_save, fh)
db_fh.close()
txt_fh.close()
logging.info('All done, check {} for output.'.format(args.outdir))
if __name__ == '__main__':
main()
[docs]class ClassifyErrorTest(unittest.TestCase):
[docs] def test_hp_del_6_5(self):
rb = 'ACAACAGCAGAAAAAACAGGA'
qb = 'ACAACAGCAG-AAAAACAGGA'
p_i = 10
expected = 'HP del (A6->5)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_hp_del_2_0(self):
rb = 'CACTTTCGGCTTGAGGATCA'
qb = 'CACTTTCGGC--GAGGATCA'
p_i = 10
expected = 'HP del (T2->0)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_multi_ins(self):
rb = 'ATGTAATGCC---AAGCTTA'
qb = 'ATGTAATGCCAGAAAGCTT'
p_i = 10
expected = 'multi ins <= 5'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_multi_del(self):
rb = 'ATGTAATGCCAGAAAGCTT'
qb = 'ATGTAATGCC---AAGCTTA'
p_i = 10
expected = 'multi del <= 5'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_hp_ins(self):
rb = 'CACCTGGTGC-AAAAGAGAG'
qb = 'CACCTGGTGCAAAAAGAGAG'
p_i = 10
expected = 'HP ins (A4->5)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_sub_split(self):
rb = 'TAATCTGGCCcCTGCAATGC'
qb = 'TAATCTGGCCTCTGCAATGC'
p_i = 10
expected = 'sub HP split (C4)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_sub_swap_trc(self):
rb = 'ACTGCGTACCtTTTGTATAAT'
qb = 'ACTGCGTACCCTTTGTATAAT'
p_i = 10
expected = 'HP sub swp (T4->C2,1)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_sub_swap_trc2(self):
rb = 'ACTGCGTACCttTTGTATAAT'
qb = 'ACTGCGTACCCCTTGTATAAT'
p_i = 10
expected = 'HP sub swp (T4->C2,2)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
def test_HP_sub_swap_ext(self):
rb = 'ACTGCGTACcTTTTGTATAAT'
qb = 'ACTGCGTACTTTTTGTATAAT'
p_i = 9
expected = 'HP sub swp (T4->C2,1)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_sub_swap_ext(self):
rb = 'ACTGCGTAccTTTTGTATAAT'
qb = 'ACTGCGTATTTTTTGTATAAT'
p_i = 9
expected = 'HP sub swp (C2->T4,2)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_swap_sub_trc(self):
rb = 'CGGGCCTTCCCCtTGCCATTCA'
qb = 'CGGGCCTTCCCCCTGCCATTCA'
p_i = 12
expected = 'HP sub swp (T2->C4,1)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_flk_sub_trc(self):
rb = 'CGAGAAAATCgGGATCGTTG'
qb = 'CGAGAAAATCAGGATCGTTG'
p_i = 10
expected = 'flk HP sub trc (G3->2)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_flk_sub_ext(self):
rb = 'ATGCAACAAGcTTACGCTGC'
qb = 'ATGCAACAAGTTTACGCTGC'
p_i = 10
expected = 'flk HP sub ext (T2->3)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_complex_sub(self):
rb = 'CGAGAAAAAAAGGATCGTTG'
qb = 'CGAGTATAAAAGGATCGTTG'
p_i = 4
expected = 'complex HP sub trc (A7)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_messy_sub(self):
rb = 'CTAaACtGCcGTG'
qb = 'CTACACCGCTGTG'
p_i = 3
expected = 'messy HP sub trc'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_sub_join(self):
rb = 'AGGAACGAATcTCTGAAGCG'
qb = 'AGGAACGAATTTCTGAAGCG'
p_i = 10
expected = 'sub HP join (T3)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_not_complete_sub(self):
rb = 'AGGAACGAATTTCTGAAGCG'
qb = 'AGGAACGAACCCCTGAAGCG'
p_i = 10
expected = 'HP sub swp (T3->C1,3)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_complete_sub(self):
rb = 'AGGAACGAATTTGTGAAGCG'
qb = 'AGGAACGAACCCGTGAAGCG'
p_i = 10
expected = 'HP sub swp (T3->C0,3)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_messy_sub_ext(self):
rb = 'CGGGTCTTTTctTTTTCATC'
qb = 'CGGGTCTTTTTCTTTTCATC'
p_i = 10
expected = 'messy HP sub ext'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_del_join(self):
rb = 'CGGGTCTTTTCTTTTCATC'
qb = 'CGGGTCTTTT-TTTTCATC'
p_i = 10
expected = 'simple del HP join (T8)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_del_join2(self):
rb = 'CGGGTCTTTTCCTTTTCATC'
qb = 'CGGGTCTTTT--TTTTCATC'
p_i = 10
expected = 'multi del HP join (T8)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_del_join3(self):
rb = 'CGGGTCTTTTCATTTCATC'
qb = 'CGGGTCTTTT--TTTCATC'
p_i = 10
expected = 'multi del HP join (T7)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_del_join4(self):
rb = 'CGGGTCTTTTCATTTCATC'
qb = 'CGG-TCTTTT--TTTCATC'
p_i = 10
expected = 'messy del HP join (T7)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_ins_split(self):
rb = 'CGGGTCTT-TTTCATC'
qb = 'CGGGTCTTGTTTCATC'
p_i = 8
expected = 'simple ins HP split (T5)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_ins_split2(self):
rb = 'CGGGTCTT--TTTCATC'
qb = 'CGGGTCTTGGTTTCATC'
p_i = 8
expected = 'multi ins HP split (T5)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_ins_split3(self):
rb = 'CGGGTCTT--TTTCATC'
qb = 'CGGGTCTTGATTTCATC'
p_i = 8
expected = 'multi ins HP split (T5)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_ins_split4(self):
rb = 'CGGGTCTT--TT-TCATC'
qb = 'CGGGTCTTGATTGTCATC'
p_i = 8
expected = 'complex ins HP split (T5)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_ins_split5(self):
rb = 'CGGG-TCTT--TT-TCATC'
qb = 'CGGGGTCTTGATTGTCATC'
p_i = 9
expected = 'messy ins HP split (T5)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
def test_fwd_repeat_ins(self):
rb = 'ACCTATAACG--GCGCGCTG'
qb = 'ACCTATAACGGCGCGCGCTG'
p_i = 10
expected = 'fwd repeat ins len 2'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_fwd_repeat_ins(self):
rb = 'ACCTATAACG--GCGCGCTG'
qb = 'ACCTATAACGCGGCGCGCTG'
p_i = 10
expected = 'rev repeat ins len 2'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_HP_ins_split6(self):
rb = 'GCCGATTTTT-TCTCCCGTA'
qb = 'GCCGATTTTTCTCTCCCGTA'
p_i = 10
expected = 'simple ins HP split (T6)'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_complex_HP_del(self):
rb = 'AGGGGGGGGACTTGAACCCCCACGTC'
qb = 'A-GGGGGGGACTTGAA-CCCCACGTC'
p_i = 16
expected = 'complex RHP del <= 5'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_long_multi_del(self):
rb = 'ACCCACACACCACACCCACACACCACACCCACACCACACCCACACCACACCCACACACCACACCCACACCACACCCACACACCACACCCACACACCACACCCACACCACACCCACACCACACCCACACACCACACCACACCCACACACCCACACACCACACACTCTCTTACATCTACCTCTACTCTCGCTGTCACACCTTACCCGGCTTTCTGACCGAAATTAAAAAAAATGAAAATGAAATCCTCTTCTTTAGCCCTACAACACTTTTACATAGCCCTAAATAGCCCTAAATAGCCCTCATGTACGTCTCCTCCAAGCCCTATTGACTCTTACCCGGAGTTTCAGCTAAAGGCTATACTTACT'
qb = 'ACCCACACACCA---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------T'
p_i = 12
expected = 'multi del > 100'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)
[docs] def test_long_multi_ins(self):
rb = 'ACCCACACACCA---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------T'
qb = 'ACCCACACACCACACCCACACACCACACCCACACCACACCCACACCACACCCACACACCACACCCACACCACACCCACACACCACACCCACACACCACACCCACACCACACCCACACCACACCCACACACCACACCACACCCACACACCCACACACCACACACTCTCTTACATCTACCTCTACTCTCGCTGTCACACCTTACCCGGCTTTCTGACCGAAATTAAAAAAAATGAAAATGAAATCCTCTTCTTTAGCCCTACAACACTTTTACATAGCCCTAAATAGCCCTAAATAGCCCTCATGTACGTCTCCTCCAAGCCCTATTGACTCTTACCCGGAGTTTCAGCTAAAGGCTATACTTACT'
p_i = 12
expected = 'multi ins > 100'
found = classify_error(Context(p_i=p_i, qb=list(qb), rb=list(rb)))[-1]
self.assertEqual(found, expected)