Spaces:
Sleeping
Sleeping
Peiran
Fix Gradio bindings: remove gr.Request from inputs; rely on implicit request injection. Docs already updated earlier.
12a36fb
| import csv | |
| import itertools | |
| import random | |
| import json | |
| import os | |
| import uuid | |
| from datetime import datetime, timedelta | |
| from io import BytesIO | |
| from typing import Dict, List, Tuple, Optional | |
| import gradio as gr | |
| try: | |
| from huggingface_hub import HfApi | |
| except Exception: # optional dependency at runtime | |
| HfApi = None # type: ignore | |
| BASE_DIR = os.path.dirname(__file__) | |
| PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data") | |
| # Persistent local storage inside HF Spaces | |
| PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data") | |
| # Evaluation knobs (can be overridden via env vars) | |
| MIN_RATERS_PER_PAIR = int(os.environ.get("MIN_RATERS_PER_PAIR", 20)) | |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 20)) | |
| RELOAD_EVERY = int(os.environ.get("RELOAD_EVERY", 5)) | |
| REPEAT_RATE = float(os.environ.get("REPEAT_RATE", 0.05)) # fraction of repeats within batch | |
| REPEAT_MIN_HOURS = float(os.environ.get("REPEAT_MIN_HOURS", 24)) | |
| FAST_MIN_SEC = float(os.environ.get("FAST_MIN_SEC", 2.0)) | |
| TASK_CONFIG = { | |
| "Scene Composition & Object Insertion": { | |
| "folder": "scene_composition_and_object_insertion", | |
| "score_fields": [ | |
| ("physical_interaction_fidelity_score", "物理交互保真度 (Physical Interaction Fidelity)"), | |
| ("optical_effect_accuracy_score", "光学效应准确度 (Optical Effect Accuracy)"), | |
| ("semantic_functional_alignment_score", "语义/功能对齐度 (Semantic/Functional Alignment)"), | |
| ("overall_photorealism_score", "整体真实感 (Overall Photorealism)"), | |
| ], | |
| }, | |
| } | |
| def _csv_path_for_task(task_name: str, filename: str) -> str: | |
| folder = TASK_CONFIG[task_name]["folder"] | |
| return os.path.join(BASE_DIR, folder, filename) | |
| def _persist_csv_path_for_task(task_name: str) -> str: | |
| folder = TASK_CONFIG[task_name]["folder"] | |
| return os.path.join(PERSIST_DIR, folder, "evaluation_results.csv") | |
| def _resolve_image_path(path: str) -> str: | |
| return path if os.path.isabs(path) else os.path.join(BASE_DIR, path) | |
| def _file_exists_under_base(rel_or_abs_path: str) -> bool: | |
| """Check if file exists, resolving relative paths under BASE_DIR.""" | |
| check_path = rel_or_abs_path if os.path.isabs(rel_or_abs_path) else os.path.join(BASE_DIR, rel_or_abs_path) | |
| return os.path.exists(check_path) | |
| def _load_task_rows(task_name: str) -> List[Dict[str, str]]: | |
| csv_path = _csv_path_for_task(task_name, "results.csv") | |
| if not os.path.exists(csv_path): | |
| raise FileNotFoundError(f"未找到任务 {task_name} 的结果文件: {csv_path}") | |
| with open(csv_path, newline="", encoding="utf-8") as csv_file: | |
| reader = csv.DictReader(csv_file) | |
| rows: List[Dict[str, str]] = [] | |
| for row in reader: | |
| # Trim whitespaces in all string fields to avoid path/key mismatches | |
| cleaned = {k.strip(): (v.strip() if isinstance(v, str) else v) for k, v in row.items()} | |
| rows.append(cleaned) | |
| return rows | |
| def _build_image_pairs(rows: List[Dict[str, str]], task_name: str) -> List[Dict[str, str]]: | |
| grouped: Dict[Tuple[str, str], List[Dict[str, str]]] = {} | |
| for row in rows: | |
| key = (row["test_id"], row["org_img"]) | |
| grouped.setdefault(key, []).append(row) | |
| pairs: List[Dict[str, str]] = [] | |
| folder = TASK_CONFIG[task_name]["folder"] | |
| for (test_id, org_img), entries in grouped.items(): | |
| for model_a, model_b in itertools.combinations(entries, 2): | |
| if model_a["model_name"] == model_b["model_name"]: | |
| continue | |
| org_path = os.path.join(folder, org_img) | |
| path_a = os.path.join(folder, model_a["path"]) | |
| path_b = os.path.join(folder, model_b["path"]) | |
| # Validate existence to avoid UI errors | |
| if not (_file_exists_under_base(org_path) and _file_exists_under_base(path_a) and _file_exists_under_base(path_b)): | |
| try: | |
| print("[VisArena] Skipping invalid paths for test_id=", test_id, { | |
| "org": org_path, | |
| "a": path_a, | |
| "b": path_b, | |
| }) | |
| except Exception: | |
| pass | |
| continue | |
| pair = { | |
| "test_id": test_id, | |
| "org_img": org_path, | |
| "model1_name": model_a["model_name"], | |
| "model1_res": model_a["res"], | |
| "model1_path": path_a, | |
| "model2_name": model_b["model_name"], | |
| "model2_res": model_b["res"], | |
| "model2_path": path_b, | |
| } | |
| pairs.append(pair) | |
| def sort_key(item: Dict[str, str]): | |
| test_id = item["test_id"] | |
| try: | |
| test_id_key = int(test_id) | |
| except ValueError: | |
| test_id_key = test_id | |
| return (test_id_key, item["model1_name"], item["model2_name"]) | |
| pairs.sort(key=sort_key) | |
| return pairs | |
| def _read_eval_counts(task_name: str) -> Dict[Tuple[str, frozenset, str], int]: | |
| """Global counts per pair key across all annotators.""" | |
| counts: Dict[Tuple[str, frozenset, str], int] = {} | |
| csv_path = _persist_csv_path_for_task(task_name) | |
| if not os.path.exists(csv_path): | |
| return counts | |
| try: | |
| with open(csv_path, newline="", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for r in reader: | |
| tid = str(r.get("test_id", "")).strip() | |
| m1 = str(r.get("model1_name", "")).strip() | |
| m2 = str(r.get("model2_name", "")).strip() | |
| org = str(r.get("org_img", "")).strip() | |
| if not (tid and m1 and m2 and org): | |
| continue | |
| key = (tid, frozenset({m1, m2}), org) | |
| counts[key] = counts.get(key, 0) + 1 | |
| except Exception: | |
| pass | |
| return counts | |
| def _read_user_done_keys(task_name: str, annotator_id: str) -> set: | |
| """Keys already evaluated by the given annotator. | |
| If CSV has no annotator_id column (legacy rows), those rows are ignored for per-user filtering. | |
| """ | |
| keys = set() | |
| if not annotator_id: | |
| return keys | |
| csv_path = _persist_csv_path_for_task(task_name) | |
| if not os.path.exists(csv_path): | |
| return keys | |
| try: | |
| with open(csv_path, newline="", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for r in reader: | |
| if str(r.get("annotator_id", "")).strip() != str(annotator_id).strip(): | |
| continue | |
| tid = str(r.get("test_id", "")).strip() | |
| m1 = str(r.get("model1_name", "")).strip() | |
| m2 = str(r.get("model2_name", "")).strip() | |
| org = str(r.get("org_img", "")).strip() | |
| if tid and m1 and m2 and org: | |
| keys.add((tid, frozenset({m1, m2}), org)) | |
| except Exception: | |
| pass | |
| return keys | |
| def _read_user_last_times(task_name: str, annotator_id: str) -> Dict[Tuple[str, frozenset, str], datetime]: | |
| """Return the user's last evaluation datetime per pair key.""" | |
| last: Dict[Tuple[str, frozenset, str], datetime] = {} | |
| if not annotator_id: | |
| return last | |
| csv_path = _persist_csv_path_for_task(task_name) | |
| if not os.path.exists(csv_path): | |
| return last | |
| try: | |
| with open(csv_path, newline="", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for r in reader: | |
| if str(r.get("annotator_id", "")).strip() != str(annotator_id).strip(): | |
| continue | |
| tid = str(r.get("test_id", "")).strip() | |
| m1 = str(r.get("model1_name", "")).strip() | |
| m2 = str(r.get("model2_name", "")).strip() | |
| org = str(r.get("org_img", "")).strip() | |
| dt = str(r.get("eval_date", "")).strip() or str(r.get("submit_ts", "")).strip() | |
| if not (tid and m1 and m2 and org and dt): | |
| continue | |
| key = (tid, frozenset({m1, m2}), org) | |
| try: | |
| t = datetime.fromisoformat(dt) | |
| except Exception: | |
| continue | |
| if key not in last or t > last[key]: | |
| last[key] = t | |
| except Exception: | |
| pass | |
| return last | |
| def _schedule_round_robin_by_test_id(pairs: List[Dict[str, str]], seed: Optional[int] = None) -> List[Dict[str, str]]: | |
| """Interleave pairs across test_ids for balanced coverage; shuffle within each group. | |
| """ | |
| groups: Dict[str, List[Dict[str, str]]] = {} | |
| for p in pairs: | |
| groups.setdefault(p["test_id"], []).append(p) | |
| rnd = random.Random(seed) | |
| for lst in groups.values(): | |
| rnd.shuffle(lst) | |
| # round-robin drain | |
| ordered: List[Dict[str, str]] = [] | |
| while True: | |
| progressed = False | |
| for tid in sorted(groups.keys(), key=lambda x: (int(x) if x.isdigit() else x)): | |
| if groups[tid]: | |
| ordered.append(groups[tid].pop()) | |
| progressed = True | |
| if not progressed: | |
| break | |
| return ordered | |
| def load_task(task_name: str, annotator_id: str = ""): | |
| if not task_name: | |
| raise gr.Error("Please select a task first.") | |
| rows = _load_task_rows(task_name) | |
| pairs_all = _build_image_pairs(rows, task_name) | |
| # Per-user filtering and global balancing | |
| def key_of(p: Dict[str, str]): | |
| return (p["test_id"], frozenset({p["model1_name"], p["model2_name"]}), p["org_img"]) | |
| user_done_keys = _read_user_done_keys(task_name, annotator_id) | |
| user_last_times = _read_user_last_times(task_name, annotator_id) | |
| global_counts = _read_eval_counts(task_name) | |
| # Main eligible set: not done by this user and below min raters threshold | |
| pairs = [ | |
| p for p in pairs_all | |
| if key_of(p) not in user_done_keys and global_counts.get(key_of(p), 0) < MIN_RATERS_PER_PAIR | |
| ] | |
| # Balanced schedule: prioritize low-count pairs, and within same count do round-robin by test_id | |
| seed_env = os.environ.get("SCHEDULE_SEED") | |
| seed = int(seed_env) if seed_env and seed_env.isdigit() else None | |
| def count_of(p: Dict[str, str]): | |
| return global_counts.get(key_of(p), 0) | |
| buckets: Dict[int, List[Dict[str, str]]] = {} | |
| for p in sorted(pairs, key=count_of): | |
| buckets.setdefault(count_of(p), []).append(p) | |
| ordered: List[Dict[str, str]] = [] | |
| for c in sorted(buckets.keys()): | |
| ordered.extend(_schedule_round_robin_by_test_id(buckets[c], seed=seed)) | |
| pairs = ordered | |
| # Deterministic rotation by user's progress to avoid always starting from the same pairs | |
| try: | |
| elig_keys = [key_of(p) for p in pairs] | |
| progress = len([k for k in user_done_keys if k in elig_keys]) | |
| if pairs: | |
| rot = progress % len(pairs) | |
| pairs = pairs[rot:] + pairs[:rot] | |
| except Exception: | |
| pass | |
| # Limit batch size | |
| main_batch = pairs[: max(0, BATCH_SIZE)] | |
| # Small proportion of spaced repeats for test-retest | |
| repeats: List[Dict[str, str]] = [] | |
| try: | |
| repeat_target = int(max(0, round(BATCH_SIZE * REPEAT_RATE))) | |
| if repeat_target > 0 and user_last_times: | |
| min_time = datetime.utcnow() - timedelta(hours=REPEAT_MIN_HOURS) | |
| candidates = [k for k, t in user_last_times.items() if t < min_time] | |
| def find_pair_from_key(k): | |
| tid, names, org = k | |
| for p in pairs_all: | |
| if p["test_id"] == tid and p["org_img"] == org and frozenset({p["model1_name"], p["model2_name"]}) == names: | |
| return p | |
| return None | |
| picked = 0 | |
| used_keys = {key_of(p) for p in main_batch} | |
| for k in candidates: | |
| if picked >= repeat_target: | |
| break | |
| p = find_pair_from_key(k) | |
| if not p: | |
| continue | |
| if key_of(p) in used_keys: | |
| continue | |
| repeats.append(p) | |
| used_keys.add(key_of(p)) | |
| picked += 1 | |
| except Exception: | |
| pass | |
| pairs = main_batch + repeats | |
| # Assign A/B order to counteract position bias: alternate after scheduling | |
| for idx, p in enumerate(pairs): | |
| p["swap"] = bool(idx % 2) # True -> A=B's image; False -> A=A's image | |
| if not pairs: | |
| try: | |
| print("[VisArena] No pending pairs.") | |
| print("[VisArena] total_pairs=", len(pairs_all)) | |
| print("[VisArena] already_done_by_user=", len(user_done_keys)) | |
| print("[VisArena] persist_csv=", _persist_csv_path_for_task(task_name)) | |
| except Exception: | |
| pass | |
| # Return empty list; UI will render an informative message instead of erroring out | |
| return [] | |
| return pairs | |
| def _format_pair_header(_pair: Dict[str, str]) -> str: | |
| # Mask model identity in UI; keep header neutral | |
| return "" | |
| def _build_eval_row(pair: Dict[str, str], scores: Dict[str, int]) -> Dict[str, object]: | |
| row = { | |
| "eval_date": datetime.utcnow().isoformat(), | |
| "test_id": pair["test_id"], | |
| "model1_name": pair["model1_name"], | |
| "model2_name": pair["model2_name"], | |
| "org_img": pair["org_img"], | |
| "model1_res": pair["model1_res"], | |
| "model2_res": pair["model2_res"], | |
| "model1_path": pair["model1_path"], | |
| "model2_path": pair["model2_path"], | |
| } | |
| row.update(scores) | |
| return row | |
| def _local_persist_csv_path(task_name: str) -> str: | |
| folder = TASK_CONFIG[task_name]["folder"] | |
| return os.path.join(PERSIST_DIR, folder, "evaluation_results.csv") | |
| def _append_local_persist_csv(task_name: str, row: Dict[str, object]) -> bool: | |
| csv_path = _local_persist_csv_path(task_name) | |
| os.makedirs(os.path.dirname(csv_path), exist_ok=True) | |
| csv_exists = os.path.exists(csv_path) | |
| fieldnames = [ | |
| "eval_date", | |
| "annotator_id", | |
| "session_id", | |
| "view_start_ts", | |
| "submit_ts", | |
| "duration_sec", | |
| "is_fast", | |
| "is_flat_a", | |
| "is_flat_b", | |
| "test_id", | |
| "model1_name", | |
| "model2_name", | |
| "org_img", | |
| "model1_res", | |
| "model2_res", | |
| "model1_path", | |
| "model2_path", | |
| "model1_physical_interaction_fidelity_score", | |
| "model1_optical_effect_accuracy_score", | |
| "model1_semantic_functional_alignment_score", | |
| "model1_overall_photorealism_score", | |
| "model2_physical_interaction_fidelity_score", | |
| "model2_optical_effect_accuracy_score", | |
| "model2_semantic_functional_alignment_score", | |
| "model2_overall_photorealism_score", | |
| ] | |
| try: | |
| with open(csv_path, "a", newline="", encoding="utf-8") as csv_file: | |
| writer = csv.DictWriter(csv_file, fieldnames=fieldnames) | |
| if not csv_exists: | |
| writer.writeheader() | |
| writer.writerow(row) | |
| return True | |
| except Exception: | |
| return False | |
| def _upload_eval_record_to_dataset(task_name: str, row: Dict[str, object]) -> Tuple[bool, str]: | |
| """Upload a single-eval JSONL record to a dataset repo. | |
| Repo is taken from EVAL_REPO_ID env or defaults to 'peiranli0930/VisEval'. | |
| Returns (ok, message) for UI feedback and debugging. | |
| """ | |
| if HfApi is None: | |
| return False, "huggingface_hub not installed" | |
| token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| repo_id = os.environ.get("EVAL_REPO_ID", "peiranli0930/VisEval") | |
| if not token: | |
| return False, "Missing write token (HF_TOKEN/HUGGINGFACEHUB_API_TOKEN)" | |
| if not repo_id: | |
| return False, "EVAL_REPO_ID is not set" | |
| try: | |
| from huggingface_hub import CommitOperationAdd | |
| api = HfApi(token=token) | |
| date_prefix = datetime.utcnow().strftime("%Y-%m-%d") | |
| folder = TASK_CONFIG[task_name]["folder"] | |
| uid = str(uuid.uuid4()) | |
| path_in_repo = f"submissions/{folder}/{date_prefix}/{uid}.jsonl" | |
| payload = (json.dumps(row, ensure_ascii=False) + "\n").encode("utf-8") | |
| operations = [CommitOperationAdd(path_in_repo=path_in_repo, path_or_fileobj=BytesIO(payload))] | |
| api.create_commit( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| operations=operations, | |
| commit_message=f"Add eval {folder} {row.get('test_id')} {uid}", | |
| ) | |
| return True, f"Uploaded: {repo_id}/{path_in_repo}" | |
| except Exception as e: | |
| # Print to logs for debugging in Space | |
| try: | |
| print("[VisArena] Upload to dataset failed:", repr(e)) | |
| except Exception: | |
| pass | |
| return False, f"Exception: {type(e).__name__}: {e}" | |
| def _extract_annotator_id(request: Optional[gr.Request]) -> str: | |
| """Best-effort extraction of a stable user identifier on HF Spaces. | |
| Priority: request.username -> X-Forwarded-User header -> cookies/user-id -> env/session fallback. | |
| """ | |
| try: | |
| if request is None: | |
| return "" | |
| # gradio>=4.0 may set username for Spaces-authenticated users | |
| username = getattr(request, "username", None) | |
| if username: | |
| return str(username) | |
| headers = getattr(request, "headers", {}) or {} | |
| for k in ("x-forwarded-user", "x-user", "x-hub-user"): | |
| v = headers.get(k) or headers.get(k.upper()) | |
| if v: | |
| return str(v) | |
| except Exception: | |
| pass | |
| return "" | |
| def on_task_change(task_name: str, _state_pairs: List[Dict[str, str]], request: gr.Request, | |
| view_started_at: float, session_quota: int, reload_count: int, session_id: str): | |
| annotator_id = _extract_annotator_id(request) | |
| if not annotator_id: | |
| default_scores = [3, 3, 3, 3, 3, 3, 3, 3] | |
| return ( | |
| [], | |
| gr.update(value=0, minimum=0, maximum=0, visible=False), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| *default_scores, | |
| gr.update(value="请先登录你的 Hugging Face 账户后再开始评测。"), | |
| float(datetime.utcnow().timestamp()), | |
| BATCH_SIZE, | |
| 0, | |
| session_id or str(uuid.uuid4()), | |
| ) | |
| pairs = load_task(task_name, annotator_id) | |
| # Defaults for A and B (8 sliders total) | |
| default_scores = [3, 3, 3, 3, 3, 3, 3, 3] | |
| if not pairs: | |
| return ( | |
| [], | |
| gr.update(value=0, minimum=0, maximum=0, visible=False), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| *default_scores, | |
| gr.update(value="当前没有待评对(或已达到最小标注阈值)。"), | |
| float(datetime.utcnow().timestamp()), | |
| BATCH_SIZE, | |
| 0, | |
| session_id or str(uuid.uuid4()), | |
| ) | |
| pair = pairs[0] | |
| header = _format_pair_header(pair) | |
| # Pick display order according to swap flag | |
| a_path = pair["model2_path"] if pair.get("swap") else pair["model1_path"] | |
| b_path = pair["model1_path"] if pair.get("swap") else pair["model2_path"] | |
| max_index = max(0, len(pairs) - 1) | |
| return ( | |
| pairs, | |
| gr.update(value=0, minimum=0, maximum=max_index, visible=(len(pairs) > 1)), | |
| gr.update(value=header), | |
| _resolve_image_path(pair["org_img"]), | |
| _resolve_image_path(a_path), | |
| _resolve_image_path(b_path), | |
| *default_scores, | |
| gr.update(value=f"本批次分配 {len(pairs)} 组;目标每对 {MIN_RATERS_PER_PAIR} 人。"), | |
| float(datetime.utcnow().timestamp()), | |
| BATCH_SIZE, | |
| 0, | |
| session_id or str(uuid.uuid4()), | |
| ) | |
| def on_pair_navigate(index: int, pairs: List[Dict[str, str]], view_started_at: float): | |
| if not pairs: | |
| # Gracefully no-op when no pairs | |
| return ( | |
| gr.update(value=0, minimum=0, maximum=0, visible=False), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| 3, 3, 3, 3, # A | |
| 3, 3, 3, 3, # B | |
| float(datetime.utcnow().timestamp()), | |
| ) | |
| index = int(index) | |
| index = max(0, min(index, len(pairs) - 1)) | |
| pair = pairs[index] | |
| header = _format_pair_header(pair) | |
| a_path = pair["model2_path"] if pair.get("swap") else pair["model1_path"] | |
| b_path = pair["model1_path"] if pair.get("swap") else pair["model2_path"] | |
| return ( | |
| gr.update(value=index), | |
| gr.update(value=header), | |
| _resolve_image_path(pair["org_img"]), | |
| _resolve_image_path(a_path), | |
| _resolve_image_path(b_path), | |
| 3, 3, 3, 3, # A | |
| 3, 3, 3, 3, # B | |
| float(datetime.utcnow().timestamp()), | |
| ) | |
| def on_submit( | |
| task_name: str, | |
| index: int, | |
| pairs: List[Dict[str, str]], | |
| a_physical_score: int, | |
| a_optical_score: int, | |
| a_semantic_score: int, | |
| a_overall_score: int, | |
| b_physical_score: int, | |
| b_optical_score: int, | |
| b_semantic_score: int, | |
| b_overall_score: int, | |
| request: gr.Request, | |
| view_started_at: float, | |
| session_quota: int, | |
| reload_count: int, | |
| session_id: str, | |
| ): | |
| if not task_name: | |
| return ( | |
| pairs, | |
| gr.update(value=0), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| 3, 3, 3, 3, | |
| 3, 3, 3, 3, | |
| gr.update(value="Please select a task first."), | |
| float(datetime.utcnow().timestamp()), | |
| session_quota, | |
| reload_count, | |
| session_id, | |
| ) | |
| if not pairs: | |
| return ( | |
| pairs, | |
| gr.update(value=0, minimum=0, maximum=0, visible=False), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| 3, 3, 3, 3, | |
| 3, 3, 3, 3, | |
| gr.update(value="No pending pairs to submit."), | |
| float(datetime.utcnow().timestamp()), | |
| session_quota, | |
| reload_count, | |
| session_id, | |
| ) | |
| # Resolve annotator id from request | |
| annotator_id = _extract_annotator_id(request) | |
| pair = pairs[index] | |
| score_map = { | |
| # Model A | |
| "model1_physical_interaction_fidelity_score": int(a_physical_score), | |
| "model1_optical_effect_accuracy_score": int(a_optical_score), | |
| "model1_semantic_functional_alignment_score": int(a_semantic_score), | |
| "model1_overall_photorealism_score": int(a_overall_score), | |
| # Model B | |
| "model2_physical_interaction_fidelity_score": int(b_physical_score), | |
| "model2_optical_effect_accuracy_score": int(b_optical_score), | |
| "model2_semantic_functional_alignment_score": int(b_semantic_score), | |
| "model2_overall_photorealism_score": int(b_overall_score), | |
| } | |
| # Map A/B scores to the correct model columns depending on swap | |
| if pair.get("swap"): | |
| # UI A == model2, UI B == model1 | |
| score_map = { | |
| "model1_physical_interaction_fidelity_score": int(b_physical_score), | |
| "model1_optical_effect_accuracy_score": int(b_optical_score), | |
| "model1_semantic_functional_alignment_score": int(b_semantic_score), | |
| "model1_overall_photorealism_score": int(b_overall_score), | |
| "model2_physical_interaction_fidelity_score": int(a_physical_score), | |
| "model2_optical_effect_accuracy_score": int(a_optical_score), | |
| "model2_semantic_functional_alignment_score": int(a_semantic_score), | |
| "model2_overall_photorealism_score": int(a_overall_score), | |
| } | |
| # Build record | |
| row = _build_eval_row(pair, score_map) | |
| row["annotator_id"] = annotator_id | |
| # timing + heuristics | |
| submit_ts = datetime.utcnow() | |
| try: | |
| started = datetime.utcfromtimestamp(float(view_started_at)) if view_started_at else submit_ts | |
| except Exception: | |
| started = submit_ts | |
| duration = max(0.0, (submit_ts - started).total_seconds()) | |
| row["view_start_ts"] = started.isoformat() | |
| row["submit_ts"] = submit_ts.isoformat() | |
| row["duration_sec"] = round(duration, 3) | |
| row["is_fast"] = bool(duration < FAST_MIN_SEC) | |
| row["is_flat_a"] = bool(len({int(a_physical_score), int(a_optical_score), int(a_semantic_score), int(a_overall_score)}) == 1) | |
| row["is_flat_b"] = bool(len({int(b_physical_score), int(b_optical_score), int(b_semantic_score), int(b_overall_score)}) == 1) | |
| row["session_id"] = session_id or str(uuid.uuid4()) | |
| # Idempotency: check if this pair already evaluated; if so, skip writing | |
| done_keys = _read_user_done_keys(task_name, annotator_id) | |
| eval_key = (pair["test_id"], frozenset({pair["model1_name"], pair["model2_name"]}), pair["org_img"]) | |
| if eval_key in done_keys: | |
| ok_local = False | |
| ok_hub, hub_msg = (False, "Skipped duplicate; already evaluated.") | |
| info_prefix = "Skipped duplicate submission." | |
| else: | |
| ok_local = _append_local_persist_csv(task_name, row) | |
| # add key locally for subsequent filtering in this call | |
| if ok_local: | |
| done_keys.add(eval_key) | |
| ok_hub, hub_msg = _upload_eval_record_to_dataset(task_name, row) | |
| info_prefix = "Saved evaluation." | |
| # Recompute remaining pairs by filtering current state against done_keys | |
| def key_of(p: Dict[str, str]): | |
| return (p["test_id"], frozenset({p["model1_name"], p["model2_name"]}), p["org_img"]) | |
| remaining_pairs = [p for p in pairs if key_of(p) not in done_keys] | |
| info = f"{info_prefix} Local persistence " + ("succeeded" if ok_local else "skipped/failed") + "." | |
| info += " Dataset upload " + ("succeeded" if ok_hub else "failed") + (f" ({hub_msg})" if hub_msg else "") + "." | |
| # Quota + reload | |
| session_quota = max(0, int(session_quota) - 1) | |
| reload_count = int(reload_count) + 1 | |
| # Periodic reload to absorb new results.csv / re-balance | |
| if reload_count >= RELOAD_EVERY: | |
| fresh_pairs = load_task(task_name, annotator_id) | |
| remaining_pairs = fresh_pairs | |
| reload_count = 0 | |
| if session_quota <= 0: | |
| return ( | |
| [], | |
| gr.update(value=0, minimum=0, maximum=0, visible=False), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| 3, 3, 3, 3, | |
| 3, 3, 3, 3, | |
| gr.update(value=info + " 本批次已完成 20 组,请刷新页面获取下一批次。"), | |
| float(datetime.utcnow().timestamp()), | |
| session_quota, | |
| reload_count, | |
| row["session_id"], | |
| ) | |
| if remaining_pairs: | |
| next_index = min(index, len(remaining_pairs) - 1) | |
| pair = remaining_pairs[next_index] | |
| header = _format_pair_header(pair) | |
| a_path = pair["model2_path"] if pair.get("swap") else pair["model1_path"] | |
| b_path = pair["model1_path"] if pair.get("swap") else pair["model2_path"] | |
| return ( | |
| remaining_pairs, | |
| gr.update(value=next_index), | |
| gr.update(value=header), | |
| _resolve_image_path(pair["org_img"]), | |
| _resolve_image_path(a_path), | |
| _resolve_image_path(b_path), | |
| 3, 3, 3, 3, | |
| 3, 3, 3, 3, | |
| gr.update(value=info + f" Next pair ({next_index + 1}/{len(remaining_pairs)})."), | |
| float(datetime.utcnow().timestamp()), | |
| session_quota, | |
| reload_count, | |
| row["session_id"], | |
| ) | |
| # No remaining pairs: clear UI, hide slider, and return updated empty state | |
| return ( | |
| [], | |
| gr.update(value=0, minimum=0, maximum=0, visible=False), | |
| gr.update(value=""), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| 3, 3, 3, 3, | |
| 3, 3, 3, 3, | |
| gr.update(value=info + " All pairs completed."), | |
| float(datetime.utcnow().timestamp()), | |
| session_quota, | |
| reload_count, | |
| row["session_id"], | |
| ) | |
| with gr.Blocks(title="VisArena Human Evaluation") as demo: | |
| gr.Markdown( | |
| """ | |
| # VisArena Human Evaluation | |
| Please select a task and rate the generated images. Each score ranges from 1 (poor) to 5 (excellent). | |
| """ | |
| ) | |
| with gr.Row(): | |
| task_selector = gr.Dropdown( | |
| label="Task", | |
| choices=list(TASK_CONFIG.keys()), | |
| interactive=True, | |
| value="Scene Composition & Object Insertion", | |
| ) | |
| index_slider = gr.Slider( | |
| label="Pair Index", | |
| value=0, | |
| minimum=0, | |
| maximum=0, | |
| step=1, | |
| interactive=True, | |
| visible=False, | |
| ) | |
| pair_state = gr.State([]) | |
| # Hidden states for session control and metrics | |
| view_started_at_state = gr.State(0.0) | |
| session_quota_state = gr.State(BATCH_SIZE) | |
| reload_count_state = gr.State(0) | |
| session_id_state = gr.State("") | |
| pair_header = gr.Markdown("") | |
| # Layout: Original on top, two outputs below with their own sliders | |
| with gr.Row(): | |
| with gr.Column(scale=12): | |
| orig_image = gr.Image(type="filepath", label="Original", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| model1_image = gr.Image(type="filepath", label="Output A", interactive=False) | |
| a_physical_input = gr.Slider(1, 5, value=3, step=1, label="A: Physical Interaction Fidelity") | |
| a_optical_input = gr.Slider(1, 5, value=3, step=1, label="A: Optical Effect Accuracy") | |
| a_semantic_input = gr.Slider(1, 5, value=3, step=1, label="A: Semantic/Functional Alignment") | |
| a_overall_input = gr.Slider(1, 5, value=3, step=1, label="A: Overall Photorealism") | |
| with gr.Column(scale=6): | |
| model2_image = gr.Image(type="filepath", label="Output B", interactive=False) | |
| b_physical_input = gr.Slider(1, 5, value=3, step=1, label="B: Physical Interaction Fidelity") | |
| b_optical_input = gr.Slider(1, 5, value=3, step=1, label="B: Optical Effect Accuracy") | |
| b_semantic_input = gr.Slider(1, 5, value=3, step=1, label="B: Semantic/Functional Alignment") | |
| b_overall_input = gr.Slider(1, 5, value=3, step=1, label="B: Overall Photorealism") | |
| submit_button = gr.Button("Submit Evaluation", variant="primary") | |
| feedback_box = gr.Markdown("") | |
| # Event bindings | |
| task_selector.change( | |
| fn=on_task_change, | |
| inputs=[task_selector, pair_state, view_started_at_state, session_quota_state, reload_count_state, session_id_state], | |
| outputs=[ | |
| pair_state, | |
| index_slider, | |
| pair_header, | |
| orig_image, | |
| model1_image, | |
| model2_image, | |
| a_physical_input, | |
| a_optical_input, | |
| a_semantic_input, | |
| a_overall_input, | |
| b_physical_input, | |
| b_optical_input, | |
| b_semantic_input, | |
| b_overall_input, | |
| feedback_box, | |
| view_started_at_state, | |
| session_quota_state, | |
| reload_count_state, | |
| session_id_state, | |
| ], | |
| ) | |
| index_slider.release( | |
| fn=on_pair_navigate, | |
| inputs=[index_slider, pair_state, view_started_at_state], | |
| outputs=[ | |
| index_slider, | |
| pair_header, | |
| orig_image, | |
| model1_image, | |
| model2_image, | |
| a_physical_input, | |
| a_optical_input, | |
| a_semantic_input, | |
| a_overall_input, | |
| b_physical_input, | |
| b_optical_input, | |
| b_semantic_input, | |
| b_overall_input, | |
| view_started_at_state, | |
| ], | |
| ) | |
| submit_button.click( | |
| fn=on_submit, | |
| inputs=[ | |
| task_selector, | |
| index_slider, | |
| pair_state, | |
| a_physical_input, | |
| a_optical_input, | |
| a_semantic_input, | |
| a_overall_input, | |
| b_physical_input, | |
| b_optical_input, | |
| b_semantic_input, | |
| b_overall_input, | |
| view_started_at_state, | |
| session_quota_state, | |
| reload_count_state, | |
| session_id_state, | |
| ], | |
| outputs=[ | |
| pair_state, | |
| index_slider, | |
| pair_header, | |
| orig_image, | |
| model1_image, | |
| model2_image, | |
| a_physical_input, | |
| a_optical_input, | |
| a_semantic_input, | |
| a_overall_input, | |
| b_physical_input, | |
| b_optical_input, | |
| b_semantic_input, | |
| b_overall_input, | |
| feedback_box, | |
| view_started_at_state, | |
| session_quota_state, | |
| reload_count_state, | |
| session_id_state, | |
| ], | |
| ) | |
| # Auto-load default task on startup | |
| demo.load( | |
| fn=on_task_change, | |
| inputs=[task_selector, pair_state, view_started_at_state, session_quota_state, reload_count_state, session_id_state], | |
| outputs=[ | |
| pair_state, | |
| index_slider, | |
| pair_header, | |
| orig_image, | |
| model1_image, | |
| model2_image, | |
| a_physical_input, | |
| a_optical_input, | |
| a_semantic_input, | |
| a_overall_input, | |
| b_physical_input, | |
| b_optical_input, | |
| b_semantic_input, | |
| b_overall_input, | |
| feedback_box, | |
| view_started_at_state, | |
| session_quota_state, | |
| reload_count_state, | |
| session_id_state, | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |