实现文件下载功能,添加下载文件方法和相关的文件信息获取逻辑。更新主驱动类以支持文件下载,并在工具类中实现文件下载的进度条显示。
This commit is contained in:
5
_api.py
5
_api.py
@@ -336,6 +336,11 @@ class API:
|
|||||||
}
|
}
|
||||||
return await self._make_request("POST", "upload/v2/file/single/create", json=data, data=file)
|
return await self._make_request("POST", "upload/v2/file/single/create", json=data, data=file)
|
||||||
|
|
||||||
|
async def download_file(self, fileId: int) -> Dict[str, Any]:
|
||||||
|
"""下载文件"""
|
||||||
|
params = {"fileId": fileId}
|
||||||
|
return await self._make_request("GET", "api/v1/file/download_info", params=params)
|
||||||
|
|
||||||
async def create_offline_downlod(self, url: str, dirID: int, fileName: Optional[str] = None, callBackUrl: Optional[str] = None) -> Dict[str, Any]:
|
async def create_offline_downlod(self, url: str, dirID: int, fileName: Optional[str] = None, callBackUrl: Optional[str] = None) -> Dict[str, Any]:
|
||||||
"""创建离线下载任务"""
|
"""创建离线下载任务"""
|
||||||
data: Dict[str, Any] = {
|
data: Dict[str, Any] = {
|
||||||
|
|||||||
54
_main.py
54
_main.py
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import os
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
from _api import API
|
from _api import API
|
||||||
@@ -30,6 +31,7 @@ class Driver:
|
|||||||
dir: str = '/',
|
dir: str = '/',
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
|
return_parentFileId: bool = False,
|
||||||
):
|
):
|
||||||
"""获取目录下的文件列表"""
|
"""获取目录下的文件列表"""
|
||||||
logger.info(f"Calling list_dir(dir={dir}, page={page}, limit={limit})")
|
logger.info(f"Calling list_dir(dir={dir}, page={page}, limit={limit})")
|
||||||
@@ -43,7 +45,10 @@ class Driver:
|
|||||||
logger.debug(f"Updated parentFileId: {parentFileId}")
|
logger.debug(f"Updated parentFileId: {parentFileId}")
|
||||||
files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit)
|
files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit)
|
||||||
logger.info(f"Returning file list for dir={dir}")
|
logger.info(f"Returning file list for dir={dir}")
|
||||||
return files['data']['fileList']
|
if return_parentFileId:
|
||||||
|
return files['data']['fileList'], parentFileId
|
||||||
|
else:
|
||||||
|
return files['data']['fileList']
|
||||||
|
|
||||||
async def _list_dir_fetch_or_cache(
|
async def _list_dir_fetch_or_cache(
|
||||||
self,
|
self,
|
||||||
@@ -121,6 +126,48 @@ class Driver:
|
|||||||
logger.debug(f"Directory {filename} not found in fileList.")
|
logger.debug(f"Directory {filename} not found in fileList.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def fetch_file(self, parentFileId: int, filename: str, lastFileId: Optional[int] = None, page: int = 1) -> Dict[str, Any]:
|
||||||
|
"""获取文件信息"""
|
||||||
|
logger.info(f"Calling fetch_file(parentFileID={parentFileId}, filename={filename})")
|
||||||
|
files = self.utils.get_cached_files(parentFileId=parentFileId, page=page)
|
||||||
|
if not files:
|
||||||
|
files = await self.api.list_files_v2(parentFileId=parentFileId, limit=100, lastFileId=lastFileId)
|
||||||
|
self.utils.cache_files(files=files, parentFileId=parentFileId, page=page)
|
||||||
|
for f in files['data']['fileList']:
|
||||||
|
if f['type'] == 0 and f['filename'] == filename:
|
||||||
|
logger.debug(f"Found file {filename} in fileList.")
|
||||||
|
return f
|
||||||
|
if lastFileId != -1: # 文件名不在文件列表中,但有lastFileId,继续搜索
|
||||||
|
logger.debug(f"Fetching more files for parentFileId={parentFileId} with lastFileId={lastFileId}")
|
||||||
|
await self.fetch_file(parentFileId=parentFileId, filename=filename, lastFileId=lastFileId, page=page+1)
|
||||||
|
logger.error(f"Error: {filename} not found in {files['data']['fileList']}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def download_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
save_path: str,
|
||||||
|
progress_bar: bool = True,
|
||||||
|
):
|
||||||
|
"""下载文件"""
|
||||||
|
logger.info(f"Calling download_file(file_path={file_path}, save_path={save_path})")
|
||||||
|
dirname = os.path.dirname(file_path)
|
||||||
|
_, parentFileId = await self.list_dir(dir=dirname, return_parentFileId=True)
|
||||||
|
file = await self.fetch_file(parentFileId=parentFileId, filename=os.path.basename(file_path))
|
||||||
|
if file:
|
||||||
|
logger.debug(f"Downloading file {file_path} to {save_path}")
|
||||||
|
downlod_info = await self.api.download_file(fileId=file['fileId'])
|
||||||
|
if downlod_info.get('code') == 0:
|
||||||
|
logger.debug(f"Downloading file {file_path} to {save_path} started.")
|
||||||
|
self.utils.download_file(url=downlod_info['data']['downloadUrl'], file_path=save_path, progress_bar=progress_bar)
|
||||||
|
else:
|
||||||
|
logger.error(f"Error: {downlod_info.get('message')}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Error: {file_path} not found.")
|
||||||
|
return downlod_info
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@async_to_sync
|
@async_to_sync
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
from privacy import client_id, client_secret
|
from privacy import client_id, client_secret
|
||||||
@@ -129,9 +176,8 @@ async def main() -> None:
|
|||||||
logger.info("Starting main()")
|
logger.info("Starting main()")
|
||||||
driver = Driver(client_id=client_id, client_secret=client_secret)
|
driver = Driver(client_id=client_id, client_secret=client_secret)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
dirs = await driver.list_dir(dir='/nas/Documents')
|
dirs = await driver.list_dir(dir='/')
|
||||||
utils.print_file_list(dirs)
|
|
||||||
print(utils.computing_time(start_time))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
24
_utils.py
24
_utils.py
@@ -3,6 +3,8 @@ import asyncio
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from tqdm import tqdm
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from cachetools import cached, TTLCache
|
from cachetools import cached, TTLCache
|
||||||
@@ -180,6 +182,28 @@ class Utils:
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
return f"{elapsed_time:.3f} s"
|
return f"{elapsed_time:.3f} s"
|
||||||
|
|
||||||
|
def download_file(self, url: str, file_path: str, progress_bar: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
下载文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 文件URL
|
||||||
|
file_path: 保存路径
|
||||||
|
progress_bar: 是否显示进度条
|
||||||
|
"""
|
||||||
|
with httpx.stream("GET", url) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
total_size = int(response.headers.get("Content-Length", 0))
|
||||||
|
if progress_bar:
|
||||||
|
progress = tqdm(total=total_size, unit="iB", unit_scale=True)
|
||||||
|
with open(file_path, "wb") as file:
|
||||||
|
for data in response.iter_bytes():
|
||||||
|
file.write(data)
|
||||||
|
if progress_bar:
|
||||||
|
progress.update(len(data))
|
||||||
|
if progress_bar:
|
||||||
|
progress.close()
|
||||||
|
|
||||||
|
|
||||||
def async_to_sync(func):
|
def async_to_sync(func):
|
||||||
|
|||||||
Reference in New Issue
Block a user