Spaces:
Running
Running
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import List, Optional | |
| os.system('pip install modelscope -U') | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| from modelscope.hub.api import HubApi | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class HFToMSConverter: | |
| def __init__(self, config: dict): | |
| self.config = config | |
| self.cache_dir = config.get('cache_dir', "hf2ms_cache") | |
| self.local_dir = config.get('local_dir', "hf2ms_local") | |
| self.hf_api = HfApi(token=config['hf_token']) | |
| self.ms_api = HubApi() | |
| self.ms_api.login(config['ms_token']) | |
| for dir_path in [self.local_dir, self.cache_dir]: | |
| Path(dir_path).mkdir(exist_ok=True) | |
| def get_hf_files(self, repo_id: str, repo_type: str = "dataset") -> List[str]: | |
| """获取HuggingFace仓库文件列表""" | |
| return self.hf_api.list_repo_files(repo_id=repo_id, repo_type=repo_type) | |
| def download_file(self, repo_id: str, filename: str) -> Optional[str]: | |
| """从HuggingFace下载文件""" | |
| save_path = Path(self.local_dir) / filename | |
| if save_path.exists(): | |
| logger.warning(f"文件已存在: {filename}") | |
| return None | |
| try: | |
| self.hf_api.hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=filename, | |
| cache_dir=self.cache_dir, | |
| local_dir=self.local_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"成功下载文件: {filename}") | |
| return str(save_path) | |
| except Exception as e: | |
| logger.error(f"下载失败 {filename}: {e}") | |
| return None | |
| def handle_file_operation(self, operation_type: str, *args) -> bool: | |
| """统一处理文件操作的异常""" | |
| try: | |
| if operation_type == "move": | |
| src, dst = args | |
| dst.parent.mkdir(parents=True, exist_ok=True) | |
| src.rename(dst) | |
| logger.info(f"移动文件成功: {src.name}") | |
| elif operation_type == "push": | |
| ms_repo_id, clone_dir = args | |
| logger.info(f"开始推送文件夹: {clone_dir}") | |
| self.ms_api.upload_folder( | |
| repo_id=f"{ms_repo_id}", | |
| folder_path=str(clone_dir), # 确保路径是字符串 | |
| commit_message='upload dataset folder', | |
| repo_type='dataset' | |
| ) | |
| logger.info(f"推送文件夹成功: {clone_dir}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"{operation_type}操作失败: {e}") | |
| return False | |
| def process_files(self, hf_repo: str, ms_repo: str, files: List[str]) -> bool: | |
| """处理所有文件的完整流程""" | |
| try: | |
| # 获取绝对路径并创建目录 | |
| clone_dir = Path(os.path.abspath('.')) / ms_repo.split("/")[-1] | |
| clone_dir.mkdir(parents=True, exist_ok=True) | |
| # 下载并移动所有文件 | |
| for filename in files: | |
| if not all([ | |
| self.download_file(hf_repo, filename), | |
| self.move_file(filename, str(clone_dir)) | |
| ]): | |
| return False | |
| # 统一推送整个文件夹 | |
| return self.push_to_ms(ms_repo, clone_dir) | |
| except Exception as e: | |
| logger.error(f"处理文件失败: {e}") | |
| return False | |
| finally: | |
| # 清理临时文件 | |
| if clone_dir.exists(): | |
| import shutil | |
| shutil.rmtree(clone_dir) | |
| def move_file(self, filename: str, clone_dir: str) -> bool: | |
| """移动文件到目标目录""" | |
| return self.handle_file_operation( | |
| "move", | |
| Path(self.local_dir) / filename, | |
| Path(clone_dir) / filename | |
| ) | |
| def push_to_ms(self, ms_repo_id: str, clone_dir: str) -> bool: | |
| """推送到ModelScope""" | |
| return self.handle_file_operation("push", ms_repo_id, clone_dir) | |
| def create_ui() -> gr.Blocks: | |
| """创建Gradio界面""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # HuggingFace to ModelScope 数据迁移工具 | |
| 请确保您拥有相应仓库的权限。 | |
| """ | |
| ) | |
| with gr.Row(): | |
| hf_token = gr.Textbox(label="HuggingFace Token") | |
| ms_token = gr.Textbox(label="ModelScope Token") | |
| with gr.Row(): | |
| repo_type = gr.Textbox(label="仓库类型", value="dataset") | |
| hf_repo = gr.Textbox(label="HuggingFace仓库") | |
| ms_repo = gr.Textbox(label="ModelScope仓库") | |
| with gr.Row(): | |
| submit = gr.Button("开始迁移", variant="primary") | |
| clear = gr.Button("清除") | |
| def handle_submit(hf_token, ms_token, repo_type, hf_repo, ms_repo): | |
| config = { | |
| 'hf_token': hf_token, | |
| 'ms_token': ms_token, | |
| 'username': "thomas", | |
| 'email': "[email protected]", | |
| } | |
| converter = HFToMSConverter(config) | |
| files = converter.get_hf_files(hf_repo, repo_type) | |
| converter.process_files(hf_repo, ms_repo, files) | |
| submit.click( | |
| handle_submit, | |
| inputs=[hf_token, ms_token, repo_type, hf_repo, ms_repo], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_ui() | |
| demo.queue(max_size=1) | |
| demo.launch(share=False, max_threads=1) |