Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import torch | |
| import kelip | |
| import gradio as gr | |
| def load_model(): | |
| model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32') | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| model_dict = {'model': model, | |
| 'preprocess_img': preprocess_img, | |
| 'tokenizer': tokenizer | |
| } | |
| return model_dict | |
| def classify(img, user_text): | |
| preprocess_img = model_dict['preprocess_img'] | |
| input_img = preprocess_img(img).unsqueeze(0) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| input_img = input_img.to(device) | |
| # extract image features | |
| with torch.no_grad(): | |
| image_features = model_dict['model'].encode_image(input_img) | |
| # extract text features | |
| user_texts = user_text.split(',') | |
| if user_text == '' or user_text.isspace(): | |
| user_texts = [] | |
| input_texts = model_dict['tokenizer'].encode(user_texts) | |
| if torch.cuda.is_available(): | |
| input_texts = input_texts.cuda() | |
| text_features = model_dict['model'].encode_text(input_texts) | |
| # l2 normalize | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| values, indices = similarity[0].topk(len(user_texts)) | |
| result = {} | |
| for value, index in zip(values, indices): | |
| result[user_texts[index]] = value.item() | |
| return result | |
| if __name__ == '__main__': | |
| global model_dict | |
| model_dict = load_model() | |
| inputs = [gr.inputs.Image(type="pil", label="Image"), | |
| gr.inputs.Textbox(lines=5, label="Caption"), | |
| ] | |
| outputs = ['label'] | |
| title = "KELIP" | |
| if torch.cuda.is_available(): | |
| demo_status = "Demo is running on GPU" | |
| else: | |
| demo_status = "Demo is running on CPU" | |
| description = f"Details: paper_url. {demo_status}" | |
| examples = [ | |
| ["squid_sundae.jpg", "์ค์ง์ด ์๋,๊น๋ฐฅ,์๋,๋ก๋ณถ์ด"], | |
| ["seokchon_lake.jpg", "ํํ์๋ฌธ,์ฌ๋ฆผํฝ๊ณต์,๋กฏ๋ฐ์๋,์์ดํธ์"], | |
| ["seokchon_lake.jpg", "๋ด,์ฌ๋ฆ,๊ฐ์,๊ฒจ์ธ"], | |
| ["hwangchil_tree.jpg", "ํฉ์น ๋๋ฌด ๋ฌ๋ชฉ,ํฉ์น ๋๋ฌด,๋,์๋๋ฌด ๋ฌ๋ชฉ,์ผ์์"], | |
| ] | |
| article = "" | |
| iface=gr.Interface( | |
| fn=classify, | |
| inputs=inputs, | |
| outputs=outputs, | |
| examples=examples, | |
| title=title, | |
| description=description, | |
| article=article | |
| ) | |
| iface.launch() |