ctheodoris commited on
Commit
933ca80
·
1 Parent(s): ec19834

update with 12L and 20L i4096 gc95M models, multitask and quantiz code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -1
  2. MANIFEST.in +3 -3
  3. config.json +9 -8
  4. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json +24 -0
  5. fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin +3 -0
  6. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json +0 -0
  7. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt +0 -0
  8. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin +0 -0
  9. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth +0 -0
  10. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt +0 -0
  11. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json +0 -0
  12. fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin +0 -0
  13. geneformer/__init__.py +10 -5
  14. geneformer/classifier.py +74 -16
  15. geneformer/classifier_utils.py +117 -5
  16. geneformer/collator_for_classification.py +15 -19
  17. geneformer/emb_extractor.py +20 -13
  18. geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl +3 -0
  19. geneformer/{gene_name_id_dict.pkl → gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl} +0 -0
  20. geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl +3 -0
  21. geneformer/gene_median_dictionary.pkl +0 -0
  22. geneformer/in_silico_perturber.py +733 -143
  23. geneformer/in_silico_perturber_stats.py +22 -6
  24. geneformer/mtl/__init__.py +0 -0
  25. geneformer/mtl/collators.py +66 -0
  26. geneformer/mtl/data.py +116 -0
  27. geneformer/mtl/eval_utils.py +81 -0
  28. geneformer/mtl/imports.py +46 -0
  29. geneformer/mtl/model.py +84 -0
  30. geneformer/mtl/optuna_utils.py +21 -0
  31. geneformer/mtl/train.py +242 -0
  32. geneformer/mtl/train_utils.py +126 -0
  33. geneformer/mtl/utils.py +106 -0
  34. geneformer/mtl_classifier.py +338 -0
  35. geneformer/perturber_utils.py +168 -16
  36. geneformer/pretrainer.py +0 -13
  37. geneformer/token_dictionary.pkl +0 -0
  38. geneformer/token_dictionary_gc95M.pkl +0 -0
  39. generation_config.json +5 -0
  40. {geneformer-12L-30M → gf-12L-30M-i2048}/config.json +0 -0
  41. {geneformer-12L-30M → gf-12L-30M-i2048}/pytorch_model.bin +0 -0
  42. {geneformer-12L-30M → gf-12L-30M-i2048}/training_args.bin +0 -0
  43. gf-12L-95M-i4096/config.json +24 -0
  44. gf-12L-95M-i4096/generation_config.json +5 -0
  45. gf-12L-95M-i4096/model.safetensors +3 -0
  46. gf-12L-95M-i4096/training_args.bin +3 -0
  47. gf-12L-95M-i4096_CLcancer/config.json +25 -0
  48. gf-12L-95M-i4096_CLcancer/generation_config.json +5 -0
  49. gf-12L-95M-i4096_CLcancer/model.safetensors +3 -0
  50. gf-12L-95M-i4096_CLcancer/training_args.bin +3 -0
.gitattributes CHANGED
@@ -26,4 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
- model.safetensors filter=lfs diff=lfs merge=lfs -text
 
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
MANIFEST.in CHANGED
@@ -1,3 +1,3 @@
1
- include geneformer/gene_median_dictionary.pkl
2
- include geneformer/token_dictionary.pkl
3
- include geneformer/gene_name_id_dict.pkl
 
1
+ include geneformer/gene_median_dictionary_95m.pkl
2
+ include geneformer/token_dictionary_95m.pkl
3
+ include geneformer/gene_name_id_dict_95m.pkl
config.json CHANGED
@@ -3,21 +3,22 @@
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
- "gradient_checkpointing": false,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
- "hidden_size": 256,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 512,
12
  "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 2048,
14
  "model_type": "bert",
15
- "num_attention_heads": 4,
16
- "num_hidden_layers": 6,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
- "transformers_version": "4.6.0",
 
20
  "type_vocab_size": 2,
21
  "use_cache": true,
22
- "vocab_size": 25426
23
  }
 
3
  "BertForMaskedLM"
4
  ],
5
  "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
+ "vocab_size": 20275
24
  }
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.2",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 20275
24
+ }
fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4
3
+ size 152363342
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/config.json RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/optimizer.pt RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/pytorch_model.bin RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/rng_state.pth RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/scheduler.pt RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/trainer_state.json RENAMED
File without changes
fine_tuned_models/{geneformer-6L-30M_CellClassifier_cardiomyopathies_220224 → gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224}/training_args.bin RENAMED
File without changes
geneformer/__init__.py CHANGED
@@ -1,10 +1,12 @@
1
  # ruff: noqa: F401
2
  from pathlib import Path
 
 
3
 
4
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
5
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
6
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
7
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict.pkl"
8
 
9
  from . import (
10
  collator_for_classification,
@@ -25,4 +27,7 @@ from .pretrainer import GeneformerPretrainer
25
  from .tokenizer import TranscriptomeTokenizer
26
 
27
  from . import classifier # noqa # isort:skip
28
- from .classifier import Classifier # noqa # isort:skip
 
 
 
 
1
  # ruff: noqa: F401
2
  from pathlib import Path
3
+ import warnings
4
+ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
5
 
6
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
7
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
8
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
9
+ ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
10
 
11
  from . import (
12
  collator_for_classification,
 
27
  from .tokenizer import TranscriptomeTokenizer
28
 
29
  from . import classifier # noqa # isort:skip
30
+ from .classifier import Classifier # noqa # isort:skip
31
+
32
+ from . import mtl_classifier # noqa # isort:skip
33
+ from .mtl_classifier import MTLClassifier # noqa # isort:skip
geneformer/classifier.py CHANGED
@@ -72,6 +72,7 @@ logger = logging.getLogger(__name__)
72
  class Classifier:
73
  valid_option_dict = {
74
  "classifier": {"cell", "gene"},
 
75
  "cell_state_dict": {None, dict},
76
  "gene_class_dict": {None, dict},
77
  "filter_data": {None, dict},
@@ -93,6 +94,7 @@ class Classifier:
93
  def __init__(
94
  self,
95
  classifier=None,
 
96
  cell_state_dict=None,
97
  gene_class_dict=None,
98
  filter_data=None,
@@ -118,6 +120,13 @@ class Classifier:
118
 
119
  classifier : {"cell", "gene"}
120
  | Whether to fine-tune a cell state or gene classifier.
 
 
 
 
 
 
 
121
  cell_state_dict : None, dict
122
  | Cell states to fine-tune model to distinguish.
123
  | Two-item dictionary with keys: state_key and states
@@ -191,6 +200,7 @@ class Classifier:
191
  self.model_type = "CellClassifier"
192
  elif self.classifier == "gene":
193
  self.model_type = "GeneClassifier"
 
194
  self.cell_state_dict = cell_state_dict
195
  self.gene_class_dict = gene_class_dict
196
  self.filter_data = filter_data
@@ -256,7 +266,7 @@ class Classifier:
256
  f"Genes to classify {missing_genes} are not in token dictionary."
257
  )
258
  self.gene_class_dict = {
259
- k: set([self.gene_token_dict.get(gene) for gene in v])
260
  for k, v in self.gene_class_dict.items()
261
  }
262
  empty_classes = []
@@ -403,6 +413,15 @@ class Classifier:
403
  "Column name 'labels' must be reserved for class IDs. Please rename column."
404
  )
405
  raise
 
 
 
 
 
 
 
 
 
406
 
407
  if self.classifier == "cell":
408
  # remove cell states representing < rare_threshold of cells
@@ -505,6 +524,7 @@ class Classifier:
505
  output_directory,
506
  output_prefix,
507
  save_eval_output=True,
 
508
  ):
509
  """
510
  Train cell state or gene classifier using all data.
@@ -525,13 +545,20 @@ class Classifier:
525
  save_eval_output : bool
526
  | Whether to save cross-fold eval output
527
  | Saves as pickle file of dictionary of eval metrics
528
-
 
 
 
529
  **Output**
530
 
531
  Returns trainer after fine-tuning with all data.
532
 
533
  """
534
 
 
 
 
 
535
  ##### Load data and prepare output directory #####
536
  # load numerical id to class dictionary (id:class)
537
  with open(id_class_dict_file, "rb") as f:
@@ -563,7 +590,7 @@ class Classifier:
563
  )
564
  assert len(targets) == len(labels)
565
  data = cu.prep_gene_classifier_all_data(
566
- data, targets, labels, self.max_ncells, self.nproc
567
  )
568
 
569
  trainer = self.train_classifier(
@@ -582,12 +609,15 @@ class Classifier:
582
  split_id_dict=None,
583
  attr_to_split=None,
584
  attr_to_balance=None,
 
585
  max_trials=100,
586
  pval_threshold=0.1,
587
  save_eval_output=True,
588
  predict_eval=True,
589
  predict_trainer=False,
590
  n_hyperopt_trials=0,
 
 
591
  ):
592
  """
593
  (Cross-)validate cell state or gene classifier.
@@ -622,6 +652,9 @@ class Classifier:
622
  attr_to_balance : None, list
623
  | List of attribute keys on which to balance data while splitting on attr_to_split
624
  | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
 
 
 
625
  max_trials : None, int
626
  | Maximum number of trials of random splitting to try to achieve balanced other attribute
627
  | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
@@ -640,11 +673,17 @@ class Classifier:
640
  n_hyperopt_trials : int
641
  | Number of trials to run for hyperparameter optimization
642
  | If 0, will not optimize hyperparameters
 
 
643
  """
644
  if self.num_crossval_splits == 0:
645
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
646
  raise
647
-
 
 
 
 
648
  # ensure number of genes in each class is > 5 if validating model
649
  if self.classifier == "gene":
650
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
@@ -725,7 +764,7 @@ class Classifier:
725
  else:
726
  # 5-fold cross-validate
727
  num_cells = len(data)
728
- fifth_cells = num_cells * 0.2
729
  num_eval = min((self.eval_size * num_cells), fifth_cells)
730
  start = i * fifth_cells
731
  end = start + num_eval
@@ -804,8 +843,19 @@ class Classifier:
804
  self.max_ncells,
805
  iteration_num,
806
  self.nproc,
 
807
  )
808
-
 
 
 
 
 
 
 
 
 
 
809
  if self.oos_test_size > 0:
810
  test_data = cu.prep_gene_classifier_split(
811
  data,
@@ -817,7 +867,14 @@ class Classifier:
817
  iteration_num,
818
  self.nproc,
819
  )
820
-
 
 
 
 
 
 
 
821
  if n_hyperopt_trials == 0:
822
  trainer = self.train_classifier(
823
  model_directory,
@@ -966,7 +1023,7 @@ class Classifier:
966
  subprocess.call(f"mkdir {output_directory}", shell=True)
967
 
968
  ##### Load model and training args #####
969
- model = pu.load_model(self.model_type, num_classes, model_directory, "train")
970
  def_training_args, def_freeze_layers = cu.get_default_train_args(
971
  model, self.classifier, train_data, output_directory
972
  )
@@ -990,14 +1047,14 @@ class Classifier:
990
  ##### Fine-tune the model #####
991
  # define the data collator
992
  if self.classifier == "cell":
993
- data_collator = DataCollatorForCellClassification()
994
  elif self.classifier == "gene":
995
- data_collator = DataCollatorForGeneClassification()
996
 
997
  # define function to initiate model
998
  def model_init():
999
  model = pu.load_model(
1000
- self.model_type, num_classes, model_directory, "train"
1001
  )
1002
 
1003
  if self.freeze_layers is not None:
@@ -1009,7 +1066,8 @@ class Classifier:
1009
  for param in module.parameters():
1010
  param.requires_grad = False
1011
 
1012
- model = model.to("cuda:0")
 
1013
  return model
1014
 
1015
  # create the trainer
@@ -1122,7 +1180,7 @@ class Classifier:
1122
  subprocess.call(f"mkdir {output_directory}", shell=True)
1123
 
1124
  ##### Load model and training args #####
1125
- model = pu.load_model(self.model_type, num_classes, model_directory, "train")
1126
 
1127
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1128
  model, self.classifier, train_data, output_directory
@@ -1152,9 +1210,9 @@ class Classifier:
1152
  ##### Fine-tune the model #####
1153
  # define the data collator
1154
  if self.classifier == "cell":
1155
- data_collator = DataCollatorForCellClassification()
1156
  elif self.classifier == "gene":
1157
- data_collator = DataCollatorForGeneClassification()
1158
 
1159
  # create the trainer
1160
  trainer = Trainer(
@@ -1276,7 +1334,7 @@ class Classifier:
1276
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1277
 
1278
  # load previously fine-tuned model
1279
- model = pu.load_model(self.model_type, num_classes, model_directory, "eval")
1280
 
1281
  # evaluate the model
1282
  result = self.evaluate_model(
 
72
  class Classifier:
73
  valid_option_dict = {
74
  "classifier": {"cell", "gene"},
75
+ "quantize": {bool, dict},
76
  "cell_state_dict": {None, dict},
77
  "gene_class_dict": {None, dict},
78
  "filter_data": {None, dict},
 
94
  def __init__(
95
  self,
96
  classifier=None,
97
+ quantize=False,
98
  cell_state_dict=None,
99
  gene_class_dict=None,
100
  filter_data=None,
 
120
 
121
  classifier : {"cell", "gene"}
122
  | Whether to fine-tune a cell state or gene classifier.
123
+ quantize : bool, dict
124
+ | Whether to fine-tune a quantized model.
125
+ | If True and no config provided, will use default.
126
+ | Will use custom config if provided.
127
+ | Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
128
+ | For example: {"bnb_config": BitsAndBytesConfig(...),
129
+ | "peft_config": LoraConfig(...)}
130
  cell_state_dict : None, dict
131
  | Cell states to fine-tune model to distinguish.
132
  | Two-item dictionary with keys: state_key and states
 
200
  self.model_type = "CellClassifier"
201
  elif self.classifier == "gene":
202
  self.model_type = "GeneClassifier"
203
+ self.quantize = quantize
204
  self.cell_state_dict = cell_state_dict
205
  self.gene_class_dict = gene_class_dict
206
  self.filter_data = filter_data
 
266
  f"Genes to classify {missing_genes} are not in token dictionary."
267
  )
268
  self.gene_class_dict = {
269
+ k: list(set([self.gene_token_dict.get(gene) for gene in v]))
270
  for k, v in self.gene_class_dict.items()
271
  }
272
  empty_classes = []
 
413
  "Column name 'labels' must be reserved for class IDs. Please rename column."
414
  )
415
  raise
416
+
417
+ if (attr_to_split is not None) and (attr_to_balance is None):
418
+ logger.error(
419
+ "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
420
+ )
421
+ raise
422
+
423
+ if not isinstance(attr_to_balance, list):
424
+ attr_to_balance = [attr_to_balance]
425
 
426
  if self.classifier == "cell":
427
  # remove cell states representing < rare_threshold of cells
 
524
  output_directory,
525
  output_prefix,
526
  save_eval_output=True,
527
+ gene_balance=False,
528
  ):
529
  """
530
  Train cell state or gene classifier using all data.
 
545
  save_eval_output : bool
546
  | Whether to save cross-fold eval output
547
  | Saves as pickle file of dictionary of eval metrics
548
+ gene_balance : None, bool
549
+ | Whether to automatically balance genes in training set.
550
+ | Only available for binary gene classifications.
551
+
552
  **Output**
553
 
554
  Returns trainer after fine-tuning with all data.
555
 
556
  """
557
 
558
+ if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
559
+ logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.")
560
+ raise
561
+
562
  ##### Load data and prepare output directory #####
563
  # load numerical id to class dictionary (id:class)
564
  with open(id_class_dict_file, "rb") as f:
 
590
  )
591
  assert len(targets) == len(labels)
592
  data = cu.prep_gene_classifier_all_data(
593
+ data, targets, labels, self.max_ncells, self.nproc, gene_balance
594
  )
595
 
596
  trainer = self.train_classifier(
 
609
  split_id_dict=None,
610
  attr_to_split=None,
611
  attr_to_balance=None,
612
+ gene_balance=False,
613
  max_trials=100,
614
  pval_threshold=0.1,
615
  save_eval_output=True,
616
  predict_eval=True,
617
  predict_trainer=False,
618
  n_hyperopt_trials=0,
619
+ save_gene_split_datasets=True,
620
+ debug_gene_split_datasets=False,
621
  ):
622
  """
623
  (Cross-)validate cell state or gene classifier.
 
652
  attr_to_balance : None, list
653
  | List of attribute keys on which to balance data while splitting on attr_to_split
654
  | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
655
+ gene_balance : None, bool
656
+ | Whether to automatically balance genes in training set.
657
+ | Only available for binary gene classifications.
658
  max_trials : None, int
659
  | Maximum number of trials of random splitting to try to achieve balanced other attribute
660
  | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
 
673
  n_hyperopt_trials : int
674
  | Number of trials to run for hyperparameter optimization
675
  | If 0, will not optimize hyperparameters
676
+ save_gene_split_datasets : bool
677
+ | Whether or not to save train, valid, and test gene-labeled datasets
678
  """
679
  if self.num_crossval_splits == 0:
680
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
681
  raise
682
+
683
+ if (gene_balance is True) and (len(self.gene_class_dict.values())!=2):
684
+ logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.")
685
+ raise
686
+
687
  # ensure number of genes in each class is > 5 if validating model
688
  if self.classifier == "gene":
689
  insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
 
764
  else:
765
  # 5-fold cross-validate
766
  num_cells = len(data)
767
+ fifth_cells = int(np.floor(num_cells * 0.2))
768
  num_eval = min((self.eval_size * num_cells), fifth_cells)
769
  start = i * fifth_cells
770
  end = start + num_eval
 
843
  self.max_ncells,
844
  iteration_num,
845
  self.nproc,
846
+ gene_balance,
847
  )
848
+
849
+ if save_gene_split_datasets is True:
850
+ for split_name in ["train", "valid"]:
851
+ labeled_dataset_output_path = (
852
+ Path(output_dir) / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
853
+ ).with_suffix(".dataset")
854
+ if split_name == "train":
855
+ train_data.save_to_disk(str(labeled_dataset_output_path))
856
+ elif split_name == "valid":
857
+ eval_data.save_to_disk(str(labeled_dataset_output_path))
858
+
859
  if self.oos_test_size > 0:
860
  test_data = cu.prep_gene_classifier_split(
861
  data,
 
867
  iteration_num,
868
  self.nproc,
869
  )
870
+ if save_gene_split_datasets is True:
871
+ test_labeled_dataset_output_path = (
872
+ Path(output_dir) / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
873
+ ).with_suffix(".dataset")
874
+ test_data.save_to_disk(str(test_labeled_dataset_output_path))
875
+ if debug_gene_split_datasets is True:
876
+ logger.error("Exiting after saving gene split datasets given debug_gene_split_datasets = True.")
877
+ raise
878
  if n_hyperopt_trials == 0:
879
  trainer = self.train_classifier(
880
  model_directory,
 
1023
  subprocess.call(f"mkdir {output_directory}", shell=True)
1024
 
1025
  ##### Load model and training args #####
1026
+ model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize)
1027
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1028
  model, self.classifier, train_data, output_directory
1029
  )
 
1047
  ##### Fine-tune the model #####
1048
  # define the data collator
1049
  if self.classifier == "cell":
1050
+ data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary)
1051
  elif self.classifier == "gene":
1052
+ data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary)
1053
 
1054
  # define function to initiate model
1055
  def model_init():
1056
  model = pu.load_model(
1057
+ self.model_type, num_classes, model_directory, "train", quantize=self.quantize
1058
  )
1059
 
1060
  if self.freeze_layers is not None:
 
1066
  for param in module.parameters():
1067
  param.requires_grad = False
1068
 
1069
+ if self.quantize is False:
1070
+ model = model.to("cuda:0")
1071
  return model
1072
 
1073
  # create the trainer
 
1180
  subprocess.call(f"mkdir {output_directory}", shell=True)
1181
 
1182
  ##### Load model and training args #####
1183
+ model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize)
1184
 
1185
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1186
  model, self.classifier, train_data, output_directory
 
1210
  ##### Fine-tune the model #####
1211
  # define the data collator
1212
  if self.classifier == "cell":
1213
+ data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary)
1214
  elif self.classifier == "gene":
1215
+ data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary)
1216
 
1217
  # create the trainer
1218
  trainer = Trainer(
 
1334
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1335
 
1336
  # load previously fine-tuned model
1337
+ model = pu.load_model(self.model_type, num_classes, model_directory, "eval", quantize=self.quantize)
1338
 
1339
  # evaluate the model
1340
  result = self.evaluate_model(
geneformer/classifier_utils.py CHANGED
@@ -137,21 +137,22 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
137
 
138
 
139
  def prep_gene_classifier_train_eval_split(
140
- data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
141
  ):
142
  # generate cross-validation splits
143
  train_data = prep_gene_classifier_split(
144
- data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc
145
  )
146
  eval_data = prep_gene_classifier_split(
147
- data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc
148
  )
149
  return train_data, eval_data
150
 
151
 
152
  def prep_gene_classifier_split(
153
- data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc
154
  ):
 
155
  # generate cross-validation splits
156
  targets = np.array(targets)
157
  labels = np.array(labels)
@@ -172,6 +173,10 @@ def prep_gene_classifier_split(
172
  f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
173
  )
174
 
 
 
 
 
175
  # subsample to max_ncells
176
  subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
177
 
@@ -187,7 +192,7 @@ def prep_gene_classifier_split(
187
  return subset_data
188
 
189
 
190
- def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
191
  targets = np.array(targets)
192
  labels = np.array(labels)
193
  label_dict_train = dict(zip(targets, labels))
@@ -205,6 +210,9 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
205
  f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
206
  )
207
 
 
 
 
208
  # subsample to max_ncells
209
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
210
 
@@ -220,6 +228,110 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
220
  return train_data
221
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def balance_attr_splits(
224
  data,
225
  attr_to_split,
 
137
 
138
 
139
  def prep_gene_classifier_train_eval_split(
140
+ data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc, balance=False
141
  ):
142
  # generate cross-validation splits
143
  train_data = prep_gene_classifier_split(
144
+ data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc, balance
145
  )
146
  eval_data = prep_gene_classifier_split(
147
+ data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc, balance
148
  )
149
  return train_data, eval_data
150
 
151
 
152
  def prep_gene_classifier_split(
153
+ data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc, balance=False
154
  ):
155
+
156
  # generate cross-validation splits
157
  targets = np.array(targets)
158
  labels = np.array(labels)
 
173
  f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
174
  )
175
 
176
+ # balance gene subsets if train
177
+ if (subset_name == "train") and (balance is True):
178
+ subset_data, label_dict_subset = balance_gene_split(subset_data, label_dict_subset, num_proc)
179
+
180
  # subsample to max_ncells
181
  subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
182
 
 
192
  return subset_data
193
 
194
 
195
+ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, balance=False):
196
  targets = np.array(targets)
197
  labels = np.array(labels)
198
  label_dict_train = dict(zip(targets, labels))
 
210
  f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
211
  )
212
 
213
+ if balance is True:
214
+ train_data, label_dict_train = balance_gene_split(train_data, label_dict_train, num_proc)
215
+
216
  # subsample to max_ncells
217
  train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
218
 
 
228
  return train_data
229
 
230
 
231
+ def balance_gene_split(subset_data, label_dict_subset, num_proc):
232
+ # count occurrence of genes in each label category
233
+ label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset, num_proc)
234
+ label_ratio_0to1 = label0_counts/label1_counts
235
+
236
+ if 8/10 <= label_ratio_0to1 <= 10/8:
237
+ # gene sets already balanced
238
+ logger.info(
239
+ "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
240
+ )
241
+ return subset_data, label_dict_subset
242
+ else:
243
+ label_ratio_0to1_orig = label_ratio_0to1+0
244
+ label_dict_subset_orig = label_dict_subset.copy()
245
+ # balance gene sets
246
+ max_ntrials = 25
247
+ boost = 1
248
+ if label_ratio_0to1 > 10/8:
249
+ # downsample label 0
250
+ for i in range(max_ntrials):
251
+ label0 = 0
252
+ label0_genes = [k for k,v in label_dict_subset.items() if v == label0]
253
+ label0_ngenes = len(label0_genes)
254
+ label0_nremove = max(1,int(np.floor(label0_ngenes - label0_ngenes/(label_ratio_0to1*boost))))
255
+ random.seed(i)
256
+ label0_remove_genes = random.sample(label0_genes, label0_nremove)
257
+ label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label0_remove_genes}
258
+ label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc)
259
+ label_ratio_0to1 = label0_counts/label1_counts
260
+ if 8/10 <= label_ratio_0to1 <= 10/8:
261
+ # if gene sets now balanced, return new filtered data and new label_dict_subset
262
+ return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
263
+ elif label_ratio_0to1 > 10/8:
264
+ boost = boost*1.1
265
+ elif label_ratio_0to1 < 8/10:
266
+ boost = boost*0.9
267
+ else:
268
+ # downsample label 1
269
+ for i in range(max_ntrials):
270
+ label1 = 1
271
+ label1_genes = [k for k,v in label_dict_subset.items() if v == label1]
272
+ label1_ngenes = len(label1_genes)
273
+ label1_nremove = max(1,int(np.floor(label1_ngenes - label1_ngenes/((1/label_ratio_0to1)*boost))))
274
+ random.seed(i)
275
+ label1_remove_genes = random.sample(label1_genes, label1_nremove)
276
+ label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label1_remove_genes}
277
+ label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc)
278
+ label_ratio_0to1 = label0_counts/label1_counts
279
+ if 8/10 <= label_ratio_0to1 <= 10/8:
280
+ # if gene sets now balanced, return new filtered data and new label_dict_subset
281
+ return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
282
+ elif label_ratio_0to1 < 8/10:
283
+ boost = boost*1.1
284
+ elif label_ratio_0to1 > 10/8:
285
+ boost = boost*0.9
286
+
287
+ assert i+1 == max_ntrials
288
+ if (label_ratio_0to1 <= label_ratio_0to1_orig < 8/10) or (10/8 > label_ratio_0to1_orig >= label_ratio_0to1):
289
+ label_ratio_0to1 = label_ratio_0to1_orig
290
+ label_dict_subset_new = label_dict_subset_orig
291
+ logger.warning(
292
+ f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
293
+ )
294
+ return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
295
+
296
+
297
+ def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
298
+ def count_targets(example):
299
+ labels = [
300
+ label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
301
+ ]
302
+ counter_labels = Counter(labels)
303
+ # get count of labels 0 or 1, or if absent, return 0
304
+ example["labels_counts"] = [counter_labels.get(0,0),counter_labels.get(1,0)]
305
+ return example
306
+
307
+ subset_data = subset_data.map(count_targets, num_proc=num_proc)
308
+
309
+ label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
310
+ label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
311
+
312
+ subset_data = subset_data.remove_columns("labels_counts")
313
+
314
+ return label0_counts, label1_counts
315
+
316
+
317
+ def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
318
+ # function to filter by whether contains labels
319
+ def if_contains_subset_label(example):
320
+ a = list(label_dict_subset.keys())
321
+ b = example["input_ids"]
322
+ return not set(a).isdisjoint(b)
323
+
324
+ # filter dataset for examples containing classes for this split
325
+ logger.info("Filtering data for balanced genes")
326
+ subset_data_len_orig = len(subset_data)
327
+ subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
328
+ logger.info(
329
+ f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
330
+ )
331
+
332
+ return subset_data, label_dict_subset
333
+
334
+
335
  def balance_attr_splits(
336
  data,
337
  attr_to_split,
geneformer/collator_for_classification.py CHANGED
@@ -18,12 +18,6 @@ from transformers import (
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
20
 
21
- from . import TOKEN_DICTIONARY_FILE
22
-
23
- # load token dictionary (Ensembl IDs:token)
24
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
25
- token_dictionary = pickle.load(f)
26
-
27
  EncodedInput = List[int]
28
  logger = logging.get_logger(__name__)
29
  VERY_LARGE_INTEGER = int(
@@ -85,16 +79,18 @@ class TensorType(ExplicitEnum):
85
 
86
 
87
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
88
- mask_token = "<mask>"
89
- mask_token_id = token_dictionary.get("<mask>")
90
- pad_token = "<pad>"
91
- pad_token_id = token_dictionary.get("<pad>")
92
- padding_side = "right"
93
- all_special_ids = [
94
- token_dictionary.get("<mask>"),
95
- token_dictionary.get("<pad>")
96
- ]
97
- model_input_names = ["input_ids"]
 
 
98
 
99
  def _get_padding_truncation_strategies(
100
  self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
@@ -550,8 +546,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
550
  label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
551
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
552
  """
553
-
554
- tokenizer = PrecollatorForGeneAndCellClassification()
555
  class_type = "gene"
556
  padding: Union[bool, str, PaddingStrategy] = True
557
  max_length: Optional[int] = None
@@ -559,8 +554,9 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
559
  label_pad_token_id: int = -100
560
 
561
  def __init__(self, *args, **kwargs) -> None:
 
562
  super().__init__(
563
- tokenizer=self.tokenizer,
564
  padding=self.padding,
565
  max_length=self.max_length,
566
  pad_to_multiple_of=self.pad_to_multiple_of,
 
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
20
 
 
 
 
 
 
 
21
  EncodedInput = List[int]
22
  logger = logging.get_logger(__name__)
23
  VERY_LARGE_INTEGER = int(
 
79
 
80
 
81
  class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
82
+ def __init__(self, *args, **kwargs) -> None:
83
+ super().__init__(mask_token="<mask>", pad_token="<pad>")
84
+
85
+ self.token_dictionary = kwargs.get("token_dictionary")
86
+ self.padding_side = "right"
87
+ self.model_input_names = ["input_ids"]
88
+ self.mask_token_id = self.token_dictionary.get("<mask>")
89
+ self.pad_token_id = self.token_dictionary.get("<pad>")
90
+ self.all_special_ids = [
91
+ self.token_dictionary.get("<mask>"),
92
+ self.token_dictionary.get("<pad>")
93
+ ]
94
 
95
  def _get_padding_truncation_strategies(
96
  self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
 
546
  label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
547
  The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
548
  """
549
+
 
550
  class_type = "gene"
551
  padding: Union[bool, str, PaddingStrategy] = True
552
  max_length: Optional[int] = None
 
554
  label_pad_token_id: int = -100
555
 
556
  def __init__(self, *args, **kwargs) -> None:
557
+ self.token_dictionary = kwargs.pop("token_dictionary")
558
  super().__init__(
559
+ tokenizer=PrecollatorForGeneAndCellClassification(token_dictionary=self.token_dictionary),
560
  padding=self.padding,
561
  max_length=self.max_length,
562
  pad_to_multiple_of=self.pad_to_multiple_of,
geneformer/emb_extractor.py CHANGED
@@ -286,12 +286,20 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
286
  sc.tl.umap(adata, random_state=seed)
287
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
288
  sns.set_style("white")
289
- default_kwargs_dict = {"palette": "Set2", "size": 200}
290
  if kwargs_dict is not None:
291
  default_kwargs_dict.update(kwargs_dict)
292
 
293
- with plt.rc_context():
294
- sc.pl.umap(adata, color=label, **default_kwargs_dict)
 
 
 
 
 
 
 
 
295
  plt.savefig(output_file, bbox_inches="tight")
296
 
297
 
@@ -470,7 +478,6 @@ class EmbExtractor:
470
  ... emb_mode="cell",
471
  ... filter_data={"cell_type":["cardiomyocyte"]},
472
  ... max_ncells=1000,
473
- ... max_ncells_to_plot=1000,
474
  ... emb_layer=-1,
475
  ... emb_label=["disease", "cell_type"],
476
  ... labels_to_plot=["disease", "cell_type"])
@@ -783,15 +790,15 @@ class EmbExtractor:
783
  logger.error("Plotting UMAP requires 'labels_to_plot'. ")
784
  raise
785
 
786
- if max_ncells_to_plot > self.max_ncells:
787
- max_ncells_to_plot = self.max_ncells
788
- logger.warning(
789
- "max_ncells_to_plot must be <= max_ncells. "
790
- f"Changing max_ncells_to_plot to {self.max_ncells}."
791
- )
792
-
793
- if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells):
794
- embs = embs.sample(max_ncells_to_plot, axis=0)
795
 
796
  if self.emb_label is None:
797
  label_len = 0
 
286
  sc.tl.umap(adata, random_state=seed)
287
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
288
  sns.set_style("white")
289
+ default_kwargs_dict = {"size": 200}
290
  if kwargs_dict is not None:
291
  default_kwargs_dict.update(kwargs_dict)
292
 
293
+ cats = set(embs_df[label])
294
+
295
+ with plt.rc_context():
296
+ ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
297
+ ax.legend(markerscale=2,
298
+ frameon=False,
299
+ loc="center left",
300
+ bbox_to_anchor=(1, 0.5),
301
+ ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3))
302
+ plt.show()
303
  plt.savefig(output_file, bbox_inches="tight")
304
 
305
 
 
478
  ... emb_mode="cell",
479
  ... filter_data={"cell_type":["cardiomyocyte"]},
480
  ... max_ncells=1000,
 
481
  ... emb_layer=-1,
482
  ... emb_label=["disease", "cell_type"],
483
  ... labels_to_plot=["disease", "cell_type"])
 
790
  logger.error("Plotting UMAP requires 'labels_to_plot'. ")
791
  raise
792
 
793
+ if max_ncells_to_plot is not None:
794
+ if max_ncells_to_plot > self.max_ncells:
795
+ max_ncells_to_plot = self.max_ncells
796
+ logger.warning(
797
+ "max_ncells_to_plot must be <= max_ncells. "
798
+ f"Changing max_ncells_to_plot to {self.max_ncells}."
799
+ )
800
+ elif max_ncells_to_plot < self.max_ncells:
801
+ embs = embs.sample(max_ncells_to_plot, axis=0)
802
 
803
  if self.emb_label is None:
804
  label_len = 0
geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3b589bb5ec75040d05fc44dd6bf0184cf87f3c362cf158d196a6ed3b7fe5f39
3
+ size 940965
geneformer/{gene_name_id_dict.pkl → gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl} RENAMED
File without changes
geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab9dc40973fa5224d77b793e2fd114cacf3d08423ed9c4c49caf0ba9c7f218f1
3
+ size 788424
geneformer/gene_median_dictionary.pkl DELETED
Binary file (941 kB)
 
geneformer/in_silico_perturber.py CHANGED
@@ -63,7 +63,7 @@ class InSilicoPerturber:
63
  "anchor_gene": {None, str},
64
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
65
  "num_classes": {int},
66
- "emb_mode": {"cell", "cell_and_gene"},
67
  "cell_emb_style": {"mean_pool"},
68
  "filter_data": {None, dict},
69
  "cell_states_to_model": {None, dict},
@@ -71,6 +71,7 @@ class InSilicoPerturber:
71
  "max_ncells": {None, int},
72
  "cell_inds_to_perturb": {"all", dict},
73
  "emb_layer": {-1, 0},
 
74
  "forward_batch_size": {int},
75
  "nproc": {int},
76
  }
@@ -94,7 +95,8 @@ class InSilicoPerturber:
94
  emb_layer=-1,
95
  forward_batch_size=100,
96
  nproc=4,
97
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
98
  ):
99
  """
100
  Initialize in silico perturber.
@@ -129,16 +131,16 @@ class InSilicoPerturber:
129
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
130
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
131
  | anchor gene will be perturbed in combination with each other gene.
132
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
133
- | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
134
  num_classes : int
135
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
136
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
137
- emb_mode : {"cell", "cell_and_gene"}
138
- | Whether to output impact of perturbation on cell and/or gene embeddings.
139
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
140
  cell_emb_style : "mean_pool"
141
- | Method for summarizing cell embeddings.
142
  | Currently only option is mean pooling of gene embeddings for given cell.
143
  filter_data : None, dict
144
  | Default is to use all input data for in silico perturbation study.
@@ -183,6 +185,8 @@ class InSilicoPerturber:
183
  | Number of CPU processes to use.
184
  token_dictionary_file : Path
185
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
 
186
  """
187
  try:
188
  set_start_method("spawn")
@@ -219,15 +223,31 @@ class InSilicoPerturber:
219
  self.emb_layer = emb_layer
220
  self.forward_batch_size = forward_batch_size
221
  self.nproc = nproc
 
 
222
 
223
  self.validate_options()
224
 
225
  # load token dictionary (Ensembl IDs:token)
 
 
226
  with open(token_dictionary_file, "rb") as f:
227
  self.gene_token_dict = pickle.load(f)
228
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
229
 
230
  self.pad_token_id = self.gene_token_dict.get("<pad>")
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if self.anchor_gene is None:
233
  self.anchor_token = None
@@ -285,7 +305,7 @@ class InSilicoPerturber:
285
  continue
286
  valid_type = False
287
  for option in valid_options:
288
- if (option in [bool, int, list, dict]) and isinstance(
289
  attr_value, option
290
  ):
291
  valid_type = True
@@ -426,22 +446,46 @@ class InSilicoPerturber:
426
  self.max_len = pu.get_model_input_size(model)
427
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
428
 
429
-
430
  ### filter input data ###
431
  # general filtering of input data based on filter_data argument
432
  filtered_input_data = pu.load_and_filter(
433
  self.filter_data, self.nproc, input_data_file
434
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
436
 
437
  if self.perturb_group is True:
438
- self.isp_perturb_set(
439
- model, filtered_input_data, layer_to_quant, output_path_prefix
440
- )
 
 
 
 
 
441
  else:
442
- self.isp_perturb_all(
443
- model, filtered_input_data, layer_to_quant, output_path_prefix
444
- )
 
 
 
 
 
445
 
446
  def apply_additional_filters(self, filtered_input_data):
447
  # additional filtering of input data dependent on isp mode
@@ -486,6 +530,7 @@ class InSilicoPerturber:
486
  layer_to_quant: int,
487
  output_path_prefix: str,
488
  ):
 
489
  def make_group_perturbation_batch(example):
490
  example_input_ids = example["input_ids"]
491
  example["tokens_to_perturb"] = self.tokens_to_perturb
@@ -504,7 +549,7 @@ class InSilicoPerturber:
504
  if self.perturb_type == "delete":
505
  example = pu.delete_indices(example)
506
  elif self.perturb_type == "overexpress":
507
- example = pu.overexpress_tokens(example, self.max_len)
508
  example["n_overflow"] = pu.calc_n_overflow(
509
  self.max_len,
510
  example["length"],
@@ -678,8 +723,6 @@ class InSilicoPerturber:
678
  cos_sims_dict = self.update_perturbation_dictionary(
679
  cos_sims_dict,
680
  cos_sims_data,
681
- filtered_input_data,
682
- indices_to_perturb,
683
  gene_list,
684
  )
685
  else:
@@ -688,8 +731,6 @@ class InSilicoPerturber:
688
  cos_sims_dict[state] = self.update_perturbation_dictionary(
689
  cos_sims_dict[state],
690
  cos_sims_data[state],
691
- filtered_input_data,
692
- indices_to_perturb,
693
  gene_list,
694
  )
695
  del minibatch
@@ -711,6 +752,256 @@ class InSilicoPerturber:
711
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
712
  )
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  def isp_perturb_all(
715
  self,
716
  model,
@@ -729,8 +1020,10 @@ class InSilicoPerturber:
729
 
730
  if self.emb_mode == "cell_and_gene":
731
  stored_gene_embs_dict = defaultdict(list)
732
- for i in trange(len(filtered_input_data)):
733
- example_cell = filtered_input_data.select([i])
 
 
734
  full_original_emb = get_embs(
735
  model,
736
  example_cell,
@@ -738,18 +1031,33 @@ class InSilicoPerturber:
738
  layer_to_quant,
739
  self.pad_token_id,
740
  self.forward_batch_size,
741
- token_gene_dict=self.token_gene_dict,
742
  summary_stat=None,
743
  silent=True,
744
  )
745
-
 
 
 
 
 
746
  # gene_list is used to assign cos sims back to genes
747
- # need to remove the anchor gene
748
  gene_list = example_cell["input_ids"][0][:]
 
749
  if self.anchor_token is not None:
750
  for token in self.anchor_token:
751
  gene_list.remove(token)
752
-
 
 
 
 
 
 
 
 
 
 
753
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
754
  example_cell,
755
  self.perturb_type,
@@ -759,148 +1067,430 @@ class InSilicoPerturber:
759
  self.nproc,
760
  )
761
 
762
- full_perturbation_emb = get_embs(
763
- model,
764
- perturbation_batch,
765
- "gene",
766
- layer_to_quant,
767
- self.pad_token_id,
768
- self.forward_batch_size,
769
- token_gene_dict=self.token_gene_dict,
770
- summary_stat=None,
771
- silent=True,
772
- )
773
-
774
- num_inds_perturbed = 1 + self.combos
775
- # need to remove overexpressed gene to quantify cosine shifts
776
- if self.perturb_type == "overexpress":
777
- perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
778
- gene_list = gene_list[
779
- num_inds_perturbed:
780
- ] # index 0 is not overexpressed
781
-
782
- elif self.perturb_type == "delete":
783
- perturbation_emb = full_perturbation_emb
784
 
785
- original_batch = pu.make_comparison_batch(
786
- full_original_emb, indices_to_perturb, perturb_group=False
787
- )
788
-
789
- if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
790
- gene_cos_sims = pu.quant_cos_sims(
791
- perturbation_emb,
792
- original_batch,
793
- self.cell_states_to_model,
794
- self.state_embs_dict,
795
- emb_mode="gene",
796
- )
797
- if self.cell_states_to_model is not None:
798
- original_cell_emb = pu.compute_nonpadded_cell_embedding(
799
- full_original_emb, "mean_pool"
800
- )
801
- perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
802
- full_perturbation_emb, "mean_pool"
803
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
- cell_cos_sims = pu.quant_cos_sims(
806
- perturbation_cell_emb,
807
- original_cell_emb,
808
- self.cell_states_to_model,
809
- self.state_embs_dict,
810
- emb_mode="cell",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  )
812
 
813
- if self.emb_mode == "cell_and_gene":
814
- # remove perturbed index for gene list
815
- perturbed_gene_dict = {
816
- gene: gene_list[:i] + gene_list[i + 1 :]
817
- for i, gene in enumerate(gene_list)
 
 
818
  }
819
 
820
- for perturbation_i, perturbed_gene in enumerate(gene_list):
821
- for gene_j, affected_gene in enumerate(
822
- perturbed_gene_dict[perturbed_gene]
823
- ):
824
- try:
825
- stored_gene_embs_dict[
826
- (perturbed_gene, affected_gene)
827
- ].append(gene_cos_sims[perturbation_i, gene_j].item())
828
- except KeyError:
829
- stored_gene_embs_dict[
830
- (perturbed_gene, affected_gene)
831
- ] = gene_cos_sims[perturbation_i, gene_j].item()
832
 
833
- if self.cell_states_to_model is None:
834
- cos_sims_data = torch.mean(gene_cos_sims, dim=1)
835
- cos_sims_dict = self.update_perturbation_dictionary(
836
- cos_sims_dict,
837
- cos_sims_data,
838
- filtered_input_data,
839
- indices_to_perturb,
840
- gene_list,
841
- )
842
- else:
843
- cos_sims_data = cell_cos_sims
844
- for state in cos_sims_dict.keys():
845
- cos_sims_dict[state] = self.update_perturbation_dictionary(
846
- cos_sims_dict[state],
847
- cos_sims_data[state],
848
- filtered_input_data,
849
- indices_to_perturb,
850
- gene_list,
851
- )
 
 
 
852
 
853
- # save dict to disk every 100 cells
854
- if i % 100 == 0:
855
- pu.write_perturbation_dictionary(
856
- cos_sims_dict,
857
- f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
858
- )
859
- if self.emb_mode == "cell_and_gene":
860
- pu.write_perturbation_dictionary(
861
- stored_gene_embs_dict,
862
- f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
863
- )
864
 
865
- # reset and clear memory every 1000 cells
866
- if i % 1000 == 0:
867
- pickle_batch += 1
868
- if self.cell_states_to_model is None:
869
- cos_sims_dict = defaultdict(list)
870
- else:
871
- cos_sims_dict = {
872
- state: defaultdict(list)
873
- for state in pu.get_possible_states(self.cell_states_to_model)
874
- }
875
 
876
- if self.emb_mode == "cell_and_gene":
877
- stored_gene_embs_dict = defaultdict(list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
 
879
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880
 
881
- pu.write_perturbation_dictionary(
882
- cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
883
- )
 
 
 
 
 
884
 
885
- if self.emb_mode == "cell_and_gene":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  pu.write_perturbation_dictionary(
887
- stored_gene_embs_dict,
888
- f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
889
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
 
 
891
  def update_perturbation_dictionary(
892
  self,
893
  cos_sims_dict: defaultdict,
894
  cos_sims_data: torch.Tensor,
895
- filtered_input_data: Dataset,
896
- indices_to_perturb: List[List[int]],
897
  gene_list=None,
898
  ):
899
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
900
  logger.error(
901
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
902
- cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
903
- len(gene_list) = {len(gene_list)}."
904
  )
905
  raise
906
 
@@ -924,4 +1514,4 @@ class InSilicoPerturber:
924
  for i, cos in enumerate(cos_sims_data.tolist()):
925
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
926
 
927
- return cos_sims_dict
 
63
  "anchor_gene": {None, str},
64
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
65
  "num_classes": {int},
66
+ "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
67
  "cell_emb_style": {"mean_pool"},
68
  "filter_data": {None, dict},
69
  "cell_states_to_model": {None, dict},
 
71
  "max_ncells": {None, int},
72
  "cell_inds_to_perturb": {"all", dict},
73
  "emb_layer": {-1, 0},
74
+ "token_dictionary_file" : {None, str},
75
  "forward_batch_size": {int},
76
  "nproc": {int},
77
  }
 
95
  emb_layer=-1,
96
  forward_batch_size=100,
97
  nproc=4,
98
+ token_dictionary_file=None,
99
+ clear_mem_ncells=1000,
100
  ):
101
  """
102
  Initialize in silico perturber.
 
131
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
132
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
133
  | anchor gene will be perturbed in combination with each other gene.
134
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
135
+ | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
136
  num_classes : int
137
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
138
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
139
+ emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
140
+ | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
141
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
142
  cell_emb_style : "mean_pool"
143
+ | Method for summarizing cell embeddings if not using CLS token.
144
  | Currently only option is mean pooling of gene embeddings for given cell.
145
  filter_data : None, dict
146
  | Default is to use all input data for in silico perturbation study.
 
185
  | Number of CPU processes to use.
186
  token_dictionary_file : Path
187
  | Path to pickle file containing token dictionary (Ensembl ID:token).
188
+ clear_mem_ncells : int
189
+ | Clear memory every n cells.
190
  """
191
  try:
192
  set_start_method("spawn")
 
223
  self.emb_layer = emb_layer
224
  self.forward_batch_size = forward_batch_size
225
  self.nproc = nproc
226
+ self.token_dictionary_file = token_dictionary_file
227
+ self.clear_mem_ncells = clear_mem_ncells
228
 
229
  self.validate_options()
230
 
231
  # load token dictionary (Ensembl IDs:token)
232
+ if self.token_dictionary_file is None:
233
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
234
  with open(token_dictionary_file, "rb") as f:
235
  self.gene_token_dict = pickle.load(f)
236
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
237
 
238
  self.pad_token_id = self.gene_token_dict.get("<pad>")
239
+ self.cls_token_id = self.gene_token_dict.get("<cls>")
240
+ self.eos_token_id = self.gene_token_dict.get("<eos>")
241
+
242
+
243
+ # Identify if special token is present in the token dictionary
244
+ if (self.cls_token_id is not None) and (self.eos_token_id is not None):
245
+ self.special_token = True
246
+ else:
247
+ if "cls" in self.emb_mode:
248
+ logger.error(f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary.")
249
+ raise
250
+ self.special_token = False
251
 
252
  if self.anchor_gene is None:
253
  self.anchor_token = None
 
305
  continue
306
  valid_type = False
307
  for option in valid_options:
308
+ if (option in [bool, int, list, dict, str]) and isinstance(
309
  attr_value, option
310
  ):
311
  valid_type = True
 
446
  self.max_len = pu.get_model_input_size(model)
447
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
448
 
 
449
  ### filter input data ###
450
  # general filtering of input data based on filter_data argument
451
  filtered_input_data = pu.load_and_filter(
452
  self.filter_data, self.nproc, input_data_file
453
  )
454
+
455
+ # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
+ if self.special_token:
457
+ if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and ("cls" not in self.emb_mode):
458
+ logger.error(
459
+ "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
460
+ )
461
+ raise
462
+ if ("cls" in self.emb_mode):
463
+ if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (filtered_input_data["input_ids"][0][-1] != self.eos_token_id):
464
+ logger.error(
465
+ "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
466
+ )
467
+ raise
468
+
469
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
470
 
471
  if self.perturb_group is True:
472
+ if (self.special_token) and ("cls" in self.emb_mode):
473
+ self.isp_perturb_set_special(
474
+ model, filtered_input_data, layer_to_quant, output_path_prefix
475
+ )
476
+ else:
477
+ self.isp_perturb_set(
478
+ model, filtered_input_data, layer_to_quant, output_path_prefix
479
+ )
480
  else:
481
+ if (self.special_token) and ("cls" in self.emb_mode):
482
+ self.isp_perturb_all_special(
483
+ model, filtered_input_data, layer_to_quant, output_path_prefix
484
+ )
485
+ else:
486
+ self.isp_perturb_all(
487
+ model, filtered_input_data, layer_to_quant, output_path_prefix
488
+ )
489
 
490
  def apply_additional_filters(self, filtered_input_data):
491
  # additional filtering of input data dependent on isp mode
 
530
  layer_to_quant: int,
531
  output_path_prefix: str,
532
  ):
533
+
534
  def make_group_perturbation_batch(example):
535
  example_input_ids = example["input_ids"]
536
  example["tokens_to_perturb"] = self.tokens_to_perturb
 
549
  if self.perturb_type == "delete":
550
  example = pu.delete_indices(example)
551
  elif self.perturb_type == "overexpress":
552
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
553
  example["n_overflow"] = pu.calc_n_overflow(
554
  self.max_len,
555
  example["length"],
 
723
  cos_sims_dict = self.update_perturbation_dictionary(
724
  cos_sims_dict,
725
  cos_sims_data,
 
 
726
  gene_list,
727
  )
728
  else:
 
731
  cos_sims_dict[state] = self.update_perturbation_dictionary(
732
  cos_sims_dict[state],
733
  cos_sims_data[state],
 
 
734
  gene_list,
735
  )
736
  del minibatch
 
752
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
753
  )
754
 
755
+
756
+ def isp_perturb_set_special(
757
+ self,
758
+ model,
759
+ filtered_input_data: Dataset,
760
+ layer_to_quant: int,
761
+ output_path_prefix: str,
762
+ ):
763
+
764
+ def make_group_perturbation_batch(example):
765
+ example_input_ids = example["input_ids"]
766
+ example["tokens_to_perturb"] = self.tokens_to_perturb
767
+ indices_to_perturb = [
768
+ example_input_ids.index(token) if token in example_input_ids else None
769
+ for token in self.tokens_to_perturb
770
+ ]
771
+ indices_to_perturb = [
772
+ item for item in indices_to_perturb if item is not None
773
+ ]
774
+ if len(indices_to_perturb) > 0:
775
+ example["perturb_index"] = indices_to_perturb
776
+ else:
777
+ # -100 indicates tokens to overexpress are not present in rank value encoding
778
+ example["perturb_index"] = [-100]
779
+ if self.perturb_type == "delete":
780
+ example = pu.delete_indices(example)
781
+ elif self.perturb_type == "overexpress":
782
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
783
+ example["n_overflow"] = pu.calc_n_overflow(
784
+ self.max_len,
785
+ example["length"],
786
+ self.tokens_to_perturb,
787
+ indices_to_perturb,
788
+ )
789
+ return example
790
+
791
+ total_batch_length = len(filtered_input_data)
792
+ if self.cell_states_to_model is None:
793
+ cos_sims_dict = defaultdict(list)
794
+ else:
795
+ cos_sims_dict = {
796
+ state: defaultdict(list)
797
+ for state in pu.get_possible_states(self.cell_states_to_model)
798
+ }
799
+
800
+ perturbed_data = filtered_input_data.map(
801
+ make_group_perturbation_batch, num_proc=self.nproc
802
+ )
803
+
804
+ if self.perturb_type == "overexpress":
805
+ filtered_input_data = filtered_input_data.add_column(
806
+ "n_overflow", perturbed_data["n_overflow"]
807
+ )
808
+ filtered_input_data = filtered_input_data.map(
809
+ pu.truncate_by_n_overflow_special, num_proc=self.nproc
810
+ )
811
+
812
+ if self.emb_mode == "cls_and_gene":
813
+ stored_gene_embs_dict = defaultdict(list)
814
+
815
+ # iterate through batches
816
+ for i in trange(0, total_batch_length, self.forward_batch_size):
817
+ max_range = min(i + self.forward_batch_size, total_batch_length)
818
+ inds_select = [i for i in range(i, max_range)]
819
+
820
+ minibatch = filtered_input_data.select(inds_select)
821
+ perturbation_batch = perturbed_data.select(inds_select)
822
+
823
+ ##### CLS Embedding Mode #####
824
+ if self.emb_mode == "cls":
825
+ indices_to_perturb = perturbation_batch["perturb_index"]
826
+
827
+ original_cls_emb = get_embs(
828
+ model,
829
+ minibatch,
830
+ "cls",
831
+ layer_to_quant,
832
+ self.pad_token_id,
833
+ self.forward_batch_size,
834
+ token_gene_dict=self.token_gene_dict,
835
+ summary_stat=None,
836
+ silent=True,
837
+ )
838
+
839
+ perturbation_cls_emb = get_embs(
840
+ model,
841
+ perturbation_batch,
842
+ "cls",
843
+ layer_to_quant,
844
+ self.pad_token_id,
845
+ self.forward_batch_size,
846
+ token_gene_dict=self.token_gene_dict,
847
+ summary_stat=None,
848
+ silent=True,
849
+ )
850
+
851
+ # Calculate the cosine similarities
852
+ cls_cos_sims = pu.quant_cos_sims(
853
+ perturbation_cls_emb,
854
+ original_cls_emb,
855
+ self.cell_states_to_model,
856
+ self.state_embs_dict,
857
+ emb_mode="cell")
858
+
859
+ # Update perturbation dictionary
860
+ if self.cell_states_to_model is None:
861
+ cos_sims_dict = self.update_perturbation_dictionary(
862
+ cos_sims_dict,
863
+ cls_cos_sims,
864
+ gene_list = None,
865
+ )
866
+ else:
867
+ for state in cos_sims_dict.keys():
868
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
869
+ cos_sims_dict[state],
870
+ cls_cos_sims[state],
871
+ gene_list = None,
872
+ )
873
+
874
+ ##### CLS and Gene Embedding Mode #####
875
+ elif self.emb_mode == "cls_and_gene":
876
+ full_original_emb = get_embs(
877
+ model,
878
+ minibatch,
879
+ "gene",
880
+ layer_to_quant,
881
+ self.pad_token_id,
882
+ self.forward_batch_size,
883
+ self.token_gene_dict,
884
+ summary_stat=None,
885
+ silent=True,
886
+ )
887
+ indices_to_perturb = perturbation_batch["perturb_index"]
888
+ # remove indices that were perturbed
889
+ original_emb = pu.remove_perturbed_indices_set(
890
+ full_original_emb,
891
+ self.perturb_type,
892
+ indices_to_perturb,
893
+ self.tokens_to_perturb,
894
+ minibatch["length"],
895
+ )
896
+ full_perturbation_emb = get_embs(
897
+ model,
898
+ perturbation_batch,
899
+ "gene",
900
+ layer_to_quant,
901
+ self.pad_token_id,
902
+ self.forward_batch_size,
903
+ self.token_gene_dict,
904
+ summary_stat=None,
905
+ silent=True,
906
+ )
907
+
908
+ # remove special tokens and padding
909
+ original_emb = original_emb[:, 1:-1, :]
910
+ if self.perturb_type == "overexpress":
911
+ perturbation_emb = full_perturbation_emb[:,1+len(self.tokens_to_perturb):-1,:]
912
+ elif self.perturb_type == "delete":
913
+ perturbation_emb = full_perturbation_emb[:,1:max(perturbation_batch["length"])-1,:]
914
+
915
+ n_perturbation_genes = perturbation_emb.size()[1]
916
+
917
+ gene_cos_sims = pu.quant_cos_sims(
918
+ perturbation_emb,
919
+ original_emb,
920
+ self.cell_states_to_model,
921
+ self.state_embs_dict,
922
+ emb_mode="gene",
923
+ )
924
+
925
+ # get cls emb
926
+ original_cls_emb = full_original_emb[:,0,:]
927
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
928
+
929
+ cls_cos_sims = pu.quant_cos_sims(
930
+ perturbation_cls_emb,
931
+ original_cls_emb,
932
+ self.cell_states_to_model,
933
+ self.state_embs_dict,
934
+ emb_mode="cell",
935
+ )
936
+
937
+ # get cosine similarities in gene embeddings
938
+ # since getting gene embeddings, need gene names
939
+
940
+ gene_list = minibatch["input_ids"]
941
+ # need to truncate gene_list
942
+ genes_to_exclude = self.tokens_to_perturb + [self.cls_token_id, self.eos_token_id]
943
+ gene_list = [
944
+ [g for g in genes if g not in genes_to_exclude][
945
+ :n_perturbation_genes
946
+ ]
947
+ for genes in gene_list
948
+ ]
949
+
950
+ for cell_i, genes in enumerate(gene_list):
951
+ for gene_j, affected_gene in enumerate(genes):
952
+ if len(self.genes_to_perturb) > 1:
953
+ tokens_to_perturb = tuple(self.tokens_to_perturb)
954
+ else:
955
+ tokens_to_perturb = self.tokens_to_perturb[0]
956
+
957
+ # fill in the gene cosine similarities
958
+ try:
959
+ stored_gene_embs_dict[
960
+ (tokens_to_perturb, affected_gene)
961
+ ].append(gene_cos_sims[cell_i, gene_j].item())
962
+ except KeyError:
963
+ stored_gene_embs_dict[
964
+ (tokens_to_perturb, affected_gene)
965
+ ] = gene_cos_sims[cell_i, gene_j].item()
966
+
967
+ if self.cell_states_to_model is None:
968
+ cos_sims_dict = self.update_perturbation_dictionary(
969
+ cos_sims_dict,
970
+ cls_cos_sims,
971
+ gene_list = None,
972
+ )
973
+ else:
974
+ for state in cos_sims_dict.keys():
975
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
976
+ cos_sims_dict[state],
977
+ cls_cos_sims[state],
978
+ gene_list = None,
979
+ )
980
+ del full_original_emb
981
+ del original_emb
982
+ del full_perturbation_emb
983
+ del perturbation_emb
984
+ del gene_cos_sims
985
+
986
+ del original_cls_emb
987
+ del perturbation_cls_emb
988
+ del cls_cos_sims
989
+ del minibatch
990
+ del perturbation_batch
991
+
992
+ torch.cuda.empty_cache()
993
+
994
+ pu.write_perturbation_dictionary(
995
+ cos_sims_dict,
996
+ f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
997
+ )
998
+
999
+ if self.emb_mode == "cls_and_gene":
1000
+ pu.write_perturbation_dictionary(
1001
+ stored_gene_embs_dict,
1002
+ f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
1003
+ )
1004
+
1005
  def isp_perturb_all(
1006
  self,
1007
  model,
 
1020
 
1021
  if self.emb_mode == "cell_and_gene":
1022
  stored_gene_embs_dict = defaultdict(list)
1023
+
1024
+ num_inds_perturbed = 1 + self.combos
1025
+ for h in trange(len(filtered_input_data)):
1026
+ example_cell = filtered_input_data.select([h])
1027
  full_original_emb = get_embs(
1028
  model,
1029
  example_cell,
 
1031
  layer_to_quant,
1032
  self.pad_token_id,
1033
  self.forward_batch_size,
1034
+ self.token_gene_dict,
1035
  summary_stat=None,
1036
  silent=True,
1037
  )
1038
+
1039
+ if self.cell_states_to_model is not None:
1040
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
1041
+ full_original_emb, "mean_pool"
1042
+ )
1043
+
1044
  # gene_list is used to assign cos sims back to genes
 
1045
  gene_list = example_cell["input_ids"][0][:]
1046
+ # need to remove the anchor gene
1047
  if self.anchor_token is not None:
1048
  for token in self.anchor_token:
1049
  gene_list.remove(token)
1050
+ # index 0 is not overexpressed so remove
1051
+ if self.perturb_type == "overexpress":
1052
+ gene_list = gene_list[
1053
+ num_inds_perturbed:
1054
+ ]
1055
+ # remove perturbed index for gene list dict
1056
+ perturbed_gene_dict = {
1057
+ gene: gene_list[:i] + gene_list[i + 1 :]
1058
+ for i, gene in enumerate(gene_list)
1059
+ }
1060
+
1061
  perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
1062
  example_cell,
1063
  self.perturb_type,
 
1067
  self.nproc,
1068
  )
1069
 
1070
+ ispall_total_batch_length = len(perturbation_batch)
1071
+ for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False):
1072
+ ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length)
1073
+ perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)])
1074
+ indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range]
1075
+ gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1076
 
1077
+ full_perturbation_emb = get_embs(
1078
+ model,
1079
+ perturbation_minibatch,
1080
+ "gene",
1081
+ layer_to_quant,
1082
+ self.pad_token_id,
1083
+ self.forward_batch_size,
1084
+ self.token_gene_dict,
1085
+ summary_stat=None,
1086
+ silent=True,
 
 
 
 
 
 
 
 
1087
  )
1088
+
1089
+ del perturbation_minibatch
1090
+
1091
+ # need to remove overexpressed gene to quantify cosine shifts
1092
+ if self.perturb_type == "overexpress":
1093
+ perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
1094
+
1095
+ elif self.perturb_type == "delete":
1096
+ perturbation_emb = full_perturbation_emb
1097
+
1098
+
1099
+ if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
1100
+ original_emb_minibatch = pu.make_comparison_batch(
1101
+ full_original_emb, indices_to_perturb_mini, perturb_group=False
1102
+ )
1103
+ gene_cos_sims = pu.quant_cos_sims(
1104
+ perturbation_emb,
1105
+ original_emb_minibatch,
1106
+ self.cell_states_to_model,
1107
+ self.state_embs_dict,
1108
+ emb_mode="gene",
1109
+ )
1110
+ del original_emb_minibatch
1111
+
1112
+ if self.cell_states_to_model is not None:
1113
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
1114
+ full_perturbation_emb, "mean_pool"
1115
+ )
1116
+
1117
+ cell_cos_sims = pu.quant_cos_sims(
1118
+ perturbation_cell_emb,
1119
+ original_cell_emb,
1120
+ self.cell_states_to_model,
1121
+ self.state_embs_dict,
1122
+ emb_mode="cell",
1123
+ )
1124
+ del perturbation_cell_emb
1125
+
1126
+ if self.emb_mode == "cell_and_gene":
1127
 
1128
+ for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1129
+ for gene_j, affected_gene in enumerate(
1130
+ perturbed_gene_dict[perturbed_gene]
1131
+ ):
1132
+ try:
1133
+ stored_gene_embs_dict[
1134
+ (perturbed_gene, affected_gene)
1135
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1136
+ except KeyError:
1137
+ stored_gene_embs_dict[
1138
+ (perturbed_gene, affected_gene)
1139
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1140
+
1141
+ del full_perturbation_emb
1142
+
1143
+ if self.cell_states_to_model is None:
1144
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1145
+ cos_sims_dict = self.update_perturbation_dictionary(
1146
+ cos_sims_dict,
1147
+ cos_sims_data,
1148
+ gene_list_mini,
1149
+ )
1150
+ else:
1151
+ cos_sims_data = cell_cos_sims
1152
+ for state in cos_sims_dict.keys():
1153
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1154
+ cos_sims_dict[state],
1155
+ cos_sims_data[state],
1156
+ gene_list_mini,
1157
+ )
1158
+
1159
+ # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1160
+ if i % self.clear_mem_ncells/10 == 0:
1161
+ pu.write_perturbation_dictionary(
1162
+ cos_sims_dict,
1163
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1164
+ )
1165
+ if self.emb_mode == "cell_and_gene":
1166
+ pu.write_perturbation_dictionary(
1167
+ stored_gene_embs_dict,
1168
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1169
+ )
1170
+
1171
+ # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1172
+ if i % self.clear_mem_ncells == 0:
1173
+ pickle_batch += 1
1174
+ if self.cell_states_to_model is None:
1175
+ cos_sims_dict = defaultdict(list)
1176
+ else:
1177
+ cos_sims_dict = {
1178
+ state: defaultdict(list)
1179
+ for state in pu.get_possible_states(self.cell_states_to_model)
1180
+ }
1181
+
1182
+ if self.emb_mode == "cell_and_gene":
1183
+ stored_gene_embs_dict = defaultdict(list)
1184
+
1185
+ torch.cuda.empty_cache()
1186
+
1187
+ pu.write_perturbation_dictionary(
1188
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}"
1189
+ )
1190
+
1191
+ if self.emb_mode == "cell_and_gene":
1192
+ pu.write_perturbation_dictionary(
1193
+ stored_gene_embs_dict,
1194
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1195
  )
1196
 
1197
+ pickle_batch = -1
1198
+ if self.cell_states_to_model is None:
1199
+ cos_sims_dict = defaultdict(list)
1200
+ else:
1201
+ cos_sims_dict = {
1202
+ state: defaultdict(list)
1203
+ for state in pu.get_possible_states(self.cell_states_to_model)
1204
  }
1205
 
1206
+ if self.emb_mode == "cell_and_gene":
1207
+ stored_gene_embs_dict = defaultdict(list)
 
 
 
 
 
 
 
 
 
 
1208
 
1209
+ # clear memory between cells
1210
+ del perturbation_batch
1211
+ del full_original_emb
1212
+ if self.cell_states_to_model is not None:
1213
+ del original_cell_emb
1214
+ torch.cuda.empty_cache()
1215
+
1216
+ def isp_perturb_all_special(
1217
+ self,
1218
+ model,
1219
+ filtered_input_data: Dataset,
1220
+ layer_to_quant: int,
1221
+ output_path_prefix: str,
1222
+ ):
1223
+ pickle_batch = -1
1224
+ if self.cell_states_to_model is None:
1225
+ cos_sims_dict = defaultdict(list)
1226
+ else:
1227
+ cos_sims_dict = {
1228
+ state: defaultdict(list)
1229
+ for state in pu.get_possible_states(self.cell_states_to_model)
1230
+ }
1231
 
1232
+ if self.emb_mode == "cls_and_gene":
1233
+ stored_gene_embs_dict = defaultdict(list)
 
 
 
 
 
 
 
 
 
1234
 
1235
+ num_inds_perturbed = 1 + self.combos
1236
+ for h in trange(len(filtered_input_data)):
1237
+ example_cell = filtered_input_data.select([h])
 
 
 
 
 
 
 
1238
 
1239
+ # get original example cell cls and/or gene embs for comparison
1240
+ if self.emb_mode == "cls":
1241
+ original_cls_emb = get_embs(
1242
+ model,
1243
+ example_cell,
1244
+ "cls",
1245
+ layer_to_quant,
1246
+ self.pad_token_id,
1247
+ self.forward_batch_size,
1248
+ self.token_gene_dict,
1249
+ summary_stat=None,
1250
+ silent=True,
1251
+ )
1252
+ elif self.emb_mode == "cls_and_gene":
1253
+ full_original_emb = get_embs(
1254
+ model,
1255
+ example_cell,
1256
+ "gene",
1257
+ layer_to_quant,
1258
+ self.pad_token_id,
1259
+ self.forward_batch_size,
1260
+ self.token_gene_dict,
1261
+ summary_stat=None,
1262
+ silent=True,
1263
+ )
1264
+ original_cls_emb = full_original_emb[:,0,:].clone().detach()
1265
+
1266
+ # gene_list is used to assign cos sims back to genes
1267
+ gene_list = example_cell["input_ids"][0][:]
1268
 
1269
+ # need to remove special tokens
1270
+ for token in [self.cls_token_id, self.eos_token_id]:
1271
+ gene_list.remove(token)
1272
+ # need to remove the anchor gene
1273
+ if self.anchor_token is not None:
1274
+ for token in self.anchor_token:
1275
+ gene_list.remove(token)
1276
+ # index 0 is not overexpressed so remove
1277
+ if self.perturb_type == "overexpress":
1278
+ gene_list = gene_list[
1279
+ num_inds_perturbed:
1280
+ ]
1281
+ # remove perturbed index for gene list dict
1282
+ perturbed_gene_dict = {
1283
+ gene: gene_list[:i] + gene_list[i + 1 :]
1284
+ for i, gene in enumerate(gene_list)
1285
+ }
1286
 
1287
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1288
+ example_cell,
1289
+ self.perturb_type,
1290
+ self.tokens_to_perturb,
1291
+ self.anchor_token,
1292
+ self.combos,
1293
+ self.nproc,
1294
+ )
1295
 
1296
+ ispall_total_batch_length = len(perturbation_batch)
1297
+ for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False):
1298
+ ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length)
1299
+ perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)])
1300
+ indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range]
1301
+ gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch
1302
+
1303
+ ##### CLS Embedding Mode #####
1304
+ if self.emb_mode == "cls":
1305
+ # Extract cls embeddings from perturbed cells
1306
+ perturbation_cls_emb = get_embs(
1307
+ model,
1308
+ perturbation_minibatch,
1309
+ "cls",
1310
+ layer_to_quant,
1311
+ self.pad_token_id,
1312
+ self.forward_batch_size,
1313
+ self.token_gene_dict,
1314
+ summary_stat=None,
1315
+ silent=True,
1316
+ )
1317
+
1318
+ # Calculate cosine similarities
1319
+ cls_cos_sims = pu.quant_cos_sims(
1320
+ perturbation_cls_emb,
1321
+ original_cls_emb,
1322
+ self.cell_states_to_model,
1323
+ self.state_embs_dict,
1324
+ emb_mode="cell",
1325
+ )
1326
+
1327
+ if self.cell_states_to_model is None:
1328
+ cos_sims_dict = self.update_perturbation_dictionary(
1329
+ cos_sims_dict,
1330
+ cls_cos_sims,
1331
+ gene_list_mini,
1332
+ )
1333
+ else:
1334
+
1335
+ for state in cos_sims_dict.keys():
1336
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1337
+ cos_sims_dict[state],
1338
+ cls_cos_sims[state],
1339
+ gene_list_mini,
1340
+ )
1341
+
1342
+ del perturbation_minibatch
1343
+ del perturbation_cls_emb
1344
+ del cls_cos_sims
1345
+
1346
+ ##### CLS and Gene Embedding Mode #####
1347
+ elif self.emb_mode == "cls_and_gene":
1348
+ full_perturbation_emb = get_embs(
1349
+ model,
1350
+ perturbation_minibatch,
1351
+ "gene",
1352
+ layer_to_quant,
1353
+ self.pad_token_id,
1354
+ self.forward_batch_size,
1355
+ self.token_gene_dict,
1356
+ summary_stat=None,
1357
+ silent=True,
1358
+ )
1359
+
1360
+ # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1361
+ if self.perturb_type == "overexpress":
1362
+ perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :].clone().detach()
1363
+ elif self.perturb_type == "delete":
1364
+ perturbation_emb = full_perturbation_emb[:, 1:-1, :].clone().detach()
1365
+
1366
+ original_emb_minibatch = pu.make_comparison_batch(
1367
+ full_original_emb, indices_to_perturb_mini, perturb_group=False
1368
+ )
1369
+
1370
+ original_emb_minibatch = original_emb_minibatch[:, 1:-1, :].clone().detach()
1371
+ gene_cos_sims = pu.quant_cos_sims(
1372
+ perturbation_emb,
1373
+ original_emb_minibatch,
1374
+ self.cell_states_to_model,
1375
+ self.state_embs_dict,
1376
+ emb_mode="gene",
1377
+ )
1378
+
1379
+ for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1380
+ for gene_j, affected_gene in enumerate(
1381
+ perturbed_gene_dict[perturbed_gene]
1382
+ ):
1383
+ try:
1384
+ stored_gene_embs_dict[
1385
+ (perturbed_gene, affected_gene)
1386
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1387
+ except KeyError:
1388
+ stored_gene_embs_dict[
1389
+ (perturbed_gene, affected_gene)
1390
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1391
+
1392
+ # get cls emb
1393
+ perturbation_cls_emb = full_perturbation_emb[:,0,:].clone().detach()
1394
+
1395
+ cls_cos_sims = pu.quant_cos_sims(
1396
+ perturbation_cls_emb,
1397
+ original_cls_emb,
1398
+ self.cell_states_to_model,
1399
+ self.state_embs_dict,
1400
+ emb_mode="cell",
1401
+ )
1402
+
1403
+ if self.cell_states_to_model is None:
1404
+ cos_sims_dict = self.update_perturbation_dictionary(
1405
+ cos_sims_dict,
1406
+ cls_cos_sims,
1407
+ gene_list_mini,
1408
+ )
1409
+ else:
1410
+ for state in cos_sims_dict.keys():
1411
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1412
+ cos_sims_dict[state],
1413
+ cls_cos_sims[state],
1414
+ gene_list_mini,
1415
+ )
1416
+
1417
+ del perturbation_minibatch
1418
+ del original_emb_minibatch
1419
+ del full_perturbation_emb
1420
+ del perturbation_emb
1421
+ del perturbation_cls_emb
1422
+ del cls_cos_sims
1423
+ del gene_cos_sims
1424
+
1425
+ # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1426
+ if i % max(1,self.clear_mem_ncells/10) == 0:
1427
+ pu.write_perturbation_dictionary(
1428
+ cos_sims_dict,
1429
+ f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1430
+ )
1431
+ if self.emb_mode == "cls_and_gene":
1432
+ pu.write_perturbation_dictionary(
1433
+ stored_gene_embs_dict,
1434
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1435
+ )
1436
+
1437
+ # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1438
+ if i % self.clear_mem_ncells == 0:
1439
+ pickle_batch += 1
1440
+ if self.cell_states_to_model is None:
1441
+ cos_sims_dict = defaultdict(list)
1442
+ else:
1443
+ cos_sims_dict = {
1444
+ state: defaultdict(list)
1445
+ for state in pu.get_possible_states(self.cell_states_to_model)
1446
+ }
1447
+
1448
+ if self.emb_mode == "cls_and_gene":
1449
+ stored_gene_embs_dict = defaultdict(list)
1450
+
1451
+ torch.cuda.empty_cache()
1452
+
1453
  pu.write_perturbation_dictionary(
1454
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}"
 
1455
  )
1456
+
1457
+ if self.emb_mode == "cls_and_gene":
1458
+ pu.write_perturbation_dictionary(
1459
+ stored_gene_embs_dict,
1460
+ f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1461
+ )
1462
+
1463
+ pickle_batch = -1
1464
+ if self.cell_states_to_model is None:
1465
+ cos_sims_dict = defaultdict(list)
1466
+ else:
1467
+ cos_sims_dict = {
1468
+ state: defaultdict(list)
1469
+ for state in pu.get_possible_states(self.cell_states_to_model)
1470
+ }
1471
+
1472
+ if self.emb_mode == "cls_and_gene":
1473
+ stored_gene_embs_dict = defaultdict(list)
1474
+
1475
+ # clear memory between cells
1476
+ del perturbation_batch
1477
+ del original_cls_emb
1478
+ if self.emb_mode == "cls_and_gene":
1479
+ del full_original_emb
1480
+ torch.cuda.empty_cache()
1481
 
1482
+
1483
  def update_perturbation_dictionary(
1484
  self,
1485
  cos_sims_dict: defaultdict,
1486
  cos_sims_data: torch.Tensor,
 
 
1487
  gene_list=None,
1488
  ):
1489
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1490
  logger.error(
1491
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1492
+ {cos_sims_data.shape[0]=}.\n \
1493
+ {len(gene_list)=}."
1494
  )
1495
  raise
1496
 
 
1514
  for i, cos in enumerate(cos_sims_data.tolist()):
1515
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1516
 
1517
+ return cos_sims_dict
geneformer/in_silico_perturber_stats.py CHANGED
@@ -114,6 +114,7 @@ def read_dictionaries(
114
  state_dict[state_value][key] += new_dict[key]
115
  except KeyError:
116
  state_dict[state_value][key] = new_dict[key]
 
117
  if not file_found:
118
  logger.error(
119
  "No raw data for processing found within provided directory. "
@@ -237,13 +238,16 @@ def find(variable, x):
237
 
238
 
239
  def isp_aggregate_gene_shifts(
240
- cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict
241
  ):
242
  cos_shift_data = dict()
243
  for i in trange(cos_sims_df.shape[0]):
244
  token = cos_sims_df["Gene"][i]
245
  for dict_i in dict_list:
246
- affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
 
 
 
247
  for key in affected_pairs:
248
  if key in cos_shift_data.keys():
249
  cos_shift_data[key] += dict_i.get(key, [])
@@ -256,11 +260,11 @@ def isp_aggregate_gene_shifts(
256
  cos_sims_full_df = pd.DataFrame()
257
  cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
258
  cos_sims_full_df["Gene_name"] = [
259
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0]
260
  for k, v in cos_data_mean.items()
261
  ]
262
  cos_sims_full_df["Ensembl_ID"] = [
263
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0]
264
  for k, v in cos_data_mean.items()
265
  ]
266
 
@@ -690,7 +694,7 @@ class InSilicoPerturberStats:
690
  | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
691
  | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
692
  combos : {0,1,2}
693
- | Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
694
  anchor_gene : None, str
695
  | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
696
  | For example, if combos=1 and anchor_gene="ENSG00000136574":
@@ -1014,7 +1018,7 @@ class InSilicoPerturberStats:
1014
  },
1015
  index=[i for i in range(len(gene_list))],
1016
  )
1017
-
1018
  if self.mode == "goal_state_shift":
1019
  cos_sims_df = isp_stats_to_goal_state(
1020
  cos_sims_df_initial,
@@ -1045,11 +1049,23 @@ class InSilicoPerturberStats:
1045
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
1046
 
1047
  elif self.mode == "aggregate_gene_shifts":
 
 
 
 
 
 
 
 
 
 
 
1048
  cos_sims_df = isp_aggregate_gene_shifts(
1049
  cos_sims_df_initial,
1050
  dict_list,
1051
  self.gene_token_id_dict,
1052
  self.gene_id_name_dict,
 
1053
  )
1054
 
1055
  # save perturbation stats to output_path
 
114
  state_dict[state_value][key] += new_dict[key]
115
  except KeyError:
116
  state_dict[state_value][key] = new_dict[key]
117
+
118
  if not file_found:
119
  logger.error(
120
  "No raw data for processing found within provided directory. "
 
238
 
239
 
240
  def isp_aggregate_gene_shifts(
241
+ cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict, token_dtype
242
  ):
243
  cos_shift_data = dict()
244
  for i in trange(cos_sims_df.shape[0]):
245
  token = cos_sims_df["Gene"][i]
246
  for dict_i in dict_list:
247
+ if token_dtype == "nontuple":
248
+ affected_pairs = [k for k, v in dict_i.items() if k[0] == token]
249
+ else:
250
+ affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
251
  for key in affected_pairs:
252
  if key in cos_shift_data.keys():
253
  cos_shift_data[key] += dict_i.get(key, [])
 
260
  cos_sims_full_df = pd.DataFrame()
261
  cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
262
  cos_sims_full_df["Gene_name"] = [
263
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"].item()
264
  for k, v in cos_data_mean.items()
265
  ]
266
  cos_sims_full_df["Ensembl_ID"] = [
267
+ cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"].item()
268
  for k, v in cos_data_mean.items()
269
  ]
270
 
 
694
  | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
695
  | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
696
  combos : {0,1,2}
697
+ | Whether genex perturbed in isp experiment were perturbed individually (0), in pairs (1), or in triplets (2).
698
  anchor_gene : None, str
699
  | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
700
  | For example, if combos=1 and anchor_gene="ENSG00000136574":
 
1018
  },
1019
  index=[i for i in range(len(gene_list))],
1020
  )
1021
+
1022
  if self.mode == "goal_state_shift":
1023
  cos_sims_df = isp_stats_to_goal_state(
1024
  cos_sims_df_initial,
 
1049
  cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
1050
 
1051
  elif self.mode == "aggregate_gene_shifts":
1052
+ if (self.genes_perturbed == "all") and (self.combos == 0):
1053
+ tuple_types = [True if isinstance(genes, tuple) else False for genes in gene_list]
1054
+ if all(tuple_types):
1055
+ token_dtype = "tuple"
1056
+ elif not any(tuple_types):
1057
+ token_dtype = "nontuple"
1058
+ else:
1059
+ token_dtype = "mix"
1060
+ else:
1061
+ token_dtype = "mix"
1062
+
1063
  cos_sims_df = isp_aggregate_gene_shifts(
1064
  cos_sims_df_initial,
1065
  dict_list,
1066
  self.gene_token_id_dict,
1067
  self.gene_id_name_dict,
1068
+ token_dtype
1069
  )
1070
 
1071
  # save perturbation stats to output_path
geneformer/mtl/__init__.py ADDED
File without changes
geneformer/mtl/collators.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #imports
2
+ import torch
3
+
4
+ from ..collator_for_classification import DataCollatorForGeneClassification
5
+
6
+ """
7
+ Geneformer collator for multi-task cell classification.
8
+ """
9
+
10
+ class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
11
+ class_type = "cell"
12
+
13
+ def __init__(self, *args, **kwargs) -> None:
14
+ super().__init__(*args, **kwargs)
15
+
16
+ def _prepare_batch(self, features):
17
+ # Process inputs as usual
18
+ batch = self.tokenizer.pad(
19
+ features,
20
+ class_type=self.class_type,
21
+ padding=self.padding,
22
+ max_length=self.max_length,
23
+ pad_to_multiple_of=self.pad_to_multiple_of,
24
+ return_tensors="pt",
25
+ )
26
+
27
+ # Check if labels are present
28
+ if "label" in features[0]:
29
+ # Initialize labels dictionary for all tasks
30
+ labels = {task: [] for task in features[0]["label"].keys()}
31
+
32
+ # Populate labels for each task
33
+ for feature in features:
34
+ for task, label in feature["label"].items():
35
+ labels[task].append(label)
36
+
37
+ # Convert label lists to tensors, handling dictionaries appropriately
38
+ for task in labels:
39
+ if isinstance(labels[task][0], (list, torch.Tensor)):
40
+ dtype = torch.long
41
+ labels[task] = torch.tensor(labels[task], dtype=dtype)
42
+ elif isinstance(labels[task][0], dict):
43
+ # Handle dict specifically if needed
44
+ pass # Resolve nested data structure
45
+
46
+ # Update the batch to include task-specific labels
47
+ batch["labels"] = labels
48
+ else:
49
+ # If no labels are present, create empty labels for all tasks
50
+ batch["labels"] = {task: torch.tensor([], dtype=torch.long) for task in features[0]["input_ids"].keys()}
51
+
52
+ return batch
53
+
54
+ def __call__(self, features):
55
+ batch = self._prepare_batch(features)
56
+
57
+ for k, v in batch.items():
58
+ if torch.is_tensor(v):
59
+ batch[k] = v.clone().detach()
60
+ elif isinstance(v, dict):
61
+ # Assuming nested structure needs conversion
62
+ batch[k] = {task: torch.tensor(labels, dtype=torch.int64) for task, labels in v.items()}
63
+ else:
64
+ batch[k] = torch.tensor(v, dtype=torch.int64)
65
+
66
+ return batch
geneformer/mtl/data.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ import os
3
+ from .collators import DataCollatorForMultitaskCellClassification
4
+
5
+ def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
6
+ try:
7
+ dataset = load_from_disk(dataset_path)
8
+
9
+ task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
10
+ task_to_column = dict(zip(task_names, config["task_columns"]))
11
+ config["task_names"] = task_names
12
+
13
+ if not is_test:
14
+ available_columns = set(dataset.column_names)
15
+ for column in task_to_column.values():
16
+ if column not in available_columns:
17
+ raise KeyError(f"Column {column} not found in the dataset. Available columns: {list(available_columns)}")
18
+
19
+ label_mappings = {}
20
+ task_label_mappings = {}
21
+ cell_id_mapping = {}
22
+ num_labels_list = []
23
+
24
+ # Load or create task label mappings
25
+ if not is_test:
26
+ for task, column in task_to_column.items():
27
+ unique_values = sorted(set(dataset[column])) # Ensure consistency
28
+ label_mappings[column] = {label: idx for idx, label in enumerate(unique_values)}
29
+ task_label_mappings[task] = label_mappings[column]
30
+ num_labels_list.append(len(unique_values))
31
+
32
+ # Print the mappings for each task with dataset type prefix
33
+ for task, mapping in task_label_mappings.items():
34
+ print(f"{dataset_type.capitalize()} mapping for {task}: {mapping}") # sanity check, for train/validation splits
35
+
36
+ # Save the task label mappings as a pickle file
37
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
38
+ pickle.dump(task_label_mappings, f)
39
+ else:
40
+ # Load task label mappings from pickle file for test data
41
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
42
+ task_label_mappings = pickle.load(f)
43
+
44
+ # Infer num_labels_list from task_label_mappings
45
+ for task, mapping in task_label_mappings.items():
46
+ num_labels_list.append(len(mapping))
47
+
48
+ # Store unique cell IDs in a separate dictionary
49
+ for idx, record in enumerate(dataset):
50
+ cell_id = record.get('unique_cell_id', idx)
51
+ cell_id_mapping[idx] = cell_id
52
+
53
+ # Transform records to the desired format
54
+ transformed_dataset = []
55
+ for idx, record in enumerate(dataset):
56
+ transformed_record = {}
57
+ transformed_record['input_ids'] = torch.tensor(record['input_ids'], dtype=torch.long)
58
+
59
+ # Use index-based cell ID for internal tracking
60
+ transformed_record['cell_id'] = idx
61
+
62
+ if not is_test:
63
+ # Prepare labels
64
+ label_dict = {}
65
+ for task, column in task_to_column.items():
66
+ label_value = record[column]
67
+ label_index = task_label_mappings[task][label_value]
68
+ label_dict[task] = label_index
69
+ transformed_record['label'] = label_dict
70
+ else:
71
+ # Create dummy labels for test data
72
+ label_dict = {task: -1 for task in config["task_names"]}
73
+ transformed_record['label'] = label_dict
74
+
75
+ transformed_dataset.append(transformed_record)
76
+
77
+ return transformed_dataset, cell_id_mapping, num_labels_list
78
+ except KeyError as e:
79
+ print(f"Missing configuration or dataset key: {e}")
80
+ except Exception as e:
81
+ print(f"An error occurred while loading or preprocessing data: {e}")
82
+ return None, None, None
83
+
84
+ def preload_and_process_data(config):
85
+ # Load and preprocess data once
86
+ train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
87
+ val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
88
+ return train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list
89
+
90
+ def get_data_loader(preprocessed_dataset, batch_size):
91
+ nproc = os.cpu_count() ### I/O operations
92
+
93
+ data_collator = DataCollatorForMultitaskCellClassification()
94
+
95
+ loader = DataLoader(preprocessed_dataset, batch_size=batch_size, shuffle=True,
96
+ collate_fn=data_collator, num_workers=nproc, pin_memory=True)
97
+ return loader
98
+ def preload_data(config):
99
+ # Preprocessing the data before the Optuna trials start
100
+ train_loader = get_data_loader("train", config)
101
+ val_loader = get_data_loader("val", config)
102
+ return train_loader, val_loader
103
+
104
+ def load_and_preprocess_test_data(config):
105
+ """
106
+ Load and preprocess test data, treating it as unlabeled.
107
+ """
108
+ return load_and_preprocess_data(config["test_path"], config, is_test=True)
109
+
110
+ def prepare_test_loader(config):
111
+ """
112
+ Prepare DataLoader for the test dataset.
113
+ """
114
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
115
+ test_loader = get_data_loader(test_dataset, config['batch_size'])
116
+ return test_loader, cell_id_mapping, num_labels_list
geneformer/mtl/eval_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ import pandas as pd
3
+ from .data import prepare_test_loader
4
+ from .model import GeneformerMultiTask
5
+
6
+ def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
7
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
8
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
9
+ cell_ids = []
10
+
11
+ # Load task label mappings from pickle file
12
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
13
+ task_label_mappings = pickle.load(f)
14
+
15
+ model.eval()
16
+ with torch.no_grad():
17
+ for batch in test_loader:
18
+ input_ids = batch['input_ids'].to(device)
19
+ attention_mask = batch['attention_mask'].to(device)
20
+ _, logits, _ = model(input_ids, attention_mask)
21
+ for sample_idx in range(len(batch['input_ids'])):
22
+ cell_id = cell_id_mapping[batch['cell_id'][sample_idx].item()]
23
+ cell_ids.append(cell_id)
24
+ for i, task_name in enumerate(config["task_names"]):
25
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
26
+ pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
27
+ task_pred_labels[task_name].append(pred_label)
28
+ task_pred_probs[task_name].append(pred_prob)
29
+
30
+ # Save test predictions with cell IDs and probabilities to CSV
31
+ test_results_dir = config["results_dir"]
32
+ os.makedirs(test_results_dir, exist_ok=True)
33
+ test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
34
+
35
+ rows = []
36
+ for sample_idx in range(len(cell_ids)):
37
+ row = {'Cell ID': cell_ids[sample_idx]}
38
+ for task_name in config["task_names"]:
39
+ row[f'{task_name} Prediction'] = task_pred_labels[task_name][sample_idx]
40
+ row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx]))
41
+ rows.append(row)
42
+
43
+ df = pd.DataFrame(rows)
44
+ df.to_csv(test_preds_file, index=False)
45
+ print(f"Test predictions saved to {test_preds_file}")
46
+
47
+ def load_and_evaluate_test_model(config):
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
50
+ model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
51
+ hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
52
+
53
+ # Load the saved best hyperparameters
54
+ with open(hyperparams_path, 'r') as f:
55
+ best_hyperparams = json.load(f)
56
+
57
+ # Extract the task weights if present, otherwise set to None
58
+ task_weights = best_hyperparams.get("task_weights", None)
59
+ normalized_task_weights = task_weights if task_weights else []
60
+
61
+ # Print the loaded hyperparameters
62
+ print("Loaded hyperparameters:")
63
+ for param, value in best_hyperparams.items():
64
+ if param == "task_weights":
65
+ print(f"normalized_task_weights: {value}")
66
+ else:
67
+ print(f"{param}: {value}")
68
+
69
+ best_model_path = os.path.join(model_directory, "pytorch_model.bin")
70
+ best_model = GeneformerMultiTask(
71
+ config["pretrained_path"],
72
+ num_labels_list,
73
+ dropout_rate=best_hyperparams["dropout_rate"],
74
+ use_task_weights=config["use_task_weights"],
75
+ task_weights=normalized_task_weights
76
+ )
77
+ best_model.load_state_dict(torch.load(best_model_path))
78
+ best_model.to(device)
79
+
80
+ evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
81
+ print("Evaluation completed.")
geneformer/mtl/imports.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import DataLoader
8
+
9
+ from itertools import chain
10
+ import warnings
11
+ from enum import Enum
12
+ from typing import Dict, List, Optional, Union
13
+ import sys
14
+ import os
15
+ import json
16
+ import gc
17
+ import functools
18
+ import pandas as pd
19
+
20
+ from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, roc_curve
21
+ from sklearn.preprocessing import LabelEncoder
22
+ from sklearn.model_selection import train_test_split
23
+
24
+ import optuna
25
+
26
+ from transformers import (
27
+ BertConfig,
28
+ BertModel,
29
+ AdamW,
30
+ get_linear_schedule_with_warmup,
31
+ get_cosine_schedule_with_warmup,
32
+ DataCollatorForTokenClassification,
33
+ SpecialTokensMixin,
34
+ BatchEncoding,
35
+ get_scheduler,
36
+ )
37
+ from transformers.utils import logging, to_py_obj
38
+
39
+ from datasets import load_from_disk
40
+
41
+ # local modules
42
+ from .data import preload_and_process_data, get_data_loader
43
+ from .model import GeneformerMultiTask
44
+ from .utils import save_model
45
+ from .optuna_utils import create_optuna_study
46
+ from .collators import DataCollatorForMultitaskCellClassification
geneformer/mtl/model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel, BertConfig
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class AttentionPool(nn.Module):
6
+ """Attention-based pooling layer."""
7
+ def __init__(self, hidden_size):
8
+ super(AttentionPool, self).__init__()
9
+ self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
10
+ nn.init.xavier_uniform_(self.attention_weights) # https://pytorch.org/docs/stable/nn.init.html
11
+
12
+ def forward(self, hidden_states):
13
+ attention_scores = torch.matmul(hidden_states, self.attention_weights)
14
+ attention_scores = torch.softmax(attention_scores, dim=1)
15
+ pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
16
+ return pooled_output
17
+
18
+ class GeneformerMultiTask(nn.Module):
19
+ def __init__(self, pretrained_path, num_labels_list, dropout_rate=0.1, use_task_weights=False, task_weights=None, max_layers_to_freeze=0, use_attention_pooling=False):
20
+ super(GeneformerMultiTask, self).__init__()
21
+ self.config = BertConfig.from_pretrained(pretrained_path)
22
+ self.bert = BertModel(self.config)
23
+ self.num_labels_list = num_labels_list
24
+ self.use_task_weights = use_task_weights
25
+ self.dropout = nn.Dropout(dropout_rate)
26
+ self.use_attention_pooling = use_attention_pooling
27
+
28
+ if use_task_weights and (task_weights is None or len(task_weights) != len(num_labels_list)):
29
+ raise ValueError("Task weights must be defined and match the number of tasks when 'use_task_weights' is True.")
30
+ self.task_weights = task_weights if use_task_weights else [1.0] * len(num_labels_list)
31
+
32
+ # Freeze the specified initial layers
33
+ for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
34
+ for param in layer.parameters():
35
+ param.requires_grad = False
36
+
37
+ self.attention_pool = AttentionPool(self.config.hidden_size) if use_attention_pooling else None
38
+
39
+ self.classification_heads = nn.ModuleList([
40
+ nn.Linear(self.config.hidden_size, num_labels) for num_labels in num_labels_list
41
+ ])
42
+ # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
43
+ for head in self.classification_heads:
44
+ nn.init.xavier_uniform_(head.weight)
45
+ nn.init.zeros_(head.bias)
46
+
47
+ def forward(self, input_ids, attention_mask, labels=None):
48
+ try:
49
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
50
+ except Exception as e:
51
+ raise RuntimeError(f"Error during BERT forward pass: {e}")
52
+
53
+ sequence_output = outputs.last_hidden_state
54
+
55
+ try:
56
+ pooled_output = self.attention_pool(sequence_output) if self.use_attention_pooling else sequence_output[:, 0, :]
57
+ pooled_output = self.dropout(pooled_output)
58
+ except Exception as e:
59
+ raise RuntimeError(f"Error during pooling and dropout: {e}")
60
+
61
+ total_loss = 0
62
+ logits = []
63
+ losses = []
64
+
65
+ for task_id, (head, num_labels) in enumerate(zip(self.classification_heads, self.num_labels_list)):
66
+ try:
67
+ task_logits = head(pooled_output)
68
+ except Exception as e:
69
+ raise RuntimeError(f"Error during forward pass of classification head {task_id}: {e}")
70
+
71
+ logits.append(task_logits)
72
+
73
+ if labels is not None:
74
+ try:
75
+ loss_fct = nn.CrossEntropyLoss()
76
+ task_loss = loss_fct(task_logits.view(-1, num_labels), labels[task_id].view(-1))
77
+ if self.use_task_weights:
78
+ task_loss *= self.task_weights[task_id]
79
+ total_loss += task_loss
80
+ losses.append(task_loss.item())
81
+ except Exception as e:
82
+ raise RuntimeError(f"Error during loss computation for task {task_id}: {e}")
83
+
84
+ return total_loss, logits, losses if labels is not None else logits
geneformer/mtl/optuna_utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ from optuna.integration import TensorBoardCallback
3
+
4
+ def save_trial_callback(study, trial, trials_result_path):
5
+ with open(trials_result_path, "a") as f:
6
+ f.write(f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n")
7
+
8
+ def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
9
+ study = optuna.create_study(direction="maximize")
10
+
11
+ # init TensorBoard callback
12
+ tensorboard_callback = TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
13
+
14
+ # callback and TensorBoard callback
15
+ callbacks = [
16
+ lambda study, trial: save_trial_callback(study, trial, trials_result_path),
17
+ tensorboard_callback
18
+ ]
19
+
20
+ study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
21
+ return study
geneformer/mtl/train.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ from .data import preload_and_process_data, get_data_loader
3
+ from .model import GeneformerMultiTask
4
+ from .utils import calculate_task_specific_metrics
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import pandas as pd
7
+ import os
8
+ from tqdm import tqdm
9
+ import random
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ def set_seed(seed):
15
+ random.seed(seed)
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed_all(seed)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
+
22
+ def initialize_wandb(config):
23
+ if config.get("use_wandb", False):
24
+ import wandb
25
+ wandb.init(project=config["wandb_project"], config=config)
26
+ print("Weights & Biases (wandb) initialized and will be used for logging.")
27
+ else:
28
+ print("Weights & Biases (wandb) is not enabled. Logging will use other methods.")
29
+
30
+ def create_model(config, num_labels_list, device):
31
+ model = GeneformerMultiTask(
32
+ config["pretrained_path"],
33
+ num_labels_list,
34
+ dropout_rate=config["dropout_rate"],
35
+ use_task_weights=config["use_task_weights"],
36
+ task_weights=config["task_weights"],
37
+ max_layers_to_freeze=config["max_layers_to_freeze"],
38
+ use_attention_pooling=config["use_attention_pooling"]
39
+ )
40
+ if config["use_data_parallel"]:
41
+ model = nn.DataParallel(model)
42
+ return model.to(device)
43
+
44
+ def setup_optimizer_and_scheduler(model, config, total_steps):
45
+ optimizer = AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
46
+ warmup_steps = int(config["warmup_ratio"] * total_steps)
47
+
48
+ if config["lr_scheduler_type"] == "linear":
49
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
50
+ elif config["lr_scheduler_type"] == "cosine":
51
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, num_cycles=0.5)
52
+
53
+ return optimizer, scheduler
54
+
55
+ def train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch):
56
+ model.train()
57
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
58
+ for batch_idx, batch in enumerate(progress_bar):
59
+ optimizer.zero_grad()
60
+ input_ids = batch['input_ids'].to(device)
61
+ attention_mask = batch['attention_mask'].to(device)
62
+ labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]]
63
+
64
+ loss, _, _ = model(input_ids, attention_mask, labels)
65
+ loss.backward()
66
+
67
+ if config["gradient_clipping"]:
68
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
69
+
70
+ optimizer.step()
71
+ scheduler.step()
72
+
73
+ writer.add_scalar('Training Loss', loss.item(), epoch * len(train_loader) + batch_idx)
74
+ if config.get("use_wandb", False):
75
+ wandb.log({'Training Loss': loss.item()})
76
+
77
+ # Update progress bar
78
+ progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
79
+
80
+ return loss.item() # Return the last batch loss
81
+
82
+ def validate_model(model, val_loader, device, config):
83
+ model.eval()
84
+ val_loss = 0.0
85
+ task_true_labels = {task_name: [] for task_name in config["task_names"]}
86
+ task_pred_labels = {task_name: [] for task_name in config["task_names"]}
87
+ task_pred_probs = {task_name: [] for task_name in config["task_names"]}
88
+
89
+ with torch.no_grad():
90
+ for batch in val_loader:
91
+ input_ids = batch['input_ids'].to(device)
92
+ attention_mask = batch['attention_mask'].to(device)
93
+ labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]]
94
+ loss, logits, _ = model(input_ids, attention_mask, labels)
95
+ val_loss += loss.item()
96
+
97
+ for sample_idx in range(len(batch['input_ids'])):
98
+ for i, task_name in enumerate(config["task_names"]):
99
+ true_label = batch['labels'][task_name][sample_idx].item()
100
+ pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
101
+ pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
102
+ task_true_labels[task_name].append(true_label)
103
+ task_pred_labels[task_name].append(pred_label)
104
+ task_pred_probs[task_name].append(pred_prob)
105
+
106
+ val_loss /= len(val_loader)
107
+ return val_loss, task_true_labels, task_pred_labels, task_pred_probs
108
+
109
+ def log_metrics(task_metrics, val_loss, config, writer, epochs):
110
+ for task_name, metrics in task_metrics.items():
111
+ print(f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}")
112
+ if config.get("use_wandb", False):
113
+ import wandb
114
+ wandb.log({
115
+ f'{task_name} Validation F1 Macro': metrics['f1'],
116
+ f'{task_name} Validation Accuracy': metrics['accuracy']
117
+ })
118
+
119
+ writer.add_scalar('Validation Loss', val_loss, epochs)
120
+ for task_name, metrics in task_metrics.items():
121
+ writer.add_scalar(f'{task_name} - Validation F1 Macro', metrics['f1'], epochs)
122
+ writer.add_scalar(f'{task_name} - Validation Accuracy', metrics['accuracy'], epochs)
123
+
124
+ def save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial_number=None):
125
+ if trial_number is not None:
126
+ trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
127
+ os.makedirs(trial_results_dir, exist_ok=True)
128
+ val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
129
+ else:
130
+ val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
131
+
132
+ rows = []
133
+ for sample_idx in range(len(val_cell_id_mapping)):
134
+ row = {'Cell ID': val_cell_id_mapping[sample_idx]}
135
+ for task_name in config["task_names"]:
136
+ row[f'{task_name} True'] = task_true_labels[task_name][sample_idx]
137
+ row[f'{task_name} Pred'] = task_pred_labels[task_name][sample_idx]
138
+ row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx]))
139
+ rows.append(row)
140
+
141
+ df = pd.DataFrame(rows)
142
+ df.to_csv(val_preds_file, index=False)
143
+ print(f"Validation predictions saved to {val_preds_file}")
144
+
145
+
146
+ def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list):
147
+ set_seed(config["seed"])
148
+ initialize_wandb(config)
149
+
150
+ model = create_model(config, num_labels_list, device)
151
+ total_steps = len(train_loader) * config["epochs"]
152
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
153
+
154
+ log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
155
+ writer = SummaryWriter(log_dir=log_dir)
156
+
157
+ epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
158
+ for epoch in epoch_progress:
159
+ last_loss = train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch)
160
+ epoch_progress.set_postfix({'last_loss': f"{last_loss:.4f}"})
161
+
162
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config)
163
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
164
+
165
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
166
+ writer.close()
167
+
168
+ save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config)
169
+
170
+ if config.get("use_wandb", False):
171
+ import wandb
172
+ wandb.finish()
173
+
174
+ print(f"\nFinal Validation Loss: {val_loss:.4f}")
175
+ return val_loss, model # Return both the validation loss and the trained model
176
+
177
+ def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, config, device):
178
+ set_seed(config["seed"]) # Set the seed before each trial
179
+ initialize_wandb(config)
180
+
181
+ # Hyperparameters
182
+ config["learning_rate"] = trial.suggest_float("learning_rate", config["hyperparameters"]["learning_rate"]["low"], config["hyperparameters"]["learning_rate"]["high"], log=config["hyperparameters"]["learning_rate"]["log"])
183
+ config["warmup_ratio"] = trial.suggest_float("warmup_ratio", config["hyperparameters"]["warmup_ratio"]["low"], config["hyperparameters"]["warmup_ratio"]["high"])
184
+ config["weight_decay"] = trial.suggest_float("weight_decay", config["hyperparameters"]["weight_decay"]["low"], config["hyperparameters"]["weight_decay"]["high"])
185
+ config["dropout_rate"] = trial.suggest_float("dropout_rate", config["hyperparameters"]["dropout_rate"]["low"], config["hyperparameters"]["dropout_rate"]["high"])
186
+ config["lr_scheduler_type"] = trial.suggest_categorical("lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"])
187
+ config["use_attention_pooling"] = trial.suggest_categorical("use_attention_pooling", [True, False])
188
+
189
+ if config["use_task_weights"]:
190
+ config["task_weights"] = [trial.suggest_float(f"task_weight_{i}", config["hyperparameters"]["task_weights"]["low"], config["hyperparameters"]["task_weights"]["high"]) for i in range(len(num_labels_list))]
191
+ weight_sum = sum(config["task_weights"])
192
+ config["task_weights"] = [weight / weight_sum for weight in config["task_weights"]]
193
+ else:
194
+ config["task_weights"] = None
195
+
196
+ # Fix for max_layers_to_freeze
197
+ if isinstance(config["max_layers_to_freeze"], dict):
198
+ config["max_layers_to_freeze"] = trial.suggest_int("max_layers_to_freeze", config["max_layers_to_freeze"]["min"], config["max_layers_to_freeze"]["max"])
199
+ elif isinstance(config["max_layers_to_freeze"], int):
200
+ # If it's already an int, we don't need to suggest it
201
+ pass
202
+ else:
203
+ raise ValueError("Invalid type for max_layers_to_freeze. Expected dict or int.")
204
+
205
+ model = create_model(config, num_labels_list, device)
206
+ total_steps = len(train_loader) * config["epochs"]
207
+ optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
208
+
209
+ log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
210
+ writer = SummaryWriter(log_dir=log_dir)
211
+
212
+ for epoch in range(config["epochs"]):
213
+ train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch)
214
+
215
+ val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config)
216
+ task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
217
+
218
+ log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
219
+ writer.close()
220
+
221
+ save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial.number)
222
+
223
+ trial.set_user_attr("model_state_dict", model.state_dict())
224
+ trial.set_user_attr("task_weights", config["task_weights"])
225
+
226
+ trial.report(val_loss, config["epochs"])
227
+
228
+ if trial.should_prune():
229
+ raise optuna.TrialPruned()
230
+
231
+ if config.get("use_wandb", False):
232
+ import wandb
233
+ wandb.log({
234
+ "trial_number": trial.number,
235
+ "val_loss": val_loss,
236
+ **{f"{task_name}_f1": metrics['f1'] for task_name, metrics in task_metrics.items()},
237
+ **{f"{task_name}_accuracy": metrics['accuracy'] for task_name, metrics in task_metrics.items()},
238
+ **{k: v for k, v in config.items() if k in ["learning_rate", "warmup_ratio", "weight_decay", "dropout_rate", "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"]}
239
+ })
240
+ wandb.finish()
241
+
242
+ return val_loss
geneformer/mtl/train_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ from .data import preload_and_process_data, get_data_loader
3
+ from .train import objective, train_model
4
+ from .model import GeneformerMultiTask
5
+ from .utils import save_model
6
+ import random
7
+
8
+ def set_seed(seed):
9
+ random.seed(seed)
10
+ np.random.seed(seed)
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed_all(seed)
13
+ torch.backends.cudnn.deterministic = True
14
+ torch.backends.cudnn.benchmark = False
15
+
16
+ def run_manual_tuning(config):
17
+ # Set seed for reproducibility
18
+ set_seed(config["seed"])
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config)
22
+ train_loader = get_data_loader(train_dataset, config['batch_size'])
23
+ val_loader = get_data_loader(val_dataset, config['batch_size'])
24
+
25
+ # Print the manual hyperparameters being used
26
+ print("\nManual hyperparameters being used:")
27
+ for key, value in config["manual_hyperparameters"].items():
28
+ print(f"{key}: {value}")
29
+ print() # Add an empty line for better readability
30
+
31
+ # Use the manual hyperparameters
32
+ for key, value in config["manual_hyperparameters"].items():
33
+ config[key] = value
34
+
35
+ # Train the model
36
+ val_loss, trained_model = train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
37
+
38
+ print(f"\nValidation loss with manual hyperparameters: {val_loss}")
39
+
40
+ # Save the trained model
41
+ model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
42
+ save_model(trained_model, model_save_directory)
43
+
44
+ # Save the hyperparameters
45
+ hyperparams_to_save = {
46
+ **config["manual_hyperparameters"],
47
+ "dropout_rate": config["dropout_rate"],
48
+ "use_task_weights": config["use_task_weights"],
49
+ "task_weights": config["task_weights"],
50
+ "max_layers_to_freeze": config["max_layers_to_freeze"],
51
+ "use_attention_pooling": config["use_attention_pooling"]
52
+ }
53
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
54
+ with open(hyperparams_path, 'w') as f:
55
+ json.dump(hyperparams_to_save, f)
56
+ print(f"Manual hyperparameters saved to {hyperparams_path}")
57
+
58
+ return val_loss
59
+
60
+ def run_optuna_study(config):
61
+ # Set seed for reproducibility
62
+ set_seed(config["seed"])
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config)
66
+ train_loader = get_data_loader(train_dataset, config['batch_size'])
67
+ val_loader = get_data_loader(val_dataset, config['batch_size'])
68
+
69
+ if config["use_manual_hyperparameters"]:
70
+ train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list)
71
+ else:
72
+ objective_with_config_and_data = functools.partial(
73
+ objective,
74
+ train_loader=train_loader,
75
+ val_loader=val_loader,
76
+ train_cell_id_mapping=train_cell_id_mapping,
77
+ val_cell_id_mapping=val_cell_id_mapping,
78
+ num_labels_list=num_labels_list,
79
+ config=config,
80
+ device=device
81
+ )
82
+
83
+ study = optuna.create_study(
84
+ direction='minimize', # Minimize validation loss
85
+ study_name=config["study_name"],
86
+ #storage=config["storage"],
87
+ load_if_exists=True
88
+ )
89
+
90
+ study.optimize(
91
+ objective_with_config_and_data,
92
+ n_trials=config["n_trials"]
93
+ )
94
+
95
+ # After finding the best trial
96
+ best_params = study.best_trial.params
97
+ best_task_weights = study.best_trial.user_attrs["task_weights"]
98
+ print("Saving the best model and its hyperparameters...")
99
+
100
+ # Saving model as before
101
+ best_model = GeneformerMultiTask(
102
+ config["pretrained_path"],
103
+ num_labels_list,
104
+ dropout_rate=best_params["dropout_rate"],
105
+ use_task_weights=config["use_task_weights"],
106
+ task_weights=best_task_weights
107
+ )
108
+
109
+ # Get the best model state dictionary
110
+ best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
111
+
112
+ # Remove the "module." prefix from the state dictionary keys if present
113
+ best_model_state_dict = {k.replace("module.", ""): v for k, v in best_model_state_dict.items()}
114
+
115
+ # Load the modified state dictionary into the model, skipping unexpected keys
116
+ best_model.load_state_dict(best_model_state_dict, strict=False)
117
+
118
+ model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
119
+ save_model(best_model, model_save_directory)
120
+
121
+ # Additionally, save the best hyperparameters and task weights
122
+ hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
123
+
124
+ with open(hyperparams_path, 'w') as f:
125
+ json.dump({**best_params, "task_weights": best_task_weights}, f)
126
+ print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
geneformer/mtl/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .imports import *
2
+ from sklearn.metrics import f1_score, accuracy_score
3
+ from sklearn.preprocessing import LabelEncoder
4
+ from transformers import BertModel, BertConfig, AutoConfig
5
+ import os
6
+ import shutil
7
+
8
+ def save_model(model, model_save_directory):
9
+ if not os.path.exists(model_save_directory):
10
+ os.makedirs(model_save_directory)
11
+
12
+ # Get the state dict
13
+ if isinstance(model, nn.DataParallel):
14
+ model_state_dict = model.module.state_dict() # Use model.module to access the underlying model
15
+ else:
16
+ model_state_dict = model.state_dict()
17
+
18
+ # Remove the "module." prefix from the keys if present
19
+ model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
20
+
21
+ model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
22
+ torch.save(model_state_dict, model_save_path)
23
+
24
+ # Save the model configuration
25
+ if isinstance(model, nn.DataParallel):
26
+ model.module.config.to_json_file(os.path.join(model_save_directory, "config.json"))
27
+ else:
28
+ model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
29
+
30
+ print(f"Model and configuration saved to {model_save_directory}")
31
+
32
+ def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
33
+ task_metrics = {}
34
+ for task_name in task_true_labels.keys():
35
+ true_labels = task_true_labels[task_name]
36
+ pred_labels = task_pred_labels[task_name]
37
+ f1 = f1_score(true_labels, pred_labels, average='macro')
38
+ accuracy = accuracy_score(true_labels, pred_labels)
39
+ task_metrics[task_name] = {'f1': f1, 'accuracy': accuracy}
40
+ return task_metrics
41
+
42
+ def calculate_combined_f1(combined_labels, combined_preds):
43
+ # Initialize the LabelEncoder
44
+ le = LabelEncoder()
45
+
46
+ # Fit and transform combined labels and predictions to numerical values
47
+ le.fit(combined_labels + combined_preds)
48
+ encoded_true_labels = le.transform(combined_labels)
49
+ encoded_pred_labels = le.transform(combined_preds)
50
+
51
+ # Print out the mapping for sanity check
52
+ print("\nLabel Encoder Mapping:")
53
+ for index, class_label in enumerate(le.classes_):
54
+ print(f"'{class_label}': {index}")
55
+
56
+ # Calculate accuracy
57
+ accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
58
+
59
+ # Calculate F1 Macro score
60
+ f1 = f1_score(encoded_true_labels, encoded_pred_labels, average='macro')
61
+
62
+ return f1, accuracy
63
+
64
+ def save_model_without_heads(original_model_save_directory):
65
+ # Create a new directory for the model without heads
66
+ new_model_save_directory = original_model_save_directory + "_No_Heads"
67
+ if not os.path.exists(new_model_save_directory):
68
+ os.makedirs(new_model_save_directory)
69
+
70
+ # Load the model state dictionary
71
+ model_state_dict = torch.load(os.path.join(original_model_save_directory, "pytorch_model.bin"))
72
+
73
+ # Initialize a new BERT model without the classification heads
74
+ config = BertConfig.from_pretrained(os.path.join(original_model_save_directory, "config.json"))
75
+ model_without_heads = BertModel(config)
76
+
77
+ # Filter the state dict to exclude classification heads
78
+ model_without_heads_state_dict = {k: v for k, v in model_state_dict.items() if not k.startswith("classification_heads")}
79
+
80
+ # Load the filtered state dict into the model
81
+ model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
82
+
83
+ # Save the model without heads
84
+ model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
85
+ torch.save(model_without_heads.state_dict(), model_save_path)
86
+
87
+ # Copy the configuration file
88
+ shutil.copy(os.path.join(original_model_save_directory, "config.json"), new_model_save_directory)
89
+
90
+ print(f"Model without classification heads saved to {new_model_save_directory}")
91
+
92
+
93
+ def get_layer_freeze_range(pretrained_path):
94
+ """
95
+ Dynamically determines the number of layers to freeze based on the model depth from its configuration.
96
+ Args:
97
+ pretrained_path (str): Path to the pretrained model directory or model identifier.
98
+ Returns:
99
+ dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
100
+ """
101
+ if pretrained_path:
102
+ config = AutoConfig.from_pretrained(pretrained_path)
103
+ total_layers = config.num_hidden_layers
104
+ return {"min": 0, "max": total_layers - 1}
105
+ else:
106
+ return {"min": 0, "max": 0}
geneformer/mtl_classifier.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer multi-task cell classifier.
3
+
4
+ **Input data:**
5
+
6
+ | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain "unique_cell_id" column for logging.
7
+
8
+ **Usage:**
9
+
10
+ .. code-block :: python
11
+
12
+ >>> from geneformer import MTLClassifier
13
+ >>> mc = MTLClassifier(task_columns = ["task1", "task2"],
14
+ ... study_name = "mtl",
15
+ ... pretrained_path = "/path/pretrained/model",
16
+ ... train_path = "/path/train/set",
17
+ ... val_path = "/path/eval/set",
18
+ ... test_path = "/path/test/set",
19
+ ... model_save_path = "/results/directory/save_path",
20
+ ... trials_result_path = "/results/directory/results.txt",
21
+ ... results_dir = "/results/directory",
22
+ ... tensorboard_log_dir = "/results/tblogdir",
23
+ ... hyperparameters = hyperparameters)
24
+ >>> mc.run_optuna_study()
25
+ >>> mc.load_and_evaluate_test_model()
26
+ >>> mc.save_model_without_heads()
27
+ """
28
+
29
+ import logging
30
+ import os
31
+ from .mtl import train_utils
32
+ from .mtl import utils
33
+ from .mtl import eval_utils
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class MTLClassifier:
39
+ valid_option_dict = {
40
+ "task_columns": {list},
41
+ "train_path": {None, str},
42
+ "val_path": {None, str},
43
+ "test_path": {None, str},
44
+ "pretrained_path": {None, str},
45
+ "model_save_path": {None, str},
46
+ "results_dir": {None, str},
47
+ "batch_size": {None, int},
48
+ "n_trials": {None, int},
49
+ "study_name": {None, str},
50
+ "max_layers_to_freeze": {None, dict},
51
+ "epochs": {None, int},
52
+ "tensorboard_log_dir": {None, str},
53
+ "use_data_parallel": {None, bool},
54
+ "use_attention_pooling": {None, bool},
55
+ "use_task_weights": {None, bool},
56
+ "hyperparameters": {None, dict},
57
+ "manual_hyperparameters": {None, dict},
58
+ "use_manual_hyperparameters": {None, bool},
59
+ "use_wandb": {None, bool},
60
+ "wandb_project": {None, str},
61
+ "gradient_clipping": {None, bool},
62
+ "max_grad_norm": {None, int, float},
63
+ "seed": {None, int},
64
+ "trials_result_path": {None, str},
65
+ }
66
+
67
+ def __init__(
68
+ self,
69
+ task_columns=None,
70
+ train_path=None,
71
+ val_path=None,
72
+ test_path=None,
73
+ pretrained_path=None,
74
+ model_save_path=None,
75
+ results_dir=None,
76
+ trials_result_path=None,
77
+ batch_size=4,
78
+ n_trials=15,
79
+ study_name="mtl",
80
+ max_layers_to_freeze=None,
81
+ epochs=1,
82
+ tensorboard_log_dir="/results/tblogdir",
83
+ use_data_parallel=False,
84
+ use_attention_pooling=True,
85
+ use_task_weights=True,
86
+ hyperparameters=None, # Default is None
87
+ manual_hyperparameters=None, # Default is None
88
+ use_manual_hyperparameters=False, # Default is False
89
+ use_wandb=False,
90
+ wandb_project=None,
91
+ gradient_clipping=False,
92
+ max_grad_norm=None,
93
+ seed=42 # Default seed value
94
+ ):
95
+
96
+ """
97
+ Initialize Geneformer multi-task classifier.
98
+ **Parameters:**
99
+ task_columns : list
100
+ | List of tasks for cell state classification
101
+ | Input data columns are labeled with corresponding task names
102
+ study_name : None, str
103
+ | Study name for labeling output files
104
+ pretrained_path : None, str
105
+ | Path to pretrained model
106
+ train_path : None, str
107
+ | Path to training dataset with task columns and "unique_cell_id" column
108
+ val_path : None, str
109
+ | Path to validation dataset with task columns and "unique_cell_id" column
110
+ test_path : None, str
111
+ | Path to test dataset with task columns and "unique_cell_id" column
112
+ model_save_path : None, str
113
+ | Path to directory to save output model (either full model or model without heads)
114
+ trials_result_path : None, str
115
+ | Path to directory to save hyperparameter tuning trial results
116
+ results_dir : None, str
117
+ | Path to directory to save results
118
+ tensorboard_log_dir : None, str
119
+ | Path to directory for Tensorboard logging results
120
+ use_data_parallel : None, bool
121
+ | Whether to use data parallelization
122
+ use_attention_pooling : None, bool
123
+ | Whether to use attention pooling
124
+ use_task_weights : None, bool
125
+ | Whether to use task weights
126
+ batch_size : None, int
127
+ | Batch size to use
128
+ n_trials : None, int
129
+ | Number of trials for hyperparameter tuning
130
+ epochs : None, int
131
+ | Number of epochs for training
132
+ max_layers_to_freeze : None, dict
133
+ | Dictionary with keys "min" and "max" indicating the min and max layers to freeze from fine-tuning (int)
134
+ | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
135
+ hyperparameters : None, dict
136
+ | Dictionary of categorical max and min for each hyperparameter for tuning
137
+ | For example:
138
+ | {"learning_rate": {"type":"float", "low":"1e-5", "high":"1e-3", "log":True}, "task_weights": {...}, ...}
139
+ manual_hyperparameters : None, dict
140
+ | Dictionary of manually set value for each hyperparameter
141
+ | For example:
142
+ | {"learning_rate": 0.001, "task_weights": [1, 1], ...}
143
+ use_manual_hyperparameters : None, bool
144
+ | Whether to use manually set hyperparameters
145
+ use_wandb : None, bool
146
+ | Whether to use Weights & Biases for logging
147
+ wandb_project : None, str
148
+ | Weights & Biases project name
149
+ gradient_clipping : None, bool
150
+ | Whether to use gradient clipping
151
+ max_grad_norm : None, int, float
152
+ | Maximum norm for gradient clipping
153
+ seed : None, int
154
+ | Random seed
155
+ """
156
+
157
+ self.task_columns = task_columns
158
+ self.train_path = train_path
159
+ self.val_path = val_path
160
+ self.test_path = test_path
161
+ self.pretrained_path = pretrained_path
162
+ self.model_save_path = model_save_path
163
+ self.results_dir = results_dir
164
+ self.trials_result_path = trials_result_path
165
+ self.batch_size = batch_size
166
+ self.n_trials = n_trials
167
+ self.study_name = study_name
168
+
169
+ if max_layers_to_freeze is None:
170
+ # Dynamically determine the range of layers to freeze
171
+ layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
172
+ self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range['max']}
173
+ else:
174
+ self.max_layers_to_freeze = max_layers_to_freeze
175
+
176
+ self.epochs = epochs
177
+ self.tensorboard_log_dir = tensorboard_log_dir
178
+ self.use_data_parallel = use_data_parallel
179
+ self.use_attention_pooling = use_attention_pooling
180
+ self.use_task_weights = use_task_weights
181
+ self.hyperparameters = hyperparameters if hyperparameters is not None else {
182
+ "learning_rate": {
183
+ "type": "float",
184
+ "low": 1e-5,
185
+ "high": 1e-3,
186
+ "log": True
187
+ },
188
+ "warmup_ratio": {
189
+ "type": "float",
190
+ "low": 0.005,
191
+ "high": 0.01
192
+ },
193
+ "weight_decay": {
194
+ "type": "float",
195
+ "low": 0.01,
196
+ "high": 0.1
197
+ },
198
+ "dropout_rate": {
199
+ "type": "float",
200
+ "low": 0.0,
201
+ "high": 0.7
202
+ },
203
+ "lr_scheduler_type": {
204
+ "type": "categorical",
205
+ "choices": ["cosine"]
206
+ },
207
+ "task_weights": {
208
+ "type": "float",
209
+ "low": 0.1,
210
+ "high": 2.0
211
+ }
212
+ }
213
+ self.manual_hyperparameters = manual_hyperparameters if manual_hyperparameters is not None else {
214
+ "learning_rate": 0.001,
215
+ "warmup_ratio": 0.01,
216
+ "weight_decay": 0.1,
217
+ "dropout_rate": 0.1,
218
+ "lr_scheduler_type": "cosine",
219
+ "use_attention_pooling": False,
220
+ "task_weights": [1, 1],
221
+ "max_layers_to_freeze": 2
222
+ }
223
+ self.use_manual_hyperparameters = use_manual_hyperparameters
224
+ self.use_wandb = use_wandb
225
+ self.wandb_project = wandb_project
226
+ self.gradient_clipping = gradient_clipping
227
+ self.max_grad_norm = max_grad_norm
228
+ self.seed = seed
229
+
230
+ if self.use_manual_hyperparameters:
231
+ logger.warning(
232
+ "Hyperparameter tuning is highly recommended for optimal results."
233
+ )
234
+
235
+ self.validate_options()
236
+
237
+ # set up output directories
238
+ if self.results_dir is not None:
239
+ self.trials_results_path = f"{self.results_dir}/results.txt".replace("//","/")
240
+
241
+ for output_dir in [self.model_save_path, self.results_dir]:
242
+ if not os.path.exists(output_dir):
243
+ os.makedirs(output_dir)
244
+
245
+ self.config = {key: value for key, value in self.__dict__.items() if key in self.valid_option_dict}
246
+
247
+ def validate_options(self):
248
+ # confirm arguments are within valid options and compatible with each other
249
+ for attr_name, valid_options in self.valid_option_dict.items():
250
+ attr_value = self.__dict__[attr_name]
251
+ if not isinstance(attr_value, (list, dict)):
252
+ if attr_value in valid_options:
253
+ continue
254
+ valid_type = False
255
+ for option in valid_options:
256
+ if (option in [int, float, list, dict, bool, str]) and isinstance(
257
+ attr_value, option
258
+ ):
259
+ valid_type = True
260
+ break
261
+ if valid_type:
262
+ continue
263
+ logger.error(
264
+ f"Invalid option for {attr_name}. "
265
+ f"Valid options for {attr_name}: {valid_options}"
266
+ )
267
+ raise ValueError(f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}")
268
+
269
+ def run_manual_tuning(self):
270
+ """
271
+ Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
272
+ """
273
+ required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"]
274
+ required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir]
275
+ req_var_dict = dict(zip(required_variable_names, required_variables))
276
+ self.validate_additional_options(req_var_dict)
277
+
278
+ if not self.use_manual_hyperparameters:
279
+ raise ValueError("Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True.")
280
+
281
+ # Ensure manual_hyperparameters are set in the config
282
+ self.config["manual_hyperparameters"] = self.manual_hyperparameters
283
+ self.config["use_manual_hyperparameters"] = True
284
+
285
+ train_utils.run_manual_tuning(self.config)
286
+
287
+ def validate_additional_options(self, req_var_dict):
288
+ missing_variable = False
289
+ for variable_name, variable in req_var_dict.items():
290
+ if variable is None:
291
+ logger.warning(
292
+ f"Please provide value to MTLClassifier for required variable {variable_name}"
293
+ )
294
+ missing_variable = True
295
+ if missing_variable is True:
296
+ raise ValueError("Missing required variables for MTLClassifier")
297
+
298
+ def run_optuna_study(
299
+ self,
300
+ ):
301
+ """
302
+ Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
303
+ """
304
+
305
+ required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"]
306
+ required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir]
307
+ req_var_dict = dict(zip(required_variable_names, required_variables))
308
+ self.validate_additional_options(req_var_dict)
309
+
310
+ train_utils.run_optuna_study(self.config)
311
+
312
+ def load_and_evaluate_test_model(
313
+ self,
314
+ ):
315
+ """
316
+ Loads previously fine-tuned multi-task model and evaluates on test data.
317
+ """
318
+
319
+ required_variable_names = ["test_path", "model_save_path", "results_dir"]
320
+ required_variables = [self.test_path, self.model_save_path, self.results_dir]
321
+ req_var_dict = dict(zip(required_variable_names, required_variables))
322
+ self.validate_additional_options(req_var_dict)
323
+
324
+ eval_utils.load_and_evaluate_test_model(self.config)
325
+
326
+ def save_model_without_heads(
327
+ self,
328
+ ):
329
+ """
330
+ Save previously fine-tuned multi-task model without classification heads.
331
+ """
332
+
333
+ required_variable_names = ["model_save_path"]
334
+ required_variables = [self.model_save_path]
335
+ req_var_dict = dict(zip(required_variable_names, required_variables))
336
+ self.validate_additional_options(req_var_dict)
337
+
338
+ utils.save_model_without_heads(os.path.join(self.model_save_path, "GeneformerMultiTask"))
geneformer/perturber_utils.py CHANGED
@@ -12,13 +12,17 @@ import pandas as pd
12
  import seaborn as sns
13
  import torch
14
  from datasets import Dataset, load_from_disk
 
15
  from transformers import (
16
  BertForMaskedLM,
17
  BertForSequenceClassification,
18
  BertForTokenClassification,
 
19
  )
20
 
21
- from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
 
 
22
 
23
 
24
  logger = logging.getLogger(__name__)
@@ -111,17 +115,49 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
111
 
112
 
113
  # load model to GPU
114
- def load_model(model_type, num_classes, model_directory, mode):
 
 
 
 
115
  if mode == "eval":
116
  output_hidden_states = True
117
  elif mode == "train":
118
  output_hidden_states = False
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if model_type == "Pretrained":
121
  model = BertForMaskedLM.from_pretrained(
122
  model_directory,
123
  output_hidden_states=output_hidden_states,
124
  output_attentions=False,
 
125
  )
126
  elif model_type == "GeneClassifier":
127
  model = BertForTokenClassification.from_pretrained(
@@ -129,6 +165,7 @@ def load_model(model_type, num_classes, model_directory, mode):
129
  num_labels=num_classes,
130
  output_hidden_states=output_hidden_states,
131
  output_attentions=False,
 
132
  )
133
  elif model_type == "CellClassifier":
134
  model = BertForSequenceClassification.from_pretrained(
@@ -136,11 +173,24 @@ def load_model(model_type, num_classes, model_directory, mode):
136
  num_labels=num_classes,
137
  output_hidden_states=output_hidden_states,
138
  output_attentions=False,
 
 
 
 
 
 
 
 
 
139
  )
140
  # if eval mode, put the model in eval mode for fwd pass
141
  if mode == "eval":
142
  model.eval()
143
- model = model.to("cuda")
 
 
 
 
144
  return model
145
 
146
 
@@ -222,27 +272,47 @@ def overexpress_indices(example):
222
  indices = example["perturb_index"]
223
  if any(isinstance(el, list) for el in indices):
224
  indices = flatten_list(indices)
225
- for index in sorted(indices, reverse=True):
226
- example["input_ids"].insert(0, example["input_ids"].pop(index))
227
-
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
  if len(example["input_ids"]) > max_len:
244
- example["input_ids"] = example["input_ids"][0:max_len]
245
-
 
 
246
  example["length"] = len(example["input_ids"])
247
  return example
248
 
@@ -259,6 +329,13 @@ def truncate_by_n_overflow(example):
259
  example["length"] = len(example["input_ids"])
260
  return example
261
 
 
 
 
 
 
 
 
262
 
263
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
264
  # indices_to_remove is list of indices to remove
@@ -392,7 +469,81 @@ def make_perturbation_batch(
392
  return perturbation_dataset, indices_to_perturb
393
 
394
 
395
- # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  # so that only non-perturbed gene embeddings are compared to each other
397
  # in original or perturbed context
398
  def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
@@ -589,9 +740,10 @@ def quant_cos_sims(
589
  cos = torch.nn.CosineSimilarity(dim=1)
590
 
591
  # if emb_mode == "gene", can only calculate gene cos sims
592
- # against original cell anyways
593
  if cell_states_to_model is None or emb_mode == "gene":
594
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
 
595
  elif cell_states_to_model is not None and emb_mode == "cell":
596
  possible_states = get_possible_states(cell_states_to_model)
597
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
@@ -758,4 +910,4 @@ class GeneIdHandler:
758
  return self.ens_to_symbol(self.token_to_ens(token))
759
 
760
  def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
12
  import seaborn as sns
13
  import torch
14
  from datasets import Dataset, load_from_disk
15
+ from peft import LoraConfig, get_peft_model
16
  from transformers import (
17
  BertForMaskedLM,
18
  BertForSequenceClassification,
19
  BertForTokenClassification,
20
+ BitsAndBytesConfig,
21
  )
22
 
23
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
24
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
25
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
26
 
27
 
28
  logger = logging.getLogger(__name__)
 
115
 
116
 
117
  # load model to GPU
118
+ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
119
+ if model_type == "MTLCellClassifier-Quantized":
120
+ model_type = "MTLCellClassifier"
121
+ quantize = True
122
+
123
  if mode == "eval":
124
  output_hidden_states = True
125
  elif mode == "train":
126
  output_hidden_states = False
127
 
128
+ if quantize is True:
129
+ if model_type == "MTLCellClassifier":
130
+ quantize = {
131
+ "peft_config": None,
132
+ "bnb_config": BitsAndBytesConfig(
133
+ load_in_8bit=True,
134
+ )
135
+ }
136
+ else:
137
+ quantize = {
138
+ "peft_config": LoraConfig(
139
+ lora_alpha=128,
140
+ lora_dropout=0.1,
141
+ r=64,
142
+ bias="none",
143
+ task_type="TokenClassification",
144
+ ),
145
+ "bnb_config": BitsAndBytesConfig(
146
+ load_in_4bit=True,
147
+ bnb_4bit_use_double_quant=True,
148
+ bnb_4bit_quant_type="nf4",
149
+ bnb_4bit_compute_dtype=torch.bfloat16
150
+ )
151
+ }
152
+ elif quantize is False:
153
+ quantize = {"bnb_config": None}
154
+
155
  if model_type == "Pretrained":
156
  model = BertForMaskedLM.from_pretrained(
157
  model_directory,
158
  output_hidden_states=output_hidden_states,
159
  output_attentions=False,
160
+ quantization_config=quantize["bnb_config"],
161
  )
162
  elif model_type == "GeneClassifier":
163
  model = BertForTokenClassification.from_pretrained(
 
165
  num_labels=num_classes,
166
  output_hidden_states=output_hidden_states,
167
  output_attentions=False,
168
+ quantization_config=quantize["bnb_config"],
169
  )
170
  elif model_type == "CellClassifier":
171
  model = BertForSequenceClassification.from_pretrained(
 
173
  num_labels=num_classes,
174
  output_hidden_states=output_hidden_states,
175
  output_attentions=False,
176
+ quantization_config=quantize["bnb_config"],
177
+ )
178
+ elif model_type == "MTLCellClassifier":
179
+ model = BertForMaskedLM.from_pretrained(
180
+ model_directory,
181
+ num_labels=num_classes,
182
+ output_hidden_states=output_hidden_states,
183
+ output_attentions=False,
184
+ quantization_config=quantize["bnb_config"],
185
  )
186
  # if eval mode, put the model in eval mode for fwd pass
187
  if mode == "eval":
188
  model.eval()
189
+ if (quantize is False) or (quantize == {'bnb_config': None}) or (model_type == "MTLCellClassifier"):
190
+ model = model.to("cuda")
191
+ else:
192
+ model.enable_input_require_grads()
193
+ model = get_peft_model(model, quantize["peft_config"])
194
  return model
195
 
196
 
 
272
  indices = example["perturb_index"]
273
  if any(isinstance(el, list) for el in indices):
274
  indices = flatten_list(indices)
275
+ insert_pos = 0
276
+ for index in sorted(indices, reverse=False):
277
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
278
+ insert_pos += 1
279
  example["length"] = len(example["input_ids"])
280
  return example
281
 
282
+ # if CLS token present, move to 1st rather than 0th position
283
+ def overexpress_indices_special(example):
284
+ indices = example["perturb_index"]
285
+ if any(isinstance(el, list) for el in indices):
286
+ indices = flatten_list(indices)
287
+ insert_pos = 1 # Insert starting after CLS token
288
+ for index in sorted(indices, reverse=False):
289
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
290
+ insert_pos += 1
291
+ example["length"] = len(example["input_ids"])
292
+ return example
293
 
294
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
295
+ def overexpress_tokens(example, max_len, special_token):
296
  # -100 indicates tokens to overexpress are not present in rank value encoding
297
  if example["perturb_index"] != [-100]:
298
  example = delete_indices(example)
299
+ if special_token:
300
+ [
301
+ example["input_ids"].insert(1, token)
302
+ for token in example["tokens_to_perturb"][::-1]
303
+ ]
304
+ else:
305
+ [
306
+ example["input_ids"].insert(0, token)
307
+ for token in example["tokens_to_perturb"][::-1]
308
+ ]
309
 
310
  # truncate to max input size, must also truncate original emb to be comparable
311
  if len(example["input_ids"]) > max_len:
312
+ if special_token:
313
+ example["input_ids"] = example["input_ids"][0:max_len-1]+[example["input_ids"][-1]]
314
+ else:
315
+ example["input_ids"] = example["input_ids"][0:max_len]
316
  example["length"] = len(example["input_ids"])
317
  return example
318
 
 
329
  example["length"] = len(example["input_ids"])
330
  return example
331
 
332
+ def truncate_by_n_overflow_special(example):
333
+ if example["n_overflow"] > 0:
334
+ new_max_len = example["length"] - example["n_overflow"]
335
+ example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
336
+ example["length"] = len(example["input_ids"])
337
+ return example
338
+
339
 
340
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
341
  # indices_to_remove is list of indices to remove
 
469
  return perturbation_dataset, indices_to_perturb
470
 
471
 
472
+ def make_perturbation_batch_special(
473
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
474
+ ) -> tuple[Dataset, List[int]]:
475
+ if combo_lvl == 0 and tokens_to_perturb == "all":
476
+ if perturb_type in ["overexpress", "activate"]:
477
+ range_start = 1
478
+ elif perturb_type in ["delete", "inhibit"]:
479
+ range_start = 0
480
+ range_start += 1 # Starting after the CLS token
481
+ indices_to_perturb = [
482
+ [i] for i in range(range_start, example_cell["length"][0]-1) # And excluding the EOS token
483
+ ]
484
+
485
+ # elif combo_lvl > 0 and anchor_token is None:
486
+ ## to implement
487
+ elif combo_lvl > 0 and (anchor_token is not None):
488
+ example_input_ids = example_cell["input_ids"][0]
489
+ anchor_index = example_input_ids.index(anchor_token[0])
490
+ indices_to_perturb = [
491
+ sorted([anchor_index, i]) if i != anchor_index else None
492
+ for i in range(1, example_cell["length"][0]-1) # Exclude CLS and EOS tokens
493
+ ]
494
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
495
+ else:
496
+ example_input_ids = example_cell["input_ids"][0]
497
+ indices_to_perturb = [
498
+ [example_input_ids.index(token)] if token in example_input_ids else None
499
+ for token in tokens_to_perturb
500
+ ]
501
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
502
+
503
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
504
+ if combo_lvl > 0 and (anchor_token is None):
505
+ if tokens_to_perturb != "all":
506
+ if len(tokens_to_perturb) == combo_lvl + 1:
507
+ indices_to_perturb = [
508
+ list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
509
+ ]
510
+ else:
511
+ all_indices = [[i] for i in range(1, example_cell["length"][0]-1)] # Exclude CLS and EOS tokens
512
+ all_indices = [
513
+ index for index in all_indices if index not in indices_to_perturb
514
+ ]
515
+ indices_to_perturb = [
516
+ [[j for i in indices_to_perturb for j in i], x] for x in all_indices
517
+ ]
518
+
519
+ length = len(indices_to_perturb)
520
+ perturbation_dataset = Dataset.from_dict(
521
+ {
522
+ "input_ids": example_cell["input_ids"] * length,
523
+ "perturb_index": indices_to_perturb,
524
+ }
525
+ )
526
+
527
+ if length < 400:
528
+ num_proc_i = 1
529
+ else:
530
+ num_proc_i = num_proc
531
+
532
+ if perturb_type == "delete":
533
+ perturbation_dataset = perturbation_dataset.map(
534
+ delete_indices, num_proc=num_proc_i
535
+ )
536
+ elif perturb_type == "overexpress":
537
+ perturbation_dataset = perturbation_dataset.map(
538
+ overexpress_indices_special, num_proc=num_proc_i
539
+ )
540
+
541
+ perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
542
+
543
+ return perturbation_dataset, indices_to_perturb
544
+
545
+
546
+ # original cell emb removing the activated/overexpressed/inhibited gene emb
547
  # so that only non-perturbed gene embeddings are compared to each other
548
  # in original or perturbed context
549
  def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
 
740
  cos = torch.nn.CosineSimilarity(dim=1)
741
 
742
  # if emb_mode == "gene", can only calculate gene cos sims
743
+ # against original cell
744
  if cell_states_to_model is None or emb_mode == "gene":
745
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
746
+
747
  elif cell_states_to_model is not None and emb_mode == "cell":
748
  possible_states = get_possible_states(cell_states_to_model)
749
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
 
910
  return self.ens_to_symbol(self.token_to_ens(token))
911
 
912
  def symbol_to_token(self, symbol):
913
+ return self.ens_to_token(self.symbol_to_ens(symbol))
geneformer/pretrainer.py CHANGED
@@ -32,8 +32,6 @@ from transformers.training_args import ParallelMode
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
35
- from . import TOKEN_DICTIONARY_FILE
36
-
37
  logger = logging.get_logger(__name__)
38
  EncodedInput = List[int]
39
  VERY_LARGE_INTEGER = int(
@@ -52,9 +50,6 @@ _is_torch_generator_available = False
52
  if version.parse(torch.__version__) >= version.parse("1.6"):
53
  _is_torch_generator_available = True
54
 
55
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
56
- token_dictionary = pickle.load(f)
57
-
58
 
59
  class ExplicitEnum(Enum):
60
  """
@@ -109,15 +104,7 @@ class GeneformerPreCollator(SpecialTokensMixin):
109
  super().__init__(mask_token="<mask>", pad_token="<pad>")
110
 
111
  self.token_dictionary = kwargs.get("token_dictionary")
112
- # self.mask_token = "<mask>"
113
- # self.mask_token_id = self.token_dictionary.get("<mask>")
114
- # self.pad_token = "<pad>"
115
- # self.pad_token_id = self.token_dictionary.get("<pad>")
116
  self.padding_side = "right"
117
- # self.all_special_ids = [
118
- # self.token_dictionary.get("<mask>"),
119
- # self.token_dictionary.get("<pad>"),
120
- # ]
121
  self.model_input_names = ["input_ids"]
122
 
123
  def convert_ids_to_tokens(self, value):
 
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
 
 
35
  logger = logging.get_logger(__name__)
36
  EncodedInput = List[int]
37
  VERY_LARGE_INTEGER = int(
 
50
  if version.parse(torch.__version__) >= version.parse("1.6"):
51
  _is_torch_generator_available = True
52
 
 
 
 
53
 
54
  class ExplicitEnum(Enum):
55
  """
 
104
  super().__init__(mask_token="<mask>", pad_token="<pad>")
105
 
106
  self.token_dictionary = kwargs.get("token_dictionary")
 
 
 
 
107
  self.padding_side = "right"
 
 
 
 
108
  self.model_input_names = ["input_ids"]
109
 
110
  def convert_ids_to_tokens(self, value):
geneformer/token_dictionary.pkl DELETED
Binary file (788 kB)
 
geneformer/token_dictionary_gc95M.pkl CHANGED
Binary files a/geneformer/token_dictionary_gc95M.pkl and b/geneformer/token_dictionary_gc95M.pkl differ
 
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
+ }
{geneformer-12L-30M → gf-12L-30M-i2048}/config.json RENAMED
File without changes
{geneformer-12L-30M → gf-12L-30M-i2048}/pytorch_model.bin RENAMED
File without changes
{geneformer-12L-30M → gf-12L-30M-i2048}/training_args.bin RENAMED
File without changes
gf-12L-95M-i4096/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.02,
6
+ "classifier_dropout": null,
7
+ "hidden_act": "relu",
8
+ "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 512,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 1024,
12
+ "layer_norm_eps": 1e-12,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "bert",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 12,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "absolute",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.1",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 20275
24
+ }
gf-12L-95M-i4096/generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
+ }
gf-12L-95M-i4096/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c
3
+ size 152012980
gf-12L-95M-i4096/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d
3
+ size 4920
gf-12L-95M-i4096_CLcancer/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models",
3
+ "architectures": [
4
+ "BertForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.02,
7
+ "classifier_dropout": null,
8
+ "hidden_act": "relu",
9
+ "hidden_dropout_prob": 0.02,
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1024,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 4096,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 8,
17
+ "num_hidden_layers": 12,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.37.1",
22
+ "type_vocab_size": 2,
23
+ "use_cache": true,
24
+ "vocab_size": 20275
25
+ }
gf-12L-95M-i4096_CLcancer/generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "pad_token_id": 0,
4
+ "transformers_version": "4.37.1"
5
+ }
gf-12L-95M-i4096_CLcancer/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2
3
+ size 152012980
gf-12L-95M-i4096_CLcancer/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1
3
+ size 5048