UniPortrait / src /util.py
Junjie96's picture
Update src/util.py
4d77ce5 verified
raw
history blame contribute delete
No virus
2.18 kB
import concurrent.futures
import io
import os
import time
import oss2
import requests
from PIL import Image
from .log import logger
# oss
access_key_id = os.getenv("ACCESS_KEY_ID")
access_key_secret = os.getenv("ACCESS_KEY_SECRET")
bucket_name = os.getenv("BUCKET_NAME")
endpoint = os.getenv("ENDPOINT")
bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
oss_path = os.getenv("OSS_PATH")
def resize(image, short_side_length=512):
width, height = image.size
ratio = short_side_length / min(width, height)
new_width = int(width * ratio)
new_height = int(height * ratio)
resized_image = image.resize((new_width, new_height))
return resized_image
def download_img_pil(index, img_url):
r = requests.get(img_url, stream=True)
if r.status_code == 200:
img = Image.open(io.BytesIO(r.content))
return (index, img)
else:
logger.error(f"Fail to download: {img_url}")
def download_images(img_urls, batch_size):
imgs_pil = [None] * batch_size
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
to_do = []
for i, url in enumerate(img_urls):
future = executor.submit(download_img_pil, i, url)
to_do.append(future)
for future in concurrent.futures.as_completed(to_do):
ret = future.result()
index, img_pil = ret
imgs_pil[index] = img_pil
return imgs_pil
def upload_np_2_oss(input_image, name="cache.jpg"):
assert name.lower().endswith((".png", ".jpg")), name
if name.endswith(".png"):
name = name[:-4] + ".jpg"
imgByteArr = io.BytesIO()
if name.lower().endswith(".png"):
Image.fromarray(input_image).save(imgByteArr, format="PNG")
else:
Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95)
imgByteArr = imgByteArr.getvalue()
start_time = time.perf_counter()
bucket.put_object(oss_path + "/" + name, imgByteArr)
ret = bucket.sign_url('GET', oss_path + "/" + name, 60 * 60 * 24)
logger.info(f"upload cost: {time.perf_counter() - start_time} s.")
del imgByteArr
return ret