Commit
5840a14
·
verified ·
1 Parent(s): e9c84ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -128
app.py CHANGED
@@ -1,136 +1,107 @@
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
- import os
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
- import spaces
6
 
7
- # Load Hugging Face token from the environment variable
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
- if HF_TOKEN is None:
10
- raise ValueError("HF_TOKEN environment variable is not set. Please set it before running the script.")
11
 
12
- # Check for GPU support and configure appropriately
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- zero = torch.Tensor([0]).to(device)
15
- print(f"Device being used: {zero.device}")
16
-
17
- # Model configurations
18
- MSA_TO_SYRIAN_MODEL = "Omartificial-Intelligence-Space/Shami-MT"
19
- SYRIAN_TO_MSA_MODEL = "Omartificial-Intelligence-Space/SHAMI-MT-2MSA"
20
-
21
- # Load models and tokenizers
22
- print("Loading MSA to Syrian model...")
23
- msa_to_syrian_tokenizer = AutoTokenizer.from_pretrained(MSA_TO_SYRIAN_MODEL)
24
- msa_to_syrian_model = AutoModelForSeq2SeqLM.from_pretrained(MSA_TO_SYRIAN_MODEL).to(device)
25
-
26
- print("Loading Syrian to MSA model...")
27
- syrian_to_msa_tokenizer = AutoTokenizer.from_pretrained(SYRIAN_TO_MSA_MODEL)
28
- syrian_to_msa_model = AutoModelForSeq2SeqLM.from_pretrained(SYRIAN_TO_MSA_MODEL).to(device)
29
-
30
- print("Models loaded successfully!")
31
-
32
- @spaces.GPU(duration=120)
33
- def translate_msa_to_syrian(text):
34
- """Translate from Modern Standard Arabic to Syrian dialect"""
35
- if not text.strip():
36
- return ""
37
-
38
- try:
39
- input_ids = msa_to_syrian_tokenizer(text, return_tensors="pt").input_ids.to(device)
40
- outputs = msa_to_syrian_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True)
41
- translated_text = msa_to_syrian_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
- return translated_text
43
- except Exception as e:
44
- return f"Translation error: {str(e)}"
45
-
46
- @spaces.GPU(duration=120)
47
- def translate_syrian_to_msa(text):
48
- """Translate from Syrian dialect to Modern Standard Arabic"""
49
- if not text.strip():
50
- return ""
51
-
52
- try:
53
- input_ids = syrian_to_msa_tokenizer(text, return_tensors="pt").input_ids.to(device)
54
- outputs = syrian_to_msa_model.generate(input_ids, max_length=128, num_beams=5, early_stopping=True)
55
- translated_text = syrian_to_msa_tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- return translated_text
57
- except Exception as e:
58
- return f"Translation error: {str(e)}"
59
-
60
- def bidirectional_translate(text, direction):
61
- """Handle bidirectional translation based on user selection"""
62
- if direction == "MSA → Syrian":
63
- return translate_msa_to_syrian(text)
64
- elif direction == "Syrian → MSA":
65
- return translate_syrian_to_msa(text)
66
- else:
67
- return "Please select a translation direction"
68
-
69
- # Create Gradio interface
70
- with gr.Blocks(title="SHAMI-MT: Bidirectional Syria Arabic Dialect MT Framework") as demo:
71
-
72
- gr.HTML("""
73
- <div style="text-align: center; margin-bottom: 2rem;">
74
- <h1>🌍 SHAMI-MT: Bidirectional Arabic Translation</h1>
75
- <p>Translate between Modern Standard Arabic (MSA) and Syrian Dialect</p>
76
- <p><strong>Built on AraT5v2-base-1024 architecture</strong></p>
77
- </div>
78
- """)
79
-
80
- with gr.Row():
81
- with gr.Column(scale=1):
82
- gr.HTML("""
83
- <div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; margin: 1rem 0;">
84
- <h3>📚 Model Information</h3>
85
- <ul>
86
- <li><strong>Model Type:</strong> Sequence-to-Sequence Translation</li>
87
- <li><strong>Base Model:</strong> UBC-NLP/AraT5v2-base-1024</li>
88
- <li><strong>Languages:</strong> Arabic (MSA ↔ Syrian Dialect)</li>
89
- <li><strong>Device:</strong> GPU/CPU Auto-detection</li>
90
- </ul>
91
- </div>
92
- """)
93
-
94
- with gr.Column(scale=2):
95
- direction = gr.Dropdown(
96
- choices=["MSA → Syrian", "Syrian → MSA"],
97
- value="MSA → Syrian",
98
- label="Translation Direction"
99
- )
100
-
101
- input_text = gr.Textbox(
102
- label="Input Text",
103
- placeholder="Enter Arabic text here...",
104
- lines=5
105
- )
106
-
107
- translate_btn = gr.Button("🚀 Translate", variant="primary")
108
-
109
- output_text = gr.Textbox(
110
- label="Translation",
111
- lines=5
112
- )
113
-
114
- # Connect the interface
115
- translate_btn.click(
116
- fn=bidirectional_translate,
117
- inputs=[input_text, direction],
118
- outputs=output_text
119
- )
120
-
121
- # Add example inputs
122
- gr.Examples(
123
- examples=[
124
- ["أنا لا أعرف إذا كان سيتمكن من الحضور اليوم أم لا.", "MSA → Syrian"],
125
- ["كيف حالك؟", "MSA → Syrian"],
126
- ["ما بعرف إذا رح يقدر يجي اليوم ولا لأ.", "Syrian → MSA"],
127
- ["شلونك؟", "Syrian → MSA"]
128
- ],
129
- inputs=[input_text, direction],
130
- outputs=output_text,
131
- fn=bidirectional_translate
132
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Launch the app
135
  if __name__ == "__main__":
136
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
+ import os
2
+ import re
3
+ import json
4
  import gradio as gr
5
+ from PIL import Image
6
+
7
  import torch
8
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
9
+ import spaces # 👈 Hugging Face ZeroGPU
 
10
 
11
+ MODEL_NAME = os.environ.get("MODEL_NAME", "NAMAA-Space/Qari-OCR-0.1-VL-2B-Instruct")
12
+ MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "1024"))
 
 
13
 
14
+ # ---- Device selection ----
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ dtype = torch.float16 if device == "cuda" else torch.float32
17
+
18
+ print(f"Device being used: {device}")
19
+
20
+ # ---- Load model & processor ----
21
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
22
+
23
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
24
+ MODEL_NAME,
25
+ torch_dtype=dtype,
26
+ device_map="auto" if device == "cuda" else {"": "cpu"},
27
+ )
28
+
29
+ print("Model loaded successfully!")
30
+
31
+ def _mk_messages(image_path: str, prompt_info: str):
32
+ return [
33
+ {
34
+ "role": "user",
35
+ "content": [
36
+ {"type": "image", "image": f"file://{image_path}"},
37
+ {"type": "text", "text": f"""
38
+ You are an advanced invoice OCR system...
39
+ (extractions instructions same as notebook)
40
+ Extra hints from user: {prompt_info}
41
+ """.strip()},
42
+ ],
43
+ }
44
+ ]
45
+
46
+ def _extract_json(text: str):
47
+ text = text.strip()
48
+ if text.startswith("{") and text.endswith("}"):
49
+ try:
50
+ return json.loads(text)
51
+ except Exception:
52
+ pass
53
+ m = re.search(r"\{[\s\S]*\}", text)
54
+ if m:
55
+ block = m.group(0)
56
+ try:
57
+ return json.loads(block)
58
+ except Exception:
59
+ pass
60
+ return {"other_text": text}
61
+
62
+ @spaces.GPU(duration=120) # 👈 Request ZeroGPU for 2 minutes
63
+ def infer(image: Image.Image, prompt_info: str):
64
+ if image is None:
65
+ return "Please upload an image.", {}
66
+
67
+ tmp_path = "input_image.png"
68
+ image.save(tmp_path)
69
+
70
+ messages = _mk_messages(tmp_path, prompt_info)
71
+ chat_text = processor.apply_chat_template(messages, add_generation_prompt=True)
72
+
73
+ inputs = processor(
74
+ text=[chat_text],
75
+ images=[Image.open(tmp_path)],
76
+ return_tensors="pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
+ inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
79
+
80
+ with torch.no_grad():
81
+ generated_ids = model.generate(
82
+ **inputs,
83
+ max_new_tokens=MAX_NEW_TOKENS,
84
+ do_sample=False,
85
+ )
86
+
87
+ gen_only = generated_ids[:, inputs["input_ids"].shape[1]:]
88
+ text_out = processor.batch_decode(gen_only, skip_special_tokens=True)[0].strip()
89
+ parsed = _extract_json(text_out)
90
+
91
+ return text_out, parsed
92
+
93
+ with gr.Blocks(title="Qari OCR (ZeroGPU)") as demo:
94
+ gr.Markdown("# Qari OCR · ZeroGPU\nUpload an invoice image and (optionally) add extraction hints.")
95
+ with gr.Row():
96
+ with gr.Column():
97
+ img_in = gr.Image(type="pil", label="Invoice Image")
98
+ prompt_box = gr.Textbox(label="Extra hints (optional)")
99
+ run_btn = gr.Button("Run OCR")
100
+ with gr.Column():
101
+ txt_out = gr.Textbox(label="Raw Model Output", lines=10)
102
+ json_out = gr.JSON(label="Parsed JSON")
103
+
104
+ run_btn.click(infer, inputs=[img_in, prompt_box], outputs=[txt_out, json_out])
105
 
 
106
  if __name__ == "__main__":
107
+ demo.launch()