diff --git a/_api.py b/_api.py index 7d61976..7ccf50e 100644 --- a/_api.py +++ b/_api.py @@ -336,6 +336,11 @@ class API: } return await self._make_request("POST", "upload/v2/file/single/create", json=data, data=file) + async def download_file(self, fileId: int) -> Dict[str, Any]: + """下载文件""" + params = {"fileId": fileId} + return await self._make_request("GET", "api/v1/file/download_info", params=params) + async def create_offline_downlod(self, url: str, dirID: int, fileName: Optional[str] = None, callBackUrl: Optional[str] = None) -> Dict[str, Any]: """创建离线下载任务""" data: Dict[str, Any] = { diff --git a/_main.py b/_main.py index 8c940fd..a5881a0 100644 --- a/_main.py +++ b/_main.py @@ -1,4 +1,5 @@ import re +import os from typing import Dict, List, Optional, Any from _api import API @@ -30,6 +31,7 @@ class Driver: dir: str = '/', page: int = 1, limit: int = 100, + return_parentFileId: bool = False, ): """获取目录下的文件列表""" logger.info(f"Calling list_dir(dir={dir}, page={page}, limit={limit})") @@ -43,7 +45,10 @@ class Driver: logger.debug(f"Updated parentFileId: {parentFileId}") files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit) logger.info(f"Returning file list for dir={dir}") - return files['data']['fileList'] + if return_parentFileId: + return files['data']['fileList'], parentFileId + else: + return files['data']['fileList'] async def _list_dir_fetch_or_cache( self, @@ -121,6 +126,48 @@ class Driver: logger.debug(f"Directory {filename} not found in fileList.") return False + async def fetch_file(self, parentFileId: int, filename: str, lastFileId: Optional[int] = None, page: int = 1) -> Dict[str, Any]: + """获取文件信息""" + logger.info(f"Calling fetch_file(parentFileID={parentFileId}, filename={filename})") + files = self.utils.get_cached_files(parentFileId=parentFileId, page=page) + if not files: + files = await self.api.list_files_v2(parentFileId=parentFileId, limit=100, lastFileId=lastFileId) + self.utils.cache_files(files=files, parentFileId=parentFileId, page=page) + for f in files['data']['fileList']: + if f['type'] == 0 and f['filename'] == filename: + logger.debug(f"Found file {filename} in fileList.") + return f + if lastFileId != -1: # 文件名不在文件列表中,但有lastFileId,继续搜索 + logger.debug(f"Fetching more files for parentFileId={parentFileId} with lastFileId={lastFileId}") + await self.fetch_file(parentFileId=parentFileId, filename=filename, lastFileId=lastFileId, page=page+1) + logger.error(f"Error: {filename} not found in {files['data']['fileList']}") + return {} + + async def download_file( + self, + file_path: str, + save_path: str, + progress_bar: bool = True, + ): + """下载文件""" + logger.info(f"Calling download_file(file_path={file_path}, save_path={save_path})") + dirname = os.path.dirname(file_path) + _, parentFileId = await self.list_dir(dir=dirname, return_parentFileId=True) + file = await self.fetch_file(parentFileId=parentFileId, filename=os.path.basename(file_path)) + if file: + logger.debug(f"Downloading file {file_path} to {save_path}") + downlod_info = await self.api.download_file(fileId=file['fileId']) + if downlod_info.get('code') == 0: + logger.debug(f"Downloading file {file_path} to {save_path} started.") + self.utils.download_file(url=downlod_info['data']['downloadUrl'], file_path=save_path, progress_bar=progress_bar) + else: + logger.error(f"Error: {downlod_info.get('message')}") + else: + logger.error(f"Error: {file_path} not found.") + return downlod_info + + + @async_to_sync async def main() -> None: from privacy import client_id, client_secret @@ -129,9 +176,8 @@ async def main() -> None: logger.info("Starting main()") driver = Driver(client_id=client_id, client_secret=client_secret) start_time = time.time() - dirs = await driver.list_dir(dir='/nas/Documents') - utils.print_file_list(dirs) - print(utils.computing_time(start_time)) + dirs = await driver.list_dir(dir='/') + if __name__ == '__main__': diff --git a/_utils.py b/_utils.py index 67dc797..f3dc78d 100644 --- a/_utils.py +++ b/_utils.py @@ -3,6 +3,8 @@ import asyncio from functools import wraps from typing import Any, Dict +import httpx +from tqdm import tqdm from rich.console import Console from rich.table import Table from cachetools import cached, TTLCache @@ -180,6 +182,28 @@ class Utils: end_time = time.time() elapsed_time = end_time - start_time return f"{elapsed_time:.3f} s" + + def download_file(self, url: str, file_path: str, progress_bar: bool = True) -> None: + """ + 下载文件 + + Args: + url: 文件URL + file_path: 保存路径 + progress_bar: 是否显示进度条 + """ + with httpx.stream("GET", url) as response: + response.raise_for_status() + total_size = int(response.headers.get("Content-Length", 0)) + if progress_bar: + progress = tqdm(total=total_size, unit="iB", unit_scale=True) + with open(file_path, "wb") as file: + for data in response.iter_bytes(): + file.write(data) + if progress_bar: + progress.update(len(data)) + if progress_bar: + progress.close() def async_to_sync(func):