Skip to content

Commit

Permalink
remove parfive dependency, introduce request based resumable download…
Browse files Browse the repository at this point in the history
…(demo single file)
  • Loading branch information
weixingjian committed Mar 8, 2023
1 parent 2eb620a commit 22e9ef6
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 20 deletions.
103 changes: 86 additions & 17 deletions opendatalab/cli/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import click
import parfive
import requests
from tqdm import tqdm

from opendatalab.cli.policy import private_policy_url, service_agreement_url
Expand All @@ -30,7 +31,52 @@ def handler(dwCtrlType):
import win32api
win32api.SetConsoleCtrlHandler(handler, True)

@exception_handler
def download_from_url(url:str, pth: str, file_name:str):
"""This function perform a resumable download for a single object
Args:
url (str): single download url
pth(str): local download path
file_name (str): file name(may contain relative path)
"""
response = requests.get(url, stream = True)

# get total file size
file_size = int(response.headers['content-length'])

target = os.path.join(pth, file_name)
# indicate a file-downloaing not complete
if os.path.exists(target):
first_byte = os.path.getsize(target)
else:
# indicate a new file
first_byte = 0

# check actual size and server size
if first_byte >= file_size:
click.secho('Download Complete')
sys.exit(1)

header = {"Range": f"bytes = {first_byte}-{file_size}"}

pbar = tqdm(total=file_size,
initial= first_byte,
unit = 'B',
unit_scale= True,
desc = 'Downloading Progress:')

req = requests.get(url, headers= header, stream=True)

with(open(target, 'ab')) as f:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size


@exception_handler
def implement_get(obj: ContextInfo, name: str, conn = 5):
"""
Expand All @@ -39,22 +85,28 @@ def implement_get(obj: ContextInfo, name: str, conn = 5):
obj (ContextInfo):
name (str):
thread (int):
limit_speed (int):
compressed (bool):
Returns:
"""
ds_split = name.split("/")
if len(ds_split) > 1:
dataset_name = ds_split[0]
sub_dir = "/".join(ds_split[1:])
else:
dataset_name = name
sub_dir = ""

# print(name, ds_split ,dataset_name)

client = obj.get_client()
data_info = client.get_api().get_info(dataset_name)
info_dataset_name = data_info['name']
info_dataset_id = data_info['id']

dataset_res_dict = client.get_api().get_dataset_files(dataset_name=info_dataset_name)
dataset_res_dict = client.get_api().get_dataset_files(dataset_name=info_dataset_name,
prefix = sub_dir)
# print(dataset_res_dict)

# obj list constuct
obj_info_list = []
for info in dataset_res_dict['list']:
Expand All @@ -63,15 +115,25 @@ def implement_get(obj: ContextInfo, name: str, conn = 5):
curr_dict['size'] = info['size']
curr_dict['name'] = info['path']
obj_info_list.append(curr_dict)

# if not sub_dir:
print(obj_info_list, sub_dir)
download_urls_list = client.get_api().get_dataset_download_urls(
dataset_id=info_dataset_id,
dataset_list=obj_info_list)
# print(obj_info_list)
print('____________________________________________________-')


url_list = []
item_list = []
for item in download_urls_list:
url_list.append(item['url'])
item_list.append(item['name'])

print(url_list[0], item_list[0])



local_dir = Path.cwd().joinpath(info_dataset_name)

Expand All @@ -95,22 +157,29 @@ def implement_get(obj: ContextInfo, name: str, conn = 5):
click.secho('bye~')
sys.exit(1)

downloader = parfive.Downloader(max_conn = conn,
max_splits= 5,
progress= True)

########################################################################
size = download_from_url(url_list[0], pth=local_dir, file_name = item_list[0])
########################################################################
print(size)

for idx, url in enumerate(url_list):
downloader.enqueue_file(url, path = local_dir, filename=item_list[idx])

# downloader = parfive.Downloader(max_conn = conn,
# max_splits= 5,
# progress= True)

# for idx, url in enumerate(url_list):
# downloader.enqueue_file(url, path = local_dir, filename=item_list[idx])

results = downloader.download()
# results = downloader.download()

for i in results:
click.echo(i)
# for i in results:
# click.echo(i)

err_str = ''
for err in results.errors:
err_str += f"{err.url} \t {err.exception}\n"
if not err_str:
print(f"{info_dataset_name}, download completed!")
else:
sys.exit(err_str)
# err_str = ''
# for err in results.errors:
# err_str += f"{err.url} \t {err.exception}\n"
# if not err_str:
# print(f"{info_dataset_name}, download completed!")
# else:
# sys.exit(err_str)
2 changes: 1 addition & 1 deletion opendatalab/cli/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def implement_ls(obj: ContextInfo, dataset: str):
info_dataset_name = client.get_api().get_info(dataset_name)['name']
dataset_instance = client.get_dataset(dataset_name=info_dataset_name)

dataset_res_dict = client.get_api().get_dataset_files(dataset_name=info_dataset_name)
dataset_res_dict = client.get_api().get_dataset_files(dataset_name=info_dataset_name, prefix = sub_dir)

# generate output info dict
object_info_dict = {}
Expand Down
5 changes: 3 additions & 2 deletions opendatalab/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, host, token, odl_cookie):
self.token = token
self.odl_cookie = odl_cookie

def get_dataset_files(self, dataset_name:str):
def get_dataset_files(self, dataset_name:str, prefix:str):
""" https request retrieve dataset files
Args:
dataset (str): dataset name
Expand All @@ -36,7 +36,8 @@ def get_dataset_files(self, dataset_name:str):
"User-Agent": UUID,
"accept" : "application/json"
}
data = {"recursive": True}
data = {"recursive": True,
"prefix":prefix}
resp = requests.get(
url = f"{self.host}/api/datasets/{dataset_name}/files",
params = data,
Expand Down

0 comments on commit 22e9ef6

Please sign in to comment.