import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import SGD, lr_scheduler torch.backends.cudnn.benchmark = False # You can set it to True if you experience performance gains torch.backends.cudnn.deterministic = False from src.loss_functions.losses import AsymmetricLoss, ASLSingleLabel device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") import torch.nn.functional as F class MyCNN(nn.Module): def __init__(self, num_classes=12, dropout_prob=0.2, in_channels=3): super(MyCNN, self).__init__() self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, padding=1) self.global_avg_pooling = nn.AdaptiveAvgPool2d(1) self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.fc1 = nn.Linear(64 *3* 3, 1024) self.fc2 = nn.Linear(1024, 256) self.fc3 = nn.Linear(256, num_classes) # Dropout layers self.dropout1 = nn.Dropout(p=dropout_prob) self.dropout2 = nn.Dropout(p=dropout_prob) def forward(self, x_input): # Apply convolutional and pooling layers # x= self.upsample(x_input) x = F.leaky_relu(self.conv1(x_input)) x = F.max_pool2d(x, 2) x = F.leaky_relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = F.leaky_relu(self.conv3(x)) x = F.max_pool2d(x, 2) # Flatten the output for the fully connected layers x = x.view(x.size(0), -1) x = self.dropout1(x) x = F.leaky_relu(self.fc1(x)) x = self.dropout2(x) x = F.leaky_relu(self.fc2(x)) # Apply fully connected layers x = self.dropout2(x) x = self.fc3(x) return x # Rest of the code remains unchanged # Initialize the model cell_attribute_model = MyCNN(num_classes=12, dropout_prob=0.5, in_channels=256).to(device) cell_attribute_model.train() # Set the model in training mode # Initialize optimizer, criterion, and scheduler optimizer_cell_model = torch.optim.SGD(cell_attribute_model.parameters(), lr=0.01, weight_decay=0.01) step_size = 5 gamma = 0.1 scheduler_cell_model = lr_scheduler.StepLR(optimizer_cell_model, step_size=step_size, gamma=gamma) #criterion = nn.CrossEntropyLoss() criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=1, clip=0.08, disable_torch_grad_focal_loss=True) # criterion = ASLSingleLabel() # /num_classes = 2 #criterion = nn.BCEWithLogitsLoss() # Binary Cross-Entropy Loss def cell_training(cell_attribute_model_main,cell_datas, labels): obj_batch_size = len(cell_datas) # Set the model in training mode #optimizer_cell_model.zero_grad() # Filter out instances with label=2 and their corresponding cell_datas # Filter out rows where any element in the row (excluding the first column) is equal to 2 valid_indices = [i for i, row in enumerate(labels[:,1:]) if not torch.any(row[1:] == 2).item()] if not valid_indices: # print("No valid instances, skipping training.") object_batch_loss = torch.tensor(0.0, requires_grad=True, device=device) # Initialize as a torch.Tensor return object_batch_loss filtered_cell_datas = [cell_datas[i] for i in valid_indices] filtered_labels = labels[:,1:][valid_indices] # Assuming each element in filtered_cell_datas is a tensor of shape (in_channels, height, width) cell_images = torch.stack(filtered_cell_datas).to(device) cell_datas_batch = cell_images.squeeze(1) filtered_labels = filtered_labels.to(device) # Initialize the model with the dynamically determined in_channels # in_channels = filtered_cell_datas[0].size(1) # Assuming the first element in filtered_cell_datas defines in_channels # cell_attribute_model_main.conv1.in_channels = in_channels # Forward pass outputs_my = cell_attribute_model_main(cell_datas_batch.float()) outputs_my = outputs_my.view(len(valid_indices), -1) # Process labels to create target_tensor # label_att = filtered_labels[:, 5].float() # Assuming label[5] contains 0 or 1 # target_tensor = label_att.view(-1, 1) # Compute the loss num_classes = 2 one_hot_encoded_tensors = [] # Perform one-hot encoding for each column separately for i in range(filtered_labels.size(1)): # Extract the current column column_values = filtered_labels[:, i].long() # Generate one-hot encoded tensor for the current column one_hot_encoded_col = torch.eye(num_classes, device=filtered_labels.device)[column_values] # Reshape to match the original shape one_hot_encoded_col = one_hot_encoded_col.unsqueeze(1) one_hot_encoded_tensors.append(one_hot_encoded_col) # Concatenate the one-hot encoded tensors along the second dimension (axis=1) one_hot_encoded_result = torch.cat(one_hot_encoded_tensors, dim=1) outputs_my = outputs_my.view(outputs_my.size(0), 6,2) object_batch_loss = criterion(outputs_my, one_hot_encoded_result) # Check if the loss contains NaN if torch.isnan(object_batch_loss): object_batch_loss= 0 # If NaN, trigger a breakpoint to inspect variables breakpoint() #torch.use_deterministic_algorithms(False, warn_only=True) # Backward pass and optimization object_batch_loss = object_batch_loss/len(filtered_labels) # object_batch_loss.backward(retain_graph=True) # optimizer_cell_model.step() #scheduler_cell_model.step() # Explicitly release tensors #del cell_images, target_tensor #torch.cuda.empty_cache() return object_batch_loss