Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import gradio as gr | |
| import pypdfium2 as pdfium | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoProcessor, VisionEncoderDecoderModel | |
| # ========= 配置 ========= | |
| # 可通过 Space Settings -> Variables 设置这些环境变量 | |
| MODEL_ID = os.getenv("MODEL_ID", "facebook/nougat-small") # small: CC-BY-4.0;base: CC-BY-NC-4.0 | |
| DEFAULT_DPI = int(os.getenv("DEFAULT_DPI", "144")) # 96~288;越高越清晰但更耗时 | |
| MAX_PAGES = int(os.getenv("MAX_PAGES", "20")) # 限制一次处理页数,避免超时 | |
| # ========= 模型加载 ========= | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device) | |
| # ========= 工具函数 ========= | |
| def rasterize_pages(pdf_bytes: bytes, dpi: int = DEFAULT_DPI): | |
| """ | |
| 将 PDF bytes 渲染为 PIL.Image 列表(每页一张)。 | |
| 说明:pypdfium2 的 Page.render(scale=...) 返回位图;dpi/72 为常用缩放方式。 | |
| """ | |
| with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp: | |
| tmp.write(pdf_bytes) | |
| tmp.flush() | |
| doc = pdfium.PdfDocument(tmp.name) | |
| images = [] | |
| for i in range(len(doc)): | |
| page = doc.get_page(i) | |
| bitmap = page.render(scale=dpi/72.0) # dpi/72 缩放 | |
| img = bitmap.to_pil().convert("RGB") | |
| bitmap.close() | |
| page.close() | |
| images.append(img) | |
| doc.close() | |
| return images | |
| def parse_pages_arg(pages_str: str, n_pages: int): | |
| """ | |
| 解析页码字符串:如 '1-4,7' 或 'all' | |
| 返回 0-based 下标列表。 | |
| """ | |
| if not pages_str or pages_str.strip().lower() == "all": | |
| return list(range(n_pages)) | |
| keep = [] | |
| for span in pages_str.split(","): | |
| span = span.strip() | |
| if "-" in span: | |
| a, b = span.split("-") | |
| a = max(1, int(a)); b = min(n_pages, int(b)) | |
| keep.extend(list(range(a-1, b))) | |
| else: | |
| k = int(span) - 1 | |
| if 0 <= k < n_pages: | |
| keep.append(k) | |
| return sorted(set(keep)) | |
| # ========= 核心推理函数(UI 与 API 共用) ========= | |
| def convert_pdf(pdf_file, pages="all", dpi=DEFAULT_DPI): | |
| """ | |
| 输入: | |
| - pdf_file: Gradio File(浏览器上传的 PDF) | |
| - pages: 'all' 或 '1-4,7' | |
| - dpi: 渲染 DPI | |
| 输出: | |
| - out_path: 生成的 .mmd 文件路径(供下载) | |
| - preview: Markdown 预览(前几页) | |
| """ | |
| if pdf_file is None: | |
| raise gr.Error("请上传 PDF 文件") | |
| # 读取 PDF bytes 并兼容 gradio 不同版本的返回值(file-like / dict / 有 name 属性 / 路径) | |
| pdf_bytes = None | |
| if hasattr(pdf_file, "read"): # file-like object | |
| pdf_bytes = pdf_file.read() | |
| elif isinstance(pdf_file, dict) and "name" in pdf_file: # gradio 有时返回 dict {'name': path} | |
| with open(pdf_file["name"], "rb") as f: | |
| pdf_bytes = f.read() | |
| elif hasattr(pdf_file, "name") and isinstance(pdf_file.name, str): # object with .name path | |
| with open(pdf_file.name, "rb") as f: | |
| pdf_bytes = f.read() | |
| else: | |
| # 兜底:尝试把对象当作路径处理 | |
| try: | |
| p = str(pdf_file) | |
| if os.path.exists(p): | |
| with open(p, "rb") as f: | |
| pdf_bytes = f.read() | |
| except Exception: | |
| pdf_bytes = None | |
| if pdf_bytes is None: | |
| raise gr.Error("无法读取上传的 PDF 文件(不支持的文件对象类型)") | |
| # 渲染 PDF 为图像 | |
| images_all = rasterize_pages(pdf_bytes, dpi=int(dpi)) | |
| # 页码选择与限制 | |
| idx = parse_pages_arg(pages, len(images_all)) | |
| if not idx: | |
| raise gr.Error("页码选择为空") | |
| if len(idx) > MAX_PAGES: | |
| idx = idx[:MAX_PAGES] | |
| # 逐页调用 Nougat 模型生成 Markdown | |
| md_pages = [] | |
| for k in idx: | |
| img = images_all[k] | |
| inputs = processor(images=[img], return_tensors="pt").to(device) | |
| ids = model.generate(**inputs, max_length=4096) | |
| md = processor.batch_decode(ids, skip_special_tokens=True)[0] | |
| md_pages.append(md) | |
| # 保存到临时 .mmd 文件 | |
| out_path = os.path.join(tempfile.gettempdir(), "nougat_output.mmd") | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| f.write("\n\n".join(md_pages)) | |
| # 预览(前若干页) | |
| preview = "\n\n".join(md_pages[:3]) | |
| return out_path, preview | |
| # ========= Gradio 应用(UI + API) ========= | |
| with gr.Blocks(title="Nougat OCR → Markdown") as demo: | |
| gr.Markdown( | |
| "# Nougat:PDF → Markdown\n" | |
| f"**模型**:`{MODEL_ID}` (small 为 CC‑BY‑4.0;base 为 CC‑BY‑NC‑4.0)。\n" | |
| "上传 PDF,选择页码与 DPI,点击转换即可下载 `.mmd`。\n" | |
| ) | |
| with gr.Row(): | |
| pdf = gr.File(label="上传 PDF", file_types=[".pdf"]) | |
| pages = gr.Textbox(value="all", label="页码(如 1-4,7 或 all)") | |
| dpi = gr.Slider(96, 288, value=DEFAULT_DPI, step=12, label="渲染 DPI") | |
| btn = gr.Button("转换", variant="primary") | |
| out_file = gr.File(label="下载 Markdown(.mmd)") | |
| out_preview = gr.Markdown(label="预览(前几页)") | |
| # 队列可避免并发拥堵;也使 API 端口支持异步排队 | |
| demo.queue(max_size=32) | |
| # 关键:为点击事件绑定一个可供 REST 调用的 api_name(对应 /api/predict) | |
| btn.click(convert_pdf, inputs=[pdf, pages, dpi], outputs=[out_file, out_preview], api_name="predict") | |
| # (本地调试用;在 Spaces 中无需) | |
| if __name__ == "__main__": | |
| demo.launch() |