diff --git a/_main.py b/_main.py index b60eda1..8c940fd 100644 --- a/_main.py +++ b/_main.py @@ -1,13 +1,15 @@ +import re from typing import Dict, List, Optional, Any from _api import API from _logger import logger -from _utils import async_to_sync +from _utils import async_to_sync, Utils class Driver: def __init__(self, client_id: str, client_secret: str, base_url: str = "https://open-api.123pan.com"): + self.utils = Utils() self.api = API(client_id=client_id, client_secret=client_secret, base_url=base_url) self.api.check_access_token() logger.info("Driver initialized.") @@ -33,16 +35,40 @@ class Driver: logger.info(f"Calling list_dir(dir={dir}, page={page}, limit={limit})") parentFileId = 0 dir_list = dir.split('/') - files = await self.api.list_files_v2(parentFileId=parentFileId, limit=limit) + files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit) for i in dir_list: logger.debug(f"Processing dir segment: '{i}' with parentFileId={parentFileId}") if i: parentFileId = await self._list_dir_fetch_parentFileId(parentFileId, files, i, limit) logger.debug(f"Updated parentFileId: {parentFileId}") - files = await self.api.list_files_v2(parentFileId=parentFileId, limit=limit) + files = await self._list_dir_fetch_or_cache(parentFileId=parentFileId, page=page, limit=limit) logger.info(f"Returning file list for dir={dir}") return files['data']['fileList'] - + + async def _list_dir_fetch_or_cache( + self, + parentFileId: int, + page: int = 1, + limit: int = 100, + lastFileId: Optional[int] = None, + ) -> Dict[str, Any]: + """从缓存中获取目录下的文件列表,如果缓存中没有,则从API获取""" + files_list = [] + pages = self.utils.computing_page(page=page, limit=limit) + for p in pages: + logger.debug(f"_list_dir_fetch_or_cache(parentFileId={parentFileId}, page={p}.") + files = self.utils.get_cached_files(parentFileId=parentFileId, page=p) + if not files: + logger.debug(f"No cached files found for parentFileId={parentFileId}, fetching from API.") + files = await self.api.list_files_v2(parentFileId=parentFileId, limit=limit, lastFileId=lastFileId) + self.utils.cache_files(files=files, parentFileId=parentFileId, page=p) + return files + else: + logger.debug(f"Cached files found for parentFileId={parentFileId}, returning from cache.") + files = self.utils.get_cached_files(parentFileId=parentFileId, page=p) + files_list.append(files) + return self.utils.merge_files(files_list) + async def _list_dir_fetch_parentFileId( self, parentFileId: int, @@ -52,21 +78,21 @@ class Driver: lastFileId: Optional[int] = None, ) -> int: logger.info(f"_list_dir_fetch_parentFileId(parentFileId={parentFileId}, filename={filename}, limit={limit}, lastFileId={lastFileId})") - if await self._list_dir_in_files(files, filename) and lastFileId != -1: + if await self._list_dir_in_files(files, filename) and lastFileId != -1: # 文件名在文件列表中,直接返回parentFileId logger.debug(f"Found {filename} in files, getting parentFileId.") return await self._list_dir_get_parentFileId(files, filename) - elif lastFileId != -1: + elif lastFileId != -1: # 文件名不在文件列表中,但有lastFileId,继续搜索 logger.debug(f"Fetching more files for parentFileId={parentFileId} with lastFileId={lastFileId}") files = await self.api.list_files_v2(parentFileId=parentFileId, limit=limit, lastFileId=lastFileId) - return await self._list_dir_fetch_parentFileId(parentFileId, files, filename, limit, files['lastFileId']) - else: + return await self._list_dir_fetch_parentFileId(parentFileId, files, filename, limit, files['data']['lastFileId']) + else: # 文件名不在文件列表中,且lastFileId等于-1,说明文件列表已经遍历完毕,没有找到返回0 if await self._list_dir_in_files(files, filename): logger.debug(f"Found {filename} in files after lastFileId exhausted.") return await self._list_dir_get_parentFileId(files, filename) else: logger.error(f"Error: {filename} not found in {files['data']['fileList']}") return 0 - + async def _list_dir_get_parentFileId( self, files: Dict[str, Any], @@ -98,11 +124,14 @@ class Driver: @async_to_sync async def main() -> None: from privacy import client_id, client_secret + from _utils import Utils, time + utils = Utils() logger.info("Starting main()") driver = Driver(client_id=client_id, client_secret=client_secret) - dirs = await driver.list_dir(dir='/') - for dir in dirs: - print(dir) + start_time = time.time() + dirs = await driver.list_dir(dir='/nas/Documents') + utils.print_file_list(dirs) + print(utils.computing_time(start_time)) if __name__ == '__main__': diff --git a/_utils.py b/_utils.py index cfc317d..67dc797 100644 --- a/_utils.py +++ b/_utils.py @@ -1,5 +1,185 @@ +import time import asyncio from functools import wraps +from typing import Any, Dict + +from rich.console import Console +from rich.table import Table +from cachetools import cached, TTLCache + + +class Utils: + + def __init__(self): + self.console = Console() + self.files_cache = TTLCache(maxsize=1000, ttl=600) # 10分钟 + + def format_file_size(self, size_bytes: int, decimal_places: int = 1) -> str: + """ + 格式化文件大小显示 + + Args: + size_bytes: 文件大小(字节) + decimal_places: 小数位数,默认1位,可选1或2位 + + Returns: + 格式化后的文件大小字符串,如 "1.2 KB", "1.23 MB" + """ + if size_bytes == 0: + return "0 B" + + # 定义单位 + units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + unit_index = 0 + + # 计算合适的单位 + size = float(size_bytes) + while size >= 1024.0 and unit_index < len(units) - 1: + size /= 1024.0 + unit_index += 1 + + # 格式化显示 + if unit_index == 0: # B单位,不显示小数 + return f"{int(size)} {units[unit_index]}" + else: + # 确保小数位数不超过2位 + decimal_places = min(decimal_places, 2) + format_str = f"{{:.{decimal_places}f}} {units[unit_index]}" + return format_str.format(size) + + def print_file_list(self, files): + """ + 打印文件列表 + + Args: + files: 文件列表 + """ + table = Table(title="文件列表") + table.add_column("类型", style="cyan", width=10) + table.add_column("名称", style="magenta") + table.add_column("大小", style="green", width=15) + table.add_column("修改时间", style="yellow", width=20) + + for file in files: + type = self.print_file_type(file) + name = file['filename'] + size = self.format_file_size(file.get('size', 0)) + modified = file.get('updateAt', '') + + table.add_row(type, name, size, modified) + + self.console.print(table) + + def print_file_type(self, file: Dict[str, Any]) -> str: + """ + 格式化文件类型显示 + + Args: + file: 文件信息字典 + + Returns: + 格式化后的文件类型字符串,如 "文件夹", "音频", "视频", "图片", "未知" + """ + if file['type'] == 1: + return "文件夹" + else: + if file['category'] == 1: + return "音频" + elif file['category'] == 2: + return "视频" + elif file['category'] == 3: + return "图片" + else: + return "未知" + + def computing_page(self, page: int, limit: int) -> list: + """ + 计算分页信息 + + Args: + page: 当前页码 + limit: 每页显示数量 + + Returns: + 包含所有分页的列表,如 [1, 2, 3] + """ + pages = [] + for i in range(page * limit - limit, page * limit, 100): + page_i = i // 100 + 1 + pages.append(page_i) + return pages + + def merge_files(self, files_list: list[Dict[str, Any]]) -> Dict[str, Any]: + """ + 合并文件列表 + + Args: + files_list: 文件列表响应json数据列表 + + Returns: + 合并后的文件列表响应json数据 + """ + merged_files = { + 'code': 0, + 'message': 'ok', + 'data': {'lastFileId': 0,'fileList': []}, + 'x-traceID': '' + } + for files in files_list: + merged_files['code'] = files['code'] + merged_files['message'] = files['message'] + merged_files['data']['lastFileId'] = files['data']['lastFileId'] + merged_files['data']['fileList'] += files['data']['fileList'] + merged_files['x-traceID'] = files['x-traceID'] + return merged_files + + + def cache_files(self, files: Dict[str, Any], parentFileId: int, page: int = 1): + """ + 缓存文件列表 + + Args: + files: 文件列表 + parentFileId: 父目录ID + page: 页码 + """ + cache_key = f"files:{parentFileId}:{page}" + self.files_cache[cache_key] = { + 'files': files, + 'timestamp': time.time() + } + + def get_cached_files(self, parentFileId: int, page: int = 1) -> Dict[str, Any]: + """获取缓存的文件列表""" + cache_key = f"files:{parentFileId}:{page}" + if self.files_cache.get(cache_key): + return self.files_cache[cache_key]['files'] + else: + return {} + + def cache_limit(self, maxsize=1000, ttl=300): + """ + 设置缓存文件大小限制 + + Args: + maxsize: 缓存文件大小限制 + ttl: 缓存过期时间 + """ + self.files_cache = TTLCache(maxsize=maxsize, ttl=ttl) + + def computing_time(self, start_time: float) -> str: + """ + 计算耗时 + + Args: + start_time: 开始时间 + + Returns: + 格式化后的耗时字符串,如 "0.123 s" + """ + end_time = time.time() + elapsed_time = end_time - start_time + return f"{elapsed_time:.3f} s" def async_to_sync(func):