实现文件下载功能,添加下载文件方法和相关的文件信息获取逻辑。更新主驱动类以支持文件下载,并在工具类中实现文件下载的进度条显示。

This commit is contained in:
2025-07-15 17:30:33 +08:00
parent dc5e8b4bfc
commit e5dbab6679
3 changed files with 79 additions and 4 deletions

View File

@@ -336,6 +336,11 @@ class API:
} }
return await self._make_request("POST", "upload/v2/file/single/create", json=data, data=file) return await self._make_request("POST", "upload/v2/file/single/create", json=data, data=file)
async def download_file(self, fileId: int) -> Dict[str, Any]:
"""下载文件"""
params = {"fileId": fileId}
return await self._make_request("GET", "api/v1/file/download_info", params=params)
async def create_offline_downlod(self, url: str, dirID: int, fileName: Optional[str] = None, callBackUrl: Optional[str] = None) -> Dict[str, Any]: async def create_offline_downlod(self, url: str, dirID: int, fileName: Optional[str] = None, callBackUrl: Optional[str] = None) -> Dict[str, Any]:
"""创建离线下载任务""" """创建离线下载任务"""
data: Dict[str, Any] = { data: Dict[str, Any] = {

View File

@@ -1,4 +1,5 @@
import re import re
import os
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from _api import API from _api import API
@@ -30,6 +31,7 @@ class Driver:
dir: str = '/', dir: str = '/',
page: int = 1, page: int = 1,
limit: int = 100, limit: int = 100,
return_parentFileId: bool = False,
): ):
"""获取目录下的文件列表""" """获取目录下的文件列表"""
logger.info(f"Calling list_dir(dir={dir}, page={page}, limit={limit})") logger.info(f"Calling list_dir(dir={dir}, page={page}, limit={limit})")
@@ -43,7 +45,10 @@ class Driver:
logger.debug(f"Updated parentFileId: {parentFileId}") logger.debug(f"Updated parentFileId: {parentFileId}")
files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit) files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit)
logger.info(f"Returning file list for dir={dir}") logger.info(f"Returning file list for dir={dir}")
return files['data']['fileList'] if return_parentFileId:
return files['data']['fileList'], parentFileId
else:
return files['data']['fileList']
async def _list_dir_fetch_or_cache( async def _list_dir_fetch_or_cache(
self, self,
@@ -121,6 +126,48 @@ class Driver:
logger.debug(f"Directory {filename} not found in fileList.") logger.debug(f"Directory {filename} not found in fileList.")
return False return False
async def fetch_file(self, parentFileId: int, filename: str, lastFileId: Optional[int] = None, page: int = 1) -> Dict[str, Any]:
"""获取文件信息"""
logger.info(f"Calling fetch_file(parentFileID={parentFileId}, filename={filename})")
files = self.utils.get_cached_files(parentFileId=parentFileId, page=page)
if not files:
files = await self.api.list_files_v2(parentFileId=parentFileId, limit=100, lastFileId=lastFileId)
self.utils.cache_files(files=files, parentFileId=parentFileId, page=page)
for f in files['data']['fileList']:
if f['type'] == 0 and f['filename'] == filename:
logger.debug(f"Found file {filename} in fileList.")
return f
if lastFileId != -1: # 文件名不在文件列表中但有lastFileId继续搜索
logger.debug(f"Fetching more files for parentFileId={parentFileId} with lastFileId={lastFileId}")
await self.fetch_file(parentFileId=parentFileId, filename=filename, lastFileId=lastFileId, page=page+1)
logger.error(f"Error: {filename} not found in {files['data']['fileList']}")
return {}
async def download_file(
self,
file_path: str,
save_path: str,
progress_bar: bool = True,
):
"""下载文件"""
logger.info(f"Calling download_file(file_path={file_path}, save_path={save_path})")
dirname = os.path.dirname(file_path)
_, parentFileId = await self.list_dir(dir=dirname, return_parentFileId=True)
file = await self.fetch_file(parentFileId=parentFileId, filename=os.path.basename(file_path))
if file:
logger.debug(f"Downloading file {file_path} to {save_path}")
downlod_info = await self.api.download_file(fileId=file['fileId'])
if downlod_info.get('code') == 0:
logger.debug(f"Downloading file {file_path} to {save_path} started.")
self.utils.download_file(url=downlod_info['data']['downloadUrl'], file_path=save_path, progress_bar=progress_bar)
else:
logger.error(f"Error: {downlod_info.get('message')}")
else:
logger.error(f"Error: {file_path} not found.")
return downlod_info
@async_to_sync @async_to_sync
async def main() -> None: async def main() -> None:
from privacy import client_id, client_secret from privacy import client_id, client_secret
@@ -129,9 +176,8 @@ async def main() -> None:
logger.info("Starting main()") logger.info("Starting main()")
driver = Driver(client_id=client_id, client_secret=client_secret) driver = Driver(client_id=client_id, client_secret=client_secret)
start_time = time.time() start_time = time.time()
dirs = await driver.list_dir(dir='/nas/Documents') dirs = await driver.list_dir(dir='/')
utils.print_file_list(dirs)
print(utils.computing_time(start_time))
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -3,6 +3,8 @@ import asyncio
from functools import wraps from functools import wraps
from typing import Any, Dict from typing import Any, Dict
import httpx
from tqdm import tqdm
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
from cachetools import cached, TTLCache from cachetools import cached, TTLCache
@@ -180,6 +182,28 @@ class Utils:
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
return f"{elapsed_time:.3f} s" return f"{elapsed_time:.3f} s"
def download_file(self, url: str, file_path: str, progress_bar: bool = True) -> None:
"""
下载文件
Args:
url: 文件URL
file_path: 保存路径
progress_bar: 是否显示进度条
"""
with httpx.stream("GET", url) as response:
response.raise_for_status()
total_size = int(response.headers.get("Content-Length", 0))
if progress_bar:
progress = tqdm(total=total_size, unit="iB", unit_scale=True)
with open(file_path, "wb") as file:
for data in response.iter_bytes():
file.write(data)
if progress_bar:
progress.update(len(data))
if progress_bar:
progress.close()
def async_to_sync(func): def async_to_sync(func):