davda54 commited on
Commit
ff22f29
·
verified ·
1 Parent(s): e6ab847

Simplified the model by always computing batch-first

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +42 -55
modeling_norbert.py CHANGED
@@ -101,23 +101,6 @@ class FeedForward(nn.Module):
101
  return self.mlp(x)
102
 
103
 
104
- class MaskedSoftmax(torch.autograd.Function):
105
- @staticmethod
106
- def forward(self, x, mask, dim):
107
- self.dim = dim
108
- x.masked_fill_(mask, float('-inf'))
109
- x = torch.softmax(x, self.dim)
110
- x.masked_fill_(mask, 0.0)
111
- self.save_for_backward(x)
112
- return x
113
-
114
- @staticmethod
115
- def backward(self, grad_output):
116
- output, = self.saved_tensors
117
- input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
118
- return input_grad, None, None
119
-
120
-
121
  class Attention(nn.Module):
122
  def __init__(self, config):
123
  super().__init__()
@@ -142,7 +125,7 @@ class Attention(nn.Module):
142
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
143
  position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
144
  position_indices = config.position_bucket_size - 1 + position_indices
145
- self.register_buffer("position_indices", position_indices, persistent=False)
146
 
147
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
148
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
@@ -155,10 +138,11 @@ class Attention(nn.Module):
155
  bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
156
  return bucket_pos
157
 
158
- def compute_attention_scores(self, hidden_states, relative_embedding):
159
- key_len, batch_size, _ = hidden_states.size()
160
  query_len = key_len
161
 
 
162
  if self.position_indices.size(0) < query_len:
163
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
164
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
@@ -166,48 +150,52 @@ class Attention(nn.Module):
166
  position_indices = self.config.position_bucket_size - 1 + position_indices
167
  self.position_indices = position_indices.to(hidden_states.device)
168
 
169
- hidden_states = self.pre_layer_norm(hidden_states)
170
-
171
- query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
172
- value = self.in_proj_v(hidden_states) # shape: [T, B, D]
173
 
174
- query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
175
- key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
176
- value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
 
177
 
178
- attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
 
 
 
 
179
 
180
- pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
181
- query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2)
182
- query = query.view(batch_size, self.num_heads, query_len, self.head_size)
183
- key = key.view(batch_size, self.num_heads, query_len, self.head_size)
184
 
185
- attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale)
186
- attention_p_c = torch.einsum("bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1))
187
 
188
- position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
189
- attention_c_p = attention_c_p.gather(3, position_indices)
190
- attention_p_c = attention_p_c.gather(2, position_indices)
 
 
 
191
 
192
- attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
193
- attention_scores.add_(attention_c_p)
194
- attention_scores.add_(attention_p_c)
195
 
196
- return attention_scores, value
 
 
197
 
198
- def compute_output(self, attention_probs, value):
199
- attention_probs = self.dropout(attention_probs)
200
- context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
201
- context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
202
- context = self.out_proj(context)
203
- context = self.post_layer_norm(context)
204
- context = self.dropout(context)
205
- return context
206
 
207
- def forward(self, hidden_states, attention_mask, relative_embedding):
208
- attention_scores, value = self.compute_attention_scores(hidden_states, relative_embedding)
209
- attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
210
- return self.compute_output(attention_probs, value), attention_probs.detach()
211
 
212
 
213
  class Embedding(nn.Module):
@@ -290,9 +278,8 @@ class NorbertModel(NorbertPreTrainedModel):
290
  attention_mask = ~attention_mask.bool()
291
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
292
 
293
- static_embeddings, relative_embedding = self.embedding(input_ids.t())
294
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
295
- contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
296
  last_layer = contextualized_embeddings[-1]
297
  contextualized_embeddings = [contextualized_embeddings[0]] + [
298
  contextualized_embeddings[i] - contextualized_embeddings[i - 1]
 
101
  return self.mlp(x)
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  class Attention(nn.Module):
105
  def __init__(self, config):
106
  super().__init__()
 
125
  - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
126
  position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
127
  position_indices = config.position_bucket_size - 1 + position_indices
128
+ self.register_buffer("position_indices", position_indices.contiguous(), persistent=False)
129
 
130
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
131
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
 
138
  bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
139
  return bucket_pos
140
 
141
+ def forward(self, hidden_states, attention_mask, relative_embedding):
142
+ batch_size, key_len, _ = hidden_states.size()
143
  query_len = key_len
144
 
145
+ # Recompute position_indices if sequence length exceeds the precomputed size
146
  if self.position_indices.size(0) < query_len:
147
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
148
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
 
150
  position_indices = self.config.position_bucket_size - 1 + position_indices
151
  self.position_indices = position_indices.to(hidden_states.device)
152
 
153
+ # Pre-LN and project query/key/value.
154
+ hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]
155
+ query, key = self.in_proj_qk(hidden_states).chunk(2, dim=-1) # shape: [B, T, D]
156
+ value = self.in_proj_v(hidden_states) # shape: [B, T, D]
157
 
158
+ # Reshape to [B, num_heads, T, head_size]
159
+ query = query.view(batch_size, query_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_q, head_size]
160
+ key = key.view(batch_size, key_len, self.num_heads, self.head_size).permute(0, 2, 3, 1) # shape: [B, num_heads, head_size, T_k]
161
+ value = value.view(batch_size, key_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
162
 
163
+ # Compute relative positional contributions
164
+ pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2*position_bucket_size - 1, 2D]
165
+ query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2) # shape: [2*position_bucket_size - 1, num_heads, head_size]
166
+ query_pos = query_pos.transpose(0, 1) # shape: [num_heads, 2*position_bucket_size - 1, head_size]
167
+ key_pos = key_pos.permute(1, 2, 0) # shape: [num_heads, head_size, 2*position_bucket_size - 1]
168
 
169
+ # Scale the keys
170
+ key = key * self.scale
171
+ key_pos = key_pos * self.scale
 
172
 
173
+ # Compute standard content-to-content attention scores
174
+ attention_c_to_c = torch.matmul(query, key) # shape: [B, num_heads, T_q, T_k]
175
 
176
+ # Compute content-to-position and position-to-content attention scores
177
+ position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1) # shape: [B, num_heads, T_q, T_k]
178
+ attention_c_to_p = torch.matmul(query, key_pos.unsqueeze(0)) # shape: [B, num_heads, T_q, 2*position_bucket_size - 1]
179
+ attention_p_to_c = torch.matmul(query_pos.unsqueeze(0), key) # shape: [B, num_heads, 2*position_bucket_size - 1, T_k]
180
+ attention_c_to_p = attention_c_to_p.gather(3, position_indices) # shape: [B, num_heads, T_q, T_k]
181
+ attention_p_to_c = attention_p_to_c.gather(2, position_indices) # shape: [B, num_heads, T_q, T_k]
182
 
183
+ # Full attention score
184
+ attention_scores = attention_c_to_c + attention_c_to_p + attention_p_to_c # shape: [B, num_heads, T_q, T_k]
 
185
 
186
+ # Masked softmax
187
+ attention_scores = attention_scores.masked_fill(attention_mask, float('-inf')) # shape: [B, num_heads, T_q, T_k]
188
+ attention_probs = F.softmax(attention_scores, dim=-1) # shape: [B, num_heads, T_q, T_k]
189
 
190
+ # Collect the weighted-averaged values
191
+ attention_probs = self.dropout(attention_probs) # shape: [B, num_heads, T_q, T_k]
192
+ output = torch.matmul(attention_probs, value) # shape: [B, num_heads, T_q, head_size]
193
+ output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T_q, D]
194
+ output = self.out_proj(output)
195
+ output = self.post_layer_norm(output)
196
+ output = self.dropout(output)
 
197
 
198
+ return output, attention_probs.detach()
 
 
 
199
 
200
 
201
  class Embedding(nn.Module):
 
278
  attention_mask = ~attention_mask.bool()
279
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
280
 
281
+ static_embeddings, relative_embedding = self.embedding(input_ids)
282
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
 
283
  last_layer = contextualized_embeddings[-1]
284
  contextualized_embeddings = [contextualized_embeddings[0]] + [
285
  contextualized_embeddings[i] - contextualized_embeddings[i - 1]