import os import tempfile import gradio as gr import pypdfium2 as pdfium from PIL import Image import torch from transformers import AutoProcessor, VisionEncoderDecoderModel # ========= 配置 ========= # 可通过 Space Settings -> Variables 设置这些环境变量 MODEL_ID = os.getenv("MODEL_ID", "facebook/nougat-small") # small: CC-BY-4.0;base: CC-BY-NC-4.0 DEFAULT_DPI = int(os.getenv("DEFAULT_DPI", "144")) # 96~288;越高越清晰但更耗时 MAX_PAGES = int(os.getenv("MAX_PAGES", "20")) # 限制一次处理页数,避免超时 # ========= 模型加载 ========= device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(MODEL_ID) model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device) # ========= 工具函数 ========= def rasterize_pages(pdf_bytes: bytes, dpi: int = DEFAULT_DPI): """ 将 PDF bytes 渲染为 PIL.Image 列表(每页一张)。 说明:pypdfium2 的 Page.render(scale=...) 返回位图;dpi/72 为常用缩放方式。 """ with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp: tmp.write(pdf_bytes) tmp.flush() doc = pdfium.PdfDocument(tmp.name) images = [] for i in range(len(doc)): page = doc.get_page(i) bitmap = page.render(scale=dpi/72.0) # dpi/72 缩放 img = bitmap.to_pil().convert("RGB") bitmap.close() page.close() images.append(img) doc.close() return images def parse_pages_arg(pages_str: str, n_pages: int): """ 解析页码字符串:如 '1-4,7' 或 'all' 返回 0-based 下标列表。 """ if not pages_str or pages_str.strip().lower() == "all": return list(range(n_pages)) keep = [] for span in pages_str.split(","): span = span.strip() if "-" in span: a, b = span.split("-") a = max(1, int(a)); b = min(n_pages, int(b)) keep.extend(list(range(a-1, b))) else: k = int(span) - 1 if 0 <= k < n_pages: keep.append(k) return sorted(set(keep)) # ========= 核心推理函数(UI 与 API 共用) ========= def convert_pdf(pdf_file, pages="all", dpi=DEFAULT_DPI): """ 输入: - pdf_file: Gradio File(浏览器上传的 PDF) - pages: 'all' 或 '1-4,7' - dpi: 渲染 DPI 输出: - out_path: 生成的 .mmd 文件路径(供下载) - preview: Markdown 预览(前几页) """ if pdf_file is None: raise gr.Error("请上传 PDF 文件") # 读取 PDF bytes 并兼容 gradio 不同版本的返回值(file-like / dict / 有 name 属性 / 路径) pdf_bytes = None if hasattr(pdf_file, "read"): # file-like object pdf_bytes = pdf_file.read() elif isinstance(pdf_file, dict) and "name" in pdf_file: # gradio 有时返回 dict {'name': path} with open(pdf_file["name"], "rb") as f: pdf_bytes = f.read() elif hasattr(pdf_file, "name") and isinstance(pdf_file.name, str): # object with .name path with open(pdf_file.name, "rb") as f: pdf_bytes = f.read() else: # 兜底:尝试把对象当作路径处理 try: p = str(pdf_file) if os.path.exists(p): with open(p, "rb") as f: pdf_bytes = f.read() except Exception: pdf_bytes = None if pdf_bytes is None: raise gr.Error("无法读取上传的 PDF 文件(不支持的文件对象类型)") # 渲染 PDF 为图像 images_all = rasterize_pages(pdf_bytes, dpi=int(dpi)) # 页码选择与限制 idx = parse_pages_arg(pages, len(images_all)) if not idx: raise gr.Error("页码选择为空") if len(idx) > MAX_PAGES: idx = idx[:MAX_PAGES] # 逐页调用 Nougat 模型生成 Markdown md_pages = [] for k in idx: img = images_all[k] inputs = processor(images=[img], return_tensors="pt").to(device) ids = model.generate(**inputs, max_length=4096) md = processor.batch_decode(ids, skip_special_tokens=True)[0] md_pages.append(md) # 保存到临时 .mmd 文件 out_path = os.path.join(tempfile.gettempdir(), "nougat_output.mmd") with open(out_path, "w", encoding="utf-8") as f: f.write("\n\n".join(md_pages)) # 预览(前若干页) preview = "\n\n".join(md_pages[:3]) return out_path, preview # ========= Gradio 应用(UI + API) ========= with gr.Blocks(title="Nougat OCR → Markdown") as demo: gr.Markdown( "# Nougat:PDF → Markdown\n" f"**模型**:`{MODEL_ID}` (small 为 CC‑BY‑4.0;base 为 CC‑BY‑NC‑4.0)。\n" "上传 PDF,选择页码与 DPI,点击转换即可下载 `.mmd`。\n" ) with gr.Row(): pdf = gr.File(label="上传 PDF", file_types=[".pdf"]) pages = gr.Textbox(value="all", label="页码(如 1-4,7 或 all)") dpi = gr.Slider(96, 288, value=DEFAULT_DPI, step=12, label="渲染 DPI") btn = gr.Button("转换", variant="primary") out_file = gr.File(label="下载 Markdown(.mmd)") out_preview = gr.Markdown(label="预览(前几页)") # 队列可避免并发拥堵;也使 API 端口支持异步排队 demo.queue(max_size=32) # 关键:为点击事件绑定一个可供 REST 调用的 api_name(对应 /api/predict) btn.click(convert_pdf, inputs=[pdf, pages, dpi], outputs=[out_file, out_preview], api_name="predict") # (本地调试用;在 Spaces 中无需) if __name__ == "__main__": demo.launch()