Alvaro Bartolome commited on
Commit
baebe6f
·
unverified ·
1 Parent(s): 0476276

Add custom `handler.py`

Browse files
Files changed (1) hide show
  1. handler.py +59 -0
handler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import requests
4
+ import torch
5
+
6
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
7
+ from PIL import Image
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(
12
+ self,
13
+ model_dir: str = "/opt/huggingface/model",
14
+ **kwargs: Any,
15
+ ) -> None:
16
+ self.model = PaliGemmaForConditionalGeneration.from_pretrained(
17
+ "google/paligemma-3b-mix-448",
18
+ revision="bfloat16",
19
+ torch_dtype=torch.bfloat16,
20
+ low_cpu_mem_usage=True,
21
+ device_map="auto",
22
+ ).eval()
23
+
24
+ self.processor = AutoProcessor.from_pretrained("google/paligemma-3b-mix-448")
25
+
26
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
27
+ if "instances" not in data:
28
+ raise ValueError(
29
+ "The request body must contain a key `instances` with a list of instances."
30
+ )
31
+
32
+ predictions = []
33
+ for input in data["instances"]:
34
+ if any(key not in input for key in {"prompt", "image_url"}):
35
+ raise ValueError(
36
+ "The request body for each instance should contain both the `prompt` and the `image_url` key with a valid image URL."
37
+ )
38
+
39
+ try:
40
+ image = Image.open(requests.get(input["image_url"], stream=True).raw) # type: ignore
41
+ except Exception as e:
42
+ raise ValueError(
43
+ f"The provided image URL ({input['image_url']}) cannot be downloaded (with exception {e}), make sure it's public and accessible."
44
+ )
45
+
46
+ inputs = self.processor(
47
+ text=input["prompt"], images=image, return_tensors="pt"
48
+ ).to(self.model.device)
49
+ input_len = inputs["input_ids"].shape[-1]
50
+
51
+ with torch.inference_mode():
52
+ generation_kwargs = data.get(
53
+ "generation_kwargs", {"max_new_tokens": 100, "do_sample": False}
54
+ )
55
+ generation = self.model.generate(**inputs, **generation_kwargs)
56
+ generation = generation[0][input_len:]
57
+ response = self.processor.decode(generation, skip_special_tokens=True)
58
+ predictions.append(response)
59
+ return {"predictions": predictions}