OpenCUA-7B-Copy / handler.py
manpk-ai's picture
Duplicate from manpk-ai/OpenCUA-7B
2c9ab42 verified
from typing import Dict, List, Any
from transformers import (
AutoTokenizer,
AutoModel,
AutoImageProcessor,
)
import torch
from PIL import Image
import base64
import io
# get dtype and device
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler():
def __init__(self, path=""):
print(f"Initializing model on device: {device}")
print(f"Using dtype: {dtype}")
# load the model - using AutoModel like in local inference
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.image_processor = AutoImageProcessor.from_pretrained(path, trust_remote_code=True)
# Load model with explicit device mapping
if device == "cuda":
self.model = AutoModel.from_pretrained(
path,
torch_dtype=dtype,
trust_remote_code=True,
device_map="auto" # Automatically map to available GPUs
)
else:
self.model = AutoModel.from_pretrained(
path,
torch_dtype=dtype,
trust_remote_code=True
)
self.model = self.model.to(device)
print(f"Model loaded successfully on device: {self.model.device}")
print(f"Model dtype: {next(self.model.parameters()).dtype}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` or `list`): messages in chat format or text input
parameters (:obj: `dict`): generation parameters
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
print("Call inside handler")
# get inputs
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
print("parameters", parameters)
# Remove parameters that might cause issues
parameters.pop("details", None)
parameters.pop("stop", None)
parameters.pop("return_full_text", None)
if "do_sample" in parameters:
parameters["do_sample"] = True
# Set default generation parameters
max_new_tokens = parameters.pop("max_new_tokens", 512)
temperature = parameters.pop("temperature", 0)
try:
# Handle different input formats
if isinstance(inputs, str):
# If it's a string, treat it as a simple text prompt
input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
**parameters
)
prompt_len = input_ids.shape[1]
generated_ids = generated_ids[:, prompt_len:]
output_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return [{"generated_text": output_text[0]}]
elif isinstance(inputs, list):
# Handle chat format with images
messages = inputs
# Apply chat template
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
input_text = self.tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
print(input_text)
input_ids = torch.tensor([input_ids]).to(self.model.device)
# Process ALL images if present
pixel_values_list = []
grid_thws_list = []
# Look for images in the messages
for message in messages:
if isinstance(message.get("content"), list):
for content_item in message["content"]:
if content_item.get("type") == "image_url":
image_data = content_item.get("image_url").get("url", "")
if image_data.startswith("data:image"):
# Decode base64 image
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Process each image individually
info = self.image_processor.preprocess(images=[image])
pixel_values = torch.tensor(info['pixel_values']).to(dtype=dtype, device=self.model.device)
grid_thws = torch.tensor(info['image_grid_thw']).to(self.model.device)
pixel_values_list.append(pixel_values)
grid_thws_list.append(grid_thws)
# Generate response
if pixel_values_list and grid_thws_list:
# Multi-modal generation with images
# Concatenate all pixel_values and grid_thws for batch processing
all_pixel_values = torch.cat(pixel_values_list, dim=0)
all_grid_thws = torch.cat(grid_thws_list, dim=0)
print(f"Processing {len(pixel_values_list)} images")
print(f"pixel_values shape: {all_pixel_values.shape}")
print(f"grid_thws shape: {all_grid_thws.shape}")
print("grid_thws", all_grid_thws)
# Ensure all tensors are on the same device as the model
all_pixel_values = all_pixel_values.to(self.model.device)
all_grid_thws = all_grid_thws.to(self.model.device)
with torch.no_grad():
generated_ids = self.model.generate(
input_ids,
pixel_values=all_pixel_values,
grid_thws=all_grid_thws,
max_new_tokens=max_new_tokens,
temperature=temperature,
**parameters
)
else:
# Text-only generation
generated_ids = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
**parameters
)
prompt_len = input_ids.shape[1]
generated_ids = generated_ids[:, prompt_len:]
output_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("##Model Response##", output_text)
return [{"generated_text": output_text[0]}]
else:
raise ValueError(f"Unsupported input type: {type(inputs)}")
except Exception as e:
print(f"Error during inference: {str(e)}")
return [{"error": str(e)}]