jamesr66a commited on
Commit
9e67b52
·
1 Parent(s): 7cd3ebd

add missing warmup

Browse files
Files changed (1) hide show
  1. flumina.py +33 -15
flumina.py CHANGED
@@ -17,6 +17,20 @@ from typing import Optional, Set, Tuple
17
  from flux_pipeline import FluxPipeline
18
  from util import load_config, ModelVersion
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Util
21
  def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]:
22
  """
@@ -35,22 +49,10 @@ def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]:
35
  f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9"
36
  )
37
 
38
- valid_aspect_ratios = [
39
- (1, 1),
40
- (21, 9),
41
- (16, 9),
42
- (3, 2),
43
- (5, 4),
44
- (4, 5),
45
- (2, 3),
46
- (9, 16),
47
- (9, 21),
48
- (4, 3),
49
- (3, 4),
50
- ]
51
- if (w, h) not in valid_aspect_ratios:
52
  raise ValueError(
53
- f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be one of {valid_aspect_ratios}"
54
  )
55
 
56
  # We consider megapixel not 10^6 pixels but 2^20 (1024x1024) pixels
@@ -203,8 +205,24 @@ class FluminaModule(FluminaModule):
203
  # Initialize LoRA adapters
204
  self.lora_adapters: Set[str] = set()
205
  self.active_lora_adapter: Optional[str] = None
 
 
 
206
  self._test_return_sync_response = False
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def _error_response(self, code: int, message: str) -> Response:
209
  response_json = ErrorResponse(
210
  error=Error(message=message),
 
17
  from flux_pipeline import FluxPipeline
18
  from util import load_config, ModelVersion
19
 
20
+ _ASPECT_RATIOS = [
21
+ (1, 1),
22
+ (21, 9),
23
+ (16, 9),
24
+ (3, 2),
25
+ (5, 4),
26
+ (4, 5),
27
+ (2, 3),
28
+ (9, 16),
29
+ (9, 21),
30
+ (4, 3),
31
+ (3, 4),
32
+ ]
33
+
34
  # Util
35
  def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]:
36
  """
 
49
  f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9"
50
  )
51
 
52
+
53
+ if (w, h) not in _ASPECT_RATIOS:
 
 
 
 
 
 
 
 
 
 
 
 
54
  raise ValueError(
55
+ f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be one of {_ASPECT_RATIOS}"
56
  )
57
 
58
  # We consider megapixel not 10^6 pixels but 2^20 (1024x1024) pixels
 
205
  # Initialize LoRA adapters
206
  self.lora_adapters: Set[str] = set()
207
  self.active_lora_adapter: Optional[str] = None
208
+
209
+ self._warm_up()
210
+
211
  self._test_return_sync_response = False
212
 
213
+ def _warm_up(self):
214
+ for f, s in _ASPECT_RATIOS:
215
+ print(f"Warm-up for aspect ratio {f}:{s}")
216
+ width, height = _aspect_ratio_to_width_height(f"{f}:{s}")
217
+ self.pipeline.generate(
218
+ prompt="a quick brown fox",
219
+ height=height,
220
+ width=width,
221
+ guidance=3.5,
222
+ num_steps=1,
223
+ seed=0,
224
+ )
225
+
226
  def _error_response(self, code: int, message: str) -> Response:
227
  response_json = ErrorResponse(
228
  error=Error(message=message),