"""
Implements the "Stage"
At it's most basic, a "Stage" is a set of Snakemake rules that share an output folder.
"""
import copy
import logging
import re
from typing import Dict, List, Set
from ymp.snakemake import WorkflowObject, RemoveValue
from ymp.stage.base import BaseStage, Activateable
from ymp.stage.params import Parametrizable
from ymp.exceptions import YmpRuleError, YmpException, YmpStageError
log = logging.getLogger(__name__) # pylint: disable=invalid-name
[docs]class Stage(WorkflowObject, Parametrizable, Activateable, BaseStage):
"""
Creates a new stage
While entered using ``with``, several stage specific variables
are expanded within rules:
* ``{:this:}`` -- The current stage directory
* ``{:that:}`` -- The alternate output stage directory
* ``{:prev:}`` -- The previous stage's directory
"""
def __init__(
self,
name: str,
altname: str = None,
env: str = None,
doc: str = None,
) -> None:
"""
Args:
name: Name of this stage
altname: Alternate name of this stage (used for stages with
multiple output variants, e.g. filter_x and remove_x.
doc: See `doc()`
env: See `env()`
"""
super().__init__(name)
#: Alternative stage name (deprecated)
self.altname: str = altname
#: Checkpoints in this stage
self.checkpoints: Dict[str, Set[str]] = {}
#: Contains override stage inputs
self.requires = None
# Inputs required by stage
self._inputs: Set[str] = set()
self._outputs: Set[str] = set()
# Regex matching self
self._regex = None
self.register()
self.doc(doc or "")
self.env(env)
[docs] def env(self, name: str) -> None:
"""Add package specifications to Stage environment
Note:
This sets the environment for all rules within the stage,
which leads to errors with Snakemake rule types
not supporting conda environments
Args:
name: Environment name or filename
>>> Env("blast", packages="blast =2.7*")
>>> with Stage("test") as S:
>>> S.env("blast")
>>> rule testing:
>>> ...
>>> with Stage("test", env="blast") as S:
>>> rule testing:
>>> ...
>>> with Stage("test") as S:
>>> rule testing:
>>> conda: "blast"
>>> ...
"""
self.conda_env = name
def __str__(self):
if self.altname:
return "|".join((self.name, self.altname))
return self.name
def __repr__(self):
return (f"{self.__class__.__name__} {self!s} "
f"({self.filename}:{self.lineno})")
[docs] def require(self, **kwargs):
"""Override inferred stage inputs
In theory, this should not be needed. But it's simpler for now.
"""
self.requires = kwargs
@property
def outputs(self) -> Set[str]:
return self._outputs
[docs] def wc2path(self, wc):
wildcards = self._wildcards(self.name)
for p in self.params:
wildcards = wildcards.replace(p.constraint, "")
return wildcards.format(**wc)
[docs] def match(self, name):
if not self._regex:
if self.altname:
sname = "(" + "|".join((self.name, self.altname)) + ")"
else:
sname = self.name
pat = sname + "".join(p.regex for p in self.params)
self._regex = re.compile(pat)
return self._regex.fullmatch(name) is not None
def _wildcards(self, name, kwargs=None):
show_constraint = kwargs and kwargs.get('field') not in ('input', 'message')
return "".join(["{_YMP_DIR}", name] +
[p.pattern(show_constraint) for p in self.params])
[docs] def prev(self, _args, kwargs) -> None:
"""Gathers {:prev:} calls from rules
Here, input requirements for each stage are collected.
"""
self.register_inout("prev", self._inputs, kwargs['item'])
[docs] def this(self, args=None, kwargs=None):
"""Replaces {:this:} in rules
Also gathers output capabilities of each stage.
"""
item = kwargs['item']
if kwargs.get('field') == 'output' and not "{:bin:}" in item:
self.register_inout("this", self._outputs, item)
return self._wildcards(self.name, kwargs=kwargs)
[docs] def that(self, _args=None, kwargs=None):
"""
Alternate directory of current stage
Used for splitting stages
"""
self.check_active_stage("that")
if not self.altname:
raise YmpException(
"Use of {:that:} requires with altname"
)
return self._wildcards(self.altname, kwargs=kwargs)
[docs] def bin(self, _args=None, kwargs=None):
"""
Dynamic ID for splitting stages
"""
rule = kwargs['rule']
if not rule.is_checkpoint:
raise YmpStageError("Only checkpoints may use '{:bin:}'")
item = kwargs['item']
norm_item = item.replace(".{:bin:}", "")
norm_suffix = self.register_inout("this", self._outputs, norm_item)
self.checkpoints.setdefault(rule.name, set()).add(norm_suffix)
raise RemoveValue()
[docs] def has_checkpoint(self) -> bool:
return bool(self.checkpoints)
[docs] def get_all_targets(self, stack):
if "/all_targets.stamp" in self.outputs:
return [stack.name + "/all_targets.stamp"]
outputs = None
if self.has_checkpoint():
checkpoint_outputs = set().union(*self.checkpoints.values())
outputs = {output for output in self.outputs if output not in checkpoint_outputs}
return super().get_all_targets(stack, outputs)
[docs] def get_group(
self,
stack, #: "StageStack"
default_groups: List[str]
) -> List[str]:
# Are we instructed by previous stack to change grouping?
override_groups = None
if stack.prev_stack is not None:
prev_stage = stack.prev_stack.stage
override_groups = prev_stage.modify_next_group(stack.prev_stack)
if override_groups is None:
# If not, just use the default groups
groups = default_groups
else:
# Otherwise, use the override groups,
groups = override_groups
# Replace "__bin__" with bins in effect
if "__bin__" in override_groups:
groups = [g for g in groups if g != "__bin__"]
# FIXME:
# Should we just use the latest bin? What if we have multiple?
groups += [
g for g in default_groups if isinstance(g, type(stack))
]
# If we are a checkpoint ourselves, add self.
if self.has_checkpoint():
groups.append(stack)
return groups
[docs] def get_checkpoint_ids(self, stack, mygroup, target):
if len(self.checkpoints) > 1:
raise RuntimeError("Multiple checkpoints not implemented")
from snakemake.workflow import checkpoints
from snakemake.io import regex
wildcards = re.match(regex(self._wildcards(self.name, {'field': 'output'})),
stack.path).groupdict()
checkpoint_name = next(iter(self.checkpoints.keys()))
checkpoint = getattr(checkpoints, checkpoint_name)
mytargets = self.get_ids(stack,
[g for g in stack.group if g != stack],
mygroup, target)
bins = set()
for mytarget in mytargets:
wildcards['target'] = mytarget
job = checkpoint.get(**wildcards)
with open(job.output.bins, "r") as fd:
bins.update(line.strip() for line in fd.readlines())
return list(bins)
[docs] def get_ids(self, stack, groups, mygroups=None, target=None):
# Make a copy of the input gorups, we don't want to modify it.
groups = groups.copy()
if mygroups is not None:
mygroups = list(mygroups)
bins = []
mybins = {}
for group in list(groups):
if not isinstance(group, type(stack)):
continue
groups.remove(group)
bins.append(group)
if mygroups is None and target is not None:
raise RuntimeError("Mygroups none but target not?")
if target is not None:
# If we are getting IDs for {:target:} of subsequent stage,
# find all generated ids from binning stage.
# Multiply binned ids
target_parts = []
for group, tgt in zip(list(mygroups), target.split("__")):
if isinstance(group, type(stack)):
mygroups.remove(group)
mybins[group] = tgt
else:
target_parts.append(tgt)
if target_parts:
mygroups = mygroups[0:len(target_parts)]
target = "__".join(target_parts)
else:
target = None
mygroups = None
# Pass to standard
ids = super().get_ids(stack, groups, mygroups, target)
for bin in bins:
ids = [
"__".join((target, binid))
for target in ids
for binid in bin.stage.get_checkpoint_ids(bin, groups, target)
if bin not in mybins or binid == mybins[bin]
]
if groups == []:
ids = [id_[len('ALL__'):] if id_.startswith('ALL__') else id_
for id_ in ids]
return ids