Utilize HF's "balanced" device_map + dynamically pair diffusion components to relevant execution cores

#1
by diopside - opened
Files changed (1) hide show
  1. app.py +34 -7
app.py CHANGED
@@ -71,12 +71,39 @@ def use_output_as_input(output_image):
71
  base_model = "Qwen/Qwen-Image"
72
  controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
73
 
74
- controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
75
-
76
  pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
77
- base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
 
 
 
78
  )
79
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  @spaces.GPU(duration=150)
@@ -93,7 +120,7 @@ def infer(edit_images,
93
 
94
  image = edit_images["background"]
95
  mask = edit_images["layers"][0]
96
-
97
  if randomize_seed:
98
  seed = random.randint(0, MAX_SEED)
99
 
@@ -113,7 +140,7 @@ def infer(edit_images,
113
  width=image.size[0],
114
  height=image.size[1],
115
  true_cfg_scale=true_cfg_scale,
116
- generator=torch.Generator(device="cuda").manual_seed(seed)
117
  ).images[0]
118
 
119
  return [image, result_image], seed
@@ -140,7 +167,7 @@ css = """
140
 
141
 
142
  with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
143
- gr.HTML("<h1 style='text-align: center'>Qwen-Image with InstantX Inpainting ControlNet</style>")
144
  gr.Markdown(
145
  "Inpaint images with [InstantX/Qwen-Image-ControlNet-Inpainting](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting)"
146
  )
 
71
  base_model = "Qwen/Qwen-Image"
72
  controlnet_model = "InstantX/Qwen-Image-ControlNet-Inpainting"
73
 
74
+ # First create the pipeline with device_map="balanced"
 
75
  pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
76
+ base_model,
77
+ controlnet=None, # We'll add the controlnet later
78
+ torch_dtype=torch.bfloat16,
79
+ device_map="balanced"
80
  )
81
+
82
+ pipe_device_map = pipe.hf_device_map
83
+ print("Initial device map:", pipe_device_map)
84
+ # Expected output: {'transformer': 0, 'text_encoder': 1, 'vae': 2}
85
+
86
+ # Move the controlnet to the same device as the VAE (cuda:2)
87
+ vae_device = pipe_device_map['vae']
88
+ vae_device = f"cuda:{vae_device}" # This is where the VAE is in the balanced config
89
+ controlnet = QwenImageControlNetModel.from_pretrained(
90
+ controlnet_model,
91
+ torch_dtype=torch.bfloat16
92
+ ).to(vae_device)
93
+
94
+ # Attach the controlnet to the pipeline
95
+ pipe.controlnet = controlnet
96
+
97
+ pipe.enable_vae_slicing()
98
+ pipe.enable_vae_tiling()
99
+
100
+ print("Controlnet device:", next(pipe.controlnet.parameters()).device)
101
+ print("VAE device:", next(pipe.vae.parameters()).device)
102
+
103
+
104
+ # Create a helper function to get a generator on the correct device
105
+ def get_generator(seed):
106
+ return torch.Generator(device=vae_device).manual_seed(seed)
107
 
108
 
109
  @spaces.GPU(duration=150)
 
120
 
121
  image = edit_images["background"]
122
  mask = edit_images["layers"][0]
123
+
124
  if randomize_seed:
125
  seed = random.randint(0, MAX_SEED)
126
 
 
140
  width=image.size[0],
141
  height=image.size[1],
142
  true_cfg_scale=true_cfg_scale,
143
+ generator=get_generator(seed)
144
  ).images[0]
145
 
146
  return [image, result_image], seed
 
167
 
168
 
169
  with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
170
+ gr.HTML("<h1 style='text-align: center'>Qwen-Image + InstantX Inpainting ControlNet</style>")
171
  gr.Markdown(
172
  "Inpaint images with [InstantX/Qwen-Image-ControlNet-Inpainting](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting)"
173
  )