实现文件下载功能,添加下载文件方法和相关的文件信息获取逻辑。更新主驱动类以支持文件下载,并在工具类中实现文件下载的进度条显示。

This commit is contained in:
2025-07-15 17:30:33 +08:00
parent dc5e8b4bfc
commit e5dbab6679
3 changed files with 79 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
import re
import os
from typing import Dict, List, Optional, Any
from _api import API
@@ -30,6 +31,7 @@ class Driver:
dir: str = '/',
page: int = 1,
limit: int = 100,
return_parentFileId: bool = False,
):
"""获取目录下的文件列表"""
logger.info(f"Calling list_dir(dir={dir}, page={page}, limit={limit})")
@@ -43,7 +45,10 @@ class Driver:
logger.debug(f"Updated parentFileId: {parentFileId}")
files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit)
logger.info(f"Returning file list for dir={dir}")
return files['data']['fileList']
if return_parentFileId:
return files['data']['fileList'], parentFileId
else:
return files['data']['fileList']
async def _list_dir_fetch_or_cache(
self,
@@ -121,6 +126,48 @@ class Driver:
logger.debug(f"Directory {filename} not found in fileList.")
return False
async def fetch_file(self, parentFileId: int, filename: str, lastFileId: Optional[int] = None, page: int = 1) -> Dict[str, Any]:
"""获取文件信息"""
logger.info(f"Calling fetch_file(parentFileID={parentFileId}, filename={filename})")
files = self.utils.get_cached_files(parentFileId=parentFileId, page=page)
if not files:
files = await self.api.list_files_v2(parentFileId=parentFileId, limit=100, lastFileId=lastFileId)
self.utils.cache_files(files=files, parentFileId=parentFileId, page=page)
for f in files['data']['fileList']:
if f['type'] == 0 and f['filename'] == filename:
logger.debug(f"Found file {filename} in fileList.")
return f
if lastFileId != -1: # 文件名不在文件列表中但有lastFileId继续搜索
logger.debug(f"Fetching more files for parentFileId={parentFileId} with lastFileId={lastFileId}")
await self.fetch_file(parentFileId=parentFileId, filename=filename, lastFileId=lastFileId, page=page+1)
logger.error(f"Error: {filename} not found in {files['data']['fileList']}")
return {}
async def download_file(
self,
file_path: str,
save_path: str,
progress_bar: bool = True,
):
"""下载文件"""
logger.info(f"Calling download_file(file_path={file_path}, save_path={save_path})")
dirname = os.path.dirname(file_path)
_, parentFileId = await self.list_dir(dir=dirname, return_parentFileId=True)
file = await self.fetch_file(parentFileId=parentFileId, filename=os.path.basename(file_path))
if file:
logger.debug(f"Downloading file {file_path} to {save_path}")
downlod_info = await self.api.download_file(fileId=file['fileId'])
if downlod_info.get('code') == 0:
logger.debug(f"Downloading file {file_path} to {save_path} started.")
self.utils.download_file(url=downlod_info['data']['downloadUrl'], file_path=save_path, progress_bar=progress_bar)
else:
logger.error(f"Error: {downlod_info.get('message')}")
else:
logger.error(f"Error: {file_path} not found.")
return downlod_info
@async_to_sync
async def main() -> None:
from privacy import client_id, client_secret
@@ -129,9 +176,8 @@ async def main() -> None:
logger.info("Starting main()")
driver = Driver(client_id=client_id, client_secret=client_secret)
start_time = time.time()
dirs = await driver.list_dir(dir='/nas/Documents')
utils.print_file_list(dirs)
print(utils.computing_time(start_time))
dirs = await driver.list_dir(dir='/')
if __name__ == '__main__':