Source code for ymp.common

"""
Collection of shared utility classes and methods
"""
import logging
import os
from collections import Iterable

import ymp

log = logging.getLogger(__name__)


[docs]class AttrDict(dict): """ AttrDict adds accessing stored keys as attributes to dict """ def __getattr__(self, attr): try: return super().__getattr__(attr) except AttributeError: try: val = self[attr] except KeyError as e: raise AttributeError(e) if isinstance(val, dict): return AttrDict(val) else: return val def __setattr__(self, attr, value): if attr.startswith("_"): super().__setattr__(attr, value) else: raise NotImplementedError()
[docs]class MkdirDict(AttrDict): "Creates directories as they are requested" def __getattr__(self, attr): dirname = super().__getattr__(attr) if not os.path.exists(dirname): log.info("Creating directory %s", dirname) os.makedirs(dirname, exist_ok=True) return dirname
[docs]def parse_number(s=""): """Basic 1k 1m 1g 1t parsing. - assumes base 2 - returns "byte" value - accepts "1kib", "1kb" or "1k" """ multiplier = 1 s = s.strip().upper().rstrip("BI") if s.endswith("K"): multiplier = 1024 if s.endswith("M"): multiplier = 1024*1024 if s.endswith("G"): multiplier = 1024*1024*1024 if s.endswith("T"): multiplier = 1024*1024*1024*1024 s = s.rstrip("KMGT") if "." in s: return float(s)*multiplier else: return int(s)*multiplier
[docs]def flatten(item): """Flatten lists without turning strings into letters""" if isinstance(item, str): yield item elif isinstance(item, Iterable): for item2 in item: yield from flatten(item2) else: yield item
[docs]def is_container(obj): """Check if object is container, considering strings not containers""" return not isinstance(obj, str) and isinstance(obj, Iterable)
[docs]def ensure_list(obj): """Wrap ``obj`` in a `list` as needed""" if obj is None: return [] if isinstance(obj, str) or not isinstance(obj, Iterable): return [obj] return list(obj)
[docs]class Cache(object): def __init__(self, root): import sqlite3 os.makedirs(os.path.join(root), exist_ok=True) db_fname = os.path.join(root, "ymp.db") log.debug("Opening database %s", db_fname) self.conn = sqlite3.connect(db_fname, check_same_thread=False) # Drop tables if the database has the wrong version number # or if the user_version has not been set (defaults to 0) version = self.conn.execute("PRAGMA user_version").fetchone()[0] if version == ymp.__numeric_version__ and version != 0: try: curs = self.conn.execute("SELECT file, time from stamps") update = any(os.path.getmtime(row[0]) > row[1] for row in curs) except FileNotFoundError: update = True del curs if update: log.error("Dropping cache: files changed") self.conn.executescript(""" DROP TABLE caches; DROP TABLE stamps; """) else: log.error("Dropping cache: version changed") update = True if update: self.conn.executescript(""" BEGIN EXCLUSIVE; DROP TABLE IF EXISTS caches; CREATE TABLE caches ( name TEXT, key TEXT, data, PRIMARY KEY (name, key) ); DROP TABLE IF EXISTS stamps; CREATE TABLE stamps ( file TEXT PRIMARY KEY, time INT ); PRAGMA user_version={}; COMMIT; """.format(ymp.__numeric_version__)) self.caches = {} self.files = {}
[docs] def close(self): self.conn.close()
[docs] def get_cache(self, name, clean=False, *args, **kwargs): if name not in self.caches: self.caches[name] = CacheDict(self, name, *args, **kwargs) return self.caches[name]
[docs] def store(self, cache, key, obj): import pickle files = ensure_list(getattr(obj, "defined_in", None)) try: stamps = [(fn, os.path.getmtime(fn)) for fn in files if fn not in self.files] self.conn.executemany( "REPLACE INTO stamps VALUES (?,?)", stamps) self.files.update(dict(stamps)) self.conn.execute(""" REPLACE INTO caches VALUES (?, ?, ?) """, [cache, key, pickle.dumps(obj)] ) except FileNotFoundError: pass
[docs] def commit(self): import sqlite3 try: self.conn.commit() except sqlite3.OperationalError as e: log.warning("Cache write failed: %s", e.what())
[docs] def load(self, cache, key): import pickle row = self.conn.execute(""" SELECT data FROM caches WHERE name=? AND key=? """, [cache, key]).fetchone() if row: return pickle.loads(row[0]) else: return None
[docs] def load_all(self, cache): import pickle rows = self.conn.execute(""" SELECT key, data FROM caches WHERE name=? """, [cache]) return ((row[0], pickle.loads(row[1])) for row in rows)
[docs]class CacheDict(AttrDict): def __init__(self, cache, name, *args, loadfunc=None, itemloadfunc=None, itemdata=None, **kwargs): self._cache = cache self._name = name self._loadfunc = loadfunc self._itemloadfunc = itemloadfunc self._itemdata = itemdata self._args = args self._kwargs = kwargs self._loading = False self._complete = False def _loaditem(self, key): cached = self._cache.load(self._name, key) if cached: super().__setitem__(key, cached) elif self._itemdata is not None: if key in self._itemdata: item = self._itemloadfunc(key, self._itemdata[key]) self._cache.store(self._name, key, item) self._cache.commit() super().__setitem__(key, item) elif self._itemloadfunc: item = self._itemloadfunc(key) self._cache.store(self._name, key, item) self._cache.commit() super().__setitem__(key, item) else: self._loadall() def _loadall(self): if self._complete: return loaded = set() for key, obj in self._cache.load_all(self._name): loaded.add(key) super().__setitem__(key, obj) if self._itemloadfunc: for key in self._itemdata: if key not in loaded: self._loaditem(key) elif self._loadfunc and not self._loading and not loaded: self._loadfunc(*self._args, **self._kwargs) self._loadfunc = None for key, item in super().items(): self._cache.store(self._name, key, item) self._cache.commit() self._complete = True def __enter__(self): self._loading = True return self def __exit__(self, a, b, c): self._loading = False def __contains__(self, key): if self._itemdata: return key in self._itemdata self._loadall() return super().__contains__(key) def __len__(self): if self._itemdata: return len(self._itemdata) self._loadall() return super().__len__() def __getitem__(self, key): if not super().__contains__(key): self._loaditem(key) return super().__getitem__(key) def __setitem__(self, key, val): super().__setitem__(key, val) def __delitem__(self, key): raise NotImplementedError() def __iter__(self): if self._itemdata: return self._itemdata.__iter__() self._loadall() return super().__iter__() def __str__(self): self._loadall() return super().__str__()
[docs] def get(self, key, default=None): if not super().__contains__(key): self._loaditem(key) return super().get(key, default)
[docs] def items(self): self._loadall() return super().items()
[docs] def keys(self): if self._itemdata: return self._itemdata.keys() return super().keys()
[docs] def values(self): self._loadall() return super().values()