Spaces:
Sleeping
Sleeping
File size: 5,698 Bytes
39f5059 cf8d232 39f5059 cf1eda9 39f5059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import tempfile
import gradio as gr
import pypdfium2 as pdfium
from PIL import Image
import torch
from transformers import AutoProcessor, VisionEncoderDecoderModel
# ========= 配置 =========
# 可通过 Space Settings -> Variables 设置这些环境变量
MODEL_ID = os.getenv("MODEL_ID", "facebook/nougat-small") # small: CC-BY-4.0;base: CC-BY-NC-4.0
DEFAULT_DPI = int(os.getenv("DEFAULT_DPI", "144")) # 96~288;越高越清晰但更耗时
MAX_PAGES = int(os.getenv("MAX_PAGES", "20")) # 限制一次处理页数,避免超时
# ========= 模型加载 =========
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
# ========= 工具函数 =========
def rasterize_pages(pdf_bytes: bytes, dpi: int = DEFAULT_DPI):
"""
将 PDF bytes 渲染为 PIL.Image 列表(每页一张)。
说明:pypdfium2 的 Page.render(scale=...) 返回位图;dpi/72 为常用缩放方式。
"""
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp:
tmp.write(pdf_bytes)
tmp.flush()
doc = pdfium.PdfDocument(tmp.name)
images = []
for i in range(len(doc)):
page = doc.get_page(i)
bitmap = page.render(scale=dpi/72.0) # dpi/72 缩放
img = bitmap.to_pil().convert("RGB")
bitmap.close()
page.close()
images.append(img)
doc.close()
return images
def parse_pages_arg(pages_str: str, n_pages: int):
"""
解析页码字符串:如 '1-4,7' 或 'all'
返回 0-based 下标列表。
"""
if not pages_str or pages_str.strip().lower() == "all":
return list(range(n_pages))
keep = []
for span in pages_str.split(","):
span = span.strip()
if "-" in span:
a, b = span.split("-")
a = max(1, int(a)); b = min(n_pages, int(b))
keep.extend(list(range(a-1, b)))
else:
k = int(span) - 1
if 0 <= k < n_pages:
keep.append(k)
return sorted(set(keep))
# ========= 核心推理函数(UI 与 API 共用) =========
def convert_pdf(pdf_file, pages="all", dpi=DEFAULT_DPI):
"""
输入:
- pdf_file: Gradio File(浏览器上传的 PDF)
- pages: 'all' 或 '1-4,7'
- dpi: 渲染 DPI
输出:
- out_path: 生成的 .mmd 文件路径(供下载)
- preview: Markdown 预览(前几页)
"""
if pdf_file is None:
raise gr.Error("请上传 PDF 文件")
# 读取 PDF bytes 并兼容 gradio 不同版本的返回值(file-like / dict / 有 name 属性 / 路径)
pdf_bytes = None
if hasattr(pdf_file, "read"): # file-like object
pdf_bytes = pdf_file.read()
elif isinstance(pdf_file, dict) and "name" in pdf_file: # gradio 有时返回 dict {'name': path}
with open(pdf_file["name"], "rb") as f:
pdf_bytes = f.read()
elif hasattr(pdf_file, "name") and isinstance(pdf_file.name, str): # object with .name path
with open(pdf_file.name, "rb") as f:
pdf_bytes = f.read()
else:
# 兜底:尝试把对象当作路径处理
try:
p = str(pdf_file)
if os.path.exists(p):
with open(p, "rb") as f:
pdf_bytes = f.read()
except Exception:
pdf_bytes = None
if pdf_bytes is None:
raise gr.Error("无法读取上传的 PDF 文件(不支持的文件对象类型)")
# 渲染 PDF 为图像
images_all = rasterize_pages(pdf_bytes, dpi=int(dpi))
# 页码选择与限制
idx = parse_pages_arg(pages, len(images_all))
if not idx:
raise gr.Error("页码选择为空")
if len(idx) > MAX_PAGES:
idx = idx[:MAX_PAGES]
# 逐页调用 Nougat 模型生成 Markdown
md_pages = []
for k in idx:
img = images_all[k]
inputs = processor(images=[img], return_tensors="pt").to(device)
ids = model.generate(**inputs, max_length=4096)
md = processor.batch_decode(ids, skip_special_tokens=True)[0]
md_pages.append(md)
# 保存到临时 .mmd 文件
out_path = os.path.join(tempfile.gettempdir(), "nougat_output.mmd")
with open(out_path, "w", encoding="utf-8") as f:
f.write("\n\n".join(md_pages))
# 预览(前若干页)
preview = "\n\n".join(md_pages[:3])
return out_path, preview
# ========= Gradio 应用(UI + API) =========
with gr.Blocks(title="Nougat OCR → Markdown") as demo:
gr.Markdown(
"# Nougat:PDF → Markdown\n"
f"**模型**:`{MODEL_ID}` (small 为 CC‑BY‑4.0;base 为 CC‑BY‑NC‑4.0)。\n"
"上传 PDF,选择页码与 DPI,点击转换即可下载 `.mmd`。\n"
)
with gr.Row():
pdf = gr.File(label="上传 PDF", file_types=[".pdf"])
pages = gr.Textbox(value="all", label="页码(如 1-4,7 或 all)")
dpi = gr.Slider(96, 288, value=DEFAULT_DPI, step=12, label="渲染 DPI")
btn = gr.Button("转换", variant="primary")
out_file = gr.File(label="下载 Markdown(.mmd)")
out_preview = gr.Markdown(label="预览(前几页)")
# 队列可避免并发拥堵;也使 API 端口支持异步排队
demo.queue(max_size=32)
# 关键:为点击事件绑定一个可供 REST 调用的 api_name(对应 /api/predict)
btn.click(convert_pdf, inputs=[pdf, pages, dpi], outputs=[out_file, out_preview], api_name="predict")
# (本地调试用;在 Spaces 中无需)
if __name__ == "__main__":
demo.launch() |