LAOD / laod_pipeline.py
fumucu's picture
initial commit
74d88b9
import torch
from transformers import pipeline
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import numpy as np
import draw_utils
import spaces
from huggingface_hub import login
import os
for variable_name in os.environ.keys():
print(variable_name)
login(token=os.environ.get('gemma_access_token'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = pipeline(
"image-text-to-text",
model="google/gemma-3-4b-it",
#device="cuda:1",
device_map=device,
torch_dtype=torch.bfloat16
)
model_id = "IDEA-Research/grounding-dino-tiny"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
@spaces.GPU
def laod_gdino(image):
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "Just Give the list of objects in given picture seperated by comma. Do not write anything else."}]},
{
"role": "user",
"content": [
{"type": "text", "text": "List the objects that you see in given picture."},
{"type": "image", "url": image},
]
},
]
output = pipe(text=messages, max_new_tokens=500)
print(output[0]["generated_text"][-1]["content"])
llm_response = output[0]["generated_text"][-1]["content"]
llm_response = llm_response.lower()
llm_response = llm_response.replace('pedestrian', 'person')
llm_response = llm_response.replace('people', 'person')
llm_response = llm_response.replace('man', 'person')
llm_response = llm_response.replace('woman', 'person')
llm_labels = llm_response.replace(', ', ',').split(',')
print(llm_labels)
llm_labels = [llm_labels]
inputs = processor(images=image, text=llm_labels, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
threshold=0.4,
text_threshold=0.3,
target_sizes=[image.size[::-1]]
)
result = results[0]
image = np.array(image)
draw_results = [result["boxes"], result["scores"], result["labels"]]
return draw_utils.visualize_detections(image, draw_results)