sasha HF Staff commited on
Commit
90bd6cb
·
1 Parent(s): c0151d4
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -19,13 +19,13 @@ ds = dataset["train"]
19
 
20
 
21
  def query(image, top_k=4):
22
- inputs = feature_extractor(image, return_tensors="pt")
23
  model_output = model(**inputs)
24
  embedding = model_output.pooler_output.detach()
25
  results = index.query(embedding, k=top_k)
26
  inx = results[0][0].tolist()
27
  logits = results[1][0].tolist()
28
- images = ds.select(inx)["image"].convert("RGB")
29
  captions = ds.select(inx)["label"]
30
  images_with_captions = [(i, c) for i, c in zip(images, captions)]
31
  labels_with_probs = dict(zip(captions, logits))
 
19
 
20
 
21
  def query(image, top_k=4):
22
+ inputs = feature_extractor(image.convert("RGB"), return_tensors="pt")
23
  model_output = model(**inputs)
24
  embedding = model_output.pooler_output.detach()
25
  results = index.query(embedding, k=top_k)
26
  inx = results[0][0].tolist()
27
  logits = results[1][0].tolist()
28
+ images = ds.select(inx)["image"]
29
  captions = ds.select(inx)["label"]
30
  images_with_captions = [(i, c) for i, c in zip(images, captions)]
31
  labels_with_probs = dict(zip(captions, logits))