Snakefile "ymp/rules/otu.rules"

"""
Rules for making OTU tables

Requires called and classified genes from previous stage.

Currently needs sequences already extracted.

WIP: Should we allow grouped processing? I.e. multiple OTU tables? Or do we run this
     always on all available data?


Inputs:
 - gene regions
 - clusters of gene regions (dim1)
 - coverages

Directories:

...assembly.{mapper}.cov/{genescan}.{query}.{gene}.csv
...assembly.{genescan}/{query}.{gene}.csv
...assembly.{genescan}/{query}.{gene}.NR.clstr.csv


"""


localrules: make_otu_table
rule make_otu_table:
    message:
        "Preparing OTU table {output.table}"
    input:
        clust="{dir}.{genefind}/{query}.{gene}.NR.clstr.csv",
        cov="{dir}.{mapper}.cov/{genefind}.{query}.{gene}.csv"
    output:
        table="{dir}.{mapper}.cov.otu/{genefind}.{query}.{gene}.table.csv",
        counts="{dir}.{mapper}.cov.otu/{genefind}.{query}.{gene}.counts.csv"
    threads:
        1
    run:
        import csv
        from collections import defaultdict
        map_seq_to_clust = {} # maps cluster member to centroid
        clusters = set() # set of centroids / cluster ids
        samples = set()
        with open(input.clust, "r") as cluster_f:
            # sacc: source accession == id of centroid
            # qacc: query accession == id of cluster member
            cluster_reader = csv.DictReader(cluster_f)
            for row in cluster_reader:
                map_seq_to_clust[row['qacc']] = row['sacc']
                clusters.add(row['sacc'])

        # create dicts of default dicts (== table)
        coverages = { centroid:defaultdict(float) for centroid in clusters }
        coverage_counts = { centroid:defaultdict(float) for centroid in clusters }

        lost = 0
        found = 0
        with open(input.cov, "r") as coverage_f:
            # target == aggregation column value
            # sacc == accession within target
            # start, end == range within target accession
            # avg, max,med,min,q23,std,sum == coverage values
            coverage_reader = csv.DictReader(coverage_f)
            for row in coverage_reader:
                seq_name = "{target}.contigs.{sacc}.{start}.{end}".format(**row)
                if seq_name in map_seq_to_clust:
                    samples.add(row['source'])
                    centroid = map_seq_to_clust[seq_name]
                    coverages[centroid][row['source']]+=float(row['q23'])
                    coverage_counts[centroid][row['source']]+=1
                    found += 1
                else:
                    lost +=1

        with open(output.table, "w") as out_f:
            table_writer = csv.DictWriter(out_f, fieldnames = ["centroid"]+list(samples))
            table_writer.writeheader()
            for centroid in coverages:
                coverages[centroid]['centroid']=centroid
                table_writer.writerow(coverages[centroid])

        with open(output.counts, "w") as out_f:
            table_writer = csv.DictWriter(out_f, fieldnames = ["centroid"]+list(samples))
            table_writer.writeheader()
            for centroid in coverages:
                coverage_counts[centroid]['centroid']=centroid
                table_writer.writerow(coverage_counts[centroid])


localrules: otu_to_qiime_txt
rule otu_to_qiime_txt:
    message: "Converting OTU table to Qiime TXT Format"
    input:  "{dir}.otu/{file}.table.csv"
    output: "{dir}.otu/{file}.table.txt"
    threads: 1
    run:
        import csv
        with open(input[0], "r") as inf, \
             open(output[0], "w") as outf:
            csv_table = csv.reader(inf)
            txt_table = csv.writer(outf, delimiter="\t")
            header = next(csv_table)
            header[0] = "#OTU ID"
            txt_table.writerow(header)
            for row in csv_table:
                txt_table.writerow([i if i else "0.0" for i in row])

localrules: otu_to_biom
rule otu_to_biom:
    message: "Converting OTU table to Biom Format"
    input:   "{dir}.otu/{file}.table.txt"
    output:  "{dir}.otu/{file}.table.biom"
    threads: 1
    conda:   "biom_format.yml"
    shell:   "biom convert"
             " --input-fp {input}"
             " --output-fp {output}"
             " --to-hdf5"
             " --table-type 'OTU table'"


localrules: blast7_coverage_per_otu
rule blast7_coverage_per_otu:
    input:
        cov="{dir}.blast/{sample}.{: targets :}.{query}.cov.csv",
        clust="{dir}.blast/{query}.{gene}.NR.csv"
    output:
        otu="{dir}.blast/{sample}.{target}.{query}.{gene}.cova.csv"
    run:
        import csv
        clust_map = {}
        cov_map = {}
        with open(input.clust, "r") as clust_f:
            clust_reader = csv.DictReader(clust_f)
            for row in clust_reader:
                clust_map[row['sacc']] = row['qacc']
        with open(input.cov, "r") as cov_f:
            cov_reader  = csv.DictReader(cov_f)
            for row in cov_reader:
                ref = ".".join([wildcards.sample,
                                row['sacc'], row['start'], row['end']])
                try:
                    cov_map[clust_map[ref]] = row['q23']
                except:
                    pass
            print(repr(cov_map))