# ========================================================== # 1. IMPORTS # ========================================================== import torch import gc import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import os # ========================================================== # 2. MODEL LOADING # ========================================================== print("Server starting up, loading CheXagent model...") gc.collect() # Load the model and tokenizer ONCE, globally model_name = "StanfordAIMI/CheXagent-2-3b" dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True) model = model.to(dtype).eval() print("✅ CheXagent model is hot-loaded and ready!") # ========================================================== # 3. DEFINE THE INFERENCE FUNCTION # ========================================================== def generate_report(image_input, question_input): """ Takes an image and a question, runs inference, and returns text. - image_input: A NumPy array (Gradio will handle this) - question_input: A string """ print(f"\n[INFO] Received request with question: {question_input}") # --- 1. Save image to a temporary file path --- # We must save the NumPy array as an image file for the model temp_image_path = "/tmp/temp_image_for_processing.png" pil_image = Image.fromarray(image_input) pil_image.save(temp_image_path) # --- 2. Run the inference code --- prompt = question_input query = tokenizer.from_list_format([*[{'image': temp_image_path}], {'text': prompt}]) conv = [ {"from": "system", "value": "You are a helpful assistant reply by supporting your answer with explaination."}, {"from": "human", "value": query} ] inputs = tokenizer.apply_chat_template(conv, add_generation_prompt=True, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate( input_ids=inputs, do_sample=False, num_beams=1, temperature=1.0, top_p=1.0, use_cache=True, max_new_tokens=512, )[0] response = tokenizer.decode(out[inputs.size(1):-1], skip_special_tokens=True) # Clean up the temp file if os.path.exists(temp_image_path): os.remove(temp_image_path) print(f"[INFO] Sending response: {response[:50]}...") return response # ========================================================== # 4. LAUNCH THE GRADIO APP # ========================================================== print("Launching Gradio Interface...") iface = gr.Interface( fn=generate_report, inputs=[ gr.Image(type="numpy", label="Upload Chest X-ray"), gr.Textbox(label="Question", value="Provide step-by-step clinical reasoning as structured Findings,") ], outputs=gr.Textbox(label="Generated Report"), title="ChestX-Reasoner", description="Enter a query and upload an X-ray to get AI-generated reasoning.", # This is the line that creates your permanent API endpoint api_name="predict" ) # This launches the app on the Hugging Face Space iface.launch()