EndlessSora commited on
Commit
7f2756e
·
1 Parent(s): dc8acb8

improve memory usage for zero GPUs

Browse files
pipelines/pipeline_flux_infusenet.py CHANGED
@@ -359,6 +359,11 @@ class FluxInfuseNetPipeline(FluxControlNetPipeline):
359
  lora_scale=lora_scale,
360
  )
361
 
 
 
 
 
 
362
  # 3. Prepare control image
363
  num_channels_latents = self.transformer.config.in_channels // 4
364
  if isinstance(self.controlnet, FluxControlNetModel):
@@ -492,11 +497,6 @@ class FluxInfuseNetPipeline(FluxControlNetPipeline):
492
  ]
493
  controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
494
 
495
- # CPU offload T5, move back controlnet to GPU
496
- self.text_encoder_2.cpu()
497
- torch.cuda.empty_cache()
498
- self.controlnet.to(device)
499
-
500
  # 7. Denoising loop
501
  with self.progress_bar(total=num_inference_steps) as progress_bar:
502
  for i, t in enumerate(timesteps):
 
359
  lora_scale=lora_scale,
360
  )
361
 
362
+ # CPU offload T5, move back controlnet to GPU
363
+ self.text_encoder_2.cpu()
364
+ torch.cuda.empty_cache()
365
+ self.controlnet.to(device)
366
+
367
  # 3. Prepare control image
368
  num_channels_latents = self.transformer.config.in_channels // 4
369
  if isinstance(self.controlnet, FluxControlNetModel):
 
497
  ]
498
  controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
499
 
 
 
 
 
 
500
  # 7. Denoising loop
501
  with self.progress_bar(total=num_inference_steps) as progress_bar:
502
  for i, t in enumerate(timesteps):
pipelines/pipeline_infu_flux.py CHANGED
@@ -12,6 +12,7 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import math
16
  import os
17
  import random
@@ -199,9 +200,9 @@ class InfUFluxPipeline:
199
  ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
200
  image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
201
  del ipm_state_dict
202
- image_proj_model.to('cuda', torch.bfloat16)
203
- image_proj_model.eval()
204
- self.image_proj_model_aes = image_proj_model
205
 
206
  image_proj_model = Resampler(
207
  dim=1280,
@@ -217,9 +218,9 @@ class InfUFluxPipeline:
217
  ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
218
  image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
219
  del ipm_state_dict
220
- image_proj_model.to('cpu', torch.bfloat16)
221
- image_proj_model.eval()
222
  self.image_proj_model_sim = image_proj_model
 
 
223
 
224
  self.image_proj_model = self.image_proj_model_aes
225
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import copy
16
  import math
17
  import os
18
  import random
 
200
  ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
201
  image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
202
  del ipm_state_dict
203
+ self.image_proj_model_aes = copy.deepcopy(image_proj_model)
204
+ self.image_proj_model_aes.to('cuda', torch.bfloat16)
205
+ self.image_proj_model_aes.eval()
206
 
207
  image_proj_model = Resampler(
208
  dim=1280,
 
218
  ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
219
  image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
220
  del ipm_state_dict
 
 
221
  self.image_proj_model_sim = image_proj_model
222
+ self.image_proj_model_sim.to('cpu', torch.bfloat16)
223
+ self.image_proj_model_sim.eval()
224
 
225
  self.image_proj_model = self.image_proj_model_aes
226