File size: 5,698 Bytes
39f5059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf8d232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39f5059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf1eda9
39f5059
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import tempfile
import gradio as gr
import pypdfium2 as pdfium
from PIL import Image
import torch
from transformers import AutoProcessor, VisionEncoderDecoderModel

# ========= 配置 =========
# 可通过 Space Settings -> Variables 设置这些环境变量
MODEL_ID = os.getenv("MODEL_ID", "facebook/nougat-small")  # small: CC-BY-4.0;base: CC-BY-NC-4.0
DEFAULT_DPI = int(os.getenv("DEFAULT_DPI", "144"))         # 96~288;越高越清晰但更耗时
MAX_PAGES   = int(os.getenv("MAX_PAGES", "20"))            # 限制一次处理页数,避免超时

# ========= 模型加载 =========
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)

# ========= 工具函数 =========
def rasterize_pages(pdf_bytes: bytes, dpi: int = DEFAULT_DPI):
    """
    将 PDF bytes 渲染为 PIL.Image 列表(每页一张)。
    说明:pypdfium2 的 Page.render(scale=...) 返回位图;dpi/72 为常用缩放方式。
    """
    with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp:
        tmp.write(pdf_bytes)
        tmp.flush()
        doc = pdfium.PdfDocument(tmp.name)
        images = []
        for i in range(len(doc)):
            page = doc.get_page(i)
            bitmap = page.render(scale=dpi/72.0)  # dpi/72 缩放
            img = bitmap.to_pil().convert("RGB")
            bitmap.close()
            page.close()
            images.append(img)
        doc.close()
    return images

def parse_pages_arg(pages_str: str, n_pages: int):
    """
    解析页码字符串:如 '1-4,7' 或 'all'
    返回 0-based 下标列表。
    """
    if not pages_str or pages_str.strip().lower() == "all":
        return list(range(n_pages))
    keep = []
    for span in pages_str.split(","):
        span = span.strip()
        if "-" in span:
            a, b = span.split("-")
            a = max(1, int(a)); b = min(n_pages, int(b))
            keep.extend(list(range(a-1, b)))
        else:
            k = int(span) - 1
            if 0 <= k < n_pages:
                keep.append(k)
    return sorted(set(keep))

# ========= 核心推理函数(UI 与 API 共用) =========
def convert_pdf(pdf_file, pages="all", dpi=DEFAULT_DPI):
    """
    输入:
      - pdf_file: Gradio File(浏览器上传的 PDF)
      - pages: 'all' 或 '1-4,7'
      - dpi: 渲染 DPI
    输出:
      - out_path: 生成的 .mmd 文件路径(供下载)
      - preview: Markdown 预览(前几页)
    """
    if pdf_file is None:
        raise gr.Error("请上传 PDF 文件")

    # 读取 PDF bytes 并兼容 gradio 不同版本的返回值(file-like / dict / 有 name 属性 / 路径)
    pdf_bytes = None
    if hasattr(pdf_file, "read"):  # file-like object
        pdf_bytes = pdf_file.read()
    elif isinstance(pdf_file, dict) and "name" in pdf_file:  # gradio 有时返回 dict {'name': path}
        with open(pdf_file["name"], "rb") as f:
            pdf_bytes = f.read()
    elif hasattr(pdf_file, "name") and isinstance(pdf_file.name, str):  # object with .name path
        with open(pdf_file.name, "rb") as f:
            pdf_bytes = f.read()
    else:
        # 兜底:尝试把对象当作路径处理
        try:
            p = str(pdf_file)
            if os.path.exists(p):
                with open(p, "rb") as f:
                    pdf_bytes = f.read()
        except Exception:
            pdf_bytes = None

    if pdf_bytes is None:
        raise gr.Error("无法读取上传的 PDF 文件(不支持的文件对象类型)")

    # 渲染 PDF 为图像
    images_all = rasterize_pages(pdf_bytes, dpi=int(dpi))

    # 页码选择与限制
    idx = parse_pages_arg(pages, len(images_all))
    if not idx:
        raise gr.Error("页码选择为空")
    if len(idx) > MAX_PAGES:
        idx = idx[:MAX_PAGES]

    # 逐页调用 Nougat 模型生成 Markdown
    md_pages = []
    for k in idx:
        img = images_all[k]
        inputs = processor(images=[img], return_tensors="pt").to(device)
        ids = model.generate(**inputs, max_length=4096)
        md = processor.batch_decode(ids, skip_special_tokens=True)[0]
        md_pages.append(md)

    # 保存到临时 .mmd 文件
    out_path = os.path.join(tempfile.gettempdir(), "nougat_output.mmd")
    with open(out_path, "w", encoding="utf-8") as f:
        f.write("\n\n".join(md_pages))

    # 预览(前若干页)
    preview = "\n\n".join(md_pages[:3])
    return out_path, preview

# ========= Gradio 应用(UI + API) =========
with gr.Blocks(title="Nougat OCR → Markdown") as demo:
    gr.Markdown(
        "# Nougat:PDF → Markdown\n"
        f"**模型**:`{MODEL_ID}` (small 为 CC‑BY‑4.0;base 为 CC‑BY‑NC‑4.0)。\n"
        "上传 PDF,选择页码与 DPI,点击转换即可下载 `.mmd`。\n"
    )
    with gr.Row():
        pdf = gr.File(label="上传 PDF", file_types=[".pdf"])
        pages = gr.Textbox(value="all", label="页码(如 1-4,7 或 all)")
        dpi = gr.Slider(96, 288, value=DEFAULT_DPI, step=12, label="渲染 DPI")
    btn = gr.Button("转换", variant="primary")
    out_file = gr.File(label="下载 Markdown(.mmd)")
    out_preview = gr.Markdown(label="预览(前几页)")

    # 队列可避免并发拥堵;也使 API 端口支持异步排队
    demo.queue(max_size=32)

    # 关键:为点击事件绑定一个可供 REST 调用的 api_name(对应 /api/predict)
    btn.click(convert_pdf, inputs=[pdf, pages, dpi], outputs=[out_file, out_preview], api_name="predict")

# (本地调试用;在 Spaces 中无需)
if __name__ == "__main__":
    demo.launch()