Snakefile "ymp/rules/blast.rules"

Env(name="blast", base="bioconda", packages="blast")

# "std" options output:
BLASTFMT =  "7 qacc sacc pident length mismatch gapopen qstart qend sstart send evalue bitscore"
# extra output
BLASTFMT += " sstrand sframe qframe score qlen stitle staxids"

BLASTIDX_SUFFIXES = "nin nhr nsi nsd nsq nog".split()
BLASTSPLITIDX_SUFFIXES = "nal ndb nos not ntf nto".split()
BLASTN_SUFFIXES = "nin nhr nsq".split()

blast_assembly_results = "{dir}.blast/{: targets :}.{query}.blast7"
blast_assembly_results_fa  = "{dir}.blast/{: targets :}.{query}.{gene}.fasta"
blast_assembly_results_csv = "{dir}.blast/{: targets :}.{query}.{gene}.csv"
blast_assembly_results_re = "{dir}.blast/(.+).{query}.blast7"
blast_results_by_gene  = "{dir}.blast/{query}.{gene}.blast7"
blast_results_by_gene_fa  = "{dir}.blast/{query}.{gene}.fasta.gz"
blast_results_by_gene_csv = "{dir}.blast/{query}.{gene}.csv"


with Stage("index_blast"):
    rule blast_makedb:
        "Build Blast index"
        message:
            "BLAST: indexing {params.db}"
        input:
            ref   = "{:prev:}/{:target:}.fasta.gz"
        output:
            db    = expand("{{:this:}}/{{target}}.{ext}", ext=BLASTIDX_SUFFIXES)
        params:
            db    = "{:this:}/{target}",
            typ   = "nucl",
            title = "{target}"
        log:
            "{:this:}/{target}.blast.log"
        threads:
            1
        conda:
            "blast"
        shell: """
        gunzip -c {input.ref} | \
        makeblastdb \
        -in - \
        -dbtype {params.typ} \
        -parse_seqids \
        -out {params.db} \
        -title {params.title} \
        &> {log} 2>&1
        """
        # FIXME: check for "-" in fasta header - blast does not like those

    rule blast_makedb_all:
        input: expand("{{:this:}}/{{:targets:}}.{ext}", ext=BLASTIDX_SUFFIXES)
        output: touch("{:this:}/all_targets.stamp")



with Stage("annotate_tblastn") as S:
    S.doc("""
    Runs ``tblastn``
    """)
    rule tblastn_query:
        """
        Runs a TBLASTN search against an assembly.
        """
        message:
            "TBLASTN: searching {params.db_name} for {input.query}"
        output:
            "{:this:}/{target}.blast7"
        input:
            index = expand("{{:prev:}}/{{:target:}}.{ext}", ext=BLASTIDX_SUFFIXES),
            query = "{:prev:}/{:target:}.fastp.gz"
        log:
            "{:this:}/{target}.blast7.log"
        params:
            db_name = lambda wc, input: input.index[0][:-4],
            blastfmt = BLASTFMT
        threads:
            24
        conda:
            "blast"
        shell: """
        gunzip -c {input.query} |
        tblastn \
          -query - \
          -db {params.db_name} \
          -outfmt "{params.blastfmt}" \
          -out {output} \
          -num_threads {threads} \
          &> {log} 2>&1
        """

    rule tblastn_all:
        output:
            touch("{:this:}/all_targets.stamp")
        input:
            "{:this:}/{:targets:}.gtf"

    rule blast7_to_gtf:
        """Convert from Blast Format 7 to GFF/GTF format"""
        message:
            "BLAST7 -> GFF/GTF: {output}"
        input:
            "{:this:}/{target}.blast7"
        output:
            "{:this:}/{target}.gtf"
        run:
            from ymp import blast, gff
            with open(input[0], "r") as inf, open(output[0], "w") as outf:
                writer = gff.writer(outf)
                for hit in blast.reader(inf):
                    feature = gff.Feature(
                        seqid=hit.sacc,
                        source='BLAST',
                        type='CDS',
                        start=min(hit.sstart, hit.send),
                        end=max(hit.sstart, hit.send),
                        score=hit.evalue,
                        strand='+' if hit.sframe > 0 else '-',
                        phase='0',
                        attributes="ID={}_{}_{};Name={}".format(
                            hit.sacc, hit.sstart, hit.send,
                            hit.qacc)
                    )
                    writer.write(feature)

with Stage("annotate_blast") as S:
    S.doc("""
    Annotate sequences with BLAST

    Searches a reference database for hits with ``blastn``. Use `E`
    flag to specify exponent to required E-value. Use ``N`` or
    ``Mega`` to specify default. Use ``Best`` to add
    ``-subject_besthit`` flag.
    """)
    S.add_param("E", typ="int", name="evalue_exp", default=0)
    S.add_param("", typ="choice", name="task",
                value=['N','Mega'], default='N')
    S.add_param("Best", typ="flag", name="besthit",
                value="-subject_besthit")
    S.require(
        contigs = [["fasta.gz"]],
        db = [BLASTIDX_SUFFIXES, BLASTSPLITIDX_SUFFIXES]
    )

    localrules: blast_db_size, blast_db_size_SPLIT
    rule blast_db_size:
        message:
            "BLASTDBCMD: Getting database size"
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                        ext=BLASTIDX_SUFFIXES)
        output:
            temp("{:this:}/{target}.blast_db_size")
        params:
            db_name = lambda wc, input: input.db[0][:-4]
        threads:
            4
        conda:   "blast"
        shell:
            "blastdbcmd  -list $(dirname {input.db[0]}) -list_outfmt '%f %l'"
            " | grep {params.db_name}"
            " | head -n 1"
            " |cut -d ' ' -f 2"
            " >{output}"

    rule blast_db_size_SPLIT: # ymp: extends blast_db_size
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                         ext=BLASTSPLITIDX_SUFFIXES)

    localrules: blastn_split_query_fasta
    rule blastn_split_query_fasta:
        message:
            "BLASTN: preparing query fasta file(s)"
        input:
            contigs = "{:prev:}/{:target:}.fasta.gz",
        output:
            queries = temp(dynamic("{:this:}/{target}.blast_query.{blastn_query_part}.fasta"))
        params:
            prefix = "{:this:}/{target}.blast_query.",
            nseqs = 1000  # number of queries to blast in one job
        run:
            import gzip
            template = params.prefix + "{index}.fasta"
            with gzip.open(input.contigs) as infile:
                lines = []
                seq_count = 0
                file_count = 0
                for line in infile:
                    if line.startswith(b">"):
                        if seq_count == params.nseqs:
                            fname = template.format(index=file_count)
                            with open(fname, "wb") as out:
                                out.write(b"".join(lines))
                            seq_count = 0
                            file_count += 1
                            lines = []
                        seq_count += 1
                    lines.append(line)
                fname = template.format(index=file_count)
                with open(fname, "wb") as out:
                    out.write(b"".join(lines))

    rule blastn_query:
        message:
            "BLASTN: {input.contigs} vs {params.db_name}"
        input:
            contigs = "{:this:}/{target}.blast_query.{blastn_query_part}.fasta",
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                        ext=BLASTIDX_SUFFIXES)
        output:
            temp("{:this:}/{target}.blast7.{blastn_query_part}")
        log:
            "{:this:}/{target}.blast7.{blastn_query_part}.log"
        benchmark:
            "benchmarks/blastn_query/{:this:}/{target}.{blastn_query_part}.txt"
        params:
            db_name = lambda wc, input: input.db[0][:-4],
            blastfmt = BLASTFMT,
        threads:
            24
        conda:   "blast"
        shell:
            'case {params.task} in'
            ' N) TASK="blastn";;'
            ' Mega) TASK="megablast";;'
            'esac;'
            'blastn'
            ' -query {input.contigs}'
            ' -db {params.db_name}'
            ' -out {output}'
            ' -outfmt "{params.blastfmt}"'
            ' -evalue 1e-{params.evalue_exp}'
            ' -num_threads {threads}'
            ' {params.besthit}'
            ' -task $TASK'
            ' >{log} 2>&1'

    rule blastn_query_SPLIT: # ymp: extends blastn_query
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                        ext=BLASTSPLITIDX_SUFFIXES)

    localrules: blastn_merge_result
    rule blastn_merge_result:
        message:
            "BLASTN: merging result {output}"
        input:
            dynamic("{:this:}/{target}.blast7.{blastn_query_part}")
        output:
            "{:this:}/{target}.blast7"
        shell:
            "cat {input} > {output}"

    rule blastn_query_all:
        message:
            "Completed {params.this}"
        params:
            this = lambda wc, output: os.path.dirname(output[0])
        input:
            "{:this:}/{:targets:}.blast7",
            "{:this:}/ALL.blast_db_size"
        output:
            touch("{:this:}/all_targets.stamp")


localrules: blast7_merge
rule blast7_merge:
    """
    Merges blast results from all samples into single file
    """
    message:
        "Merging BLAST results for {wildcards.query} in {wildcards.dir}"
    input:   blast_assembly_results
    output:  blast_results_by_gene
    params:  re="{dir}.blast/(.+).{query}.blast7",
             gene="{gene}"
    threads: 1
    run:
        import fileinput
        sample_re=re.compile(params.re)
        with fileinput.input(input) as f, open(output[0], "w") as out:
            for line in f:
                if f.isfirstline():
                    # extract target name from filename
                    sample = sample_re.match(f.filename()).group(1)
                if line[0] == "#":
                    if "BLAST" in line:
                        header = ""
                    header += line
                elif line.startswith(params.gene):
                    if len(header) > 0:
                        out.write(header)
                        header = ""
                    line_parts = line.split('\t', 1)
                    out.write("".join([line_parts[0], '\t', sample, '_', line_parts[1]]))


localrules: blast7_extract
rule blast7_extract:
    """
    Generates meta-data csv and sequence fasta pair from blast7 file for one gene.
    """
    message:
        "Collecting hits for {wildcards.query}/{wildcards.gene}"
    input:   db     = expand("{{:dir.references:}}.blast/{{ref}}.fasta.{ext}", \
                             ext=BLASTIDX_SUFFIXES),
             blast  = "{dir}.blast/{sample}.{query}.blast7"
    output:  fasta  = "{dir}.blast/{sample}.{query}.{gene}.fasta",
             csv    = "{dir}.blast/{sample}.{query}.{gene}.csv",
    params:  re     = blast_assembly_results_re,
             gene   = "{gene}",
             db     = "{path}.index/{file}",
             sample = "{sample}"
    threads: 1
    run:
        import subprocess, csv
        blastfmt = BLASTFMT.split()
        get_field = lambda sline, col: sline[blastfmt.index(col)-1]
        get_fields = lambda sline, cols: (get_field(sline, col) for col in cols)
        with open(input.blast, "r") as blast, \
             open(output.fasta, "w") as out, \
             open(output.csv, "w") as out_csv_f:
            out_csv = csv.writer(out_csv_f)
            out_csv.writerow([
                "fasta_id", "sample_id", "sequence_id", "gene_id",
                "start", "end", "evalue", "pident", "length"])
            for line in blast:
                if line[0] == "#": continue
                if not line.startswith(params.gene): continue
                line_parts = line.strip().split('\t')
                sacc, sstart, send = get_fields(line_parts,
                                                ['sacc','sstart','send',])
                qacc, evalue, pident, length = get_fields(line_parts,
                                                          ['qacc','evalue', 'pident', 'length'])
                sstart = int(sstart)
                send = int(send)
                
                seq = subprocess.check_output(
                    ['blastdbcmd',
                     '-db', params.db,
                     '-entry', sacc,
                     '-range', "{}-{}".format(min(sstart,send), max(sstart,send)),
                     '-strand', "plus" if sstart < send else "minus"
                     ]).decode('ascii')
                seqbla, seq = seq.split("\n", 1)
                seq = seq.replace("\n","")
                seqname = "{sample}.{sacc}.{start}.{stop}".format(
                    sample=params.sample,
                    sacc=sacc,
                    start=sstart,
                    stop=send,
                )
                out.write(">{seqname}\n{seq}\n".format(seq=seq, seqname=seqname))
                out_csv.writerow([seqname, params.sample, sacc, qacc,
                                  sstart, send, evalue, pident, length])

        
localrules: blast7_extract_merge
rule blast7_extract_merge:
    """
    Merges extracted csv/fasta pairs over all samples.
    """
    message:
        "Merging {wildcards.query}/{wildcards.gene}"
    input:  fa  = blast_assembly_results_fa,
            csv = blast_assembly_results_csv
    output: fa  = blast_results_by_gene_fa,
            csv = blast_results_by_gene_csv
    threads: 1
    shell: """
    cat {input.fa} | gzip -c9 > {output.fa}

    if test "$(echo {input.csv} | wc -w)" -eq 1; then
       cp {input.csv} {output.csv}
    else
        (
            head -n1 {input.csv[0]};
            tail -n +2 -q {input.csv};
        ) > {output.csv}
    fi
    """


rule blast7_all:
    output:
        "{dir}.blast/{query}.csv"
    input:
        lambda wc: expand("{{dir}}.blast/{{query}}.{gene}.csv",gene=fasta_names(wc.query+".faa"))
    shell: """
    echo {input}
    """



##### Reports ##########

rule blast7_reports:
    input:
        "{:dir.reports:}/{dir}.blast.{query}.{gene}.html"
    output:
        touch("{dir}.blast/reports_{query}_{gene}")


rule blast7_eval_hist:
    input:
        "{dir}.blast/{query}.{gene}.blast7"
    wildcard_constraints:
       type="(evalue|bitscore|score|length|pident)"
    output:
        "{dir}.blast/{query}.{gene}.blast7.{type}_hist.pdf"
    run:
        col = BLASTFMT.split().index(wildcards.type)
        from ymp.util import R
        
        R("""
        df <- read.csv("{input}", header=FALSE, sep="\t", comment.char="#")
        minval = min(df[df[,{col}]>0,{col}])
        print(minval)
        print(log10(minval))
        maxval = max(df[,{col}])
        library(ggplot2)

        f = function(x,y) {{x[x == -Inf]=y[1]; x}}
        pdf("{output}")

        g <- ggplot(df, aes(x=V{col})) + \
        geom_histogram(bins=100) + \
        scale_x_log10("{wildcards.gene}",oob=f,minor_breaks=10^seq(0,log10(minval),-1),
                      breaks=10^seq(0,log10(minval),-10)) + \
        scale_y_log10();
        
        dev.off()
        """)


localrules: blast7_eval_plot
rule blast7_eval_plot:
    input:
        inputs=lambda wc: expand("{{dir}}.blast/{{query}}.{gene}.csv",
                                 gene=fasta_names(wc.query+".faa")),
        rmd=srcdir("../R/blast.Rmd"),
        rmdrun=srcdir("../R/RmdRunner.R")
    output:
        "{:dir.reports:}/{dir}.blast.{query}.html"
    params:
        names=lambda wc: fasta_names(wc.query+".faa")
    shell: """
    {input.rmdrun} {input.rmd} {output} input="{input.inputs}" names="{params.names}"
    """