Spaces:
Runtime error
Runtime error
| import gradio | |
| from transformers import ViltProcessor, ViltForQuestionAnswering | |
| from PIL import Image | |
| processor = ViltProcessor.from_pretrained("vilt-b32-finetuned-vqa") | |
| model = ViltForQuestionAnswering.from_pretrained("vilt-b32-finetuned-vqa") | |
| def predict_answer(image, question): | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| encoding = processor(image, question, return_tensors="pt") | |
| outputs = model(**encoding) | |
| logits = outputs.logits | |
| probs = logits.softmax(dim=-1) | |
| sorted_probs, sorted_indices = probs[0].sort(descending=True) | |
| answer_list = [] | |
| for i in range(5): | |
| prob = sorted_probs[i].item() | |
| if prob > 0.00: | |
| idx = sorted_indices[i].item() | |
| answer = model.config.id2label[idx] | |
| answer_list.append(f"{answer}: {prob:.2%}") | |
| return answer_list | |
| inputs = [ | |
| gradio.components.Image(label="Image"), | |
| gradio.components.Textbox(label="Question", placeholder="Enter your question here.") | |
| ] | |
| outputs = [ | |
| gradio.components.Textbox(label="Answer 1"), | |
| gradio.components.Textbox(label="Answer 2"), | |
| gradio.components.Textbox(label="Answer 3"), | |
| gradio.components.Textbox(label="Answer 4"), | |
| gradio.components.Textbox(label="Answer 5") | |
| ] | |
| title = "Visual Question Answering (vilt-b32-finetuned-vqa)" | |
| gradio.Interface(fn=predict_answer, inputs=inputs, outputs=outputs, title=title, allow_flagging="never", | |
| css="footer{display:none !important}").launch() | |