Snakefile "ymp/rules/blast.rules"

Env(name="blast", base="bioconda", packages="blast")

# "std" options output:
BLASTFMT =  "7 qacc sacc pident length mismatch gapopen qstart qend sstart send evalue bitscore"
# extra output
BLASTFMT += " sstrand sframe qframe score qlen stitle staxids btop"

BLASTIDX_SUFFIXES =  "    nhr nin nog nsq   ntf nto nos not ndb".split()
BLASTIDX_SUFFIXES_V4 = "  nhr nin nog nsq   nsi".split()
BLASTSPLITIDX_SUFFIXES = "                  ntf nto not nos ndb   nal".split()  # + the 00.nxx files

with Stage("index_blast"):
    S.doc("""
    Creates `BLAST <bioconda:blast>` index running ``makeblastdb`` on input fasta.gz files.

    >>> ymp make toy.ref_genome.index_blast
    """)
    rule blast_makedb:
        "Build Blast index"
        message:
            "{:name:}: indexing {params.db}"
        input:
            ref   = "{:prev:}/{:target:}.fasta.gz"
        output:
            db    = expand("{{:this:}}/{{target}}.{ext}", ext=BLASTIDX_SUFFIXES)
        params:
            db    = "{:this:}/{target}",
            typ   = "nucl",
            title = "{target}"
        log:
            "{:this:}/{target}.blast.log"
        threads:
            1
        conda:
            "blast"
        shell: """
        gunzip -c {input.ref} | \
        makeblastdb \
        -in - \
        -dbtype {params.typ} \
        -parse_seqids \
        -out {params.db} \
        -title {params.title} \
        -blastdb_version 5 \
        -hash_index \
        &> {log} 2>&1
        """
        # FIXME: check for "-" in fasta header - blast does not like those


with Stage("annotate_tblastn") as S:
    S.doc("""
    Runs ``tblastn``
    """)
    rule tblastn_query:
        """
        Runs a TBLASTN search against an assembly.
        """
        message:
            "{:name:}: searching {params.db_name} for {input.query}"
        output:
            "{:this:}/{target}.blast7"
        input:
            index = expand("{{:prev:}}/{{:target:}}.{ext}", ext=BLASTIDX_SUFFIXES),
            query = "{:prev:}/{:target:}.fastp.gz"
        log:
            "{:this:}/{target}.blast7.log"
        params:
            db_name = lambda wc, input: input.index[0][:-4],
            blastfmt = BLASTFMT
        threads:
            24
        conda:
            "blast"
        shell: """
        gunzip -c {input.query} |
        tblastn \
          -query - \
          -db {params.db_name} \
          -outfmt "{params.blastfmt}" \
          -out {output} \
          -num_threads {threads} \
          &> {log} 2>&1
        """

    rule blast7_to_gtf:
        """Convert from Blast Format 7 to GFF/GTF format"""
        message:
            "BLAST7 -> GFF/GTF: {output}"
        input:
            "{:this:}/{target}.blast7"
        output:
            "{:this:}/{target}.gtf"
        run:
            from ymp import blast, gff
            with open(input[0], "r") as inf, open(output[0], "w") as outf:
                writer = gff.writer(outf)
                for hit in blast.reader(inf):
                    feature = gff.Feature(
                        seqid=hit.sacc,
                        source='BLAST',
                        type='CDS',
                        start=min(hit.sstart, hit.send),
                        end=max(hit.sstart, hit.send),
                        score=hit.evalue,
                        strand='+' if hit.sframe > 0 else '-',
                        phase='0',
                        attributes="ID={}_{}_{};Name={}".format(
                            hit.sacc, hit.sstart, hit.send,
                            hit.qacc)
                    )
                    writer.write(feature)

with Stage("annotate_blast") as S:
    S.doc("""
    Annotate sequences with BLAST

    Searches a reference database for hits with ``blastn``. Use ``E``
    flag to specify exponent to required E-value. Use ``N`` or
    ``Mega`` to specify default. Use ``Best`` to add
    ``-subject_besthit`` flag.

    This stage produces ``blast7.gz`` files as output.

    >>> ymp make toy.ref_genome.index_blast.annotate_blast
    """)
    S.add_param("E", typ="int", name="evalue_exp", default=0)
    S.add_param("", typ="choice", name="task",
                value=['N','Mega'], default='N')
    S.add_param("Best", typ="flag", name="besthit",
                value="-subject_besthit")
    S.require(
        contigs = [["fasta.gz"]],
        db = [BLASTIDX_SUFFIXES, BLASTIDX_SUFFIXES_V4, BLASTSPLITIDX_SUFFIXES],
    )

    localrules: blast_db_size, blast_db_size_SPLIT
    rule blast_db_size:
        """Determines size of BLAST database (for splitting)"""
        message:
            "{:name:}: Getting database size for {input.db[0]}"
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                        ext=BLASTIDX_SUFFIXES)
        output:
            temp("{:this:}/{target}.blast_db_size")
        params:
            db_name = lambda wc, input: input.db[0][:-4]
        threads:
            4
        conda:   "blast"
        shell:
            "blastdbcmd  -list $(dirname {input.db[0]}) -list_outfmt '%f %l'"
            " | grep {params.db_name}"
            " | head -n 1"
            " |cut -d ' ' -f 2"
            " >{output}"

    rule blast_db_size_SPLIT: # ymp: extends blast_db_size
        """Variant of `blast_db_size` for multi-file blast indices"""
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                         ext=BLASTSPLITIDX_SUFFIXES)

    rule blast_db_size_V4: # ymp: extends blast_db_size
        """Variant of `blast_db_size` for V4 blast indices"""
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                         ext=BLASTIDX_SUFFIXES_V4)

    localrules: blastn_split_query_fasta_hack
    rule blastn_split_query_fasta_hack:
        """Workaround for a problem with snakemake checkpoints and run: statements"""
        message:
            "Working around Snakemake bug"
        input:
            contigs = "{:prev:}/{:target:}.fasta.gz",
        output:
            contig_list = "{:this:}/{target}.fasta_files"
        shell: """
        echo {input.contigs} > {output.contig_list}
        """

    localrules: blastn_split_query_fasta
    checkpoint blastn_split_query_fasta:
        """Split FASTA query file into chunks for individual BLAST runs"""
        message:
            "{:name:}: preparing query fasta file(s)"
        input:
            contigs = "{:prev:}/{:target:}.fasta.gz",
            dbsize = "{:this:}/{target}.blast_db_size",
            contig_list = "{:this:}/{target}.fasta_files"
        output:
            queries = temp(directory(
                "{:this:}/{target}.split_queries"
            ))
        params:
            nseq_max = 100000,
            nseq_min = 10
        run:
            with open(input.dbsize, "r") as fd:
                dbsize = int(fd.read())
            nseqs = 1*10**14/dbsize
            nseqs = int(min(params.nseq_max, max(params.nseq_min, nseqs)))
            with open(input.contig_list, "r") as fd:
                contigs = fd.read().strip()

            os.makedirs(output.queries, exist_ok=True)
            import gzip
            template = os.path.join(output.queries,"{index}.fasta")
            with gzip.open(contigs) as infile:
                lines = []
                seq_count = 0
                file_count = 0
                for line in infile:
                    if line.startswith(b">"):
                        if seq_count == nseqs:
                            fname = template.format(index=file_count)
                            with open(fname, "wb") as out:
                                out.write(b"".join(lines))
                            seq_count = 0
                            file_count += 1
                            lines = []
                        seq_count += 1
                    lines.append(line)
                fname = template.format(index=file_count)
                with open(fname, "wb") as out:
                    out.write(b"".join(lines))

    def blastn_join_input(wildcards):
        cpt = checkpoints.blastn_split_query_fasta.get(**wildcards)
        cpt_outdir = cpt.output.queries
        indices = glob_wildcards(os.path.join(cpt_outdir, '{index}.fasta'))
        return expand(os.path.join(cpt_outdir, '{index}.blast7.gz'),
                          index=indices.index)

    localrules: blastn_join_result
    rule blastn_join_result:
        """Merges BLAST results"""
        message:
            "{:name:}: merging result {output}"
        input:
            results = blastn_join_input,
            folder = "{:this:}/{target}.split_queries"
        output:
            "{:this:}/{target}.blast7.gz"
        shell:
            "cat {input.results} > {output}"

    rule blastn_query:
        """Runs BLAST"""
        message:
            "{:name:}: {input.contigs} vs {params.db_name}"
        input:
            folder = "{:this:}/{target}.split_queries",
            contigs = "{:this:}/{target}.split_queries/{index}.fasta",
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                        ext=BLASTIDX_SUFFIXES)
        output:
            "{:this:}/{target}.split_queries/{index}.blast7.gz"
        log:
            "{:this:}/{target}.split_queries.{index}.log"
        benchmark:
            "benchmarks/{:name:}/{:this:}/{target}.{index}.txt"
        params:
            db_name = lambda wc, input: input.db[0][:-4],
            blastfmt = BLASTFMT,
            tmpdir = "{:ensuredir.tmp:}",
        resources:
            mem = "128g",
        threads:
            24
        conda:   "blast"
        shell:
            'case {params.task} in'
            ' N) TASK="blastn";;'
            ' Mega) TASK="megablast";;'
            'esac;'
            ''
            'tmpout=$(mktemp {params.tmpdir}/blastn_query.blast7.gz.XXXXXXXXXX);'
            'trap "{{ rm -f $tmpout; }}" EXIT;'
            ''
            'blastn'
            ' -query {input.contigs}'
            ' -db {params.db_name}'
            ' -outfmt "{params.blastfmt}"'
            ' -evalue 1e-{params.evalue_exp}'
            ' -num_threads {threads}'
            ' {params.besthit}'
            ' -task $TASK'
            ''
            ' | gzip -c >$tmpout'
            ' 2>{log}'
            ';'
            'mv $tmpout {output}'

    rule blastn_query_SPLIT: # ymp: extends blastn_query
        """Variant of `blastn_query` for multi-file blast indices"""
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                        ext=BLASTSPLITIDX_SUFFIXES)

    rule blastn_query_V4: # ymp: extends blastn_query
        """Variant of `blastn_query` for V4 blast indices"""
        input:
            db = expand("{{:prev:}}/{{:target:}}.{ext}",
                        ext=BLASTIDX_SUFFIXES_V4)