PyTorch 关系抽取
复现论文:Relation Classification via Multi-Level Attention CNNs
源码: https://github.com/SStarLib/myACnn
[TOC]
一、论文简介
简介:
《Relation Classification via Multi-Level Attention CNNs》这篇论文是清华刘知远老师的团队出的一篇文章,这篇文章通过基于两种attention机制的CNN来进行关系抽取任务
motivation:
句子中有些词,对实体关系分类起着重要作用。例如
“Fizzy [drinks] and meat cause heart disease and [diabetes].”
这里面的cause 对实体关系分类就有很重要的作用。通过attention机制学习到句子中比较重要的词
- 通过 input-attention 找到句子中对于entity1和entity2中比较重要的部分
- 通过 attention-pooling 找到特征图中对于构建relation embedding中重要的部分
二、模型构建
数据预处理
数据预处理应该是比较花时间的一部分。我这里做的也不好。不过很多论文使用这个数据集,可以找到别人已经处理好的数据。
构建模型
需要构建的模块大概分为:
- 输入表示
- input attention
- convolution
- attention-based pooling
- 损失函数
1. 输入表示
文本数据向量化
通过数据集构建vocab, 所谓的vocab 就是一个存储 word-index 的字典。在vocab中需要实现的功能有“增、通过token 查 index, 通过 index查token”, vocab(完整的代码点这里)
1
2
3
4
5
6
7
8
9
10
11
12
13
14class Vocabulary(object):
"""Class to process text and extract vocabulary for mapping"""
def __init__(self, token_to_idx=None):
"""
:param token_to_idx(dict): a pre_existing map of tokens to indices
"""
if token_to_idx is None:
token_to_idx = {}
self._token_to_idx = token_to_idx
self._idx_to_token = {idx: token
for token, idx in self._token_to_idx.items()}
...........将输入的句子 向量化 vectorizer
将句子中的文本数值化,生成一个索引列表。
构建数据集生成器 dataset
该类实现了 torch.utils.data.Dataset 方法,方便使用DataLoader载入数据。
模型的表示层
使用pytorch的自带的embedding函数。本论文中需要构建四个表示层。word embedding, pos1 embedding, pos2 embedding, relation embedding。
拼接词向量:
原论文中公式如上,将连续三个词沿embedding 维度拼接,目的尽可能保留序列信息。实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15def createWin(self, input_vec):
"""
[b_s, s_l, e_s+2*pe_s], k=win_size
:param input_vec: [b_s, s_l, e_s+2*pe_s]
:return:shape [b_s, s_l, (e_s+2*pe_s)*k]
"""
n = self.win_size
result = input_vec
input_len = input_vec.shape[1]
for i in range(1,n):
input_temp = input_vec.narrow(1,i, input_len-i) # 长度应该是 input_len - i
ten = torch.zeros((input_vec.shape[0],i, input_vec.shape[2]),dtype=torch.float)
input_temp = torch.cat((input_temp, ten),dim=1)
result=torch.cat((result,input_temp), dim=2)
return resultinput attention
这里是第一次的attention,目的是学习出每句话的部分相对于实体的重要性程度
实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23class Atten_input(nn.Module):
def __init__(self, config):
super(Atten_input, self).__init__()
self.config = config
def forward(self, entity1, entity2, word, z):
"""
b_s:batch_size, s_l:length of sentence,e_s:embedding size
e(n)_l: length of entity n n={1, 2}, pe_s: position embedding size
:param entity1: [b_s, e1_l, e_s]
:param entity2: [b_s, e2_l, e_s]
:param word:shape: [b_s, s_l, e_s]
:param z: [b_s, s_l, e_s+2*pe_s]
:return:
"""
# mean or sum ???
# 此处出错,应该用 似乎用 w_d 点乘 e
# 此处出错,mean()会减去一个维度, mean(dim=2)
A1 = torch.bmm(word, entity1.permute(0,2,1)).mean(dim=2) # [b_s, s_l, 1]
A2 = torch.bmm(word, entity2.permute(0,2,1)).mean(dim=2) # [b_s, s_l, 1]
alpha1 = F.softmax(A1, dim=1).unsqueeze(dim=2) # [b_s, s_l, 1]
alpha2 = F.softmax(A2, dim=1).unsqueeze(dim=2) # [b_s, s_l, 1]
R = torch.mul(z, (alpha1+alpha2)/2) # [b_s, s_l, e_s+2*pe_s]
return R卷积层
跟图像卷积的目的一样,学习出local feature map。同时feed 进下一层的attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19class ConvNet(nn.Module):
def __init__(self, config):
super(ConvNet, self).__init__()
self.config =config
self.in_height = self.config.kernel_size
self.in_width = \
(self.config.word_embed_size + 2 * self.config.pos_embed_size)*self.config.win_size
self.stride = (1, 1)
self.padding = (1, 0)
self.kernel_size = (self.in_height,self.in_width)
self.cnn = nn.Conv2d(1, 1000, self.kernel_size,self.stride,self.padding)
def forward(self, R):
"""
d_c =hidden_size= 1000
:param R: shape: [b_s, s_l, e_s+2*pe_s]
:return: R_star shape: [b_s,d_c, s_l]
"""
R_star = torch.tanh(self.cnn(R.unsqueeze(dim=1)).squeeze(-1))
return R_starattention pool
目的是为了在Max pooling 之前 学习 这些local feature map 相对relation 之间的重要性程度,然后使用max pooling。
![image-20200314152613820](/Users/wei/Library/Application Support/typora-user-images/image-20200314152613820.png)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21class Attn_pool(nn.Module):
def __init__(self,config):
super(Attn_pool, self).__init__()
self.config = config
self.in_dim = self.config.hidden_size
self.out_dim = 1
self.U = nn.Parameter(torch.randn(self.in_dim, self.out_dim))
self.kernel_size = self.config.max_sen_len
self.max_pool = nn.MaxPool1d(self.kernel_size, 1)
def forward(self, R_star, W_L):
"""
:param R_star: [b_s,d_c, s_l]
:param W_L: [b_s,1,rel_emb ]
:return:
"""
RU = torch.matmul(R_star.permute(0,2,1), self.U) # [b_s, s_l,1]
G = torch.matmul(RU,W_L) # [b_s, s_l,rel_emb]
AP = F.softmax(G, dim=1)
RA = torch.mul(R_star, AP.transpose(2, 1)) #[b_s,d_c, s_l]
wo = self.max_pool(RA) #[b_s,d_c, 1]
return woloss func
这篇 paper 里的单独设计了loss 函数,主要使用曼哈顿距离:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32class DistanceLoss(nn.Module):
def __init__(self, margin,rel_emb_size, rel_vocab_size,all_y, padding_idx=0):
super(DistanceLoss, self).__init__()
self._margin = margin
self.all_y = torch.from_numpy(np.array(all_y))
self.relemb = nn.Embedding(embedding_dim=rel_emb_size,
num_embeddings=rel_vocab_size,
padding_idx=padding_idx)
self.relemb.weight.requires_grad=False
def forward(self, wo, rel_weight, in_y):
"""
:param wo:[b_s,d_c, 1]
:param rel_weight:
:param in_y: [b_s, 1]
:return:
"""
self.relemb.weight=nn.Parameter(rel_weight)
wo_norm = F.normalize(wo.squeeze()) #[b_s,d_c]
wo_norm_tile = wo_norm.unsqueeze(1).repeat(1, self.all_y.shape[0], 1) # [b_s, num_rel, d_c]
rel_emb = self.relemb(in_y).squeeze(dim=1) # [b_s, rel_emb]
all_y_emb = self.relemb(self.all_y) # [b_s, num_rel, rel_emb]
y_dist=torch.norm(wo_norm - rel_emb, 2, 1) # [b_s, rel_emb]
# 求最大的错分距离
# Mask in_y
all_dist = torch.norm(wo_norm_tile - all_y_emb, 2, 2) # [b_s, num_rel, rel_emb]
one_hot_y = torch.zeros(in_y.shape[0],self.all_y.shape[0]).scatter_(1, in_y, 1)
masking_y = torch.mul(one_hot_y, 10000)
_t_dist = torch.min(torch.add(all_dist, masking_y), 1)[0]
loss = torch.mean(self._margin + y_dist - _t_dist)
return loss这里的loss 参考了别人的代码,其实代码我也没有完全理解。
三、 训练例程
Pytorch 的训练方式比较套路化。代码:https://github.com/SStarLib/myACnn/blob/master/myMain.py
四、 结尾
可能因为我分割的数据集有问题,效果出奇的好,因为我的测试集是从训练集分出来的,而且shuffle过,很大程度上数据在一个分布上。该任务的原有数据集我还没处理过,不过以后会接着复现该数据集的论文。可以下次一起处理。
另外,有没有人知道mac 下maven 明明已经下载了包,但是idea里还显示报错是什么情况,网上的各种方法都试过了,就是无效。