Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import logging | |
| import os | |
| import sys | |
| import PIL.Image | |
| import torch | |
| from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline, PNDMPipeline, | |
| PNDMScheduler) | |
| HF_TOKEN = os.environ['HF_TOKEN'] | |
| formatter = logging.Formatter( | |
| '[%(asctime)s] %(name)s %(levelname)s: %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S') | |
| stream_handler = logging.StreamHandler(stream=sys.stdout) | |
| stream_handler.setLevel(logging.INFO) | |
| stream_handler.setFormatter(formatter) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| logger.propagate = False | |
| logger.addHandler(stream_handler) | |
| class Model: | |
| MODEL_NAMES = [ | |
| 'ddpm-128-exp000', | |
| ] | |
| def __init__(self, device: str | torch.device): | |
| self.device = torch.device(device) | |
| self._download_all_models() | |
| self.model_name = self.MODEL_NAMES[0] | |
| self.scheduler_type = 'DDIM' | |
| self.pipeline = self._load_pipeline(self.model_name, | |
| self.scheduler_type) | |
| def _load_pipeline(self, model_name: str, | |
| scheduler_type: str) -> DDIMPipeline | DDPMPipeline: | |
| repo_id = f'hysts/diffusers-anime-faces-{model_name}' | |
| if scheduler_type == 'DDPM': | |
| pipeline = DDPMPipeline.from_pretrained(repo_id, | |
| use_auth_token=HF_TOKEN) | |
| elif scheduler_type == 'DDIM': | |
| pipeline = DDIMPipeline.from_pretrained(repo_id, | |
| use_auth_token=HF_TOKEN) | |
| config, _ = DDIMScheduler.extract_init_dict( | |
| dict(pipeline.scheduler.config)) | |
| pipeline.scheduler = DDIMScheduler(**config) | |
| elif scheduler_type == 'PNDM': | |
| pipeline = PNDMPipeline.from_pretrained(repo_id, | |
| use_auth_token=HF_TOKEN) | |
| config, _ = PNDMScheduler.extract_init_dict( | |
| dict(pipeline.scheduler.config)) | |
| pipeline.scheduler = PNDMScheduler(**config) | |
| else: | |
| raise ValueError | |
| return pipeline | |
| def set_pipeline(self, model_name: str, scheduler_type: str) -> None: | |
| logger.info('--- set_pipeline ---') | |
| logger.info(f'{model_name=}, {scheduler_type=}') | |
| if model_name == self.model_name and scheduler_type == self.scheduler_type: | |
| logger.info('Skipping') | |
| logger.info('--- done ---') | |
| return | |
| self.model_name = model_name | |
| self.scheduler_type = scheduler_type | |
| self.pipeline = self._load_pipeline(model_name, scheduler_type) | |
| logger.info('--- done ---') | |
| def _download_all_models(self): | |
| for name in self.MODEL_NAMES: | |
| self._load_pipeline(name, 'DDPM') | |
| def generate(self, seed: int, num_steps: int) -> PIL.Image.Image: | |
| logger.info('--- generate ---') | |
| logger.info(f'{seed=}, {num_steps=}') | |
| torch.manual_seed(seed) | |
| if self.scheduler_type == 'DDPM': | |
| res = self.pipeline(batch_size=1, | |
| torch_device=self.device)['sample'][0] | |
| elif self.scheduler_type in ['DDIM', 'PNDM']: | |
| res = self.pipeline(batch_size=1, | |
| torch_device=self.device, | |
| num_inference_steps=num_steps)['sample'][0] | |
| else: | |
| raise ValueError | |
| logger.info('--- done ---') | |
| return res | |
| def run( | |
| self, | |
| model_name: str, | |
| scheduler_type: str, | |
| num_steps: int, | |
| seed: int, | |
| ) -> PIL.Image.Image: | |
| self.set_pipeline(model_name, scheduler_type) | |
| if scheduler_type == 'PNDM': | |
| num_steps = max(4, min(num_steps, 100)) | |
| return self.generate(seed, num_steps) | |