Source code for ymp.stage.base

"""
Base classes for all Stage types
"""

import logging
import os
import re

from typing import Set, Dict, Union, List, Optional

from snakemake.rules import Rule
from snakemake.workflow import Workflow

from ymp.exceptions import YmpStageError, YmpRuleError, YmpException
from ymp.string import ProductFormatter
from ymp.yaml import MultiProxy


log = logging.getLogger(__name__)  # pylint: disable=invalid-name


[docs]class BaseStage: """Base class for stage types""" def __init__(self, name: str) -> None: #: The name of the stage is a string uniquely identifying it #: among all stages. self.name = name #: Alternative name self.altname = None #: The docstring describing this stage. Visible via ``ymp #: stage list`` and in the generated sphinx documentation. self.docstring: Optional[str] = None def __str__(self) -> str: """Cast to string we just emit our name""" return self.name def __repr__(self): """Using `repr()` we emit the subclass as well as our name""" return f"{self.__class__.__name__}({self!s})"
[docs] def doc(self, doc: str) -> None: """Add documentation to Stage Args: doc: Docstring passed to Sphinx """ #: Docstring describing stage self.docstring = doc
[docs] def match(self, name: str) -> bool: """Check if the ``name`` can refer to this stage As component of a `StageStack`, a stage may be identified by alternative names and may also be parametrized by suffix modifiers. Stage types supporting this behavior must override this function. """ return name == self.name
[docs] def get_inputs(self) -> Set[str]: # pylint: disable = no-self-use """Returns the set of inputs required by this stage This function must return a copy, to ensure internal data is not modified. """ return set()
@property def outputs(self) -> Union[Set[str], Dict[str, str]]: """Returns the set of outputs this stage is able to generate. May return either a `set` or a `dict` with the dictionary values representing redirections in the case of virtual stages such as `Pipeline` or `Reference`. """ return set()
[docs] def get_outputs(self, path: str) -> Dict[str, str]: """Returns a dictionary of outputs""" outputs = self.outputs if isinstance(outputs, set): return {output: path for output in outputs} path, _, _ = path.rpartition("." + self.name) # false positive - pylint: disable=no-member return { output: path + p for output, p in outputs.items() }
[docs] def can_provide(self, inputs: Set[str]) -> Dict[str, str]: """Determines which of ``inputs`` this stage can provide. Returns a dictionary with the keys a subset of ``inputs`` and the values identifying redirections. An empty string indicates that no redirection is to take place. Otherwise, the string is the suffix to be appended to the prior `StageStack`. """ return { output: "" for output in inputs.intersection(self.outputs) }
[docs] def get_path(self, stack: "StageStack") -> str: # pylint: disable = no-self-use """On disk location for this stage given ``stack``. Called by `StageStack` to determine the real path for virtual stages (which must override this function). """ return stack.name
[docs] def get_all_targets(self, stack: "StageStack", output_types=None) -> List[str]: """Targets to build to complete this stage given ``stack``. Typically, this is the StageStack's path appended with the stamp name. """ if output_types is None: output_types = [ output for output in self.outputs if "{sample}" in output and not "{:bin:}" in output ] targets = stack.targets path = stack.path output_files = [ path + output_type.format(sample=target) for output_type in output_types for target in targets ] return output_files
[docs] def get_group( self, stack: "StageStack", default_groups: List[str], ) -> List[str]: """Determine output grouping for stage Args: stack: The stack for which output grouping is requested. default_groups: Grouping determined from stage inputs override_groups: Override grouping from GroupBy stage or None. """ if stack.prev_stack is not None: if stack.prev_stack.stage.modify_next_group(stack.prev_stack): raise YmpStageError(f"Cannot override {self} grouping") return default_groups
[docs] def modify_next_group(self, _stack: "StageStack"): # pylint: disable = no-self-use return None
[docs] def get_ids( self, stack: "ymp.stage.StageStack", groups: List[str], match_groups: Optional[List[str]] = None, match_value: Optional[str] = None, ) -> List[str]: # pylint: disable = no-self-use """Determine the target ID names for a set of active groupings Called from ``{:target:}`` and ``{:targets:}``. For ``{:targets:}``, ``groups`` is the set of active groupings for the stage stack. For ``{:target:}``, it's the same set for the source of the file type, the current grouping and the current target. Args: groups: Set of columns the values of which should form IDs match_value: Limit output to rows with this value match_groups: ... in these groups """ # empty groups means single output file, termed ALL if not groups: return ['ALL'] if match_value == 'ALL': match_value = None match_groups = None if not match_groups and match_value: return [match_value] # Fastpath: If groups and match groups are identical the input # and output IDs must be identical. if groups == match_groups: return [match_value] # Pass through to project return stack.project.do_get_ids(stack, groups, match_groups, match_value)
[docs] def has_checkpoint(self) -> bool: # pylint: disable = no-self-use return False
[docs]class Activateable: """ Mixin for Stages that can be filled with rules from Snakefiles. """ #: Currently active stage ("entered") _active: Optional[BaseStage] = None
[docs] @staticmethod def get_active() -> BaseStage: return Activateable._active
[docs] @staticmethod def set_active(stage: Optional[BaseStage]) -> None: Activateable._active = stage
def __init__(self, *args, **kwargs) -> None: #: Rules in this stage self.rules: List[Rule] = [] self._last_rules: List[Rule] = [] super().__init__(*args, **kwargs) def __enter__(self) -> "Activateable": if self.get_active() is not None: raise YmpRuleError( self, f"Failed to enter stage '{self}', " f"already in stage {self.get_active()}'." ) self.set_active(self) self._last_rules = self.rules.copy() return self def __exit__(self, *args) -> None: self.set_active(None)
[docs] def add_rule(self, rule: "Rule", workflow: "Workflow") -> None: rule.ymp_stage = self self.rules.append(rule.name) if self._last_rules: for lastrule in self._last_rules: workflow.ruleorder(rule.name, lastrule)
[docs] def check_active_stage(self, name: str) -> None: if not self.get_active(): raise YmpException( f"Use of {{:{name}:}} requires active Stage" ) if not self.get_active() == self: raise YmpException( f"Internal error: {self} running but {self.get_active()} active." )
[docs] def register_inout(self, name: str, target: Set, item: str) -> None: """Determine stage input/output file type from prev/this filename Detects patterns like "PREFIX{: NAME :}/INFIX{TARGET}.EXT". Also checks if there is an active stage. Args: name: The NAME target: Set to which to add the type item: The filename Returns: Normalized output pattern """ self.check_active_stage(name) match = re.fullmatch(r""" (?P<prefix>.*)\{{:\s*{name}\s*:\}} (?P<infix>/?.*?) (?P<target>\{{:?\s*(?:target|sample|source)(?:|\([^)]*\))\s*:?}})? (?P<suffix>.*) """.format(name=name), item, re.VERBOSE) if not match: raise YmpRuleError(self, f"Malformed '{{:{name}:}}' string: '{item}'") parts = match.groupdict() prefix = parts["prefix"] if parts.get("prefix"): raise YmpRuleError(self, f"Stage prefix '{prefix}' in '{item}' not supported") infix = parts["infix"] if infix and infix != "/": raise YmpRuleError(self, f"Filename prefix '{infix}' in '{item}' not supported") suffix = parts["suffix"] if parts["target"]: normtype = "/{sample}" + suffix else: normtype = "/" + suffix if not "{" in suffix: target.add(normtype) return normtype
[docs]class ConfigStage(BaseStage): """Base for stages created via configuration These Stages derive from the ``yml.yml`` and not from a rules file. """ def __init__(self, name: str, cfg: 'MultiProxy'): #: Semi-colon separated list of file names defining this Stage. self.filename = ';'.join(cfg.get_files()) #: Line number within the first file at which this Stage is defined. self.lineno = next(iter(cfg.get_linenos()), None) super().__init__(name) #: The configuration object defining this Stage. self.cfg = cfg @property def defined_in(self): """List of files defining this stage Used to invalidate caches. """ return self.cfg.get_files()