"""
Implements the StageStack
"""
import logging
import copy
import re
from typing import List
import ymp
from ymp.stage.stage import Stage
from ymp.stage.groupby import GroupBy
from ymp.exceptions import YmpStageError
from ymp.snakemake import ExpandLateException
from snakemake.exceptions import IncompleteCheckpointException # type: ignore
log = logging.getLogger(__name__) # pylint: disable=invalid-name
[docs]def norm_wildcards(pattern):
pattern = re.sub(r"\{:\s*target(\(.*\))?\s*:\}", "{sample}", pattern)
for pat in ("{target}", "{source}", "{:target:}"):
pattern = pattern.replace(pat, "{sample}")
for pat in ("{:targets:}", "{:sources:}"):
pattern = pattern.replace(pat, "{:samples:}")
return pattern
[docs]def find_stage(name):
cfg = ymp.get_config()
registry = Stage.get_registry()
if name.startswith("group_"):
return GroupBy(name)
if name.startswith("ref_"):
refname = name[4:]
if refname in cfg.ref:
return cfg.ref[refname]
raise YmpStageError(f"Unknown reference '{refname}'")
if name in cfg.projects:
return cfg.projects[name]
for stage in registry.values():
if stage.match(name):
return stage
for pipeline in cfg.pipelines.values():
if pipeline.match(name):
return pipeline
raise YmpStageError(f"Unknown stage '{name}'")
[docs]class StageStack:
"""The "head" of a processing chain - a stack of stages
"""
used_stacks = set()
#: Set to true to enable additional Stack debug logging
debug = False
[docs] @classmethod
def instance(cls, path):
"""
Cached access to StageStack
Args:
path: Stage path
stage: Stage object at head of stack
"""
cfg = ymp.get_config()
cache = cfg.cache.get_cache(cls.__name__, itemloadfunc=StageStack)
res = cache[path]
if res not in cls.used_stacks:
cls.used_stacks.add(res)
res.show_info()
return res
def __str__(self):
return self.path
def __repr__(self):
return f"{self.__class__.__name__}({self.name}, {self.stage})"
def __init__(self, path):
#: Name of stack, aka is its full path
self.name = path
#: Names of stages on stack
self.stage_names = path.split(".")
#: Stages on stack
self.stages = [find_stage(name)
for name in self.stage_names]
#: Top Stage
self.stage = self.stages[-1]
#: Top Stage Name
self.stage_name = self.stage_names[-1]
#: Stage below top stage or None if first in stack
self.prev_stage = self.stages[-2] if len(self.stages) > 1 else None
self.prev_stack = None
if len(self.stages) > 1:
self.prev_stack = self.instance(".".join(self.stage_names[:-1]))
cfg = ymp.get_config()
#: Project on which stack operates
#: This is needed for grouping variables currently.
self.project = cfg.projects.get(self.stage_names[0])
if not self.project:
raise YmpStageError(f"No project for stage stack {path} found")
#: Mapping of each input type required by the stage of this stack
#: to the prefix stack providing it.
self.prevs = self.resolve_prevs()
# Gather all previous groups
groups = list(dict.fromkeys(
group
for stack in reversed(list(self.prevs.values()))
for group in stack.group
))
project_groups, other_groups = self.project.minimize_variables(groups)
#: Grouping in effect for this StageStack. And empty list groups into
#: one pseudo target, 'ALL'.
self.group: List[str] = \
self.stage.get_group(self, project_groups + other_groups)
[docs] def show_info(self):
def ellip(text: str) -> str:
if len(text) < 40:
return text
return "..."+text[-37:]
prevmap = dict()
for typ, stack in self.prevs.items():
prevmap.setdefault(str(stack), []).append(typ)
log.info(
"Stage stack '%s' (output by %s%s)",
self.name,
", ".join(ellip(str(g)) for g in self.group) or "*ALL*",
" + bins" if self.stage.has_checkpoint() else ""
)
for stack, typ in prevmap.items():
ftypes = ", ".join(typ).replace("/{sample}", "*")
title = stack.split(".")[-1]
if self.stage_names.count(title) != 1:
title = stack
log.info(" input from %s: %s", title, ftypes)
[docs] def resolve_prevs(self):
inputs = self.stage.get_inputs()
stage = self.stage
prevs = self._do_resolve_prevs(stage, inputs, exclude_self=True)
if inputs:
raise YmpStageError(self._format_missing_input_error(inputs))
return prevs
def _format_missing_input_error(self, inputs):
registry = Stage.get_registry()
# Can't find the right types, try to present useful error message:
words = []
for item in inputs:
words.extend((item, "--"))
words.extend([name for name, stage in registry.items()
if stage.can_provide(set(item))])
words.extend('\n')
text = ' '.join(words)
return f"""
File type(s) '{" ".join(inputs)}' required by '{self.stage}'
not found in '{self.name}'. Stages providing missing
file types:
{text}
"""
def _do_resolve_prevs(self, stage, inputs, exclude_self):
stage_names = copy.copy(self.stage_names)
if exclude_self:
stage_names.pop()
prevs = {}
while stage_names and inputs:
path = ".".join(stage_names)
prev_stack = self.instance(path)
prev_stage = find_stage(stage_names.pop())
provides = stage.satisfy_inputs(prev_stage, inputs)
for typ, ppath in provides.items():
if ppath:
npath = prev_stage.get_path(prev_stack, typ)
prevs[typ] = self.instance(npath)
else:
prevs[typ] = prev_stack
return prevs
[docs] def complete(self, incomplete):
registry = Stage.get_registry()
cfg = ymp.get_config()
result = []
groups = ("group_" + name for name in self.project.variables + ['ALL'])
result += (opt for opt in groups if opt.startswith(incomplete))
refs = ("ref_" + name for name in cfg.ref)
result += (opt for opt in refs if opt.startswith(incomplete))
for stage in registry.values():
for name in (stage.name, stage.altname):
if name and name.startswith(incomplete):
try:
self.instance(".".join((self.path, name)))
result.append(name)
except YmpStageError:
pass
return result
@property
def path(self):
"""On disk location of files provided by this stack"""
path = self.stage.get_path(self)
while True:
try:
stack = self.instance(path)
except YmpStageError:
return path
newpath = stack.stage.get_path(stack)
if path == newpath:
return newpath
path = newpath
[docs] def all_targets(self):
return self.stage.get_all_targets(self)
@property
def defined_in(self):
return None
[docs] def prev(self, _args=None, kwargs=None) -> "StageStack":
"""
Directory of previous stage
"""
if not kwargs or "wc" not in kwargs:
raise ExpandLateException()
_, _, suffix = kwargs['item'].partition("{:prev:}")
suffix = norm_wildcards(suffix)
return self.prevs[suffix]
[docs] def get_ids(self, select_cols, where_cols=None, where_vals=None):
if not self.debug:
return self.stage.get_ids(self, select_cols, where_cols, where_vals)
log.warning(" select %s", select_cols)
log.warning(" where %s == %s", repr(where_cols), where_vals)
try:
ids = self.stage.get_ids(self, select_cols, where_cols, where_vals)
except IncompleteCheckpointException as exc:
log.warning(" ===> checkpoint deferred (%s)", exc.targetfile)
raise
log.warning(" ===> %s", repr(ids))
return ids
@property
def targets(self):
"""
Determines the IDs to be built by this Stage Stack
(replaces "{:targets:}").
"""
if self.debug:
log.error("output ids for %s", self)
log.warning(" select %s", repr(self.group))
if self in self.group:
group = self.group.copy()
group.remove(self)
else:
group = self.group
return self.get_ids(group)
[docs] def target(self, args, kwargs):
"""
Determines the IDs for a given input data type and output ID
(replaces "{:target:}").
"""
# Find stage stack from which input should be requested.
# (not sure why the below causes a false positive in pylint)
prev_stack = self.prev(args, kwargs) # pylint: disable=not-callable
# Find name of current output target
cur_target = kwargs['wc'].target
if self.debug:
rulename = getattr(kwargs.get('rule'), 'name', 'N/A')
log.error("input ids for %s", self)
log.warning(" rule %s", rulename)
log.warning(" from stack %s", prev_stack)
cols = self.group
vals = cur_target
if cols == [] and vals == 'ALL':
cols = vals = None
ids = prev_stack.get_ids(prev_stack.group, cols, vals)
if ids == []:
rulename = getattr(kwargs.get('rule'), 'name', 'N/A')
raise YmpStageError(
f"Internal Error: Failed to find inputs\n\n"
f"Context:\n"
f" In stack '{self}' rule '{rulename}'\n"
f" Building '{vals}' (grouped on '{','.join(cols)}')\n"
f" Seeking input from '{prev_stack}' "
f"(grouped on '{','.join(prev_stack.group)}')"
f"\n"
)
return ids