Adding device specific configs & more input image type options + small model spec from args change
Browse files- configs/config-dev-1-RTX6000ADA.json +57 -0
- configs/config-dev-offload-1-4080.json +58 -0
- configs/config-dev-offload-1-4090.json +58 -0
- flux_pipeline.py +123 -17
- image_encoder.py +7 -16
- util.py +2 -2
configs/config-dev-1-RTX6000ADA.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qfloat8",
|
52 |
+
"compile_extras": true,
|
53 |
+
"compile_blocks": true,
|
54 |
+
"offload_text_encoder": false,
|
55 |
+
"offload_vae": false,
|
56 |
+
"offload_flow": false
|
57 |
+
}
|
configs/config-dev-offload-1-4080.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qint4",
|
52 |
+
"ae_quantization_dtype": "qfloat8",
|
53 |
+
"compile_extras": true,
|
54 |
+
"compile_blocks": true,
|
55 |
+
"offload_text_encoder": true,
|
56 |
+
"offload_vae": true,
|
57 |
+
"offload_flow": true
|
58 |
+
}
|
configs/config-dev-offload-1-4090.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "flux-dev",
|
3 |
+
"params": {
|
4 |
+
"in_channels": 64,
|
5 |
+
"vec_in_dim": 768,
|
6 |
+
"context_in_dim": 4096,
|
7 |
+
"hidden_size": 3072,
|
8 |
+
"mlp_ratio": 4.0,
|
9 |
+
"num_heads": 24,
|
10 |
+
"depth": 19,
|
11 |
+
"depth_single_blocks": 38,
|
12 |
+
"axes_dim": [
|
13 |
+
16,
|
14 |
+
56,
|
15 |
+
56
|
16 |
+
],
|
17 |
+
"theta": 10000,
|
18 |
+
"qkv_bias": true,
|
19 |
+
"guidance_embed": true
|
20 |
+
},
|
21 |
+
"ae_params": {
|
22 |
+
"resolution": 256,
|
23 |
+
"in_channels": 3,
|
24 |
+
"ch": 128,
|
25 |
+
"out_ch": 3,
|
26 |
+
"ch_mult": [
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
4,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"z_channels": 16,
|
34 |
+
"scale_factor": 0.3611,
|
35 |
+
"shift_factor": 0.1159
|
36 |
+
},
|
37 |
+
"ckpt_path": "/big/generator-ui/flux-testing/flux/model-dir/flux1-dev.sft",
|
38 |
+
"ae_path": "/big/generator-ui/flux-testing/flux/model-dir/ae.sft",
|
39 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
40 |
+
"repo_flow": "flux1-dev.sft",
|
41 |
+
"repo_ae": "ae.sft",
|
42 |
+
"text_enc_max_length": 512,
|
43 |
+
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
|
44 |
+
"text_enc_device": "cuda:0",
|
45 |
+
"ae_device": "cuda:0",
|
46 |
+
"flux_device": "cuda:0",
|
47 |
+
"flow_dtype": "float16",
|
48 |
+
"ae_dtype": "bfloat16",
|
49 |
+
"text_enc_dtype": "bfloat16",
|
50 |
+
"flow_quantization_dtype": "qfloat8",
|
51 |
+
"text_enc_quantization_dtype": "qint4",
|
52 |
+
"ae_quantization_dtype": "qfloat8",
|
53 |
+
"compile_extras": true,
|
54 |
+
"compile_blocks": true,
|
55 |
+
"offload_text_encoder": true,
|
56 |
+
"offload_vae": true,
|
57 |
+
"offload_flow": false
|
58 |
+
}
|
flux_pipeline.py
CHANGED
@@ -40,6 +40,12 @@ if TYPE_CHECKING:
|
|
40 |
|
41 |
|
42 |
class FluxPipeline:
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def __init__(
|
44 |
self,
|
45 |
name: str,
|
@@ -56,7 +62,12 @@ class FluxPipeline:
|
|
56 |
t5_device: torch.device | str = "cuda:1",
|
57 |
config: ModelSpec = None,
|
58 |
):
|
|
|
|
|
59 |
|
|
|
|
|
|
|
60 |
self.name = name
|
61 |
self.device_flux = (
|
62 |
flux_device
|
@@ -104,10 +115,10 @@ class FluxPipeline:
|
|
104 |
if not self.config.prequantized_flow:
|
105 |
print("Warmups for compile...")
|
106 |
warmup_dict = dict(
|
107 |
-
prompt="
|
108 |
-
height=
|
109 |
-
width=
|
110 |
-
num_steps=
|
111 |
guidance=3.5,
|
112 |
seed=10,
|
113 |
)
|
@@ -138,6 +149,32 @@ class FluxPipeline:
|
|
138 |
target_device: torch.device = torch.device("cuda:0"),
|
139 |
target_dtype: torch.dtype = torch.float16,
|
140 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
bs, c, h, w = img.shape
|
142 |
if bs == 1 and not isinstance(prompt, str):
|
143 |
bs = len(prompt)
|
@@ -165,8 +202,8 @@ class FluxPipeline:
|
|
165 |
|
166 |
img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
|
167 |
if self.offload_text_encoder:
|
168 |
-
self.clip.
|
169 |
-
self.t5.
|
170 |
vec, txt, txt_ids = get_weighted_text_embeddings_flux(
|
171 |
self,
|
172 |
prompt,
|
@@ -201,6 +238,7 @@ class FluxPipeline:
|
|
201 |
max_shift: float = 1.15,
|
202 |
shift: bool = True,
|
203 |
) -> list[float]:
|
|
|
204 |
# extra step for zero
|
205 |
timesteps = torch.linspace(1, 0, num_steps + 1)
|
206 |
|
@@ -221,7 +259,8 @@ class FluxPipeline:
|
|
221 |
generator: torch.Generator,
|
222 |
dtype=None,
|
223 |
device=None,
|
224 |
-
):
|
|
|
225 |
if device is None:
|
226 |
device = self.device_flux
|
227 |
if dtype is None:
|
@@ -240,6 +279,7 @@ class FluxPipeline:
|
|
240 |
|
241 |
@torch.inference_mode()
|
242 |
def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
|
|
|
243 |
# bring into PIL format and save
|
244 |
torch.cuda.synchronize()
|
245 |
x = x.contiguous()
|
@@ -257,10 +297,34 @@ class FluxPipeline:
|
|
257 |
torch.cuda.synchronize()
|
258 |
im = self.img_encoder.encode_torch(im, quality=99)
|
259 |
images.clear()
|
260 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
@torch.inference_mode()
|
263 |
def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
|
264 |
if self.offload_vae:
|
265 |
self.ae.to(self.device_ae)
|
266 |
x = x.to(self.device_ae)
|
@@ -290,6 +354,7 @@ class FluxPipeline:
|
|
290 |
def resize_center_crop(
|
291 |
self, img: torch.Tensor, height: int, width: int
|
292 |
) -> torch.Tensor:
|
|
|
293 |
img = TF.resize(img, min(width, height))
|
294 |
img = TF.center_crop(img, (height, width))
|
295 |
return img
|
@@ -305,6 +370,11 @@ class FluxPipeline:
|
|
305 |
generator: torch.Generator = None,
|
306 |
num_images: int = 1,
|
307 |
) -> tuple[torch.Tensor, List[float]]:
|
|
|
|
|
|
|
|
|
|
|
308 |
# prepare input
|
309 |
|
310 |
if init_image is not None:
|
@@ -364,20 +434,55 @@ class FluxPipeline:
|
|
364 |
num_steps: int = 24,
|
365 |
guidance: float = 3.5,
|
366 |
seed: int | None = None,
|
367 |
-
init_image: torch.Tensor | str | None = None,
|
368 |
strength: float = 1.0,
|
369 |
silent: bool = False,
|
370 |
num_images: int = 1,
|
371 |
return_seed: bool = False,
|
372 |
) -> io.BytesIO:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
num_steps = 4 if self.name == "flux-schnell" else num_steps
|
374 |
|
375 |
-
|
376 |
-
try:
|
377 |
-
init_image = Image.open(init_image)
|
378 |
-
except Exception as e:
|
379 |
-
init_image = Image.open(io.BytesIO(standard_b64decode(init_image)))
|
380 |
-
init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
|
381 |
|
382 |
# allow for packing and conversion to latent space
|
383 |
height = 16 * (height // 16)
|
@@ -465,8 +570,9 @@ class FluxPipeline:
|
|
465 |
from float8_quantize import quantize_flow_transformer_and_dispatch_float8
|
466 |
|
467 |
with torch.inference_mode():
|
468 |
-
|
469 |
-
|
|
|
470 |
|
471 |
models = load_models_from_config(config)
|
472 |
config = models.config
|
|
|
40 |
|
41 |
|
42 |
class FluxPipeline:
|
43 |
+
"""
|
44 |
+
FluxPipeline is a class that provides a pipeline for generating images using the Flux model.
|
45 |
+
It handles input preparation, timestep generation, noise generation, device management
|
46 |
+
and model compilation.
|
47 |
+
"""
|
48 |
+
|
49 |
def __init__(
|
50 |
self,
|
51 |
name: str,
|
|
|
62 |
t5_device: torch.device | str = "cuda:1",
|
63 |
config: ModelSpec = None,
|
64 |
):
|
65 |
+
"""
|
66 |
+
Initialize the FluxPipeline class.
|
67 |
|
68 |
+
This class is responsible for preparing input tensors for the Flux model, generating
|
69 |
+
timesteps and noise, and handling device management for model offloading.
|
70 |
+
"""
|
71 |
self.name = name
|
72 |
self.device_flux = (
|
73 |
flux_device
|
|
|
115 |
if not self.config.prequantized_flow:
|
116 |
print("Warmups for compile...")
|
117 |
warmup_dict = dict(
|
118 |
+
prompt="A beautiful test image used to solidify the fp8 nn.Linear input scales prior to compilation 😉",
|
119 |
+
height=768,
|
120 |
+
width=768,
|
121 |
+
num_steps=25,
|
122 |
guidance=3.5,
|
123 |
seed=10,
|
124 |
)
|
|
|
149 |
target_device: torch.device = torch.device("cuda:0"),
|
150 |
target_dtype: torch.dtype = torch.float16,
|
151 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
152 |
+
"""
|
153 |
+
Prepare input tensors for the Flux model.
|
154 |
+
|
155 |
+
This function processes the input image and text prompt, converting them into
|
156 |
+
the appropriate format and embedding representations required by the model.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
img (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
|
160 |
+
prompt (str | list[str]): Text prompt or list of prompts guiding the image generation.
|
161 |
+
target_device (torch.device, optional): The target device for the output tensors.
|
162 |
+
Defaults to torch.device("cuda:0").
|
163 |
+
target_dtype (torch.dtype, optional): The target data type for the output tensors.
|
164 |
+
Defaults to torch.float16.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
|
168 |
+
- img: Processed image tensor.
|
169 |
+
- img_ids: Image position IDs.
|
170 |
+
- vec: Clip text embedding vector.
|
171 |
+
- txt: T5 text embedding hidden states.
|
172 |
+
- txt_ids: Text position IDs.
|
173 |
+
|
174 |
+
Note:
|
175 |
+
This function handles the necessary device management for text encoder offloading
|
176 |
+
if enabled in the configuration.
|
177 |
+
"""
|
178 |
bs, c, h, w = img.shape
|
179 |
if bs == 1 and not isinstance(prompt, str):
|
180 |
bs = len(prompt)
|
|
|
202 |
|
203 |
img_ids = img_ids[None].repeat(bs, 1, 1, 1).flatten(1, 2)
|
204 |
if self.offload_text_encoder:
|
205 |
+
self.clip.to(device=self.device_clip)
|
206 |
+
self.t5.to(device=self.device_t5)
|
207 |
vec, txt, txt_ids = get_weighted_text_embeddings_flux(
|
208 |
self,
|
209 |
prompt,
|
|
|
238 |
max_shift: float = 1.15,
|
239 |
shift: bool = True,
|
240 |
) -> list[float]:
|
241 |
+
"""Generates a schedule of timesteps for the given number of steps and image sequence length."""
|
242 |
# extra step for zero
|
243 |
timesteps = torch.linspace(1, 0, num_steps + 1)
|
244 |
|
|
|
259 |
generator: torch.Generator,
|
260 |
dtype=None,
|
261 |
device=None,
|
262 |
+
) -> torch.Tensor:
|
263 |
+
"""Generates a latent noise tensor of the given shape and dtype on the given device."""
|
264 |
if device is None:
|
265 |
device = self.device_flux
|
266 |
if dtype is None:
|
|
|
279 |
|
280 |
@torch.inference_mode()
|
281 |
def into_bytes(self, x: torch.Tensor) -> io.BytesIO:
|
282 |
+
"""Converts the image tensor to bytes."""
|
283 |
# bring into PIL format and save
|
284 |
torch.cuda.synchronize()
|
285 |
x = x.contiguous()
|
|
|
297 |
torch.cuda.synchronize()
|
298 |
im = self.img_encoder.encode_torch(im, quality=99)
|
299 |
images.clear()
|
300 |
+
return im
|
301 |
+
|
302 |
+
@torch.inference_mode()
|
303 |
+
def load_init_image_if_needed(
|
304 |
+
self, init_image: torch.Tensor | str | Image.Image | np.ndarray
|
305 |
+
) -> torch.Tensor:
|
306 |
+
"""
|
307 |
+
Loads the initial image if it is a string, numpy array, or PIL.Image,
|
308 |
+
if torch.Tensor, expects it to be in the correct format and returns it as is.
|
309 |
+
"""
|
310 |
+
if isinstance(init_image, str):
|
311 |
+
try:
|
312 |
+
init_image = Image.open(init_image)
|
313 |
+
except Exception as e:
|
314 |
+
init_image = Image.open(
|
315 |
+
io.BytesIO(standard_b64decode(init_image.split(",")[-1]))
|
316 |
+
)
|
317 |
+
init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
|
318 |
+
elif isinstance(init_image, np.ndarray):
|
319 |
+
init_image = torch.from_numpy(init_image).type(torch.uint8)
|
320 |
+
elif isinstance(init_image, Image.Image):
|
321 |
+
init_image = torch.from_numpy(np.array(init_image)).type(torch.uint8)
|
322 |
+
|
323 |
+
return init_image
|
324 |
|
325 |
@torch.inference_mode()
|
326 |
def vae_decode(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
327 |
+
"""Decodes the latent tensor to the pixel space."""
|
328 |
if self.offload_vae:
|
329 |
self.ae.to(self.device_ae)
|
330 |
x = x.to(self.device_ae)
|
|
|
354 |
def resize_center_crop(
|
355 |
self, img: torch.Tensor, height: int, width: int
|
356 |
) -> torch.Tensor:
|
357 |
+
"""Resizes and crops the image to the given height and width."""
|
358 |
img = TF.resize(img, min(width, height))
|
359 |
img = TF.center_crop(img, (height, width))
|
360 |
return img
|
|
|
370 |
generator: torch.Generator = None,
|
371 |
num_images: int = 1,
|
372 |
) -> tuple[torch.Tensor, List[float]]:
|
373 |
+
"""
|
374 |
+
Preprocesses the latent tensor for the given number of steps and image sequence length.
|
375 |
+
Also, if an initial image is provided, it is vae encoded and injected with the appropriate noise
|
376 |
+
given the strength and number of steps replacing the latent tensor.
|
377 |
+
"""
|
378 |
# prepare input
|
379 |
|
380 |
if init_image is not None:
|
|
|
434 |
num_steps: int = 24,
|
435 |
guidance: float = 3.5,
|
436 |
seed: int | None = None,
|
437 |
+
init_image: torch.Tensor | str | Image.Image | np.ndarray | None = None,
|
438 |
strength: float = 1.0,
|
439 |
silent: bool = False,
|
440 |
num_images: int = 1,
|
441 |
return_seed: bool = False,
|
442 |
) -> io.BytesIO:
|
443 |
+
"""
|
444 |
+
Generate images based on the given prompt and parameters.
|
445 |
+
|
446 |
+
Args:
|
447 |
+
prompt `(str)`: The text prompt to guide the image generation.
|
448 |
+
|
449 |
+
width `(int, optional)`: Width of the generated image. Defaults to 720.
|
450 |
+
|
451 |
+
height `(int, optional)`: Height of the generated image. Defaults to 1024.
|
452 |
+
|
453 |
+
num_steps `(int, optional)`: Number of denoising steps. Defaults to 24.
|
454 |
+
|
455 |
+
guidance `(float, optional)`: Guidance scale for text-to-image generation. Defaults to 3.5.
|
456 |
+
|
457 |
+
seed `(int | None, optional)`: Random seed for reproducibility. If None, a random seed is used. Defaults to None.
|
458 |
+
|
459 |
+
init_image `(torch.Tensor | str | Image.Image | np.ndarray | None, optional)`: Initial image for image-to-image generation. Defaults to None.
|
460 |
+
|
461 |
+
-- note: if the image's height/width do not match the height/width of the generated image, the image is resized and centered cropped to match the height/width arguments.
|
462 |
+
|
463 |
+
-- If a string is provided, it is assumed to be either a path to an image file or a base64 encoded image.
|
464 |
+
|
465 |
+
-- If a numpy array is provided, it is assumed to be an RGB numpy array of shape (height, width, 3) and dtype uint8.
|
466 |
+
|
467 |
+
-- If a PIL.Image is provided, it is assumed to be an RGB PIL.Image.
|
468 |
+
|
469 |
+
-- If a torch.Tensor is provided, it is assumed to be a torch.Tensor of shape (height, width, 3) and dtype uint8 with range [0, 255].
|
470 |
+
|
471 |
+
strength `(float, optional)`: Strength of the init_image in image-to-image generation. Defaults to 1.0.
|
472 |
+
|
473 |
+
silent `(bool, optional)`: If True, suppresses progress bar. Defaults to False.
|
474 |
+
|
475 |
+
num_images `(int, optional)`: Number of images to generate. Defaults to 1.
|
476 |
+
|
477 |
+
return_seed `(bool, optional)`: If True, returns the seed along with the generated image. Defaults to False.
|
478 |
+
|
479 |
+
Returns:
|
480 |
+
io.BytesIO: Generated image(s) in bytes format.
|
481 |
+
int: Seed used for generation (only if return_seed is True).
|
482 |
+
"""
|
483 |
num_steps = 4 if self.name == "flux-schnell" else num_steps
|
484 |
|
485 |
+
init_image = self.load_init_image_if_needed(init_image)
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
# allow for packing and conversion to latent space
|
488 |
height = 16 * (height // 16)
|
|
|
570 |
from float8_quantize import quantize_flow_transformer_and_dispatch_float8
|
571 |
|
572 |
with torch.inference_mode():
|
573 |
+
logger.info(
|
574 |
+
f"Loading as prequantized flow transformer? {config.prequantized_flow}"
|
575 |
+
)
|
576 |
|
577 |
models = load_models_from_config(config)
|
578 |
config = models.config
|
image_encoder.py
CHANGED
@@ -7,38 +7,29 @@ import torch
|
|
7 |
class ImageEncoder:
|
8 |
|
9 |
@torch.inference_mode()
|
10 |
-
def encode_torch(self, img: torch.Tensor, quality=
|
11 |
if img.ndim == 2:
|
12 |
img = (
|
13 |
img[None]
|
14 |
-
.contiguous()
|
15 |
.repeat_interleave(3, dim=0)
|
|
|
16 |
.contiguous()
|
17 |
.clamp(0, 255)
|
18 |
.type(torch.uint8)
|
19 |
)
|
20 |
-
print(img.shape)
|
21 |
elif img.ndim == 3:
|
22 |
if img.shape[0] == 3:
|
23 |
-
img = img.contiguous().clamp(0, 255).type(torch.uint8)
|
24 |
-
|
25 |
elif img.shape[2] == 3:
|
26 |
-
img = img.
|
27 |
else:
|
28 |
raise ValueError(f"Unsupported image shape: {img.shape}")
|
29 |
else:
|
30 |
raise ValueError(f"Unsupported image num dims: {img.ndim}")
|
31 |
|
32 |
-
img = (
|
33 |
-
img.permute(1, 2, 0)
|
34 |
-
.contiguous()
|
35 |
-
.to(torch.uint8)
|
36 |
-
.cpu()
|
37 |
-
.numpy()
|
38 |
-
.astype(np.uint8)
|
39 |
-
)
|
40 |
im = Image.fromarray(img)
|
41 |
iob = io.BytesIO()
|
42 |
-
im.save(iob, format="JPEG", quality=
|
43 |
iob.seek(0)
|
44 |
-
return iob
|
|
|
7 |
class ImageEncoder:
|
8 |
|
9 |
@torch.inference_mode()
|
10 |
+
def encode_torch(self, img: torch.Tensor, quality=95):
|
11 |
if img.ndim == 2:
|
12 |
img = (
|
13 |
img[None]
|
|
|
14 |
.repeat_interleave(3, dim=0)
|
15 |
+
.permute(1, 2, 0)
|
16 |
.contiguous()
|
17 |
.clamp(0, 255)
|
18 |
.type(torch.uint8)
|
19 |
)
|
|
|
20 |
elif img.ndim == 3:
|
21 |
if img.shape[0] == 3:
|
22 |
+
img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8)
|
|
|
23 |
elif img.shape[2] == 3:
|
24 |
+
img = img.contiguous().clamp(0, 255).type(torch.uint8)
|
25 |
else:
|
26 |
raise ValueError(f"Unsupported image shape: {img.shape}")
|
27 |
else:
|
28 |
raise ValueError(f"Unsupported image num dims: {img.ndim}")
|
29 |
|
30 |
+
img = img.cpu().numpy().astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
im = Image.fromarray(img)
|
32 |
iob = io.BytesIO()
|
33 |
+
im.save(iob, format="JPEG", quality=quality)
|
34 |
iob.seek(0)
|
35 |
+
return iob
|
util.py
CHANGED
@@ -141,7 +141,7 @@ def load_config(
|
|
141 |
axes_dim=[16, 56, 56],
|
142 |
theta=10_000,
|
143 |
qkv_bias=True,
|
144 |
-
guidance_embed=
|
145 |
),
|
146 |
ae_path=ae_path,
|
147 |
ae_params=AutoEncoderParams(
|
@@ -243,8 +243,8 @@ def load_autoencoder(config: ModelSpec) -> AutoEncoder:
|
|
243 |
sd = load_sft(ckpt_path, device=str(config.ae_device))
|
244 |
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
245 |
print_load_warning(missing, unexpected)
|
|
|
246 |
if config.ae_quantization_dtype is not None:
|
247 |
-
ae.to(into_device(config.ae_device))
|
248 |
from float8_quantize import recursive_swap_linears
|
249 |
|
250 |
recursive_swap_linears(ae)
|
|
|
141 |
axes_dim=[16, 56, 56],
|
142 |
theta=10_000,
|
143 |
qkv_bias=True,
|
144 |
+
guidance_embed=name == ModelVersion.flux_dev,
|
145 |
),
|
146 |
ae_path=ae_path,
|
147 |
ae_params=AutoEncoderParams(
|
|
|
243 |
sd = load_sft(ckpt_path, device=str(config.ae_device))
|
244 |
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
245 |
print_load_warning(missing, unexpected)
|
246 |
+
ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype))
|
247 |
if config.ae_quantization_dtype is not None:
|
|
|
248 |
from float8_quantize import recursive_swap_linears
|
249 |
|
250 |
recursive_swap_linears(ae)
|