import argparse import os import re import bleach import cv2 import jsonlines import numpy as np import torch from loguru import logger from PIL import Image from tqdm import tqdm from transformers import AutoTokenizer, CLIPImageProcessor, PreTrainedTokenizer from eval.utils import grounding_image_ecoder_preprocess from model.Legion import LegionForCls from model.llava import conversation as conversation_lib from model.llava.mm_utils import tokenizer_image_token from model.SAM.utils.transforms import ResizeLongestSide from tools.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX def parse_args(): parser = argparse.ArgumentParser(description="LEGION Inference") # model related parser.add_argument("--model_path", required=True, help="The directory to your legion ckpt") parser.add_argument("--image_size", default=1024, type=int, help="image size") parser.add_argument("--model_max_length", default=512, type=int) # data related parser.add_argument("--image_root", required=True, help="The directory containing images to run inference.") parser.add_argument("--save_root", required=True, help="The directory to store the inference result.") args = parser.parse_args() return args class LEGION: """A simple wrapper for LEGION model loading and inference. Args: model_path (str): Path to the model checkpoint. image_size (int): Size of the input images. model_max_length (int): Maximum length of the model input sequence. """ INSTRUCTION = ( "Please provide a detailed analysis of artifacts in this photo, considering " "physical artifacts (e.g., optical display issues, violations of physical laws, " "and spatial/perspective errors), structural artifacts (e.g., deformed objects, asymmetry, or distorted text), " "and distortion artifacts (e.g., color/texture distortion, noise/blur, artistic style errors, and material misrepresentation). " "Output with interleaved segmentation masks for the corresponding parts of the answer." ) def __init__(self, model_path: str, image_size: int = 1024, model_max_length: int = 512): self.model_path = model_path self.image_size = image_size self.model_max_length = model_max_length # load tokenizer self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( self.model_path, cache_dir=None, model_max_length=self.model_max_length, padding_side="right", use_fast=False ) self.tokenizer.pad_token = self.tokenizer.unk_token seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] logger.info("Tokenizer loaded successfully.") # load model self.model: LegionForCls = LegionForCls.from_pretrained( self.model_path, low_cpu_mem_usage=True, seg_token_idx=seg_token_idx, torch_dtype=torch.bfloat16 ) # update model config self.model.config.eos_token_id = self.tokenizer.eos_token_id self.model.config.bos_token_id = self.tokenizer.bos_token_id self.model.config.pad_token_id = self.tokenizer.pad_token_id # init global image encoder (CLIP) self.model.get_model().initialize_vision_modules(self.model.get_model().config) vision_tower = self.model.get_model().get_vision_tower() vision_tower.to(dtype=torch.bfloat16) # transfer the model to GPU self.model = self.model.bfloat16().cuda() vision_tower.to(device="cuda") self.model.eval() logger.info("Model loaded successfully.") # init image processor for global image encoder (CLIP) self.image_processor = CLIPImageProcessor.from_pretrained(self.model.config.vision_tower) self.transform = ResizeLongestSide(self.image_size) logger.info("Image processor initialized successfully.") @torch.inference_mode() def _infer(self, raw_image: np.ndarray): """Run inference on a single image. Args: raw_image (np.ndarray): The input image in numpy array format. Returns: tuple: A tuple containing the explanation string, predicted masks, phrases, and classification result. """ # clean instructions instructions = bleach.clean(LEGION.INSTRUCTION) instructions = instructions.replace('<', '<').replace('>', '>') # prepare prompt conv = conversation_lib.conv_templates["llava_v1"].copy() conv.messages = [] prompt = f"The {DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN} provides an overview of the picture.\n" + instructions conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], "") prompt = conv.get_prompt() # preprocess image (CLIP) image_np = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) original_size_list = [image_np.shape[:2]] image_clip = (self.image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0].unsqueeze(0).cuda()) image_clip = image_clip.bfloat16() # preprocess image (Grounding image encoder) image = self.transform.apply_image(image_np) resize_list = [image.shape[:2]] image = (grounding_image_ecoder_preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()).unsqueeze(0).cuda()) image = image.bfloat16() # prepare inputs for inference input_ids = tokenizer_image_token(prompt, self.tokenizer, return_tensors="pt") input_ids = input_ids.unsqueeze(0).cuda() # generate output output_ids, pred_masks = self.model.evaluate( image_clip, image, input_ids, resize_list, original_size_list, max_tokens_new=512, bboxes=None # No box/region is input in GCG task ) output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] # post-processing text_output = self.tokenizer.decode(output_ids, skip_special_tokens=False) text_output = text_output.replace("\n", "").replace(" ", " ") text_output = text_output.split("ASSISTANT: ")[-1] cleaned_str = re.sub(r'<.*?>', '', text_output) # remove [SEG] token and unnecessary spaces cleaned_str = cleaned_str.replace('[SEG]', '') # strip unnecessary spaces cleaned_str = ' '.join(cleaned_str.split()).strip("'") cleaned_str = cleaned_str.strip() # infer detection head logits = self.model(global_enc_images=image_clip, inference_cls=True)['logits'].cpu() _, pred_cls = torch.max(logits, dim=1) pred_cls = int(pred_cls) return cleaned_str, pred_masks, pred_cls @torch.inference_mode() def infer(self, image_path: str): """Run inference on a single image. Args: image_path (str): Path to the input image. Returns: dict: A dictionary containing the explanation, localization mask path, and detection result. """ raw_image = cv2.imread(image_path) explanation, localization, detection = self._infer(raw_image.astype(np.uint8)) # post-process localization mask localization = localization[0].cpu() binary_localization = localization > 0 binary_localization = torch.any(binary_localization, dim=0).int() localization = (binary_localization.numpy() * 255).astype(np.uint8) localization = Image.fromarray(localization, mode="L") # post-process detection detection = "real" if detection == 1 else "fake" return { "explanation": explanation, "localization": localization, "detection": detection } def main(args): # get images suffixes = [".jpg", ".jpeg", ".png"] image_paths = sorted(os.listdir(args.image_root)) image_paths = [p for p in image_paths if os.path.splitext(p)[-1].lower() in suffixes] logger.info(f"Found {len(image_paths)} images for inference.") # init legion legion = LEGION(args.model_path, args.image_size, args.model_max_length) # check save root os.makedirs(args.save_root, exist_ok=True) localization_save_dir = os.path.join(args.save_root, "localization") os.makedirs(localization_save_dir, exist_ok=True) explanation_save_path = os.path.join(args.save_root, "explanations.jsonl") # prepare resume num_processed_images = 0 if os.path.exists(explanation_save_path): num_processed_images = len(list(jsonlines.open(explanation_save_path))) logger.info(f"Resuming from {num_processed_images} processed images.") image_paths = image_paths[num_processed_images:] # run inference with jsonlines.open(explanation_save_path, mode="a", flush=True) as writer: for image_path in tqdm(image_paths): image_name = os.path.splitext(image_path)[0] full_image_path = os.path.join(args.image_root, image_path) result = legion.infer(full_image_path) # save localization this_localization_save_path = os.path.join(localization_save_dir, f"{image_name}_mask.png") result["localization"].save(this_localization_save_path) result["localization"] = this_localization_save_path # add original image path result["image_path"] = full_image_path # write to jsonl writer.write(result) if __name__ == "__main__": args = parse_args() main(args)