|
|
import argparse |
|
|
import os |
|
|
import re |
|
|
|
|
|
import bleach |
|
|
import cv2 |
|
|
import jsonlines |
|
|
import numpy as np |
|
|
import torch |
|
|
from loguru import logger |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoTokenizer, CLIPImageProcessor, PreTrainedTokenizer |
|
|
|
|
|
from eval.utils import grounding_image_ecoder_preprocess |
|
|
from model.Legion import LegionForCls |
|
|
from model.llava import conversation as conversation_lib |
|
|
from model.llava.mm_utils import tokenizer_image_token |
|
|
from model.SAM.utils.transforms import ResizeLongestSide |
|
|
from tools.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="LEGION Inference") |
|
|
|
|
|
parser.add_argument("--model_path", required=True, help="The directory to your legion ckpt") |
|
|
parser.add_argument("--image_size", default=1024, type=int, help="image size") |
|
|
parser.add_argument("--model_max_length", default=512, type=int) |
|
|
|
|
|
parser.add_argument("--image_root", required=True, help="The directory containing images to run inference.") |
|
|
parser.add_argument("--save_root", required=True, help="The directory to store the inference result.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
class LEGION: |
|
|
"""A simple wrapper for LEGION model loading and inference. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the model checkpoint. |
|
|
image_size (int): Size of the input images. |
|
|
model_max_length (int): Maximum length of the model input sequence. |
|
|
""" |
|
|
|
|
|
INSTRUCTION = ( |
|
|
"Please provide a detailed analysis of artifacts in this photo, considering " |
|
|
"physical artifacts (e.g., optical display issues, violations of physical laws, " |
|
|
"and spatial/perspective errors), structural artifacts (e.g., deformed objects, asymmetry, or distorted text), " |
|
|
"and distortion artifacts (e.g., color/texture distortion, noise/blur, artistic style errors, and material misrepresentation). " |
|
|
"Output with interleaved segmentation masks for the corresponding parts of the answer." |
|
|
) |
|
|
|
|
|
def __init__(self, model_path: str, image_size: int = 1024, model_max_length: int = 512): |
|
|
self.model_path = model_path |
|
|
self.image_size = image_size |
|
|
self.model_max_length = model_max_length |
|
|
|
|
|
|
|
|
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_path, |
|
|
cache_dir=None, |
|
|
model_max_length=self.model_max_length, |
|
|
padding_side="right", |
|
|
use_fast=False |
|
|
) |
|
|
self.tokenizer.pad_token = self.tokenizer.unk_token |
|
|
seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
|
|
logger.info("Tokenizer loaded successfully.") |
|
|
|
|
|
|
|
|
self.model: LegionForCls = LegionForCls.from_pretrained( |
|
|
self.model_path, |
|
|
low_cpu_mem_usage=True, |
|
|
seg_token_idx=seg_token_idx, |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
self.model.config.eos_token_id = self.tokenizer.eos_token_id |
|
|
self.model.config.bos_token_id = self.tokenizer.bos_token_id |
|
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
|
|
self.model.get_model().initialize_vision_modules(self.model.get_model().config) |
|
|
vision_tower = self.model.get_model().get_vision_tower() |
|
|
vision_tower.to(dtype=torch.bfloat16) |
|
|
|
|
|
self.model = self.model.bfloat16().cuda() |
|
|
vision_tower.to(device="cuda") |
|
|
self.model.eval() |
|
|
logger.info("Model loaded successfully.") |
|
|
|
|
|
|
|
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.model.config.vision_tower) |
|
|
self.transform = ResizeLongestSide(self.image_size) |
|
|
logger.info("Image processor initialized successfully.") |
|
|
|
|
|
@torch.inference_mode() |
|
|
def _infer(self, raw_image: np.ndarray): |
|
|
"""Run inference on a single image. |
|
|
|
|
|
Args: |
|
|
raw_image (np.ndarray): The input image in numpy array format. |
|
|
|
|
|
Returns: |
|
|
tuple: A tuple containing the explanation string, predicted masks, phrases, and classification result. |
|
|
""" |
|
|
|
|
|
instructions = bleach.clean(LEGION.INSTRUCTION) |
|
|
instructions = instructions.replace('<', '<').replace('>', '>') |
|
|
|
|
|
|
|
|
conv = conversation_lib.conv_templates["llava_v1"].copy() |
|
|
conv.messages = [] |
|
|
prompt = f"The {DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN} provides an overview of the picture.\n" + instructions |
|
|
conv.append_message(conv.roles[0], prompt) |
|
|
conv.append_message(conv.roles[1], "") |
|
|
prompt = conv.get_prompt() |
|
|
|
|
|
|
|
|
image_np = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) |
|
|
original_size_list = [image_np.shape[:2]] |
|
|
image_clip = (self.image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0].unsqueeze(0).cuda()) |
|
|
image_clip = image_clip.bfloat16() |
|
|
|
|
|
|
|
|
image = self.transform.apply_image(image_np) |
|
|
resize_list = [image.shape[:2]] |
|
|
image = (grounding_image_ecoder_preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()).unsqueeze(0).cuda()) |
|
|
image = image.bfloat16() |
|
|
|
|
|
|
|
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, return_tensors="pt") |
|
|
input_ids = input_ids.unsqueeze(0).cuda() |
|
|
|
|
|
|
|
|
output_ids, pred_masks = self.model.evaluate( |
|
|
image_clip, |
|
|
image, |
|
|
input_ids, |
|
|
resize_list, |
|
|
original_size_list, |
|
|
max_tokens_new=512, |
|
|
bboxes=None |
|
|
) |
|
|
output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] |
|
|
|
|
|
|
|
|
text_output = self.tokenizer.decode(output_ids, skip_special_tokens=False) |
|
|
text_output = text_output.replace("\n", "").replace(" ", " ") |
|
|
text_output = text_output.split("ASSISTANT: ")[-1] |
|
|
cleaned_str = re.sub(r'<.*?>', '', text_output) |
|
|
|
|
|
cleaned_str = cleaned_str.replace('[SEG]', '') |
|
|
|
|
|
cleaned_str = ' '.join(cleaned_str.split()).strip("'") |
|
|
cleaned_str = cleaned_str.strip() |
|
|
|
|
|
|
|
|
logits = self.model(global_enc_images=image_clip, inference_cls=True)['logits'].cpu() |
|
|
_, pred_cls = torch.max(logits, dim=1) |
|
|
pred_cls = int(pred_cls) |
|
|
return cleaned_str, pred_masks, pred_cls |
|
|
|
|
|
@torch.inference_mode() |
|
|
def infer(self, image_path: str): |
|
|
"""Run inference on a single image. |
|
|
|
|
|
Args: |
|
|
image_path (str): Path to the input image. |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary containing the explanation, localization mask path, and detection result. |
|
|
""" |
|
|
raw_image = cv2.imread(image_path) |
|
|
explanation, localization, detection = self._infer(raw_image.astype(np.uint8)) |
|
|
|
|
|
|
|
|
localization = localization[0].cpu() |
|
|
binary_localization = localization > 0 |
|
|
binary_localization = torch.any(binary_localization, dim=0).int() |
|
|
localization = (binary_localization.numpy() * 255).astype(np.uint8) |
|
|
localization = Image.fromarray(localization, mode="L") |
|
|
|
|
|
|
|
|
detection = "real" if detection == 1 else "fake" |
|
|
|
|
|
return { |
|
|
"explanation": explanation, |
|
|
"localization": localization, |
|
|
"detection": detection |
|
|
} |
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
suffixes = [".jpg", ".jpeg", ".png"] |
|
|
image_paths = sorted(os.listdir(args.image_root)) |
|
|
image_paths = [p for p in image_paths if os.path.splitext(p)[-1].lower() in suffixes] |
|
|
logger.info(f"Found {len(image_paths)} images for inference.") |
|
|
|
|
|
|
|
|
legion = LEGION(args.model_path, args.image_size, args.model_max_length) |
|
|
|
|
|
|
|
|
os.makedirs(args.save_root, exist_ok=True) |
|
|
localization_save_dir = os.path.join(args.save_root, "localization") |
|
|
os.makedirs(localization_save_dir, exist_ok=True) |
|
|
explanation_save_path = os.path.join(args.save_root, "explanations.jsonl") |
|
|
|
|
|
|
|
|
num_processed_images = 0 |
|
|
if os.path.exists(explanation_save_path): |
|
|
num_processed_images = len(list(jsonlines.open(explanation_save_path))) |
|
|
logger.info(f"Resuming from {num_processed_images} processed images.") |
|
|
image_paths = image_paths[num_processed_images:] |
|
|
|
|
|
|
|
|
with jsonlines.open(explanation_save_path, mode="a", flush=True) as writer: |
|
|
for image_path in tqdm(image_paths): |
|
|
image_name = os.path.splitext(image_path)[0] |
|
|
full_image_path = os.path.join(args.image_root, image_path) |
|
|
result = legion.infer(full_image_path) |
|
|
|
|
|
this_localization_save_path = os.path.join(localization_save_dir, f"{image_name}_mask.png") |
|
|
result["localization"].save(this_localization_save_path) |
|
|
result["localization"] = this_localization_save_path |
|
|
|
|
|
result["image_path"] = full_image_path |
|
|
|
|
|
writer.write(result) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
main(args) |
|
|
|