import asyncio
import atexit
import hashlib
import logging
import os
import re
import threading
from typing import List, Optional, Union
from urllib.parse import urlsplit
import aiohttp
from tqdm import tqdm
from ymp.common import ensure_list
LOG = logging.getLogger(__name__)
[docs]class FileDownloader(object):
"""Manages download of a set of URLs
Downloads happen concurrently using asyncronous network IO.
Args:
block_size: Byte size of chunks to download
timeout: Aiohttp cumulative timeout
parallel: Number of files to download in parallel
loglevel: Log level for messages send to logging
(Errors are send with loglevel+10)
alturls: List of regexps modifying URLs
retry: Number of times to retry download
"""
def __init__(self, block_size: int=4096, timeout: int=300, parallel: int=4,
loglevel: int=logging.WARNING, alturls=None, retry: int=3):
self._block_size = block_size
self._timeout = timeout
self._parallel = parallel
self._retry = retry
self._alturls = []
alturls = ["///"] + (alturls or [])
for pat in alturls:
sep = pat[0]
if pat.strip(sep):
patsub = re.split(r"(?<=[^\\])"+sep, pat.strip(sep))
if len(patsub) != 2:
raise ValueError("Malformed regular expression '{}'"
"".format(pat))
patsub[1] = patsub[1].replace(r"\/", "/")
else:
patsub = ["", ""]
self._alturls.append(patsub)
try:
self.loop = asyncio.get_event_loop()
except RuntimeError:
# no loop in context (i.e. running in thread)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self._sem = asyncio.Semaphore(parallel)
self._progress = LOG.getEffectiveLevel() <= loglevel
self._loglevel = loglevel
self._sum_bar = None
[docs] def log(self, msg: str, *args, modlvl: int=0, **kwargs) -> None:
"""Send message to logger
Honors loglevel set for the FileDownloader object.
Args:
msg: The log message
modlvl: Added to default logging level for object
"""
LOG.log(self._loglevel + modlvl, msg, *args, **kwargs)
[docs] def error(self, msg: str, *args, **kwargs) -> None:
"""Send error to logger
Message is sent with a log level 10 higher than the default
for this object.
"""
self.log(msg, *args, modlvl=10, **kwargs)
async def _download(self, session: aiohttp.ClientSession,
url: str, dest: str, md5: Optional[str]=None) -> bool:
"""Asynchronously download a single file
- If ``dest`` points to an existing directory, the file name
is derived from the trailing path portion of the URL.
- Will skip download for existing files with matching MD5
Args:
session: aiohttp session object
url: source url
dest: destination path
md5: optional md5 checksum to verify
"""
if os.path.isdir(dest):
parts = urlsplit(url)
basename = os.path.basename(parts.path)
destfile = os.path.join(dest, basename)
else:
basename = os.path.basename(dest)
destfile = dest
if os.path.exists(destfile) and md5 and not isinstance(md5, bool):
if self._check_md5(basename, destfile, md5):
return True
tryurls = [re.sub(pat, rep, url) for pat, rep in self._alturls]
for url in tryurls: # try alturls
exc = None
for num_try in range(self._retry): # retry after timeout
if exc:
self.log("Downloading %s failed with %s. Retrying %i/%i",
basename, exc, num_try, self._retry-1)
try:
if await self._download_one(session, basename, url,
destfile, md5):
return True
break
except TimeoutError as e:
exc = e
return False
def _check_md5(self, name, fname, md5):
md5_new = hashlib.md5()
with open(fname, 'rb') as f:
while True:
block = f.read(8192)
if not block:
break
md5_new.update(block)
if md5_new.hexdigest() == md5.strip():
self.log("Download skipped: %s (file exists, md5 verified)", name)
return True
return False
async def _download_one(self, session, name, url, dest, md5):
part = dest+".part"
if md5:
md5_new = hashlib.md5()
try:
async with self._sem, \
session.get(url, timeout=self._timeout) as resp:
if not resp.status == 200:
self.log("Download failed: %s (error code %i)",
name, resp.status)
self.log(" URL: '%s'", url.strip())
return False
size = int(resp.headers.get('content-length', 0))
if os.path.exists(dest):
existing_size = os.path.getsize(dest)
if existing_size == size:
if md5:
self.log("Overwriting: %s (md5 failed)", name)
else:
self.log("Download skipped: %s (file exists)",
name)
return True
else:
self.log("Overwriting: %s (size mismatch %i!=%i)",
name, size, existing_size)
try:
self._sum_bar.total += size
except AttributeError:
pass
with open(part, mode="wb") as out, \
tqdm(total=size,
unit='B', unit_scale=True, unit_divisor=1024,
desc=name, leave=False,
miniters=1, disable=not self._progress,
bar_format=self.make_bar_format(40, 7, rate=True)) as t:
while True:
block = await resp.content.read(self._block_size)
if not block:
break
out.write(block)
if md5:
md5_new.update(block)
t.update(len(block))
self._sum_bar.update(len(block))
os.rename(part, dest)
if md5:
md5_hash = md5_new.hexdigest()
if isinstance(md5, bool):
self.log("Download complete: %s (md5=%s)", name,
md5_hash.strip())
elif md5.strip() == md5_hash:
self.log("Download complete: %s (md5 verified)", name)
else:
self.error("Download failed: %s (md5 failed)", name)
return False
return True
except (asyncio.CancelledError, asyncio.TimeoutError):
if os.path.exists(part):
os.unlink(part)
raise
async def _run(self, urls: List[str], dest: str,
md5s: Optional[List[str]]=None) -> List[bool]:
"""Executes a download session
Args:
urls: List of URLs
dest: Destination path
md5s: Optional list of md5 checksums
"""
if not md5s:
md5s = [None]*len(urls)
async with aiohttp.ClientSession() as session:
if len(urls) == 0:
# No need to show progress bar for just 1 file
self.log("Downloading 1 file.")
result = await asyncio.ensure_future(
self._download(session, urls[0], dest, md5s[0])
)
self.log("Finished download.")
else:
self.log("Downloading %i files.", len(urls))
coros = [
asyncio.ensure_future(
self._download(session, url, dest, md5)
)
for url, md5 in zip(urls, md5s)
]
with tqdm(
asyncio.as_completed(coros), total=len(coros),
unit="Files", desc="Total files:",
disable=not self._progress, leave=False,
bar_format=self.make_bar_format(20, 7, eta=True)
) as t, tqdm(
total=1, # must be >0
unit="B", desc="Total bytes:",
unit_scale=True, unit_divisor=1024,
disable=not self._progress, leave=False, miniters=1,
bar_format=self.make_bar_format(20, 7, rate=True)
) as t2:
self._sum_bar = t2
result = [await coro for coro in t]
self.log("Finished downloads.")
return result
[docs] def get(self, urls: Union[str, List[str]], dest: str,
md5s: Optional[List[str]]=None) -> None:
"""Download a list of URLs
Args:
urls: List of URLs
dest: Destination folder
md5s: List of MD5 sums to check
"""
urls = ensure_list(urls)
if not urls:
return True # nothing to do
if len(urls) > 1:
if not os.path.exists(dest):
os.makedirs(dest)
try:
task = asyncio.ensure_future(self._run(urls, dest, md5s))
self.loop.run_until_complete(task)
except KeyboardInterrupt:
end = asyncio.gather(*asyncio.Task.all_tasks())
end.cancel()
try:
self.loop.run_until_complete(end)
except asyncio.CancelledError:
pass
raise
return all(task.result())
[docs]class DownloadThread(object):
def __init__(self):
LOG.error("made downloader")
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self.main)
self.thread.start()
atexit.register(self.terminate)
[docs] def terminate(self):
self.loop.call_soon_threadsafe(self.loop.stop)
[docs] def main(self):
LOG.error("here")
asyncio.set_event_loop(self.loop)
self.downloader = FileDownloader()
self.loop.run_forever()
[docs] def get(self, url, dest, md5):
LOG.error("scheduling get %s", url)
self.loop.call_soon_threadsafe(
self.downloader.get(url, dest, md5)
)
#DOWNLOADER = DownloadThread()
#def download(url, dest, md5=None):
# LOG.error("called download %s", url)
# DOWNLOADER.get(url, dest, md5)