Simplified the model by always computing batch-first
Browse files- modeling_norbert.py +42 -55
modeling_norbert.py
CHANGED
@@ -101,23 +101,6 @@ class FeedForward(nn.Module):
|
|
101 |
return self.mlp(x)
|
102 |
|
103 |
|
104 |
-
class MaskedSoftmax(torch.autograd.Function):
|
105 |
-
@staticmethod
|
106 |
-
def forward(self, x, mask, dim):
|
107 |
-
self.dim = dim
|
108 |
-
x.masked_fill_(mask, float('-inf'))
|
109 |
-
x = torch.softmax(x, self.dim)
|
110 |
-
x.masked_fill_(mask, 0.0)
|
111 |
-
self.save_for_backward(x)
|
112 |
-
return x
|
113 |
-
|
114 |
-
@staticmethod
|
115 |
-
def backward(self, grad_output):
|
116 |
-
output, = self.saved_tensors
|
117 |
-
input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
118 |
-
return input_grad, None, None
|
119 |
-
|
120 |
-
|
121 |
class Attention(nn.Module):
|
122 |
def __init__(self, config):
|
123 |
super().__init__()
|
@@ -142,7 +125,7 @@ class Attention(nn.Module):
|
|
142 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
143 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
144 |
position_indices = config.position_bucket_size - 1 + position_indices
|
145 |
-
self.register_buffer("position_indices", position_indices, persistent=False)
|
146 |
|
147 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
148 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
@@ -155,10 +138,11 @@ class Attention(nn.Module):
|
|
155 |
bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
|
156 |
return bucket_pos
|
157 |
|
158 |
-
def
|
159 |
-
|
160 |
query_len = key_len
|
161 |
|
|
|
162 |
if self.position_indices.size(0) < query_len:
|
163 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
164 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
@@ -166,48 +150,52 @@ class Attention(nn.Module):
|
|
166 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
167 |
self.position_indices = position_indices.to(hidden_states.device)
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
query, key = self.in_proj_qk(hidden_states).chunk(2, dim
|
172 |
-
value = self.in_proj_v(hidden_states) # shape: [
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
177 |
|
178 |
-
|
|
|
|
|
|
|
|
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
key = key.view(batch_size, self.num_heads, query_len, self.head_size)
|
184 |
|
185 |
-
|
186 |
-
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
191 |
|
192 |
-
|
193 |
-
attention_scores
|
194 |
-
attention_scores.add_(attention_p_c)
|
195 |
|
196 |
-
|
|
|
|
|
197 |
|
198 |
-
|
199 |
-
attention_probs = self.dropout(attention_probs)
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
return context
|
206 |
|
207 |
-
|
208 |
-
attention_scores, value = self.compute_attention_scores(hidden_states, relative_embedding)
|
209 |
-
attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
|
210 |
-
return self.compute_output(attention_probs, value), attention_probs.detach()
|
211 |
|
212 |
|
213 |
class Embedding(nn.Module):
|
@@ -290,9 +278,8 @@ class NorbertModel(NorbertPreTrainedModel):
|
|
290 |
attention_mask = ~attention_mask.bool()
|
291 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
292 |
|
293 |
-
static_embeddings, relative_embedding = self.embedding(input_ids
|
294 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|
295 |
-
contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
|
296 |
last_layer = contextualized_embeddings[-1]
|
297 |
contextualized_embeddings = [contextualized_embeddings[0]] + [
|
298 |
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|
|
|
101 |
return self.mlp(x)
|
102 |
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
class Attention(nn.Module):
|
105 |
def __init__(self, config):
|
106 |
super().__init__()
|
|
|
125 |
- torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
|
126 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
|
127 |
position_indices = config.position_bucket_size - 1 + position_indices
|
128 |
+
self.register_buffer("position_indices", position_indices.contiguous(), persistent=False)
|
129 |
|
130 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
131 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
|
|
138 |
bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
|
139 |
return bucket_pos
|
140 |
|
141 |
+
def forward(self, hidden_states, attention_mask, relative_embedding):
|
142 |
+
batch_size, key_len, _ = hidden_states.size()
|
143 |
query_len = key_len
|
144 |
|
145 |
+
# Recompute position_indices if sequence length exceeds the precomputed size
|
146 |
if self.position_indices.size(0) < query_len:
|
147 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
148 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
|
|
150 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
151 |
self.position_indices = position_indices.to(hidden_states.device)
|
152 |
|
153 |
+
# Pre-LN and project query/key/value.
|
154 |
+
hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]
|
155 |
+
query, key = self.in_proj_qk(hidden_states).chunk(2, dim=-1) # shape: [B, T, D]
|
156 |
+
value = self.in_proj_v(hidden_states) # shape: [B, T, D]
|
157 |
|
158 |
+
# Reshape to [B, num_heads, T, head_size]
|
159 |
+
query = query.view(batch_size, query_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_q, head_size]
|
160 |
+
key = key.view(batch_size, key_len, self.num_heads, self.head_size).permute(0, 2, 3, 1) # shape: [B, num_heads, head_size, T_k]
|
161 |
+
value = value.view(batch_size, key_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
|
162 |
|
163 |
+
# Compute relative positional contributions
|
164 |
+
pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2*position_bucket_size - 1, 2D]
|
165 |
+
query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2) # shape: [2*position_bucket_size - 1, num_heads, head_size]
|
166 |
+
query_pos = query_pos.transpose(0, 1) # shape: [num_heads, 2*position_bucket_size - 1, head_size]
|
167 |
+
key_pos = key_pos.permute(1, 2, 0) # shape: [num_heads, head_size, 2*position_bucket_size - 1]
|
168 |
|
169 |
+
# Scale the keys
|
170 |
+
key = key * self.scale
|
171 |
+
key_pos = key_pos * self.scale
|
|
|
172 |
|
173 |
+
# Compute standard content-to-content attention scores
|
174 |
+
attention_c_to_c = torch.matmul(query, key) # shape: [B, num_heads, T_q, T_k]
|
175 |
|
176 |
+
# Compute content-to-position and position-to-content attention scores
|
177 |
+
position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1) # shape: [B, num_heads, T_q, T_k]
|
178 |
+
attention_c_to_p = torch.matmul(query, key_pos.unsqueeze(0)) # shape: [B, num_heads, T_q, 2*position_bucket_size - 1]
|
179 |
+
attention_p_to_c = torch.matmul(query_pos.unsqueeze(0), key) # shape: [B, num_heads, 2*position_bucket_size - 1, T_k]
|
180 |
+
attention_c_to_p = attention_c_to_p.gather(3, position_indices) # shape: [B, num_heads, T_q, T_k]
|
181 |
+
attention_p_to_c = attention_p_to_c.gather(2, position_indices) # shape: [B, num_heads, T_q, T_k]
|
182 |
|
183 |
+
# Full attention score
|
184 |
+
attention_scores = attention_c_to_c + attention_c_to_p + attention_p_to_c # shape: [B, num_heads, T_q, T_k]
|
|
|
185 |
|
186 |
+
# Masked softmax
|
187 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float('-inf')) # shape: [B, num_heads, T_q, T_k]
|
188 |
+
attention_probs = F.softmax(attention_scores, dim=-1) # shape: [B, num_heads, T_q, T_k]
|
189 |
|
190 |
+
# Collect the weighted-averaged values
|
191 |
+
attention_probs = self.dropout(attention_probs) # shape: [B, num_heads, T_q, T_k]
|
192 |
+
output = torch.matmul(attention_probs, value) # shape: [B, num_heads, T_q, head_size]
|
193 |
+
output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T_q, D]
|
194 |
+
output = self.out_proj(output)
|
195 |
+
output = self.post_layer_norm(output)
|
196 |
+
output = self.dropout(output)
|
|
|
197 |
|
198 |
+
return output, attention_probs.detach()
|
|
|
|
|
|
|
199 |
|
200 |
|
201 |
class Embedding(nn.Module):
|
|
|
278 |
attention_mask = ~attention_mask.bool()
|
279 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
280 |
|
281 |
+
static_embeddings, relative_embedding = self.embedding(input_ids)
|
282 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|
|
|
283 |
last_layer = contextualized_embeddings[-1]
|
284 |
contextualized_embeddings = [contextualized_embeddings[0]] + [
|
285 |
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|