|
""" |
|
Utilities for working with the local dataset cache. |
|
This file is adapted from the huggingface transformers library at |
|
https://github.com/huggingface/transformers, which in turn is adapted from the AllenNLP |
|
library at https://github.com/allenai/allennlp |
|
Copyright by the AllenNLP authors. |
|
Note - this file goes to effort to support Python 2, but the rest of this repository does not. |
|
""" |
|
from __future__ import (absolute_import, division, print_function, unicode_literals) |
|
|
|
import typing |
|
import sys |
|
import json |
|
import logging |
|
import os |
|
import tempfile |
|
import fnmatch |
|
from io import open |
|
|
|
import boto3 |
|
import requests |
|
from botocore.exceptions import ClientError |
|
from tqdm import tqdm |
|
|
|
from contextlib import contextmanager |
|
from functools import partial, wraps |
|
from hashlib import sha256 |
|
|
|
from filelock import FileLock |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
from torch.hub import _get_torch_home |
|
torch_cache_home = _get_torch_home() |
|
except ImportError: |
|
torch_cache_home = os.path.expanduser( |
|
os.getenv('TORCH_HOME', os.path.join( |
|
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) |
|
default_cache_path = os.path.join(torch_cache_home, 'protein_models') |
|
|
|
try: |
|
from urllib.parse import urlparse |
|
except ImportError: |
|
from urlparse import urlparse |
|
|
|
try: |
|
from pathlib import Path |
|
PYTORCH_PRETRAINED_BERT_CACHE: typing.Union[str, Path] = Path( |
|
os.getenv('PROTEIN_MODELS_CACHE', os.getenv( |
|
'PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) |
|
except (AttributeError, ImportError): |
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PROTEIN_MODELS_CACHE', |
|
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', |
|
default_cache_path)) |
|
|
|
PROTEIN_MODELS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_cache(): |
|
return PROTEIN_MODELS_CACHE |
|
|
|
|
|
def get_etag(url): |
|
|
|
if url.startswith("s3://"): |
|
etag = s3_etag(url) |
|
else: |
|
try: |
|
response = requests.head(url, allow_redirects=True) |
|
if response.status_code != 200: |
|
etag = None |
|
else: |
|
etag = response.headers.get("ETag") |
|
except EnvironmentError: |
|
etag = None |
|
|
|
if sys.version_info[0] == 2 and etag is not None: |
|
etag = etag.decode('utf-8') |
|
|
|
return etag |
|
|
|
|
|
def url_to_filename(url, etag=None): |
|
""" |
|
Convert `url` into a hashed filename in a repeatable way. |
|
If `etag` is specified, append its hash to the url's, delimited |
|
by a period. |
|
""" |
|
url_bytes = url.encode('utf-8') |
|
url_hash = sha256(url_bytes) |
|
filename = url_hash.hexdigest() |
|
|
|
if etag: |
|
etag_bytes = etag.encode('utf-8') |
|
etag_hash = sha256(etag_bytes) |
|
filename += '.' + etag_hash.hexdigest() |
|
|
|
return filename |
|
|
|
|
|
def filename_to_url(filename, cache_dir=None): |
|
""" |
|
Return the url and etag (which may be ``None``) stored for `filename`. |
|
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. |
|
""" |
|
if cache_dir is None: |
|
cache_dir = PROTEIN_MODELS_CACHE |
|
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
cache_path = os.path.join(cache_dir, filename) |
|
if not os.path.exists(cache_path): |
|
raise EnvironmentError("file {} not found".format(cache_path)) |
|
|
|
meta_path = cache_path + '.json' |
|
if not os.path.exists(meta_path): |
|
raise EnvironmentError("file {} not found".format(meta_path)) |
|
|
|
with open(meta_path, encoding="utf-8") as meta_file: |
|
metadata = json.load(meta_file) |
|
url = metadata['url'] |
|
etag = metadata['etag'] |
|
|
|
return url, etag |
|
|
|
|
|
def cached_path(url_or_filename, force_download=False, cache_dir=None): |
|
""" |
|
Given something that might be a URL (or might be a local path), |
|
determine which. If it's a URL, download the file and cache it, and |
|
return the path to the cached file. If it's already a local path, |
|
make sure the file exists and then return the path. |
|
|
|
Args: |
|
cache_dir: specify a cache directory to save the file to |
|
(overwrite the default cache dir). |
|
force_download: if True, re-dowload the file even if it's |
|
already cached in the cache dir. |
|
""" |
|
if cache_dir is None: |
|
cache_dir = PROTEIN_MODELS_CACHE |
|
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): |
|
url_or_filename = str(url_or_filename) |
|
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
|
|
parsed = urlparse(url_or_filename) |
|
|
|
if parsed.scheme in ('http', 'https', 's3'): |
|
|
|
output_path = get_from_cache(url_or_filename, cache_dir, force_download) |
|
elif os.path.exists(url_or_filename): |
|
|
|
output_path = url_or_filename |
|
elif parsed.scheme == '': |
|
|
|
raise EnvironmentError("file {} not found".format(url_or_filename)) |
|
else: |
|
|
|
raise ValueError("unable to parse {} as a URL or as a local path".format( |
|
url_or_filename)) |
|
|
|
return output_path |
|
|
|
|
|
def split_s3_path(url): |
|
"""Split a full s3 path into the bucket name and path.""" |
|
parsed = urlparse(url) |
|
if not parsed.netloc or not parsed.path: |
|
raise ValueError("bad s3 path {}".format(url)) |
|
bucket_name = parsed.netloc |
|
s3_path = parsed.path |
|
|
|
if s3_path.startswith("/"): |
|
s3_path = s3_path[1:] |
|
return bucket_name, s3_path |
|
|
|
|
|
def s3_request(func): |
|
""" |
|
Wrapper function for s3 requests in order to create more helpful error |
|
messages. |
|
""" |
|
|
|
@wraps(func) |
|
def wrapper(url, *args, **kwargs): |
|
try: |
|
return func(url, *args, **kwargs) |
|
except ClientError as exc: |
|
if int(exc.response["Error"]["Code"]) == 404: |
|
raise EnvironmentError("file {} not found".format(url)) |
|
else: |
|
raise |
|
|
|
return wrapper |
|
|
|
|
|
@s3_request |
|
def s3_etag(url): |
|
"""Check ETag on S3 object.""" |
|
s3_resource = boto3.resource("s3") |
|
bucket_name, s3_path = split_s3_path(url) |
|
s3_object = s3_resource.Object(bucket_name, s3_path) |
|
return s3_object.e_tag |
|
|
|
|
|
@s3_request |
|
def s3_get(url, temp_file): |
|
"""Pull a file directly from S3.""" |
|
s3_resource = boto3.resource("s3") |
|
bucket_name, s3_path = split_s3_path(url) |
|
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) |
|
|
|
|
|
def http_get(url, temp_file): |
|
req = requests.get(url, stream=True) |
|
content_length = req.headers.get('Content-Length') |
|
total = int(content_length) if content_length is not None else None |
|
progress = tqdm(unit="B", total=total) |
|
for chunk in req.iter_content(chunk_size=1024): |
|
if chunk: |
|
progress.update(len(chunk)) |
|
temp_file.write(chunk) |
|
progress.close() |
|
|
|
|
|
def get_from_cache(url, cache_dir=None, force_download=False, resume_download=False): |
|
""" |
|
Given a URL, look for the corresponding dataset in the local cache. |
|
If it's not there, download it. Then return the path to the cached file. |
|
""" |
|
if cache_dir is None: |
|
cache_dir = PROTEIN_MODELS_CACHE |
|
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
|
cache_dir = str(cache_dir) |
|
if sys.version_info[0] == 2 and not isinstance(cache_dir, str): |
|
cache_dir = str(cache_dir) |
|
|
|
if not os.path.exists(cache_dir): |
|
os.makedirs(cache_dir) |
|
|
|
|
|
if url.startswith("s3://"): |
|
etag = s3_etag(url) |
|
else: |
|
try: |
|
response = requests.head(url, allow_redirects=True) |
|
if response.status_code != 200: |
|
etag = None |
|
else: |
|
etag = response.headers.get("ETag") |
|
except EnvironmentError: |
|
etag = None |
|
|
|
if sys.version_info[0] == 2 and etag is not None: |
|
etag = etag.decode('utf-8') |
|
filename = url_to_filename(url, etag) |
|
|
|
|
|
cache_path = os.path.join(cache_dir, filename) |
|
|
|
if os.path.exists(cache_path) and etag is None: |
|
return cache_path |
|
|
|
|
|
|
|
if not os.path.exists(cache_path) and etag is None: |
|
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') |
|
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) |
|
if matching_files: |
|
cache_path = os.path.join(cache_dir, matching_files[-1]) |
|
|
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
return cache_path |
|
|
|
|
|
lock_path = cache_path + ".lock" |
|
with FileLock(lock_path): |
|
|
|
|
|
if os.path.exists(cache_path) and not force_download: |
|
|
|
return cache_path |
|
|
|
if resume_download: |
|
incomplete_path = cache_path + ".incomplete" |
|
|
|
@contextmanager |
|
def _resumable_file_manager(): |
|
with open(incomplete_path, "a+b") as f: |
|
yield f |
|
|
|
temp_file_manager = _resumable_file_manager |
|
else: |
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, |
|
delete=False) |
|
|
|
|
|
with temp_file_manager() as temp_file: |
|
logger.info("%s not in cache or force_download=True, download to %s", |
|
url, temp_file.name) |
|
|
|
http_get(url, temp_file) |
|
|
|
logger.info("storing %s in cache at %s", url, cache_path) |
|
os.replace(temp_file.name, cache_path) |
|
|
|
logger.info("creating metadata file for %s", cache_path) |
|
meta = {"url": url, "etag": etag} |
|
meta_path = cache_path + ".json" |
|
with open(meta_path, "w") as meta_file: |
|
json.dump(meta, meta_file) |
|
''' |
|
if not os.path.exists(cache_path): |
|
# Download to temporary file, then copy to cache dir once finished. |
|
# Otherwise you get corrupt cache entries if the download gets interrupted. |
|
with tempfile.NamedTemporaryFile() as temp_file: |
|
logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
|
|
|
# GET file object |
|
if url.startswith("s3://"): |
|
s3_get(url, temp_file) |
|
else: |
|
http_get(url, temp_file) |
|
|
|
# we are copying the file before closing it, so flush to avoid truncation |
|
temp_file.flush() |
|
# shutil.copyfileobj() starts at the current position, so go to the start |
|
temp_file.seek(0) |
|
|
|
logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
|
with open(cache_path, 'wb') as cache_file: |
|
shutil.copyfileobj(temp_file, cache_file) |
|
|
|
logger.info("creating metadata file for %s", cache_path) |
|
meta = {'url': url, 'etag': etag} |
|
meta_path = cache_path + '.json' |
|
with open(meta_path, 'w') as meta_file: |
|
output_string = json.dumps(meta) |
|
if sys.version_info[0] == 2 and isinstance(output_string, str): |
|
# The beauty of python 2 |
|
output_string = unicode(output_string, 'utf-8') # noqa: F821 |
|
meta_file.write(output_string) |
|
|
|
logger.info("removing temp file %s", temp_file.name) |
|
''' |
|
return cache_path |
|
|