Snakefile "ymp/rules/diamond.rules"

Env(name="diamond", base="bioconda", packages=["diamond", "samtools"])

with Stage("index_diamond"):
    rule diamond_makedb:
        "Build Diamond index file"
        message:
            "Diamond: running makedb on {input}"
        input:
            ref = "{:prev:}/{:target:}.fastp.gz"
        params:
            db  = "{:this:}/{target}"
        output:
            db  = "{:this:}/{target}.dmnd"
        log:
                  "{:this:}/{target}.log"
        conda:
            "diamond"
        resources:
            mem = "32g",
        threads:
            16
        shell: """
        diamond makedb \
        --in {input.ref} \
        --db {params.db} \
        --threads {threads} \
        >{log} 2>&1
        """


with Stage("annotate_diamond") as S:
    S.doc("""
    FIXME
    """)
    rule diamond_blastx_fasta:
        message:
            "Diamond: running blastx on {input.fa}"
        input:
            fa = "{:prev:}/{:target:}.fasta.gz",
            db = "{:prev:}/{:target:}.dmnd"
        output:
            "{:this:}/{target}.daa"
        log:
            "{:this:}/{target}.log"
        resources:
            mem = "32g",
        threads:
            16
        conda:
            "diamond"
        shell: """
        diamond blastx \
        --db {input.db} \
        --query {input.fa} \
        --out {output} \
        --outfmt 100 \
        --threads {threads} >{log} 2>&1
        """

    rule diamond_view:
        "Convert Diamond binary output (daa) to BLAST6 format"
        message:
            "Diamond: converting {input} to BLAST6"
        input:
            "{:this:}/{target}.daa"
        output:
            temp("{:this:}/{target}.blast6")
        conda:
            "diamond"
        shell: """
        diamond view --daa {input} > {output}
        """


with Stage("map_diamond") as S:
    rule diamond_blastx_fastq: 
        input:
            fa = "{:prev:}/{target}.{:pairnames[0]:}.fq.gz",
            db = "{:prev:}/{:target:}.dmnd"
        output:
            "{:this:}/{target}.{:pairnames[0]:}.daa"
        log:
            "{:this:}/{target}.{:pairnames[0]:}.log"
        resources:
            mem = "32g",
        threads:
            16
        conda:
            "diamond"
        shell: """
        diamond blastx \
        --db {input.db} \
        --query {input.fa} \
        --out {output} \
        --outfmt 100 \
        --threads {threads} >{log} 2>&1
        """

    rule diamond_blastx_fastq2:  # ymp: extends diamond_blastx_fastq
        input:
            fa = "{:prev:}/{target}.{:pairnames[1]:}.fq.gz",
            db = "{:prev:}/{:target:}.dmnd"
        output:
            "{:this:}/{target}.{:pairnames[1]:}.daa"
        log:
            "{:this:}/{target}.{:pairnames[1]:}.log"

    rule diamond_view_2:
        "Convert Diamond binary output (daa) to BLAST6 format"
        message:
            "Diamond: converting {input} to BLAST6"
        input:
            "{:this:}/{target}.{:pairnames:}.daa"
        output:
            temp("{:this:}/{target}.{:pairnames[0]:}.blast6"),
            temp("{:this:}/{target}.{:pairnames[1]:}.blast6"),
        conda:
            "diamond"
        shell: """
        diamond view -a {input[0]} -f 6  > {output[0]}
        diamond view -a {input[1]} -f 6  > {output[1]}
        """

             
with Stage("count_diamond"):
    rule diamond_count:
        message:
            "Calculating hits per reference"
        input:
            "{:prev:}/{:target:}.{:pairnames:}.blast6"
        output:
            "{:this:}/{target}.counts.tsv"
        run:
            import collections
            import csv
            import ymp.blast
            f = [iter(ymp.blast.reader(open(input[n]), t=6)) for n in range(2)]
            f_last = [next(fn) for fn in f]
            f_sets = [[] for _ in f]
            seen = set()
            hits = collections.Counter()
            n = 1
            while f:
                n = 1-n if len(f) == 2 else 0
                cur_id = f_last[n].qseqid
                cur_set = [f_last[n]]
                for cur in f[n]:
                    if cur.qseqid == cur_id:
                        cur_set.append(cur)
                    else:
                        f_last[n] = cur
                        break
                if cur.qseqid == cur_id:
                    del f[n]
                cur_id = cur_id[:-2]
                if cur_id not in seen:
                    seen.add(cur_id)
                    f_sets[n].append(cur_set)
                else:
                    idx = next(n for n, hitset in enumerate(f_sets[1-n])
                               if hitset[0].qseqid[:-2] == cur_id)
                    tomerge = f_sets[1-n][idx] + cur_set
                    todump = f_sets[n] + f_sets[1-n][:idx]
                    f_sets[n] = []
                    f_sets[1-n] = f_sets[1-n][idx+1:]

                    for s in todump:
                        best = min(s, key=lambda h: h.evalue)
                        if best.evalue < 10**-5:
                            hits[best.sseqid] += 1
                    best = min(tomerge, key=lambda h: h.evalue)
                    if best.evalue < 10**-5:
                        hits[best.sseqid] += 2
            with open(output[0], "w") as out:
                w = csv.writer(out, delimiter="\t")
                w.writerow(["name", "count"])
                for pair in hits.items():
                    w.writerow(pair)