摘要:
合集:AI案例-NLP-传媒业
赛题:中文对话文本匹配挑战赛
主办方:科大讯飞、Datawhale
主页:https://challenge.xfyun.cn/topic/info?type=text-match&ch=j4XWs7V
AI问题:语义相似度识别
数据集:对话文本语义匹配数据集
数据集价值:根据问题识别出正确的待匹配文本,给定两个问题,判定该问题对语义是否匹配。
解决方案:bert-base-chinese模型
一、赛题描述
背景
文本匹配任务在自然语言处理中是非常重要的基础任务之一,在问答系统、智能对话等诸多应用场景起到关键性的作用,但中文对话中的文本匹配仍然存在很多难点。
任务
根据问题识别出正确的待匹配文本,给定两个问题Q,判定该问题对语义是否匹配。
类型 | 句子1 | 句子2 | 标签 (label) |
---|---|---|---|
相似文本 | 看图猜一电影名 | 看图猜电影 | 1 |
不相似文本 | 无线路由器怎么无线上网 | 无线上网卡和无线路由器怎么用 | 0 |
二、数据集描述
数据说明
赛题数据由训练集 (train.csv) 和测试集 (test.csv) 组成。
数据集版权许可协议
BY-NC-SA 4.0
https://creativecommons.org/licenses/by-nc-sa/4.0/deed.zh-hans
三、解决方案样例
工作原理介绍
工作流程
- 输入处理:将两个问题拼接为
[CLS]问题1[SEP]问题2[SEP]
的格式 - 特征提取:BERT模型提取问题的上下文相关表示
- 匹配判断:基于
[CLS]
位置的表示进行二分类(匹配/不匹配) - 损失计算:使用交叉熵损失进行优化
CLS-Classification Token
在BERT模型中,[CLS]
和[SEP]
是两种特殊的分词符号(special tokens),它们在模型处理文本时起到关键作用。
CLS含义:全称为”Classification”,位于输入序列的开头(第一个位置)。
设计目的:在预训练和微调阶段,[CLS]位置的隐藏层输出(即该位置的向量表示)被设计用来汇总整个序列的语义信息。对于分类任务(如文本匹配、情感分析等),模型会基于[CLS]的向量表示进行分类决策。
为什么能用于分类:在BERT的预训练过程中,[CLS]被强制学习全局语义信息。例如:在下一句预测(NSP)任务中,[CLS]用于判断两个句子是否连续。在微调阶段,[CLS]的向量会通过一个额外的分类层(通常是一个全连接层)输出分类结果。技术上,[CLS]的向量是模型对输入序列的”摘要”表示,包含了整个输入的综合信息。
SEP-Separator Token
SEP含义:全称为”Separator”,用于分隔不同的部分(如句子A和句子B)。
作用:在单句输入中,[SEP]标记句子的结束。在句子对任务(如文本匹配、问答)中,[SEP]分隔两个句子,帮助模型区分上下文边界。例如:[CLS] + 句子A + [SEP] + 句子B + [SEP]。
BERT-base-Chinese
bert-base-chinese
是谷歌发布的 中文预训练BERT模型,基于Transformer架构,专门针对中文文本进行优化。- 预训练:在大规模中文语料(如维基百科、新闻、书籍)上训练,学习语言的通用表示。
- 模型规模:
- Base版:12层Transformer,768隐藏维度,12个注意力头,约110M参数。
- 无
bert-large-chinese
官方版本,但社区有类似扩展模型。
- 核心特点:
- 支持中文分词(基于字或词,实际以字为单位处理)。
- 双向上下文建模(通过Masked Language Model任务)。
- 适用于多种下游任务(分类、NER、问答等)。
可能的改进
可能的改进方向为:使用更适合文本匹配的模型架构,如BertForSequenceClassification
。BertForSequenceClassification 是 Hugging Face Transformers 库中提供的一个预定义模型类,专门用于基于 BERT 的文本分类任务。它将 BERT 模型的强大语义理解能力与一个简单的分类层结合,适用于情感分析、垃圾邮件检测、意图识别等需要将文本分为不同类别的场景。
运行环境
外部库名称 | 版本号 |
---|---|
python | 3.12.8 |
sklearn-crfsuite | 0.5.0 |
transformers | 4.49.0 |
torch | 2.6.0+cu126 |
部署大语言模型在当前目录下:bert-base-chinese。
源码逻辑
源码:NLP_Chinese_Dialogue_Text_Matching_Challenge.ipynb。实现了一个基于BERT的中文问题对语义匹配系统。
1 数据准备
train_df = pd.read_csv('./data/train.csv', sep='\t', header=None, nrows=None)
test_df = pd.read_csv('./data/test.csv', sep='\t', header=None)
- 训练数据格式:
问题1\t问题2\t标签(0/1)
- 测试数据格式:
问题1\t问题2
- 数据分布:匹配(29039) vs 不匹配(20961)
数据被划分为训练集和验证集(9:1比例),使用分层抽样保持类别分布一致。
2 数据编码
tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')
train_encoding = tokenizer(list(q1_train), list(q2_train), truncation=True, padding=True, max_length=32)
- 使用BERT的分词器将中文问题对转换为模型可接受的输入格式
- 参数说明:
truncation=True
:超过最大长度的文本会被截断padding=True
:不足最大长度的文本会被填充max_length=32
:设置最大序列长度为32
3 数据集类
class QuoraDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(int(self.labels[idx]))
return item
- 自定义PyTorch数据集类,将编码后的数据和标签封装为适合模型训练的格式
- 返回包含
input_ids
,attention_mask
和labels
的字典
4 模型架构
model = BertForNextSentencePrediction.from_pretrained('./bert-base-chinese')
- 使用
BertForNextSentencePrediction
模型,该模型专门用于判断两个句子是否连续 - 对于语义匹配任务,可以理解为判断两个问题是否语义相关
5 训练过程
训练循环的主要步骤:
- 前向传播:将输入数据传入模型,计算损失
- 反向传播:计算梯度
- 梯度裁剪:防止梯度爆炸
- 参数更新:使用AdamW优化器更新模型参数
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs[0]
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
6 验证评估
def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
return np.sum(pred_flat == labels_flat) / len(labels_flat)
- 计算准确率作为主要评估指标
- 模型在验证集上的准确率约为89-90%
7 预测输出
def prediciton():
model.eval()
prediction_list = []
for batch in test_dataloader:
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
logits = outputs[1]
pred_flat = np.argmax(logits.detach().cpu().numpy(), axis=1).flatten()
prediction_list += list(pred_flat)
return prediction_list
- 使用训练好的模型对测试数据进行预测
- 保存预测结果为CSV文件
关键点说明
- 学习率:设置为2e-5,这是BERT微调的典型学习率
- 批量大小:16,适合大多数GPU内存
- 序列长度:32,对于中文问题通常足够
- 训练轮数:3个epoch,观察到损失持续下降
运行结果
---------------- Training - Epoch: 0 ----------------
Epoth: 0, iter_num: 100, loss: 0.5201, progress: 3.55%
Epoth: 0, iter_num: 200, loss: 0.3454, progress: 7.11%
Epoth: 0, iter_num: 300, loss: 0.1842, progress: 10.66%
Epoth: 0, iter_num: 400, loss: 0.1717, progress: 14.22%
Epoth: 0, iter_num: 500, loss: 0.3616, progress: 17.77%
...
输出数据样例:submit.csv
1
1
1
1
1
1
1
1
0
1
0
0
...
源码开源协议
GPL-v3