diff --git a/opendatalab/cli/cmd.py b/opendatalab/cli/cmd.py index 9126af3..d7b92ea 100644 --- a/opendatalab/cli/cmd.py +++ b/opendatalab/cli/cmd.py @@ -136,24 +136,29 @@ def info(obj: ContextInfo, name): @command(synopsis=("$ odl get dataset_name # get dataset files into local",)) @click.argument("name", nargs=1) @click.option( - "--conn", - "-c", - default=5, - help="The number of parallel download slots", - show_default=True, + "--dest", + "-d", + default='', + help="Desired dataset store path", + show_default=True +) +@click.option( + "--workers", + "-w", + default = 8, + help= "number of workers", + show_default = True ) @click.pass_obj -def get(obj: ContextInfo, name, conn = 5): +def get(obj: ContextInfo, name, dest, workers): """Get(Download) dataset files into local path.\f Args: obj (ContextInfo): context info\f name (str): dataset name\f - conn (int): The number of parallel download slots\f + destination(str): desired dataset store path\f """ - + from opendatalab.cli.get import implement_get - implement_get(obj, name, conn) - - + implement_get(obj, name, dest, workers) if __name__ == "__main__": cli() diff --git a/opendatalab/cli/get.py b/opendatalab/cli/get.py index 31a525b..add0756 100644 --- a/opendatalab/cli/get.py +++ b/opendatalab/cli/get.py @@ -2,22 +2,17 @@ # # Copyright 2022 Shanghai AI Lab. Licensed under MIT License. # -import logging import os import sys -import threading -import time -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import List import click -import parfive -import requests from tqdm import tqdm from opendatalab.cli.policy import private_policy_url, service_agreement_url from opendatalab.cli.utility import ContextInfo, exception_handler +from opendatalab.client import downloader from opendatalab.exception import OdlDataNotExistsError @@ -30,55 +25,10 @@ def handler(dwCtrlType): if sys.platform == "win32": import win32api win32api.SetConsoleCtrlHandler(handler, True) - -@exception_handler -def download_from_url(url:str, pth: str, file_name:str): - """This function perform a resumable download for a single object - - Args: - url (str): single download url - pth(str): local download path - file_name (str): file name(may contain relative path) - """ - response = requests.get(url, stream = True) - - # get total file size - file_size = int(response.headers['content-length']) - - target = os.path.join(pth, file_name) - # indicate a file-downloaing not complete - if os.path.exists(target): - first_byte = os.path.getsize(target) - else: - # indicate a new file - first_byte = 0 - - # check actual size and server size - if first_byte >= file_size: - click.secho('Download Complete') - sys.exit(1) - - header = {"Range": f"bytes = {first_byte}-{file_size}"} - - pbar = tqdm(total=file_size, - initial= first_byte, - unit = 'B', - unit_scale= True, - desc = 'Downloading Progress:') - - req = requests.get(url, headers= header, stream=True) - - with(open(target, 'ab')) as f: - for chunk in req.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - pbar.update(1024) - pbar.close() - return file_size @exception_handler -def implement_get(obj: ContextInfo, name: str, conn = 5): +def implement_get(obj: ContextInfo, name: str, destination:str, num_workers:int): """ implementation for getting dataset files Args: @@ -96,7 +46,7 @@ def implement_get(obj: ContextInfo, name: str, conn = 5): dataset_name = name sub_dir = "" - # print(name, ds_split ,dataset_name) + # print(name, ds_split ,dataset_name, sub_dir) client = obj.get_client() data_info = client.get_api().get_info(dataset_name) @@ -113,16 +63,16 @@ def implement_get(obj: ContextInfo, name: str, conn = 5): curr_dict = {} if not info['isDir']: curr_dict['size'] = info['size'] - curr_dict['name'] = info['path'] + curr_dict['name'] = os.path.join(sub_dir,info['path']) obj_info_list.append(curr_dict) # if not sub_dir: - print(obj_info_list, sub_dir) + # print(obj_info_list, sub_dir) download_urls_list = client.get_api().get_dataset_download_urls( dataset_id=info_dataset_id, dataset_list=obj_info_list) # print(obj_info_list) - print('____________________________________________________-') + print('___________________________________________________') url_list = [] @@ -131,11 +81,9 @@ def implement_get(obj: ContextInfo, name: str, conn = 5): url_list.append(item['url']) item_list.append(item['name']) - print(url_list[0], item_list[0]) + # print(url_list[0], item_list[0]) - - - local_dir = Path.cwd().joinpath(info_dataset_name) + local_dir = destination download_data = client.get_api().get_download_record(info_dataset_name) has_download = download_data['hasDownload'] @@ -146,7 +94,7 @@ def implement_get(obj: ContextInfo, name: str, conn = 5): f"\n[Warning]: Before downloading, please agree above content."): client.get_api().submit_download_record(info_dataset_name, download_data) else: - click.secho('bye~') + click.secho('See you next time~!') sys.exit(1) if click.confirm(f"Download files into local directory: {local_dir} ?", default=True): @@ -154,32 +102,28 @@ def implement_get(obj: ContextInfo, name: str, conn = 5): Path(local_dir).mkdir(parents=True) print(f"create local dir: {local_dir}") else: - click.secho('bye~') + click.secho('See you next time~!') sys.exit(1) - + # print(url_list[0], item_list[0]) ######################################################################## - size = download_from_url(url_list[0], pth=local_dir, file_name = item_list[0]) - ######################################################################## - print(size) - - - # downloader = parfive.Downloader(max_conn = conn, - # max_splits= 5, - # progress= True) - - # for idx, url in enumerate(url_list): - # downloader.enqueue_file(url, path = local_dir, filename=item_list[idx]) - - # results = downloader.download() - - # for i in results: - # click.echo(i) - - # err_str = '' - # for err in results.errors: - # err_str += f"{err.url} \t {err.exception}\n" - # if not err_str: - # print(f"{info_dataset_name}, download completed!") - # else: - # sys.exit(err_str) \ No newline at end of file + + with tqdm(total = len(url_list)) as pbar: + for idx in range(len(url_list)): + if len(item_list[idx].split('/')) == 1: + filename = item_list[idx] + prefix = '' + else: + filename = item_list[idx].split('/')[-1] + prefix = os.path.dirname(item_list[idx]) + + click.echo(f"Downloading No.{idx+1} of total {len(url_list)} files\n") + if os.path.exists(os.path.join(destination,info_dataset_name, prefix,filename)): + click.echo('target already exists, jumping to next!') + continue + downloader.Downloader(url = url_list[idx], + filename=item_list[idx], + download_dir = os.path.join(destination, info_dataset_name), + blocks_num= num_workers).start() + pbar.update(1) + click.echo(f"\nDownload Complete!") diff --git a/opendatalab/client/downloader.py b/opendatalab/client/downloader.py new file mode 100644 index 0000000..adf2498 --- /dev/null +++ b/opendatalab/client/downloader.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +import glob +import os +import sys +import threading +import time + +import requests + + +class Worker: + def __init__(self, name: str, url: str, range_start, range_end, cache_dir, finish_callback): + self.name = name + self.url = url + self.cache_filename = os.path.join(cache_dir, name + ".odl") + self.range_start = range_start # fixed + self.range_end = range_end # fixed + self.range_curser = range_start # curser dynamic + self.finish_callback = finish_callback + self.terminate_flag = False + self.FINISH_TYPE = "" # DONE\HELP\RETIRE + + def __run(self): + chunk_size = 1 * 1024 + header = { + 'Range': f'Bytes={self.range_curser}-{self.range_end}', + } + req = requests.get(self.url, stream=True, headers=header) + + if 200 <= req.status_code <= 299: + with open(self.cache_filename, "wb") as cache: + for chunk in req.iter_content(chunk_size=chunk_size): + if self.terminate_flag: + break + cache.write(chunk) + self.range_curser += len(chunk) + if not self.terminate_flag: + self.FINISH_TYPE = "DONE" + req.close() + self.finish_callback(self) + + def start(self): + threading.Thread(target=self.__run).start() + + def help(self): + self.FINISH_TYPE = "HELP" + self.terminate_flag = True + + def retire(self): + self.FINISH_TYPE = "RETIRE" + self.terminate_flag = True + + def __lt__(self, another): + return self.range_start < another.range_start + + def get_progress(self): + """progress for each worker""" + _progress = { + "curser": self.range_curser, + "start": self.range_start, + "end": self.range_end + } + return _progress + + +class Downloader: + def __init__(self, url: str, filename:str, download_dir: str, blocks_num: int = 8): + assert 0 <= blocks_num <= 32 + self.prefix_flag = False + if len(filename.split('/')) == 1: + self.filename = filename + self.prefix = '' + else: + self.filename = filename.split('/')[-1] + self.prefix_flag = True + self.prefix = os.path.dirname(filename) + self.url = url + self.download_dir = download_dir + if os.path.exists(os.path.join(self.download_dir, self.prefix, self.filename)): + self.__whistleblower('target already exists') + self.stop() + + # self.download_dir = os.path.join(download_dir, f".{os.sep}odl{os.sep}") + self.blocks_num = blocks_num + self.file_size = self.__get_size() + + # make download dir + if not os.path.exists(self.download_dir): + os.makedirs(self.download_dir) + + # make cache dir + if self.prefix_flag: + self.cache_dir = os.path.join(self.download_dir,self.prefix,'.cache/') + else: + self.cache_dir = os.path.join(self.download_dir,'.cache/') + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + + # slicing + self.start_since = time.time() + # worker container + self.workers = [] + self.LOG = self.__get_log_from_cache() + self.__done = threading.Event() + self.__download_record = [] + threading.Thread(target=self.__supervise).start() + # main + self.__main_thread_done = threading.Event() + # + readable_size = self.__get_readable_size(self.file_size) + pathfilename = os.path.join(self.download_dir, self.prefix,self.filename) + + def __get_size(self): + try: + req = requests.head(self.url) + content_length = req.headers["Content-Length"] + req.close() + return int(content_length) + except Exception as err: + self.__whistleblower(f"[Error] {err}") + return 0 + + def __get_readable_size(self, size): + units = ["B", "KB", "MB", "GB", "TB", "PB"] + unit_index = 0 + K = 1024.0 + while size >= K: + size = size / K + unit_index += 1 + return "%.1f %s" % (size, units[unit_index]) + + def __get_cache_filenames(self): + return glob.glob(f"{self.cache_dir}{self.filename}.*.odl") + + def __get_ranges_from_cache(self): + # like ./cache/filename.1120.odl + ranges = [] + for filename in self.__get_cache_filenames(): + size = os.path.getsize(filename) + if size > 0: + cache_start = int(filename.split(".")[-2]) + cache_end = cache_start + size - 1 + ranges.append((cache_start, cache_end)) + ranges.sort(key=lambda x: x[0]) + return ranges + + def __get_log_from_cache(self): + ranges = self.__get_ranges_from_cache() + LOG = [] + if len(ranges) == 0: + LOG.append((0, self.file_size - 1)) + else: + for i, (start, end) in enumerate(ranges): + if i == 0: + if start > 0: + LOG.append((0, start - 1)) + next_start = self.file_size if i == len(ranges) - 1 else ranges[i + 1][0] + if end < next_start - 1: + LOG.append((end + 1, next_start - 1)) + return LOG + + def __increase_ranges_slice(self, ranges: list, minimum_size=1024 * 1024): + assert len(ranges) > 0 + block_size = [end - start + 1 for start, end in ranges] + index_of_max = block_size.index(max(block_size)) + start, end = ranges[index_of_max] + halfsize = block_size[index_of_max] // 2 + if halfsize >= minimum_size: + new_ranges = [x for i, x in enumerate(ranges) if i != index_of_max] + new_ranges.append((start, start + halfsize)) + new_ranges.append((start + halfsize + 1, end)) + else: + new_ranges = ranges + return new_ranges + + def __ask_for_work(self, worker_num: int): + """ask for work, return[work_range],update self.LOG""" + assert worker_num > 0 + task = [] + LOG_num = len(self.LOG) + # no work now, ask for new work + if LOG_num == 0: + self.__share_the_burdern() + return [] + # enough work, consume + if LOG_num >= worker_num: + for _ in range(worker_num): + task.append(self.LOG.pop(0)) + # too much work + else: + slice_num = worker_num - LOG_num + task = self.LOG + self.LOG = [] + for _ in range(slice_num): + task = self.__increase_ranges_slice(task) + task.sort(key=lambda x: x[0]) + return task + + def __share_the_burdern(self, minimum_size=1024 * 1024): + """Find the heavy worker, and introduce helper""" + max_size = 0 + max_size_name = "" + for w in self.workers: + p = w.get_progress() + size = p["end"] - p["curser"] + 1 + if size > max_size: + max_size = size + max_size_name = w.name + if max_size >= minimum_size: + for w in self.workers: + if w.name == max_size_name: + w.help() + break + + def __give_back_work(self, worker: Worker): + """Take unfinished work""" + progress = worker.get_progress() + curser = progress["curser"] + end = progress["end"] + if curser <= end: + self.LOG.append((curser, end)) + self.LOG.sort(key=lambda x: x[0]) + + def __give_me_a_worker(self, start, end): + worker = Worker(name=f"{self.filename}.{start}", + url=self.url, range_start=start, range_end=end, cache_dir=self.cache_dir, + finish_callback=self.__on_worker_finish, + ) + return worker + + def __whip(self, worker: Worker): + """assign new job""" + self.workers.append(worker) + self.workers.sort() + worker.start() + + def __on_worker_finish(self, worker: Worker): + assert worker.FINISH_TYPE != "" + self.workers.remove(worker) + # need helper + if worker.FINISH_TYPE == "HELP": + self.__give_back_work(worker) + self.workaholic(2) + # job done + elif worker.FINISH_TYPE == "DONE": + # get one more job + self.workaholic(1) + elif worker.FINISH_TYPE == "RETIRE": + self.__give_back_work(worker) + # Job Done, Sewing! + if self.workers == [] and self.__get_log_from_cache() == []: + self.__sew() + + def start(self): + # workers assembly + for start, end in self.__ask_for_work(self.blocks_num): + worker = self.__give_me_a_worker(start, end) + self.__whip(worker) + # wait till done + self.__main_thread_done.wait() + + def stop(self): + for w in self.workers: + w.retire() + while len(self.workers) != 0: + time.sleep(0.5) + self.LOG = self.__get_log_from_cache() + + def workaholic(self, n=1): + """ no work no life""" + for s, e in self.__ask_for_work(n): + worker = self.__give_me_a_worker(s, e) + self.__whip(worker) + + def restart(self): + self.stop() + # worker assembly again! + for start, end in self.__ask_for_work(self.blocks_num): + worker = self.__give_me_a_worker(start, end) + self.__whip(worker) + + def __supervise(self): + """worker and download status supervisor""" + REFRESH_INTERVAL = 1 + # serve as a time window-length + LAG_COUNT = 5 + WAIT_TIMES_BEFORE_RESTART = 30 + SPEED_DEGRADATION_PERCENTAGE = 0.3 + self.__download_record = [] + maxspeed = 0 + wait_times = WAIT_TIMES_BEFORE_RESTART + while not self.__done.is_set(): + dwn_size = sum([os.path.getsize(cachefile) for cachefile in self.__get_cache_filenames()]) + self.__download_record.append({"timestamp": time.time(), "size": dwn_size}) + if len(self.__download_record) > LAG_COUNT: + self.__download_record.pop(0) + s = self.__download_record[-1]["size"] - self.__download_record[0]["size"] + t = self.__download_record[-1]["timestamp"] - self.__download_record[0]["timestamp"] + if not t == 0: + speed = s / t + readable_speed = self.__get_readable_size(speed) + # print(s,t,readable_speed) + percentage = self.__download_record[-1]["size"] / self.file_size * 100 + status_msg = f"\r[Download Info] {percentage:.2f} % | {readable_speed}/s | workers: {len(self.workers)} | {(time.time() - self.start_since):.0f}s" + self.__whistleblower(status_msg) + # speed monitor + maxspeed = max(maxspeed, speed) + EPSILON = 1e-5 + # tolerance reached + time_over = wait_times < 0 + # not finished yet + not_finished = not self.__done.is_set() + + # still running fast enough + speed_drops_significantly = (maxspeed - speed + EPSILON) / (maxspeed + EPSILON) > SPEED_DEGRADATION_PERCENTAGE + speed_under_threshold = speed < 1024 * 1024 + scene_1 = speed_drops_significantly and speed_under_threshold + # running slow + scene_2 = speed < 100 * 1024 + if time_over and not_finished and (scene_1 or scene_2): + self.__whistleblower("\r[info] speed degradation, restarting...") + self.restart() + maxspeed = 0 + wait_times = WAIT_TIMES_BEFORE_RESTART + else: + wait_times -= 1 + time.sleep(REFRESH_INTERVAL) + + def __sew(self): + self.__done.set() + chunk_size = 10 * 1024 * 1024 + with open(f"{os.path.join(self.download_dir, self.prefix, self.filename)}", "wb") as f: + for start, _ in self.__get_ranges_from_cache(): + cache_filename = f"{self.cache_dir}{self.filename}.{start}.odl" + with open(cache_filename, "rb") as cache_file: + data = cache_file.read(chunk_size) + while data: + f.write(data) + f.flush() + data = cache_file.read(chunk_size) + self.clear() + self.__whistleblower("\r") + self.__main_thread_done.set() + + def __whistleblower(self, saying: str): + sys.stdout.write(saying) + + def clear(self): + for filename in self.__get_cache_filenames(): + os.remove(filename) \ No newline at end of file