from typing import Dict, List, Any from transformers import ( AutoTokenizer, AutoModel, AutoImageProcessor, ) import torch from PIL import Image import base64 import io # get dtype and device dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" class EndpointHandler(): def __init__(self, path=""): print(f"Initializing model on device: {device}") print(f"Using dtype: {dtype}") # load the model - using AutoModel like in local inference self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.image_processor = AutoImageProcessor.from_pretrained(path, trust_remote_code=True) # Load model with explicit device mapping if device == "cuda": self.model = AutoModel.from_pretrained( path, torch_dtype=dtype, trust_remote_code=True, device_map="auto" # Automatically map to available GPUs ) else: self.model = AutoModel.from_pretrained( path, torch_dtype=dtype, trust_remote_code=True ) self.model = self.model.to(device) print(f"Model loaded successfully on device: {self.model.device}") print(f"Model dtype: {next(self.model.parameters()).dtype}") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` or `list`): messages in chat format or text input parameters (:obj: `dict`): generation parameters Return: A :obj:`list` | `dict`: will be serialized and returned """ print("Call inside handler") # get inputs inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) print("parameters", parameters) # Remove parameters that might cause issues parameters.pop("details", None) parameters.pop("stop", None) parameters.pop("return_full_text", None) if "do_sample" in parameters: parameters["do_sample"] = True # Set default generation parameters max_new_tokens = parameters.pop("max_new_tokens", 512) temperature = parameters.pop("temperature", 0) try: # Handle different input formats if isinstance(inputs, str): # If it's a string, treat it as a simple text prompt input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device) generated_ids = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, **parameters ) prompt_len = input_ids.shape[1] generated_ids = generated_ids[:, prompt_len:] output_text = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return [{"generated_text": output_text[0]}] elif isinstance(inputs, list): # Handle chat format with images messages = inputs # Apply chat template input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True ) input_text = self.tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) print(input_text) input_ids = torch.tensor([input_ids]).to(self.model.device) # Process ALL images if present pixel_values_list = [] grid_thws_list = [] # Look for images in the messages for message in messages: if isinstance(message.get("content"), list): for content_item in message["content"]: if content_item.get("type") == "image_url": image_data = content_item.get("image_url").get("url", "") if image_data.startswith("data:image"): # Decode base64 image image_data = image_data.split(",")[1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)).convert('RGB') # Process each image individually info = self.image_processor.preprocess(images=[image]) pixel_values = torch.tensor(info['pixel_values']).to(dtype=dtype, device=self.model.device) grid_thws = torch.tensor(info['image_grid_thw']).to(self.model.device) pixel_values_list.append(pixel_values) grid_thws_list.append(grid_thws) # Generate response if pixel_values_list and grid_thws_list: # Multi-modal generation with images # Concatenate all pixel_values and grid_thws for batch processing all_pixel_values = torch.cat(pixel_values_list, dim=0) all_grid_thws = torch.cat(grid_thws_list, dim=0) print(f"Processing {len(pixel_values_list)} images") print(f"pixel_values shape: {all_pixel_values.shape}") print(f"grid_thws shape: {all_grid_thws.shape}") print("grid_thws", all_grid_thws) # Ensure all tensors are on the same device as the model all_pixel_values = all_pixel_values.to(self.model.device) all_grid_thws = all_grid_thws.to(self.model.device) with torch.no_grad(): generated_ids = self.model.generate( input_ids, pixel_values=all_pixel_values, grid_thws=all_grid_thws, max_new_tokens=max_new_tokens, temperature=temperature, **parameters ) else: # Text-only generation generated_ids = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, **parameters ) prompt_len = input_ids.shape[1] generated_ids = generated_ids[:, prompt_len:] output_text = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print("##Model Response##", output_text) return [{"generated_text": output_text[0]}] else: raise ValueError(f"Unsupported input type: {type(inputs)}") except Exception as e: print(f"Error during inference: {str(e)}") return [{"error": str(e)}]