I have been working on an implementation of the T5 architecture in PyTorch. I am having some issues properly implementing the Cross Attention Layers and Decoder.
If anyone who is familiar with the architecture could provide any advice it would be greatly appreciated.
I am sometimes receiving this error as well:
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
Thank you!
Code for T5 in PyTorch:
import torch
from torch import nn
import torch.nn.functional as F
import math
from einops import rearrange
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# residual wrapper
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# pre-normalization wrapper
# they use layernorm without bias
class T5LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = T5LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feedforward layer
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.ReLU(),
nn.Dropout(dropout), # optional dropout
nn.Linear(inner_dim, dim)
)
def forward(self, x):
return self.net(x)
# T5 relative positional bias
class T5RelativePositionBias(nn.Module):
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 12):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
#staticmethod
def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, qk_dots):
i, j, device = *qk_dots.shape[-2:], qk_dots.device
q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(
rel_pos,
causal = self.causal,
num_buckets = self.num_buckets,
max_distance = self.max_distance
)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> h i j')
return qk_dots + (bias * self.scale)
# T5 Self Attention
class T5SelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads = 12,
dim_head = 64,
causal = False,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.relative_position_bias = T5RelativePositionBias(
scale = dim_head ** -0.5,
causal = causal,
heads = heads
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
sim = self.relative_position_bias(sim)
# mask
mask_value = -torch.finfo(sim.dtype).max
if mask is not None:
sim = sim.masked_fill_(~mask, mask_value)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)
# attention
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# combine heads and linear output
return self.to_out(out)
# T5 Cross Attention
class T5CrossAttention(nn.Module):
def __init__(
self,
*,
dim,
context_dim = None,
heads = 12,
dim_head = 64,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(context_dim, inner_dim, bias = False)
self.to_v = nn.Linear(context_dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.relative_position_bias = T5RelativePositionBias(
scale = dim_head ** -0.5,
causal = False,
heads = heads
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context, mask = None, context_mask = None):
b, n, _, h = *x.shape, self.heads
kv_input = default(context, x)
q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
sim = self.relative_position_bias(sim)
# mask
mask_value = -torch.finfo(sim.dtype).max
if mask is not None:
sim = sim.masked_fill_(~mask, mask_value)
if context_mask is not None:
sim = sim.masked_fill_(~context_mask[:, None, :], mask_value)
# attention
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# combine heads and linear output
return self.to_out(out)
# T5 Encoder
class T5Encoder(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
#max_seq_len,
depth,
heads = 12,
dim_head = 64,
causal = False,
mlp_mult = 4,
dropout = 0.
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
#self.pos_emb = nn.Embedding(max_seq_len, dim)
self.layer = nn.ModuleList([])
for _ in range(depth):
self.layer.append(nn.ModuleList([
Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))),
]))
self.final_norm = T5LayerNorm(dim)
def forward(self, x, mask = None):
x = self.token_emb(x)
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
for attn, mlp in self.layer:
x = attn(x, mask = mask)
x = mlp(x)
x = self.final_norm(x)
return x
# T5 Decoder
class T5Decoder(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
#max_seq_len,
depth,
heads = 12,
dim_head = 64,
causal = True,
mlp_mult = 4,
dropout = 0.
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
#self.pos_emb = nn.Embedding(max_seq_len, dim)
self.layer = nn.ModuleList([])
for _ in range(depth):
self.layer.append(nn.ModuleList([
Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))),
Residual(PreNorm(dim, T5CrossAttention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))),
]))
self.final_norm = T5LayerNorm(dim)
def forward(self, x, context, mask = None, context_mask = None):
x = self.token_emb(x)
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
for attn, cross_attn, mlp in self.layer:
x = attn(x, mask = mask)
x = cross_attn(x, context = context, mask = mask, context_mask = context_mask)
x = mlp(x)
x = self.final_norm(x)
return x
# T5
class T5(nn.Module):
def __init__(
self,
*,
dim,
#max_seq_len,
enc_num_tokens,
enc_depth,
enc_heads,
enc_dim_head,
enc_mlp_mult,
dec_num_tokens,
dec_depth,
dec_heads,
dec_dim_head,
dec_mlp_mult,
dropout = 0.,
tie_token_emb = True
):
super().__init__()
self.embedding = nn.Embedding(enc_num_tokens, dim)
#self.pos_emb = nn.Embedding(max_seq_len, dim)
self.encoder = T5Encoder(
dim = dim,
#max_seq_len = max_seq_len,
num_tokens = enc_num_tokens,
depth = enc_depth,
heads = enc_heads,
dim_head = enc_dim_head,
mlp_mult = enc_mlp_mult,
dropout = dropout
)
self.decoder = T5Decoder(
dim = dim,
#max_seq_len= max_seq_len,
num_tokens = dec_num_tokens,
depth = dec_depth,
heads = dec_heads,
dim_head = dec_dim_head,
mlp_mult = dec_mlp_mult,
dropout = dropout
)
self.to_logits = nn.Linear(dim, dec_num_tokens)
# tie weights
if tie_token_emb:
self.encoder.token_emb.weight = self.decoder.token_emb.weight
def forward(self, src, tgt, mask = None, context_mask = None):
x = self.embedding(src)
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
x = self.encoder(src, mask = mask)
x = self.decoder(tgt, x, mask = mask, context_mask = context_mask)
x = self.to_logits(x)
return x
if __name__ == '__main__':
from opendelta import Visualization
model = T5(
dim = 768,
#max_seq_len = 1024,
enc_num_tokens = 512,
enc_depth = 6,
enc_heads = 12,
enc_dim_head = 64,
enc_mlp_mult = 4,
dec_num_tokens = 512,
dec_depth = 6,
dec_heads = 12,
dec_dim_head = 64,
dec_mlp_mult = 4,
dropout = 0.,
tie_token_emb = True
)
src = torch.randint(0, 512, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 512, (1, 1024))
loss = model(src, tgt, mask = src_mask)
Visualization(model).structure_graph()
print(loss.shape) #torch.Size([1, 1024, 512])
Working implementation of T5 in pytorch:
import torch
from torch import nn
import torch.nn.functional as F
import math
from einops import rearrange
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# residual wrapper
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
# pre-normalization wrapper
# they use layernorm without bias
class T5LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = T5LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feedforward layer
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.ReLU(),
nn.Dropout(dropout), # optional dropout
nn.Linear(inner_dim, dim)
)
def forward(self, x):
return self.net(x)
# T5 relative positional bias
class T5RelativePositionBias(nn.Module):
def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 12):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
#staticmethod
def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, qk_dots):
i, j, device = *qk_dots.shape[-2:], qk_dots.device
q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(
rel_pos,
causal = self.causal,
num_buckets = self.num_buckets,
max_distance = self.max_distance
)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> h i j')
return qk_dots + (bias * self.scale)
# T5 Self Attention
class T5SelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads = 12,
dim_head = 64,
causal = False,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.causal = causal
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.relative_position_bias = T5RelativePositionBias(
scale = dim_head ** -0.5,
causal = causal,
heads = heads
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
sim = self.relative_position_bias(sim)
# mask
mask_value = -torch.finfo(sim.dtype).max
if mask is not None:
sim = sim.masked_fill_(~mask, mask_value)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)
# attention
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# combine heads and linear output
return self.to_out(out)
# T5 Cross Attention
class T5CrossAttention(nn.Module):
def __init__(
self,
*,
dim,
context_dim = None,
heads = 12,
dim_head = 64,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(context_dim, inner_dim, bias = False)
self.to_v = nn.Linear(context_dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
# self.relative_position_bias = T5RelativePositionBias(
# scale = dim_head ** -0.5,
# causal = False,
# heads = heads
# )
self.dropout = nn.Dropout(dropout)
def forward(self, x, context, mask = None, context_mask = None):
b, n, _, h = *x.shape, self.heads
kv_input = default(context, x)
q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
#sim = self.relative_position_bias(sim)
# mask
mask_value = -torch.finfo(sim.dtype).max
if mask is not None:
sim = sim.masked_fill_(~mask, mask_value)
if context_mask is not None:
sim = sim.masked_fill_(~context_mask[:, None, :], mask_value)
# attention
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
# merge heads
out = rearrange(out, 'b h n d -> b n (h d)')
# combine heads and linear output
return self.to_out(out)
# T5 Encoder
class T5Encoder(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
#max_seq_len,
depth,
heads = 12,
dim_head = 64,
causal = False,
mlp_mult = 4,
dropout = 0.
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
#self.pos_emb = nn.Embedding(max_seq_len, dim)
self.layer = nn.ModuleList([])
for _ in range(depth):
self.layer.append(nn.ModuleList([
Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))),
]))
self.final_norm = T5LayerNorm(dim)
def forward(self, x, mask = None):
x = self.token_emb(x)
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
for attn, mlp in self.layer:
x = attn(x, mask = mask)
x = mlp(x)
x = self.final_norm(x)
return x
# T5 Decoder
class T5Decoder(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
#max_seq_len,
depth,
heads = 12,
dim_head = 64,
causal = True,
mlp_mult = 4,
dropout = 0.
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
#self.pos_emb = nn.Embedding(max_seq_len, dim)
self.layer = nn.ModuleList([])
for _ in range(depth):
self.layer.append(nn.ModuleList([
Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))),
Residual(PreNorm(dim, T5CrossAttention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))),
]))
self.final_norm = T5LayerNorm(dim)
def forward(self, x, context, mask = None, context_mask = None):
x = self.token_emb(x)
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
for attn, cross_attn, mlp in self.layer:
x = attn(x, mask = mask)
x = cross_attn(x, context = context, mask = mask, context_mask = context_mask)
x = mlp(x)
x = self.final_norm(x)
return x
# T5
class T5(nn.Module):
def __init__(
self,
*,
dim,
#max_seq_len,
enc_num_tokens,
enc_depth,
enc_heads,
enc_dim_head,
enc_mlp_mult,
dec_num_tokens,
dec_depth,
dec_heads,
dec_dim_head,
dec_mlp_mult,
dropout = 0.,
tie_token_emb = True
):
super().__init__()
self.embedding = nn.Embedding(enc_num_tokens, dim)
#self.pos_emb = nn.Embedding(max_seq_len, dim)
self.encoder = T5Encoder(
dim = dim,
#max_seq_len = max_seq_len,
num_tokens = enc_num_tokens,
depth = enc_depth,
heads = enc_heads,
dim_head = enc_dim_head,
mlp_mult = enc_mlp_mult,
dropout = dropout
)
self.decoder = T5Decoder(
dim = dim,
#max_seq_len= max_seq_len,
num_tokens = dec_num_tokens,
depth = dec_depth,
heads = dec_heads,
dim_head = dec_dim_head,
mlp_mult = dec_mlp_mult,
dropout = dropout
)
self.to_logits = nn.Linear(dim, dec_num_tokens)
# tie weights
if tie_token_emb:
self.encoder.token_emb.weight = self.decoder.token_emb.weight
def forward(self, src, tgt, mask = None, context_mask = None):
x = self.embedding(src)
#x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device))
x = self.encoder(src, mask = mask)
x = self.decoder(tgt, x, mask = mask, context_mask = context_mask)
x = self.to_logits(x)
return x
if __name__ == '__main__':
model = T5(
dim = 768,
#max_seq_len = 1024,
enc_num_tokens = 512,
enc_depth = 6,
enc_heads = 12,
enc_dim_head = 64,
enc_mlp_mult = 4,
dec_num_tokens = 512,
dec_depth = 6,
dec_heads = 12,
dec_dim_head = 64,
dec_mlp_mult = 4,
dropout = 0.,
tie_token_emb = True
)
src = torch.randint(0, 512, (1, 1024))
src_mask = torch.ones_like(src).bool()
tgt = torch.randint(0, 512, (1, 1024))
loss = model(src, tgt, mask = src_mask)
print(loss.shape) #torch.Size([1, 1024, 512])
I'm following the guide to Transformers and the colab project https://colab.research.google.com/drive/1XBP0Zh8K4g_n0A2p1UlGFf3dij0EX_Kt
but when I run the cell with the line multi_head = build_model() I get the error.
this is the output from the console:
NameError Traceback (most recent call
last) in ()
----> 1 multi_head = build_model()
5 frames in (x)
40 self.dropout = Dropout(attn_dropout)
41 def call(self, q, k, v, mask):
---> 42 attn = Lambda(lambda x:K.batch_dot(x[0],x[1],axes=[2,2])/self.temper)([q, k])
43 if mask is not None:
44 mmask = Lambda(lambda x:(-1e+10)*(1-x))(mask)
NameError: name 'K' is not defined
It just runs after the model architecture code, which the error refers to.
Can you see where this Kshould be defined?
import random, os, sys
import numpy as np
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.initializers import *
import tensorflow as tf
from tensorflow.python.keras.layers import Layer
try:
from dataloader import TokenList, pad_to_longest
# for transformer
except: pass
embed_size = 60
class LayerNormalization(Layer):
def __init__(self, eps=1e-6, **kwargs):
self.eps = eps
super(LayerNormalization, self).__init__(**kwargs)
def build(self, input_shape):
self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
initializer=Ones(), trainable=True)
self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
initializer=Zeros(), trainable=True)
super(LayerNormalization, self).build(input_shape)
def call(self, x):
mean = K.mean(x, axis=-1, keepdims=True)
std = K.std(x, axis=-1, keepdims=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
def compute_output_shape(self, input_shape):
return input_shape
class ScaledDotProductAttention():
def __init__(self, d_model, attn_dropout=0.1):
self.temper = np.sqrt(d_model)
self.dropout = Dropout(attn_dropout)
def __call__(self, q, k, v, mask):
attn = Lambda(lambda x:K.batch_dot(x[0],x[1],axes=[2,2])/self.temper)([q, k])
if mask is not None:
mmask = Lambda(lambda x:(-1e+10)*(1-x))(mask)
attn = Add()([attn, mmask])
attn = Activation('softmax')(attn)
attn = self.dropout(attn)
output = Lambda(lambda x:K.batch_dot(x[0], x[1]))([attn, v])
return output, attn
class MultiHeadAttention():
# mode 0 - big martixes, faster; mode 1 - more clear implementation
def __init__(self, n_head, d_model, d_k, d_v, dropout, mode=0, use_norm=True):
self.mode = mode
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.dropout = dropout
if mode == 0:
self.qs_layer = Dense(n_head*d_k, use_bias=False)
self.ks_layer = Dense(n_head*d_k, use_bias=False)
self.vs_layer = Dense(n_head*d_v, use_bias=False)
elif mode == 1:
self.qs_layers = []
self.ks_layers = []
self.vs_layers = []
for _ in range(n_head):
self.qs_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
self.ks_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
self.vs_layers.append(TimeDistributed(Dense(d_v, use_bias=False)))
self.attention = ScaledDotProductAttention(d_model)
self.layer_norm = LayerNormalization() if use_norm else None
self.w_o = TimeDistributed(Dense(d_model))
def __call__(self, q, k, v, mask=None):
d_k, d_v = self.d_k, self.d_v
n_head = self.n_head
if self.mode == 0:
qs = self.qs_layer(q) # [batch_size, len_q, n_head*d_k]
ks = self.ks_layer(k)
vs = self.vs_layer(v)
def reshape1(x):
s = tf.shape(x) # [batch_size, len_q, n_head * d_k]
x = tf.reshape(x, [s[0], s[1], n_head, d_k])
x = tf.transpose(x, [2, 0, 1, 3])
x = tf.reshape(x, [-1, s[1], d_k]) # [n_head * batch_size, len_q, d_k]
return x
qs = Lambda(reshape1)(qs)
ks = Lambda(reshape1)(ks)
vs = Lambda(reshape1)(vs)
if mask is not None:
mask = Lambda(lambda x:K.repeat_elements(x, n_head, 0))(mask)
head, attn = self.attention(qs, ks, vs, mask=mask)
def reshape2(x):
s = tf.shape(x) # [n_head * batch_size, len_v, d_v]
x = tf.reshape(x, [n_head, -1, s[1], s[2]])
x = tf.transpose(x, [1, 2, 0, 3])
x = tf.reshape(x, [-1, s[1], n_head*d_v]) # [batch_size, len_v, n_head * d_v]
return x
head = Lambda(reshape2)(head)
elif self.mode == 1:
heads = []; attns = []
for i in range(n_head):
qs = self.qs_layers[i](q)
ks = self.ks_layers[i](k)
vs = self.vs_layers[i](v)
head, attn = self.attention(qs, ks, vs, mask)
heads.append(head); attns.append(attn)
head = Concatenate()(heads) if n_head > 1 else heads[0]
attn = Concatenate()(attns) if n_head > 1 else attns[0]
outputs = self.w_o(head)
outputs = Dropout(self.dropout)(outputs)
if not self.layer_norm: return outputs, attn
# outputs = Add()([outputs, q]) # sl: fix
return self.layer_norm(outputs), attn
class PositionwiseFeedForward():
def __init__(self, d_hid, d_inner_hid, dropout=0.1):
self.w_1 = Conv1D(d_inner_hid, 1, activation='relu')
self.w_2 = Conv1D(d_hid, 1)
self.layer_norm = LayerNormalization()
self.dropout = Dropout(dropout)
def __call__(self, x):
output = self.w_1(x)
output = self.w_2(output)
output = self.dropout(output)
output = Add()([output, x])
return self.layer_norm(output)
class EncoderLayer():
def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
self.self_att_layer = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn_layer = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)
def __call__(self, enc_input, mask=None):
output, slf_attn = self.self_att_layer(enc_input, enc_input, enc_input, mask=mask)
output = self.pos_ffn_layer(output)
return output, slf_attn
def GetPosEncodingMatrix(max_len, d_emb):
pos_enc = np.array([
[pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]
if pos != 0 else np.zeros(d_emb)
for pos in range(max_len)
])
pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2]) # dim 2i
pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2]) # dim 2i+1
return pos_enc
def GetPadMask(q, k):
ones = K.expand_dims(K.ones_like(q, 'float32'), -1)
mask = K.cast(K.expand_dims(K.not_equal(k, 0), 1), 'float32')
mask = K.batch_dot(ones, mask, axes=[2,1])
return mask
def GetSubMask(s):
len_s = tf.shape(s)[1]
bs = tf.shape(s)[:1]
mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)
return mask
class Transformer():
def __init__(self, len_limit, embedding_matrix, d_model=embed_size, \
d_inner_hid=512, n_head=10, d_k=64, d_v=64, layers=2, dropout=0.1, \
share_word_emb=False, **kwargs):
self.name = 'Transformer'
self.len_limit = len_limit
self.src_loc_info = False # True # sl: fix later
self.d_model = d_model
self.decode_model = None
d_emb = d_model
pos_emb = Embedding(len_limit, d_emb, trainable=False, \
weights=[GetPosEncodingMatrix(len_limit, d_emb)])
i_word_emb = Embedding(max_features, d_emb, weights=[embedding_matrix]) # Add Kaggle provided embedding here
self.encoder = Encoder(d_model, d_inner_hid, n_head, d_k, d_v, layers, dropout, \
word_emb=i_word_emb, pos_emb=pos_emb)
def get_pos_seq(self, x):
mask = K.cast(K.not_equal(x, 0), 'int32')
pos = K.cumsum(K.ones_like(x, 'int32'), 1)
return pos * mask
def compile(self, active_layers=999):
src_seq_input = Input(shape=(None, ))
x = Embedding(max_features, embed_size, weights=[embedding_matrix])(src_seq_input)
# LSTM before attention layers
x = Bidirectional(LSTM(128, return_sequences=True))(x)
x = Bidirectional(LSTM(64, return_sequences=True))(x)
x, slf_attn = MultiHeadAttention(n_head=3, d_model=300, d_k=64, d_v=64, dropout=0.1)(x, x, x)
avg_pool = GlobalAveragePooling1D()(x)
max_pool = GlobalMaxPooling1D()(x)
conc = concatenate([avg_pool, max_pool])
conc = Dense(64, activation="relu")(conc)
x = Dense(1, activation="sigmoid")(conc)
self.model = Model(inputs=src_seq_input, outputs=x)
self.model.compile(optimizer = 'adam', loss = 'mean_squared_error', metrics=['accuracy'])
If you look at where K is being used you will see:
K.expand_dims
K.cumsum
K.batch_dot
These are Keras backend functions. The code is missing a from keras import backend as K, which I think is a standard abbreviation.