aredden commited on
Commit
e81fa57
·
1 Parent(s): 21a2bf7

Adding device specific configs & more input image type options + small model spec from args change

Browse files
configs/config-dev-1-RTX6000ADA.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qfloat8",
52
+ "compile_extras": true,
53
+ "compile_blocks": true,
54
+ "offload_text_encoder": false,
55
+ "offload_vae": false,
56
+ "offload_flow": false
57
+ }
configs/config-dev-offload-1-4080.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qint4",
52
+ "ae_quantization_dtype": "qfloat8",
53
+ "compile_extras": true,
54
+ "compile_blocks": true,
55
+ "offload_text_encoder": true,
56
+ "offload_vae": true,
57
+ "offload_flow": true
58
+ }
configs/config-dev-offload-1-4090.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "flux-dev",
3
+ "params": {
4
+ "in_channels": 64,
5
+ "vec_in_dim": 768,
6
+ "context_in_dim": 4096,
7
+ "hidden_size": 3072,
8
+ "mlp_ratio": 4.0,
9
+ "num_heads": 24,
10
+ "depth": 19,
11
+ "depth_single_blocks": 38,
12
+ "axes_dim": [
13
+ 16,
14
+ 56,
15
+ 56
16
+ ],
17
+ "theta": 10000,
18
+ "qkv_bias": true,
19
+ "guidance_embed": true
20
+ },
21
+ "ae_params": {
22
+ "resolution": 256,
23
+ "in_channels": 3,
24
+ "ch": 128,
25
+ "out_ch": 3,
26
+ "ch_mult": [
27
+ 1,
28
+ 2,
29
+ 4,
30
+ 4
31
+ ],
32
+ "num_res_blocks": 2,
33
+ "z_channels": 16,
34
+ "scale_factor": 0.3611,
35
+ "shift_factor": 0.1159
36
+ },
37
+ "ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
38
+ "ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
39
+ "repo_id": "black-forest-labs/FLUX.1-dev",
40
+ "repo_flow": "flux1-dev.sft",
41
+ "repo_ae": "ae.sft",
42
+ "text_enc_max_length": 512,
43
+ "text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
44
+ "text_enc_device": "cuda:0",
45
+ "ae_device": "cuda:0",
46
+ "flux_device": "cuda:0",
47
+ "flow_dtype": "float16",
48
+ "ae_dtype": "bfloat16",
49
+ "text_enc_dtype": "bfloat16",
50
+ "flow_quantization_dtype": "qfloat8",
51
+ "text_enc_quantization_dtype": "qint4",
52
+ "ae_quantization_dtype": "qfloat8",
53
+ "compile_extras": true,
54
+ "compile_blocks": true,
55
+ "offload_text_encoder": true,
56
+ "offload_vae": true,
57
+ "offload_flow": false
58
+ }
flux_pipeline.py CHANGED
@@ -40,6 +40,12 @@ if TYPE_CHECKING:
40
 
41
 
42
  class FluxPipeline:
 
 
 
 
 
 
43
  def __init__(
44
  self,
45
  name: str,
@@ -56,7 +62,12 @@ class FluxPipeline:
56
  t5_device: torch.device | str = "cuda:1",
57
  config: ModelSpec = None,
58
  ):
 
 
59
 
 
 
 
60
  self.name = name
61
  self.device_flux = (
62
  flux_device
@@ -104,10 +115,10 @@ class FluxPipeline:
104
  if not self.config.prequantized_flow:
105
  print("Warmups for compile...")
106
  warmup_dict = dict(
107
- prompt="Street photography portrait of a beautiful asian woman in traditional clothing with golden hairpin and blue eyes, wearing a red kimono with dragon patterns",
108
- height=1024,
109
- width=1024,
110
- num_steps=30,
111
  guidance=3.5,
112
  seed=10,
113
  )
@@ -138,6 +149,32 @@ class FluxPipeline:
138
  target_device: torch.device = torch.device("cuda:0"),
139
  target_dtype: torch.dtype = torch.float16,
140
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  bs, c, h, w = img.shape
142
  if bs == 1 and not isinstance(prompt, str):
143
  bs = len(prompt)
@@ -165,8 +202,8 @@ class FluxPipeline:
165
 
166
  img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
167
  if self.offload_text_encoder:
168
- self.clip.cuda(self.device_clip)
169
- self.t5.cuda(self.device_t5)
170
  vec, txt, txt_ids = get_weighted_text_embeddings_flux(
171
  self,
172
  prompt,
@@ -201,6 +238,7 @@ class FluxPipeline:
201
  max_shift: float = 1.15,
202
  shift: bool = True,
203
  ) -> list[float]:
 
204
  # extra step for zero
205
  timesteps = torch.linspace(1, 0, num_steps + 1)
206
 
@@ -221,7 +259,8 @@ class FluxPipeline:
221
  generator: torch.Generator,
222
  dtype=None,
223
  device=None,
224
- ):
 
225
  if device is None:
226
  device = self.device_flux
227
  if dtype is None:
@@ -240,6 +279,7 @@ class FluxPipeline:
240
 
241
  @torch.inference_mode()
242
  def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
 
243
  # bring into PIL format and save
244
  torch.cuda.synchronize()
245
  x = x.contiguous()
@@ -257,10 +297,34 @@ class FluxPipeline:
257
  torch.cuda.synchronize()
258
  im = self.img_encoder.encode_torch(im, quality=99)
259
  images.clear()
260
- return io.BytesIO(im)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  @torch.inference_mode()
263
  def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
 
264
  if self.offload_vae:
265
  self.ae.to(self.device_ae)
266
  x = x.to(self.device_ae)
@@ -290,6 +354,7 @@ class FluxPipeline:
290
  def resize_center_crop(
291
  self, img: torch.Tensor, height: int, width: int
292
  ) -> torch.Tensor:
 
293
  img = TF.resize(img, min(width, height))
294
  img = TF.center_crop(img, (height, width))
295
  return img
@@ -305,6 +370,11 @@ class FluxPipeline:
305
  generator: torch.Generator = None,
306
  num_images: int = 1,
307
  ) -> tuple[torch.Tensor, List[float]]:
 
 
 
 
 
308
  # prepare input
309
 
310
  if init_image is not None:
@@ -364,20 +434,55 @@ class FluxPipeline:
364
  num_steps: int = 24,
365
  guidance: float = 3.5,
366
  seed: int | None = None,
367
- init_image: torch.Tensor | str | None = None,
368
  strength: float = 1.0,
369
  silent: bool = False,
370
  num_images: int = 1,
371
  return_seed: bool = False,
372
  ) -> io.BytesIO:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  num_steps = 4 if self.name == "flux-schnell" else num_steps
374
 
375
- if isinstance(init_image, str):
376
- try:
377
- init_image = Image.open(init_image)
378
- except Exception as e:
379
- init_image = Image.open(io.BytesIO(standard_b64decode(init_image)))
380
- init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
381
 
382
  # allow for packing and conversion to latent space
383
  height = 16 * (height // 16)
@@ -465,8 +570,9 @@ class FluxPipeline:
465
  from float8_quantize import quantize_flow_transformer_and_dispatch_float8
466
 
467
  with torch.inference_mode():
468
- print("flow_quantization_dtype", config.flow_quantization_dtype)
469
- print("prequantized_flow?", config.prequantized_flow)
 
470
 
471
  models = load_models_from_config(config)
472
  config = models.config
 
40
 
41
 
42
  class FluxPipeline:
43
+ """
44
+ FluxPipeline is a class that provides a pipeline for generating images using the Flux model.
45
+ It handles input preparation, timestep generation, noise generation, device management
46
+ and model compilation.
47
+ """
48
+
49
  def __init__(
50
  self,
51
  name: str,
 
62
  t5_device: torch.device | str = "cuda:1",
63
  config: ModelSpec = None,
64
  ):
65
+ """
66
+ Initialize the FluxPipeline class.
67
 
68
+ This class is responsible for preparing input tensors for the Flux model, generating
69
+ timesteps and noise, and handling device management for model offloading.
70
+ """
71
  self.name = name
72
  self.device_flux = (
73
  flux_device
 
115
  if not self.config.prequantized_flow:
116
  print("Warmups for compile...")
117
  warmup_dict = dict(
118
+ prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
119
+ height=768,
120
+ width=768,
121
+ num_steps=25,
122
  guidance=3.5,
123
  seed=10,
124
  )
 
149
  target_device: torch.device = torch.device("cuda:0"),
150
  target_dtype: torch.dtype = torch.float16,
151
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
152
+ """
153
+ Prepare input tensors for the Flux model.
154
+
155
+ This function processes the input image and text prompt, converting them into
156
+ the appropriate format and embedding representations required by the model.
157
+
158
+ Args:
159
+ img (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
160
+ prompt (str | list[str]): Text prompt or list of prompts guiding the image generation.
161
+ target_device (torch.device, optional): The target device for the output tensors.
162
+ Defaults to torch.device("cuda:0").
163
+ target_dtype (torch.dtype, optional): The target data type for the output tensors.
164
+ Defaults to torch.float16.
165
+
166
+ Returns:
167
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
168
+ - img: Processed image tensor.
169
+ - img_ids: Image position IDs.
170
+ - vec: Clip text embedding vector.
171
+ - txt: T5 text embedding hidden states.
172
+ - txt_ids: Text position IDs.
173
+
174
+ Note:
175
+ This function handles the necessary device management for text encoder offloading
176
+ if enabled in the configuration.
177
+ """
178
  bs, c, h, w = img.shape
179
  if bs == 1 and not isinstance(prompt, str):
180
  bs = len(prompt)
 
202
 
203
  img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
204
  if self.offload_text_encoder:
205
+ self.clip.to(device=self.device_clip)
206
+ self.t5.to(device=self.device_t5)
207
  vec, txt, txt_ids = get_weighted_text_embeddings_flux(
208
  self,
209
  prompt,
 
238
  max_shift: float = 1.15,
239
  shift: bool = True,
240
  ) -> list[float]:
241
+ """Generates a schedule of timesteps for the given number of steps and image sequence length."""
242
  # extra step for zero
243
  timesteps = torch.linspace(1, 0, num_steps + 1)
244
 
 
259
  generator: torch.Generator,
260
  dtype=None,
261
  device=None,
262
+ ) -> torch.Tensor:
263
+ """Generates a latent noise tensor of the given shape and dtype on the given device."""
264
  if device is None:
265
  device = self.device_flux
266
  if dtype is None:
 
279
 
280
  @torch.inference_mode()
281
  def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
282
+ """Converts the image tensor to bytes."""
283
  # bring into PIL format and save
284
  torch.cuda.synchronize()
285
  x = x.contiguous()
 
297
  torch.cuda.synchronize()
298
  im = self.img_encoder.encode_torch(im, quality=99)
299
  images.clear()
300
+ return im
301
+
302
+ @torch.inference_mode()
303
+ def load_init_image_if_needed(
304
+ self, init_image: torch.Tensor | str | Image.Image | np.ndarray
305
+ ) -> torch.Tensor:
306
+ """
307
+ Loads the initial image if it is a string, numpy array, or PIL.Image,
308
+ if torch.Tensor, expects it to be in the correct format and returns it as is.
309
+ """
310
+ if isinstance(init_image, str):
311
+ try:
312
+ init_image = Image.open(init_image)
313
+ except Exception as e:
314
+ init_image = Image.open(
315
+ io.BytesIO(standard_b64decode(init_image.split(",")[-1]))
316
+ )
317
+ init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
318
+ elif isinstance(init_image, np.ndarray):
319
+ init_image = torch.from_numpy(init_image).type(torch.uint8)
320
+ elif isinstance(init_image, Image.Image):
321
+ init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
322
+
323
+ return init_image
324
 
325
  @torch.inference_mode()
326
  def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
327
+ """Decodes the latent tensor to the pixel space."""
328
  if self.offload_vae:
329
  self.ae.to(self.device_ae)
330
  x = x.to(self.device_ae)
 
354
  def resize_center_crop(
355
  self, img: torch.Tensor, height: int, width: int
356
  ) -> torch.Tensor:
357
+ """Resizes and crops the image to the given height and width."""
358
  img = TF.resize(img, min(width, height))
359
  img = TF.center_crop(img, (height, width))
360
  return img
 
370
  generator: torch.Generator = None,
371
  num_images: int = 1,
372
  ) -> tuple[torch.Tensor, List[float]]:
373
+ """
374
+ Preprocesses the latent tensor for the given number of steps and image sequence length.
375
+ Also, if an initial image is provided, it is vae encoded and injected with the appropriate noise
376
+ given the strength and number of steps replacing the latent tensor.
377
+ """
378
  # prepare input
379
 
380
  if init_image is not None:
 
434
  num_steps: int = 24,
435
  guidance: float = 3.5,
436
  seed: int | None = None,
437
+ init_image: torch.Tensor | str | Image.Image | np.ndarray | None = None,
438
  strength: float = 1.0,
439
  silent: bool = False,
440
  num_images: int = 1,
441
  return_seed: bool = False,
442
  ) -> io.BytesIO:
443
+ """
444
+ Generate images based on the given prompt and parameters.
445
+
446
+ Args:
447
+ prompt `(str)`: The text prompt to guide the image generation.
448
+
449
+ width `(int, optional)`: Width of the generated image. Defaults to 720.
450
+
451
+ height `(int, optional)`: Height of the generated image. Defaults to 1024.
452
+
453
+ num_steps `(int, optional)`: Number of denoising steps. Defaults to 24.
454
+
455
+ guidance `(float, optional)`: Guidance scale for text-to-image generation. Defaults to 3.5.
456
+
457
+ seed `(int | None, optional)`: Random seed for reproducibility. If None, a random seed is used. Defaults to None.
458
+
459
+ init_image `(torch.Tensor | str | Image.Image | np.ndarray | None, optional)`: Initial image for image-to-image generation. Defaults to None.
460
+
461
+ -- note: if the image's height/width do not match the height/width of the generated image, the image is resized and centered cropped to match the height/width arguments.
462
+
463
+ -- If a string is provided, it is assumed to be either a path to an image file or a base64 encoded image.
464
+
465
+ -- If a numpy array is provided, it is assumed to be an RGB numpy array of shape (height, width, 3) and dtype uint8.
466
+
467
+ -- If a PIL.Image is provided, it is assumed to be an RGB PIL.Image.
468
+
469
+ -- If a torch.Tensor is provided, it is assumed to be a torch.Tensor of shape (height, width, 3) and dtype uint8 with range [0, 255].
470
+
471
+ strength `(float, optional)`: Strength of the init_image in image-to-image generation. Defaults to 1.0.
472
+
473
+ silent `(bool, optional)`: If True, suppresses progress bar. Defaults to False.
474
+
475
+ num_images `(int, optional)`: Number of images to generate. Defaults to 1.
476
+
477
+ return_seed `(bool, optional)`: If True, returns the seed along with the generated image. Defaults to False.
478
+
479
+ Returns:
480
+ io.BytesIO: Generated image(s) in bytes format.
481
+ int: Seed used for generation (only if return_seed is True).
482
+ """
483
  num_steps = 4 if self.name == "flux-schnell" else num_steps
484
 
485
+ init_image = self.load_init_image_if_needed(init_image)
 
 
 
 
 
486
 
487
  # allow for packing and conversion to latent space
488
  height = 16 * (height // 16)
 
570
  from float8_quantize import quantize_flow_transformer_and_dispatch_float8
571
 
572
  with torch.inference_mode():
573
+ logger.info(
574
+ f"Loading as prequantized flow transformer? {config.prequantized_flow}"
575
+ )
576
 
577
  models = load_models_from_config(config)
578
  config = models.config
image_encoder.py CHANGED
@@ -7,38 +7,29 @@ import torch
7
  class ImageEncoder:
8
 
9
  @torch.inference_mode()
10
- def encode_torch(self, img: torch.Tensor, quality=90):
11
  if img.ndim == 2:
12
  img = (
13
  img[None]
14
- .contiguous()
15
  .repeat_interleave(3, dim=0)
 
16
  .contiguous()
17
  .clamp(0, 255)
18
  .type(torch.uint8)
19
  )
20
- print(img.shape)
21
  elif img.ndim == 3:
22
  if img.shape[0] == 3:
23
- img = img.contiguous().clamp(0, 255).type(torch.uint8)
24
-
25
  elif img.shape[2] == 3:
26
- img = img.permute(2, 0, 1).contiguous().clamp(0, 255).type(torch.uint8)
27
  else:
28
  raise ValueError(f"Unsupported image shape: {img.shape}")
29
  else:
30
  raise ValueError(f"Unsupported image num dims: {img.ndim}")
31
 
32
- img = (
33
- img.permute(1, 2, 0)
34
- .contiguous()
35
- .to(torch.uint8)
36
- .cpu()
37
- .numpy()
38
- .astype(np.uint8)
39
- )
40
  im = Image.fromarray(img)
41
  iob = io.BytesIO()
42
- im.save(iob, format="JPEG", quality=95)
43
  iob.seek(0)
44
- return iob.getvalue()
 
7
  class ImageEncoder:
8
 
9
  @torch.inference_mode()
10
+ def encode_torch(self, img: torch.Tensor, quality=95):
11
  if img.ndim == 2:
12
  img = (
13
  img[None]
 
14
  .repeat_interleave(3, dim=0)
15
+ .permute(1, 2, 0)
16
  .contiguous()
17
  .clamp(0, 255)
18
  .type(torch.uint8)
19
  )
 
20
  elif img.ndim == 3:
21
  if img.shape[0] == 3:
22
+ img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8)
 
23
  elif img.shape[2] == 3:
24
+ img = img.contiguous().clamp(0, 255).type(torch.uint8)
25
  else:
26
  raise ValueError(f"Unsupported image shape: {img.shape}")
27
  else:
28
  raise ValueError(f"Unsupported image num dims: {img.ndim}")
29
 
30
+ img = img.cpu().numpy().astype(np.uint8)
 
 
 
 
 
 
 
31
  im = Image.fromarray(img)
32
  iob = io.BytesIO()
33
+ im.save(iob, format="JPEG", quality=quality)
34
  iob.seek(0)
35
+ return iob
util.py CHANGED
@@ -141,7 +141,7 @@ def load_config(
141
  axes_dim=[16, 56, 56],
142
  theta=10_000,
143
  qkv_bias=True,
144
- guidance_embed=True,
145
  ),
146
  ae_path=ae_path,
147
  ae_params=AutoEncoderParams(
@@ -243,8 +243,8 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
243
  sd = load_sft(ckpt_path, device=str(config.ae_device))
244
  missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
245
  print_load_warning(missing, unexpected)
 
246
  if config.ae_quantization_dtype is not None:
247
- ae.to(into_device(config.ae_device))
248
  from float8_quantize import recursive_swap_linears
249
 
250
  recursive_swap_linears(ae)
 
141
  axes_dim=[16, 56, 56],
142
  theta=10_000,
143
  qkv_bias=True,
144
+ guidance_embed=name == ModelVersion.flux_dev,
145
  ),
146
  ae_path=ae_path,
147
  ae_params=AutoEncoderParams(
 
243
  sd = load_sft(ckpt_path, device=str(config.ae_device))
244
  missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
245
  print_load_warning(missing, unexpected)
246
+ ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype))
247
  if config.ae_quantization_dtype is not None:
 
248
  from float8_quantize import recursive_swap_linears
249
 
250
  recursive_swap_linears(ae)