import os
import json
from PIL import Image, ImageFont, ImageDraw
from tqdm import tqdm

# names = ['loose', 'bas', 'ours']
# path = [
#     'LooseControl/results',
#     'build_a_scene/eval/evaluation_set_1/generation/images',
#     'exp/bas',
# ]

# for i in range(100):
#     image = Image.new('RGB', (512*4, 600), (255, 255, 255))
#     prompt = conditions[i]['caption']

#     idx = str(i).zfill(4)
#     cond = Image.open(f'build_a_scene/eval/evaluation_set_1/generation/renders/{idx}.png')
#     image.paste(cond, (512*0, 0))

#     image_name = os.listdir(f"LooseControl/results/{idx}")[-2][:-4]
#     for j, name in enumerate(names):
#         try:
#             depth = Image.open(f'{path[j]}/{idx}/{image_name}.png')
#         except:
#             depth = Image.open(f'{path[j]}/{idx}/{image_name}_2.png')
#         image.paste(depth, (512*(j+1), 0))

#     font = ImageFont.load_default(35)   
#     draw = ImageDraw.Draw(image)
#     draw.text((10+512*0, 512), 'cond', fill=(0, 0, 0), font=font)
#     for j,name in enumerate(names):
#         draw.text((10+512*(j+1), 512), names[j], fill=(0, 0, 0), font=font)
#     draw.text((10+512*0, 512+45), prompt, fill=(0, 0, 0), font=font)

#     image.save(f"exp/concat/{idx}.png")
    

names = [
    # 'rpg', 
    'eligen', 
    'visual', 
    'ours', 
    'visual'
]
path = [
    # 'exp/rpg/samples',
    'exp/eligen/samples',
    'exp/eligen/visual',
    'exp/gen_30000/samples',
    'exp/gen_30000/visual',
]
json_path = 'highly_scored_prompts.json'
with open(json_path, 'r') as f:
    data = json.load(f)

for d in tqdm(data):
    try:
        prompt = d['prompt']
        relation = d['relation']
        os.makedirs(f"exp/show_new/{relation}", exist_ok=True)

        save_path = f"exp/show_new/{relation}/{prompt}.png"
        # if os.path.exists(save_path):
        #     continue

        image = Image.new('RGB', (256*4, 300), (255, 255, 255))

        for j, p in enumerate(path):
            try:
                depth = Image.open(f'{path[j]}/{prompt}_42.png')
            except:
                depth = Image.open(f'{path[j]}/{prompt}_42_mask.png')
            image.paste(depth.resize((256,256)), (256*j, 0))

        font = ImageFont.load_default(15)   
        draw = ImageDraw.Draw(image)
        for j,name in enumerate(names):
            draw.text((10+256*j, 256), names[j], fill=(0, 0, 0), font=font)
        draw.text((10+256*0, 256+20), prompt, fill=(0, 0, 0), font=font)

        image.save(save_path)
    except:
        pass