import logging
from itertools import combinations
import networkx as nx
import numpy as np
from pore_c.model import (
    AlignmentRecordDf,
    FragmentRecordDf,
    PoreCContactRecord,
    PoreCContactRecordDf,
    PoreCRecordDf,
)
logger = logging.getLogger(__name__)
[docs]def assign_fragments(
    align_table: AlignmentRecordDf,
    fragment_df: FragmentRecordDf,
    mapping_quality_cutoff: int = 1,
    min_overlap_length: int = 10,
    containment_cutoff: float = 99.0,
) -> PoreCRecordDf:
    from pore_c.model import PoreCRecord
    # initialise everything as having passed filter
    pore_c_table = PoreCRecord.init_dataframe(align_table)
    dtype = pore_c_table.dtypes
    align_types = pore_c_table.align_type.value_counts()
    some_aligns = align_types["unmapped"] != align_types.sum()
    if not some_aligns:
        logger.warning("No alignments in dataframe")
        pore_c_table["pass_filter"] = False
        pore_c_table["filter_reason"] = "unmapped"
        return pore_c_table
    fragment_assignments = assign_fragment(pore_c_table, fragment_df, min_overlap_length, containment_cutoff)
    pore_c_table = pore_c_table.set_index("align_idx", drop=False)
    pore_c_table.update(fragment_assignments)
    pore_c_table = pore_c_table.astype(dtype)
    # apply alignment-level filters
    unmapped_mask = pore_c_table.align_type == "unmapped"
    pore_c_table.loc[unmapped_mask, "pass_filter"] = False
    pore_c_table.loc[unmapped_mask, "filter_reason"] = "unmapped"
    # if not unmapped, but mapping quality below cutoff then fail
    fail_mq_mask = ~unmapped_mask & (pore_c_table.mapping_quality <= mapping_quality_cutoff)
    pore_c_table.loc[fail_mq_mask, "pass_filter"] = False
    pore_c_table.loc[fail_mq_mask, "filter_reason"] = "low_mq"
    # short overlap filter
    short_overlap_mask = ~(unmapped_mask | fail_mq_mask) & (pore_c_table.overlap_length < min_overlap_length)
    pore_c_table.loc[short_overlap_mask, "pass_filter"] = False
    pore_c_table.loc[short_overlap_mask, "filter_reason"] = "short_overlap"
    # no need to do other checks if nothing left
    if pore_c_table["pass_filter"].any():
        # for the remaining alignments filtering happens on a per-read basis
        by_read_res = (
            pore_c_table[pore_c_table.pass_filter]
            .groupby("read_idx", sort=False, as_index=False)
            .apply(apply_per_read_filters)
        )
        pore_c_table.update(by_read_res[["pass_filter", "filter_reason"]])
        pore_c_table = pore_c_table.astype(dtype)
    else:
        logger.warning("No alignments passed filter")
    pore_c_table.loc[pore_c_table.pass_filter, "filter_reason"] = "pass"
    return pore_c_table.reset_index(drop=True).sort_values(["read_idx", "align_idx"], ascending=True) 
[docs]def assign_fragment(pore_c_table, fragment_df, min_overlap_length: int, containment_cutoff: float):
    import pyranges as pr
    align_range = pr.PyRanges(
        pore_c_table[["chrom", "start", "end", "align_idx"]].rename(
            columns={"chrom": "Chromosome", "start": "Start", "end": "End"}
        )
    )
    fragment_range = pr.PyRanges(
        fragment_df[["chrom", "start", "end", "fragment_id"]].rename(
            columns={"chrom": "Chromosome", "start": "Start", "end": "End"}
        )
    )
    # all overlaps, one to many
    # TODO: pyranges API has changed so that the new_position call overwrites the alignment start and end
    overlaps = align_range.join(fragment_range).new_position("intersection")
    if len(overlaps) == 0:
        raise ValueError("No overlaps found between alignments and fragments, this shouldn't happen")
    overlaps = (
        overlaps.df.rename(
            columns={
                "Start": "start",
                "End": "end",
                "Start_a": "align_start",
                "End_a": "align_end",
                "Start_b": "fragment_start",
                "End_b": "fragment_end",
            }
        )
        .eval("overlap_length = (end - start)")
        # .query(f"overlap_length >= {min_overlap_length}")  # TODO: what if restriction fragment < minimum
        .eval("perc_of_alignment = (100.0 * overlap_length) / (align_end - align_start)")
        .eval("perc_of_fragment = (100.0 * overlap_length) / (fragment_end - fragment_start)")
        .eval(f"is_contained = (perc_of_fragment >= {containment_cutoff})")
    )
    # per-alignment statistics
    by_align = overlaps.groupby("align_idx", sort=True)
    rank = by_align["overlap_length"].rank(method="first", ascending=False).astype(int)
    overlaps["overlap_length_rank"] = rank
    best_overlap = overlaps[overlaps.overlap_length_rank == 1].set_index(["align_idx"])
    contained_fragments = (
        by_align["is_contained"]
        .agg(["size", "sum"])
        .astype({"sum": int})
        .rename(columns={"size": "num_overlapping_fragments", "sum": "num_contained_fragments"})
    )
    align_df = contained_fragments.join(
        best_overlap[
            [
                "fragment_id",
                "fragment_start",
                "fragment_end",
                "overlap_length",
                "perc_of_alignment",
                "perc_of_fragment",
                "is_contained",
            ]
        ]
    )
    dtype = {col: dtype for col, dtype in pore_c_table.dtypes.items() if col in align_df.columns}
    align_df = align_df.astype(dtype)
    return align_df 
[docs]def apply_per_read_filters(read_df):
    return read_df.pipe(filter_singleton).pipe(filter_exact_overlap_on_query).pipe(filter_shortest_path) 
[docs]def filter_singleton(read_df):
    if len(read_df) == 1:  # if you have a single alignment at this point you fail
        read_df.loc[:, "pass_filter"] = False
        read_df.loc[:, "filter_reason"] = "singleton"
    return read_df 
[docs]def filter_exact_overlap_on_query(read_df):
    overlap_on_read = read_df.duplicated(subset=["read_start", "read_end"], keep=False)
    if overlap_on_read.any():
        best_align_idx = read_df.loc[overlap_on_read, :].groupby(["read_start", "read_end"])["align_score"].idxmax()
        overlap_on_read[best_align_idx.values] = False
        read_df.loc[overlap_on_read, "pass_filter"] = False
        read_df.loc[overlap_on_read, "filter_reason"] = "overlap_on_read"
    return read_df 
[docs]def minimap_gapscore(length, o1=4, o2=24, e1=2, e2=1):
    return min([o1 + int(length) * e1, o2 + int(length) * e2]) 
[docs]def bwa_gapscore(length, O=5, E=2):  # noqa: E741
    # O=5 E=2 is default for bwa bwasw
    # O=6 E=1 is default for bwa mem
    return O + length * E 
[docs]def create_align_graph(aligns, gap_fn):
    # we'll visit the alignments in order of increasing endpoint on the read, need to keep
    # the ids as the index in the original list of aligns for filtering later
    aligns = aligns[["read_start", "read_end", "read_length", "align_score"]].copy().sort_values(["read_end"])
    node_ids = list(aligns.index)
    graph = nx.DiGraph()
    # initialise graph with root and sink node, and one for each alignment
    # edges in the graph will represent transitions from one alignment segment
    # to the next
    graph.add_nodes_from(["ROOT", "SINK"] + node_ids)
    for align in aligns.itertuples():
        align_idx = align.Index
        align_score = align.align_score
        gap_penalty_start = gap_fn(align.read_start)
        graph.add_edge("ROOT", align_idx, weight=gap_penalty_start - align_score)
        gap_penalty_end = gap_fn(int(align.read_length - align.read_end))
        graph.add_edge(align_idx, "SINK", weight=gap_penalty_end)
    # for each pair of aligned segments add an edge
    for idx_a, align_idx_a in enumerate(node_ids[:-1]):
        align_a_end = aligns.at[align_idx_a, "read_end"]
        for align_idx_b in node_ids[idx_a + 1 :]:  # noqa: E203 black does this
            align_b_score = aligns.at[align_idx_b, "align_score"]
            align_b_read_start = aligns.at[align_idx_b, "read_start"]
            gap_penalty = gap_fn(abs(int(align_b_read_start) - int(align_a_end)))
            graph.add_edge(align_idx_a, align_idx_b, weight=gap_penalty - align_b_score)
    return graph 
[docs]def filter_shortest_path(read_df, aligner="minimap2"):
    aligns = read_df[read_df.pass_filter]
    num_aligns = len(aligns)
    if num_aligns < 2:
        # can't build a graph, so nothing is filtered by this method
        return read_df
    if aligner == "minimap2":
        gap_fn = minimap_gapscore
    elif aligner == "bwa":
        gap_fn = bwa_gapscore
    else:
        raise ValueError(f"Unrecognised aligner: {aligner}")
    graph = create_align_graph(aligns, gap_fn)
    distance, shortest_path = nx.single_source_bellman_ford(graph, "ROOT", "SINK")
    for idx in aligns.index:
        if idx not in shortest_path:
            read_df.at[idx, "pass_filter"] = False
            read_df.at[idx, "filter_reason"] = "not_on_shortest_path"
    return read_df