kr-cen commited on
Commit
aa391b3
·
verified ·
1 Parent(s): 4c42ed2

Delete transformer/modeling_qwen_image.py

Browse files
Files changed (1) hide show
  1. transformer/modeling_qwen_image.py +0 -935
transformer/modeling_qwen_image.py DELETED
@@ -1,935 +0,0 @@
1
- # Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- import math
17
- from typing import Any, Callable, Dict, List, Optional, Union
18
-
19
- from PIL import Image
20
- import numpy as np
21
- import torch
22
- from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
23
-
24
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
25
- from diffusers.loaders import QwenImageLoraLoaderMixin
26
- from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
27
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
28
- from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
29
- from diffusers.utils.torch_utils import randn_tensor
30
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
- from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
32
-
33
- from transformers import PretrainedConfig
34
-
35
- class QwenImageConfig(PretrainedConfig):
36
- model_type = "qwen_image_transformer" # 必须与 json 中的 model_type 一致
37
-
38
- def __init__(
39
- self,
40
- attention_head_dim=128,
41
- num_attention_heads=24,
42
- num_layers=60,
43
- in_channels=64,
44
- out_channels=16,
45
- patch_size=2,
46
- joint_attention_dim=3584,
47
- axes_dims_rope=[16, 56, 56],
48
- guidance_embeds=False,
49
- **kwargs,
50
- ):
51
- self.attention_head_dim = attention_head_dim
52
- self.num_attention_heads = num_attention_heads
53
- self.num_layers = num_layers
54
- self.in_channels = in_channels
55
- self.out_channels = out_channels
56
- self.patch_size = patch_size
57
- self.joint_attention_dim = joint_attention_dim
58
- self.axes_dims_rope = axes_dims_rope
59
- self.guidance_embeds = guidance_embeds
60
- super().__init__(**kwargs)
61
-
62
-
63
- if is_torch_xla_available():
64
- import torch_xla.core.xla_model as xm
65
-
66
- XLA_AVAILABLE = True
67
- else:
68
- XLA_AVAILABLE = False
69
-
70
-
71
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
72
-
73
- EXAMPLE_DOC_STRING = """
74
- Examples:
75
- ```py
76
- >>> import torch
77
- >>> from PIL import Image
78
- >>> from diffusers import QwenImageEditPipeline
79
- >>> from diffusers.utils import load_image
80
-
81
- >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
82
- >>> pipe.to("cuda")
83
- >>> image = load_image(
84
- ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
85
- ... ).convert("RGB")
86
- >>> prompt = (
87
- ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
88
- ... )
89
- >>> # Depending on the variant being used, the pipeline call will slightly vary.
90
- >>> # Refer to the pipeline documentation for more details.
91
- >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
92
- >>> image.save("qwenimage_edit.png")
93
- ```
94
- """
95
-
96
-
97
- # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
98
- def calculate_shift(
99
- image_seq_len,
100
- base_seq_len: int = 256,
101
- max_seq_len: int = 4096,
102
- base_shift: float = 0.5,
103
- max_shift: float = 1.15,
104
- ):
105
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
106
- b = base_shift - m * base_seq_len
107
- mu = image_seq_len * m + b
108
- return mu
109
-
110
-
111
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
112
- def retrieve_timesteps(
113
- scheduler,
114
- num_inference_steps: Optional[int] = None,
115
- device: Optional[Union[str, torch.device]] = None,
116
- timesteps: Optional[List[int]] = None,
117
- sigmas: Optional[List[float]] = None,
118
- **kwargs,
119
- ):
120
- r"""
121
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
122
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
123
-
124
- Args:
125
- scheduler (`SchedulerMixin`):
126
- The scheduler to get timesteps from.
127
- num_inference_steps (`int`):
128
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
129
- must be `None`.
130
- device (`str` or `torch.device`, *optional*):
131
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
132
- timesteps (`List[int]`, *optional*):
133
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
134
- `num_inference_steps` and `sigmas` must be `None`.
135
- sigmas (`List[float]`, *optional*):
136
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
137
- `num_inference_steps` and `timesteps` must be `None`.
138
-
139
- Returns:
140
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
141
- second element is the number of inference steps.
142
- """
143
- if timesteps is not None and sigmas is not None:
144
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
145
- if timesteps is not None:
146
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
147
- if not accepts_timesteps:
148
- raise ValueError(
149
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
150
- f" timestep schedules. Please check whether you are using the correct scheduler."
151
- )
152
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
153
- timesteps = scheduler.timesteps
154
- num_inference_steps = len(timesteps)
155
- elif sigmas is not None:
156
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
157
- if not accept_sigmas:
158
- raise ValueError(
159
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
160
- f" sigmas schedules. Please check whether you are using the correct scheduler."
161
- )
162
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
163
- timesteps = scheduler.timesteps
164
- num_inference_steps = len(timesteps)
165
- else:
166
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
167
- timesteps = scheduler.timesteps
168
- return timesteps, num_inference_steps
169
-
170
-
171
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
172
- def retrieve_latents(
173
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
174
- ):
175
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
176
- return encoder_output.latent_dist.sample(generator)
177
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
178
- return encoder_output.latent_dist.mode()
179
- elif hasattr(encoder_output, "latents"):
180
- return encoder_output.latents
181
- else:
182
- raise AttributeError("Could not access latents of provided encoder_output")
183
-
184
-
185
- def calculate_dimensions(target_area, ratio):
186
- width = math.sqrt(target_area * ratio)
187
- height = width / ratio
188
-
189
- width = round(width / 32) * 32
190
- height = round(height / 32) * 32
191
-
192
- return width, height, None
193
-
194
-
195
- def resize_to_multiple_of(image, multiple_of=32):
196
- width, height = image.size
197
- width = round(width / multiple_of) * multiple_of
198
- height = round(height / multiple_of) * multiple_of
199
-
200
- image = image.resize((width, height))
201
-
202
- return image
203
-
204
-
205
-
206
- class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
207
- r"""
208
- The Qwen-Image-Edit pipeline for image editing.
209
-
210
- Args:
211
- transformer ([`QwenImageTransformer2DModel`]):
212
- Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
213
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
214
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
215
- vae ([`AutoencoderKL`]):
216
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
217
- text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
218
- [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
219
- [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
220
- tokenizer (`QwenTokenizer`):
221
- Tokenizer of class
222
- [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
223
- """
224
-
225
- model_cpu_offload_seq = "text_encoder->transformer->vae"
226
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
227
-
228
- def __init__(
229
- self,
230
- scheduler: FlowMatchEulerDiscreteScheduler,
231
- vae: AutoencoderKLQwenImage,
232
- text_encoder: Qwen2_5_VLForConditionalGeneration,
233
- tokenizer: Qwen2Tokenizer,
234
- processor: Qwen2VLProcessor,
235
- transformer: QwenImageTransformer2DModel,):
236
- super().__init__()
237
-
238
- self.register_modules(
239
- vae=vae,
240
- text_encoder=text_encoder,
241
- tokenizer=tokenizer,
242
- processor=processor,
243
- transformer=transformer,
244
- scheduler=scheduler,
245
- )
246
- self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
247
- self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
248
- # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
249
- # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
250
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
251
- self.vl_processor = processor
252
- self.tokenizer_max_length = 1024
253
-
254
- self.system_message = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
255
- self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
256
- self.prompt_template_encode_start_idx = 64
257
- self.default_sample_size = 128
258
-
259
- # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
260
- def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
261
- bool_mask = mask.bool()
262
- valid_lengths = bool_mask.sum(dim=1)
263
- selected = hidden_states[bool_mask]
264
- split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
265
-
266
- return split_result
267
-
268
- def _get_qwen_prompt_embeds(
269
- self,
270
- prompts: Union[str, List[str]] = None,
271
- images: List[List[Image.Image]] = None,
272
- device: Optional[torch.device] = None,
273
- dtype: Optional[torch.dtype] = None,):
274
- device = device or self._execution_device
275
- dtype = dtype or self.text_encoder.dtype
276
-
277
- prompts = [prompts] if isinstance(prompts, str) else prompts
278
-
279
- if isinstance(images, Image.Image):
280
- images = [[images], ]
281
- elif isinstance(images[0], Image.Image):
282
- images = [images, ]
283
- assert len(prompts) == len(images)
284
-
285
- texts = []
286
-
287
- for prompt, image_list in zip(prompts, images):
288
- messages = [
289
- {
290
- "role": "system",
291
- "content": self.system_message,},
292
- {
293
- "role": "user",
294
- "content": [{"type": "image", "image": image} for image in image_list]
295
- + [{"type": "text", "text": prompt}, ],
296
- },
297
- ]
298
-
299
- # Apply chat template
300
- text = self.processor.apply_chat_template(
301
- messages,
302
- tokenize=False,
303
- add_generation_prompt=True
304
- )
305
- texts.append(text)
306
-
307
- # Process inputs
308
- model_inputs = self.processor(
309
- text=texts,
310
- images=images,
311
- do_resize=False, # already resized
312
- padding=True,
313
- return_tensors="pt"
314
- ).to(self.device)
315
-
316
- # template = self.prompt_template_encode
317
- drop_idx = self.prompt_template_encode_start_idx
318
- # txt = [template.format(e) for e in prompt]
319
-
320
- # model_inputs = self.processor(
321
- # text=txt,
322
- # images=image,
323
- # padding=True,
324
- # return_tensors="pt",
325
- # ).to(device)
326
-
327
- outputs = self.text_encoder(
328
- input_ids=model_inputs.input_ids,
329
- attention_mask=model_inputs.attention_mask,
330
- pixel_values=model_inputs.pixel_values,
331
- image_grid_thw=model_inputs.image_grid_thw,
332
- output_hidden_states=True,
333
- )
334
- # import pdb; pdb.set_trace()
335
-
336
- hidden_states = outputs.hidden_states[-1]
337
- split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
338
- split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
339
- attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
340
- max_seq_len = max([e.size(0) for e in split_hidden_states])
341
- prompt_embeds = torch.stack(
342
- [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
343
- )
344
- encoder_attention_mask = torch.stack(
345
- [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
346
- )
347
-
348
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
349
-
350
- return prompt_embeds, encoder_attention_mask
351
-
352
- def encode_prompt(
353
- self,
354
- prompt: Union[str, List[str]],
355
- images: List[Image.Image] = None,
356
- device: Optional[torch.device] = None,
357
- num_images_per_prompt: int = 1,
358
- prompt_embeds: Optional[torch.Tensor] = None,
359
- prompt_embeds_mask: Optional[torch.Tensor] = None,
360
- max_sequence_length: int = 1024,):
361
- r"""
362
-
363
- Args:
364
- prompt (`str` or `List[str]`, *optional*):
365
- prompt to be encoded
366
- images (`List[Image.Image]`, *optional*):
367
- images to be encoded
368
- device: (`torch.device`):
369
- torch device
370
- num_images_per_prompt (`int`):
371
- number of images that should be generated per prompt
372
- prompt_embeds (`torch.Tensor`, *optional*):
373
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
374
- provided, text embeddings will be generated from `prompt` input argument.
375
- """
376
- device = device or self._execution_device
377
-
378
- prompt = [prompt] if isinstance(prompt, str) else prompt
379
- batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
380
-
381
- if prompt_embeds is None:
382
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, images, device)
383
-
384
- _, seq_len, _ = prompt_embeds.shape
385
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
386
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
387
- prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
388
- prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
389
-
390
- return prompt_embeds, prompt_embeds_mask
391
-
392
- def check_inputs(
393
- self,
394
- prompt,
395
- height,
396
- width,
397
- negative_prompt=None,
398
- prompt_embeds=None,
399
- negative_prompt_embeds=None,
400
- prompt_embeds_mask=None,
401
- negative_prompt_embeds_mask=None,
402
- callback_on_step_end_tensor_inputs=None,
403
- max_sequence_length=None,):
404
- if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
405
- logger.warning(
406
- f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
407
- )
408
-
409
- if callback_on_step_end_tensor_inputs is not None and not all(
410
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
411
- ):
412
- raise ValueError(
413
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
414
- )
415
-
416
- if prompt is not None and prompt_embeds is not None:
417
- raise ValueError(
418
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
419
- " only forward one of the two."
420
- )
421
- elif prompt is None and prompt_embeds is None:
422
- raise ValueError(
423
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
424
- )
425
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
426
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
427
-
428
- if negative_prompt is not None and negative_prompt_embeds is not None:
429
- raise ValueError(
430
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
431
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
432
- )
433
-
434
- if prompt_embeds is not None and prompt_embeds_mask is None:
435
- raise ValueError(
436
- "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
437
- )
438
- if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
439
- raise ValueError(
440
- "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
441
- )
442
-
443
- if max_sequence_length is not None and max_sequence_length > 1024:
444
- raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
445
-
446
- @staticmethod
447
- # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
448
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
449
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
450
- latents = latents.permute(0, 2, 4, 1, 3, 5)
451
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
452
-
453
- return latents
454
-
455
- @staticmethod
456
- # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
457
- def _unpack_latents(latents, height, width, vae_scale_factor):
458
- batch_size, num_patches, channels = latents.shape
459
-
460
- # VAE applies 8x compression on images but we must also account for packing which requires
461
- # latent height and width to be divisible by 2.
462
- height = 2 * (int(height) // (vae_scale_factor * 2))
463
- width = 2 * (int(width) // (vae_scale_factor * 2))
464
-
465
- latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
466
- latents = latents.permute(0, 3, 1, 4, 2, 5)
467
-
468
- latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
469
-
470
- return latents
471
-
472
- def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
473
- if isinstance(generator, list):
474
- image_latents = [
475
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
476
- for i in range(image.shape[0])
477
- ]
478
- image_latents = torch.cat(image_latents, dim=0)
479
- else:
480
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
481
- latents_mean = (
482
- torch.tensor(self.vae.config.latents_mean)
483
- .view(1, self.latent_channels, 1, 1, 1)
484
- .to(image_latents.device, image_latents.dtype)
485
- )
486
- latents_std = (
487
- torch.tensor(self.vae.config.latents_std)
488
- .view(1, self.latent_channels, 1, 1, 1)
489
- .to(image_latents.device, image_latents.dtype)
490
- )
491
- image_latents = (image_latents - latents_mean) / latents_std
492
-
493
- return image_latents
494
-
495
- def enable_vae_slicing(self):
496
- r"""
497
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
498
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
499
- """
500
- self.vae.enable_slicing()
501
-
502
- def disable_vae_slicing(self):
503
- r"""
504
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
505
- computing decoding in one step.
506
- """
507
- self.vae.disable_slicing()
508
-
509
- def enable_vae_tiling(self):
510
- r"""
511
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
512
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
513
- processing larger images.
514
- """
515
- self.vae.enable_tiling()
516
-
517
- def disable_vae_tiling(self):
518
- r"""
519
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
520
- computing decoding in one step.
521
- """
522
- self.vae.disable_tiling()
523
-
524
- def prepare_latents(
525
- self,
526
- images,
527
- batch_size,
528
- num_channels_latents,
529
- height,
530
- width,
531
- dtype,
532
- device,
533
- generator,
534
- latents=None,
535
- ):
536
- # VAE applies 8x compression on images but we must also account for packing which requires
537
- # latent height and width to be divisible by 2.
538
- height = 2 * (int(height) // (self.vae_scale_factor * 2))
539
- width = 2 * (int(width) // (self.vae_scale_factor * 2))
540
-
541
- shape = (batch_size, 1, num_channels_latents, height, width)
542
-
543
- image_latents_list = []
544
- for image in images:
545
- image = image.to(device=device, dtype=dtype)
546
- if image.shape[1] != self.latent_channels:
547
- image_latents = self._encode_vae_image(image=image, generator=generator)
548
- else:
549
- image_latents = image
550
- if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
551
- # expand init_latents for batch_size
552
- additional_image_per_prompt = batch_size // image_latents.shape[0]
553
- image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
554
- elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
555
- raise ValueError(
556
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
557
- )
558
- else:
559
- image_latents = torch.cat([image_latents], dim=0)
560
-
561
- image_latent_height, image_latent_width = image_latents.shape[3:]
562
- image_latents = self._pack_latents(
563
- image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
564
- )
565
- image_latents_list.append(image_latents)
566
-
567
- if isinstance(generator, list) and len(generator) != batch_size:
568
- raise ValueError(
569
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
570
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
571
- )
572
- if latents is None:
573
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
574
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
575
- else:
576
- latents = latents.to(device=device, dtype=dtype)
577
-
578
- return latents, image_latents_list
579
-
580
- @property
581
- def guidance_scale(self):
582
- return self._guidance_scale
583
-
584
- @property
585
- def attention_kwargs(self):
586
- return self._attention_kwargs
587
-
588
- @property
589
- def num_timesteps(self):
590
- return self._num_timesteps
591
-
592
- @property
593
- def current_timestep(self):
594
- return self._current_timestep
595
-
596
- @property
597
- def interrupt(self):
598
- return self._interrupt
599
-
600
- @torch.no_grad()
601
- @replace_example_docstring(EXAMPLE_DOC_STRING)
602
- def __call__(
603
- self,
604
- images: List[PipelineImageInput] = None,
605
- prompt: Union[str, List[str]] = None,
606
- negative_prompt: Union[str, List[str]] = None,
607
- true_cfg_scale: float = 4.0,
608
- height: Optional[int] = None,
609
- width: Optional[int] = None,
610
- num_inference_steps: int = 50,
611
- sigmas: Optional[List[float]] = None,
612
- guidance_scale: float = 1.0,
613
- num_images_per_prompt: int = 1,
614
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
615
- latents: Optional[torch.Tensor] = None,
616
- prompt_embeds: Optional[torch.Tensor] = None,
617
- prompt_embeds_mask: Optional[torch.Tensor] = None,
618
- negative_prompt_embeds: Optional[torch.Tensor] = None,
619
- negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
620
- output_type: Optional[str] = "pil",
621
- return_dict: bool = True,
622
- attention_kwargs: Optional[Dict[str, Any]] = None,
623
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
624
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
625
- max_sequence_length: int = 512,
626
- ):
627
- r"""
628
- Function invoked when calling the pipeline for generation.
629
-
630
- Args:
631
- prompt (`str` or `List[str]`, *optional*):
632
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
633
- instead.
634
- negative_prompt (`str` or `List[str]`, *optional*):
635
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
636
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
637
- not greater than `1`).
638
- true_cfg_scale (`float`, *optional*, defaults to 1.0):
639
- When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
640
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
641
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
642
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
643
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
644
- num_inference_steps (`int`, *optional*, defaults to 50):
645
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
646
- expense of slower inference.
647
- sigmas (`List[float]`, *optional*):
648
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
649
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
650
- will be used.
651
- guidance_scale (`float`, *optional*, defaults to 3.5):
652
- Guidance scale as defined in [Classifier-Free Diffusion
653
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
654
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
655
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
656
- the text `prompt`, usually at the expense of lower image quality.
657
-
658
- This parameter in the pipeline is there to support future guidance-distilled models when they come up.
659
- Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
660
- please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
661
- enable classifier-free guidance computations.
662
- num_images_per_prompt (`int`, *optional*, defaults to 1):
663
- The number of images to generate per prompt.
664
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
665
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
666
- to make generation deterministic.
667
- latents (`torch.Tensor`, *optional*):
668
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
669
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
670
- tensor will be generated by sampling using the supplied random `generator`.
671
- prompt_embeds (`torch.Tensor`, *optional*):
672
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
673
- provided, text embeddings will be generated from `prompt` input argument.
674
- negative_prompt_embeds (`torch.Tensor`, *optional*):
675
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
676
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
677
- argument.
678
- output_type (`str`, *optional*, defaults to `"pil"`):
679
- The output format of the generate image. Choose between
680
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
681
- return_dict (`bool`, *optional*, defaults to `True`):
682
- Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
683
- attention_kwargs (`dict`, *optional*):
684
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
685
- `self.processor` in
686
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
687
- callback_on_step_end (`Callable`, *optional*):
688
- A function that calls at the end of each denoising steps during the inference. The function is called
689
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
690
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
691
- `callback_on_step_end_tensor_inputs`.
692
- callback_on_step_end_tensor_inputs (`List`, *optional*):
693
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
694
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
695
- `._callback_tensor_inputs` attribute of your pipeline class.
696
- max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
697
-
698
- Examples:
699
-
700
- Returns:
701
- [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
702
- [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
703
- returning a tuple, the first element is a list with the generated images.
704
- """
705
- if not isinstance(images, (list, tuple)):
706
- images = [images]
707
-
708
- # prepare multiple images
709
- total_number_of_pixels = sum([math.prod(image.size) for image in images])
710
- ratio = 1024 / total_number_of_pixels ** 0.5
711
- images = [image.resize(size=(round(image.width*ratio), round(image.height*ratio))) for image in images]
712
- images = [resize_to_multiple_of(image=image, multiple_of=32) for image in images]
713
-
714
- if height is None or width is None:
715
- width, height = images[0].size
716
-
717
- multiple_of = self.vae_scale_factor * 2
718
- width = width // multiple_of * multiple_of
719
- height = height // multiple_of * multiple_of
720
-
721
- # 1. Check inputs. Raise error if not correct
722
- self.check_inputs(
723
- prompt,
724
- height,
725
- width,
726
- negative_prompt=negative_prompt,
727
- prompt_embeds=prompt_embeds,
728
- negative_prompt_embeds=negative_prompt_embeds,
729
- prompt_embeds_mask=prompt_embeds_mask,
730
- negative_prompt_embeds_mask=negative_prompt_embeds_mask,
731
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
732
- max_sequence_length=max_sequence_length,
733
- )
734
-
735
- self._guidance_scale = guidance_scale
736
- self._attention_kwargs = attention_kwargs
737
- self._current_timestep = None
738
- self._interrupt = False
739
-
740
- # 2. Define call parameters
741
- if prompt is not None and isinstance(prompt, str):
742
- batch_size = 1
743
- elif prompt is not None and isinstance(prompt, list):
744
- batch_size = len(prompt)
745
- else:
746
- batch_size = prompt_embeds.shape[0]
747
-
748
- device = self._execution_device
749
- # 3. Preprocess image
750
-
751
- prompt_images = [image.resize((round(image.width * 28 / 32), round(image.height * 28 / 32))) for image in images]
752
- images = [self.image_processor.preprocess(image, image.height, image.width).unsqueeze(2) for image in images]
753
- # import pdb; pdb.set_trace()
754
-
755
- has_neg_prompt = negative_prompt is not None or (
756
- negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
757
- )
758
- do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
759
- prompt_embeds, prompt_embeds_mask = self.encode_prompt(
760
- images=prompt_images,
761
- prompt=prompt,
762
- prompt_embeds=prompt_embeds,
763
- prompt_embeds_mask=prompt_embeds_mask,
764
- device=device,
765
- num_images_per_prompt=num_images_per_prompt,
766
- max_sequence_length=max_sequence_length,
767
- )
768
- if do_true_cfg:
769
- negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
770
- images=prompt_images,
771
- prompt=negative_prompt,
772
- prompt_embeds=negative_prompt_embeds,
773
- prompt_embeds_mask=negative_prompt_embeds_mask,
774
- device=device,
775
- num_images_per_prompt=num_images_per_prompt,
776
- max_sequence_length=max_sequence_length,
777
- )
778
-
779
- # 4. Prepare latent variables
780
- num_channels_latents = self.transformer.config.in_channels // 4
781
- latents, image_latents = self.prepare_latents(
782
- images,
783
- batch_size * num_images_per_prompt,
784
- num_channels_latents,
785
- height,
786
- width,
787
- prompt_embeds.dtype,
788
- device,
789
- generator,
790
- latents,
791
- )
792
- img_shapes = [
793
- [
794
- (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
795
- ] +
796
- [
797
- (1, image.shape[-2] // self.vae_scale_factor // 2, image.shape[-1] // self.vae_scale_factor // 2)
798
- for image in images
799
- ]
800
- ] * batch_size
801
-
802
- # import pdb; pdb.set_trace()
803
-
804
- # 5. Prepare timesteps
805
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
806
- image_seq_len = latents.shape[1]
807
- mu = calculate_shift(
808
- image_seq_len,
809
- self.scheduler.config.get("base_image_seq_len", 256),
810
- self.scheduler.config.get("max_image_seq_len", 4096),
811
- self.scheduler.config.get("base_shift", 0.5),
812
- self.scheduler.config.get("max_shift", 1.15),
813
- )
814
- timesteps, num_inference_steps = retrieve_timesteps(
815
- self.scheduler,
816
- num_inference_steps,
817
- device,
818
- sigmas=sigmas,
819
- mu=mu,
820
- )
821
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
822
- self._num_timesteps = len(timesteps)
823
-
824
- # handle guidance
825
- if self.transformer.config.guidance_embeds:
826
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
827
- guidance = guidance.expand(latents.shape[0])
828
- else:
829
- guidance = None
830
-
831
- if self.attention_kwargs is None:
832
- self._attention_kwargs = {}
833
-
834
- txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
835
- negative_txt_seq_lens = (
836
- negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
837
- )
838
-
839
- # 6. Denoising loop
840
- self.scheduler.set_begin_index(0)
841
- with self.progress_bar(total=num_inference_steps) as progress_bar:
842
- for i, t in enumerate(timesteps):
843
- if self.interrupt:
844
- continue
845
-
846
- self._current_timestep = t
847
-
848
- latent_model_input = torch.cat([latents] + image_latents, dim=1)
849
-
850
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
851
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
852
- with self.transformer.cache_context("cond"):
853
- noise_pred = self.transformer(
854
- hidden_states=latent_model_input,
855
- timestep=timestep / 1000,
856
- guidance=guidance,
857
- encoder_hidden_states_mask=prompt_embeds_mask,
858
- encoder_hidden_states=prompt_embeds,
859
- img_shapes=img_shapes,
860
- txt_seq_lens=txt_seq_lens,
861
- attention_kwargs=self.attention_kwargs,
862
- return_dict=False,
863
- )[0]
864
- noise_pred = noise_pred[:, : latents.size(1)]
865
-
866
- if do_true_cfg:
867
- with self.transformer.cache_context("uncond"):
868
- neg_noise_pred = self.transformer(
869
- hidden_states=latent_model_input,
870
- timestep=timestep / 1000,
871
- guidance=guidance,
872
- encoder_hidden_states_mask=negative_prompt_embeds_mask,
873
- encoder_hidden_states=negative_prompt_embeds,
874
- img_shapes=img_shapes,
875
- txt_seq_lens=negative_txt_seq_lens,
876
- attention_kwargs=self.attention_kwargs,
877
- return_dict=False,
878
- )[0]
879
- neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
880
- comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
881
-
882
- cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
883
- noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
884
- noise_pred = comb_pred * (cond_norm / noise_norm)
885
-
886
- # compute the previous noisy sample x_t -> x_t-1
887
- latents_dtype = latents.dtype
888
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
889
-
890
- if latents.dtype != latents_dtype:
891
- if torch.backends.mps.is_available():
892
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
893
- latents = latents.to(latents_dtype)
894
-
895
- if callback_on_step_end is not None:
896
- callback_kwargs = {}
897
- for k in callback_on_step_end_tensor_inputs:
898
- callback_kwargs[k] = locals()[k]
899
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
900
-
901
- latents = callback_outputs.pop("latents", latents)
902
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
903
-
904
- # call the callback, if provided
905
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
906
- progress_bar.update()
907
-
908
- if XLA_AVAILABLE:
909
- xm.mark_step()
910
-
911
- self._current_timestep = None
912
- if output_type == "latent":
913
- image = latents
914
- else:
915
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
916
- latents = latents.to(self.vae.dtype)
917
- latents_mean = (
918
- torch.tensor(self.vae.config.latents_mean)
919
- .view(1, self.vae.config.z_dim, 1, 1, 1)
920
- .to(latents.device, latents.dtype)
921
- )
922
- latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
923
- latents.device, latents.dtype
924
- )
925
- latents = latents / latents_std + latents_mean
926
- image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
927
- image = self.image_processor.postprocess(image, output_type=output_type)
928
-
929
- # Offload all models
930
- self.maybe_free_model_hooks()
931
-
932
- if not return_dict:
933
- return (image,)
934
-
935
- return QwenImagePipelineOutput(images=image)