Nougat_deploy / app.py
NeuroDong
Update app.py
cf8d232
import os
import tempfile
import gradio as gr
import pypdfium2 as pdfium
from PIL import Image
import torch
from transformers import AutoProcessor, VisionEncoderDecoderModel
# ========= 配置 =========
# 可通过 Space Settings -> Variables 设置这些环境变量
MODEL_ID = os.getenv("MODEL_ID", "facebook/nougat-small") # small: CC-BY-4.0;base: CC-BY-NC-4.0
DEFAULT_DPI = int(os.getenv("DEFAULT_DPI", "144")) # 96~288;越高越清晰但更耗时
MAX_PAGES = int(os.getenv("MAX_PAGES", "20")) # 限制一次处理页数,避免超时
# ========= 模型加载 =========
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
# ========= 工具函数 =========
def rasterize_pages(pdf_bytes: bytes, dpi: int = DEFAULT_DPI):
"""
将 PDF bytes 渲染为 PIL.Image 列表(每页一张)。
说明:pypdfium2 的 Page.render(scale=...) 返回位图;dpi/72 为常用缩放方式。
"""
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp:
tmp.write(pdf_bytes)
tmp.flush()
doc = pdfium.PdfDocument(tmp.name)
images = []
for i in range(len(doc)):
page = doc.get_page(i)
bitmap = page.render(scale=dpi/72.0) # dpi/72 缩放
img = bitmap.to_pil().convert("RGB")
bitmap.close()
page.close()
images.append(img)
doc.close()
return images
def parse_pages_arg(pages_str: str, n_pages: int):
"""
解析页码字符串:如 '1-4,7' 或 'all'
返回 0-based 下标列表。
"""
if not pages_str or pages_str.strip().lower() == "all":
return list(range(n_pages))
keep = []
for span in pages_str.split(","):
span = span.strip()
if "-" in span:
a, b = span.split("-")
a = max(1, int(a)); b = min(n_pages, int(b))
keep.extend(list(range(a-1, b)))
else:
k = int(span) - 1
if 0 <= k < n_pages:
keep.append(k)
return sorted(set(keep))
# ========= 核心推理函数(UI 与 API 共用) =========
def convert_pdf(pdf_file, pages="all", dpi=DEFAULT_DPI):
"""
输入:
- pdf_file: Gradio File(浏览器上传的 PDF)
- pages: 'all' 或 '1-4,7'
- dpi: 渲染 DPI
输出:
- out_path: 生成的 .mmd 文件路径(供下载)
- preview: Markdown 预览(前几页)
"""
if pdf_file is None:
raise gr.Error("请上传 PDF 文件")
# 读取 PDF bytes 并兼容 gradio 不同版本的返回值(file-like / dict / 有 name 属性 / 路径)
pdf_bytes = None
if hasattr(pdf_file, "read"): # file-like object
pdf_bytes = pdf_file.read()
elif isinstance(pdf_file, dict) and "name" in pdf_file: # gradio 有时返回 dict {'name': path}
with open(pdf_file["name"], "rb") as f:
pdf_bytes = f.read()
elif hasattr(pdf_file, "name") and isinstance(pdf_file.name, str): # object with .name path
with open(pdf_file.name, "rb") as f:
pdf_bytes = f.read()
else:
# 兜底:尝试把对象当作路径处理
try:
p = str(pdf_file)
if os.path.exists(p):
with open(p, "rb") as f:
pdf_bytes = f.read()
except Exception:
pdf_bytes = None
if pdf_bytes is None:
raise gr.Error("无法读取上传的 PDF 文件(不支持的文件对象类型)")
# 渲染 PDF 为图像
images_all = rasterize_pages(pdf_bytes, dpi=int(dpi))
# 页码选择与限制
idx = parse_pages_arg(pages, len(images_all))
if not idx:
raise gr.Error("页码选择为空")
if len(idx) > MAX_PAGES:
idx = idx[:MAX_PAGES]
# 逐页调用 Nougat 模型生成 Markdown
md_pages = []
for k in idx:
img = images_all[k]
inputs = processor(images=[img], return_tensors="pt").to(device)
ids = model.generate(**inputs, max_length=4096)
md = processor.batch_decode(ids, skip_special_tokens=True)[0]
md_pages.append(md)
# 保存到临时 .mmd 文件
out_path = os.path.join(tempfile.gettempdir(), "nougat_output.mmd")
with open(out_path, "w", encoding="utf-8") as f:
f.write("\n\n".join(md_pages))
# 预览(前若干页)
preview = "\n\n".join(md_pages[:3])
return out_path, preview
# ========= Gradio 应用(UI + API) =========
with gr.Blocks(title="Nougat OCR → Markdown") as demo:
gr.Markdown(
"# Nougat:PDF → Markdown\n"
f"**模型**:`{MODEL_ID}` (small 为 CC‑BY‑4.0;base 为 CC‑BY‑NC‑4.0)。\n"
"上传 PDF,选择页码与 DPI,点击转换即可下载 `.mmd`。\n"
)
with gr.Row():
pdf = gr.File(label="上传 PDF", file_types=[".pdf"])
pages = gr.Textbox(value="all", label="页码(如 1-4,7 或 all)")
dpi = gr.Slider(96, 288, value=DEFAULT_DPI, step=12, label="渲染 DPI")
btn = gr.Button("转换", variant="primary")
out_file = gr.File(label="下载 Markdown(.mmd)")
out_preview = gr.Markdown(label="预览(前几页)")
# 队列可避免并发拥堵;也使 API 端口支持异步排队
demo.queue(max_size=32)
# 关键:为点击事件绑定一个可供 REST 调用的 api_name(对应 /api/predict)
btn.click(convert_pdf, inputs=[pdf, pages, dpi], outputs=[out_file, out_preview], api_name="predict")
# (本地调试用;在 Spaces 中无需)
if __name__ == "__main__":
demo.launch()