摘要:
赛题:CCF大数据与计算智能大赛(CCF BDCI)-2020-小学数学应用题自动解题
主办方:中国计算机学会 & 题拍拍
主页:https://www.datafountain.cn/competitions/467
数据集:20,000道应用题,包含1-6年级的校内题目,简单题目居多。
数据集发布方:题拍拍
数据集价值:数据来源于K12真实教育场景,用于训练模型解数学题的逻辑推理能力。
解决方案:Graph-to-Tree (GTS) Learning 是一种结合图结构(Graph)和树形解码(Tree Decoding)的深度学习框架,主要用于解决结构化预测问题。
一、赛题描述背景
应用题阅读理解
阅读理解是近年来NLP的一个常见任务,通常要求在大段文本中理解关键信息。由于很多关键信息直接来源于文本的关键句子,所以很难衡量模型本身的”理解能力“,而机器对内容的理解是衡量AI在教育领域发展的一个重要依据。应用题包含简单的文字表述,相对密集的推理和计算,是评估机器阅读理解能力的一个重要场景。同时,应用题也是K12教研的重要组成部分,如果机器能完美的理解题意,将会给AI在教育中的发展产生巨大的想象空间。
任务
该任务是为了衡量现有机器学习模型在应用题理解方面的能力,模型读入一个应用题,输出该题的结果。为了降低任务的难度,赛题选择小学数学1-6年级校内题目。 例如: 1.商店有4框苹果,每框55千克,已经卖出135千克,还剩多少千克苹果?答案: 85千克 2.玩具厂生产了960个电子玩具,每3个装一盒,每5盒装一箱,一共装了多少箱?答案: 64箱。
二、数据集内容
数据文件为csv格式。其中训练数据中每行包括题目和答案,测试数据不包括答案。official_data/train.csv中包含12,000条训练数据; official_data/test.csv中包含8,000条测试数据。
数据集使用许可协议
BY-NC-SA 4.0
https://creativecommons.org/licenses/by-nc-sa/4.0/deed.zh-hans
数据集发布方:题拍拍
三、解决方案样例
Graph-to-Tree (GTS) Learning
Graph-to-Tree (GTS) Learning 是一种结合图结构(Graph)和树形解码(Tree Decoding)的深度学习框架,主要用于解决结构化预测问题,例如数学应用题求解(Math Word Problem Solving)、程序代码生成、语义解析等任务。在本案例 graph2tree.py
中,它特指 Quantity Graph(数值关系图)到表达式树(Expression Tree) 的转换学习,核心思想是:
- 输入:将自然语言描述的数学问题建模为数值关系图Quantity Graph,用于捕获数字、实体及其关系。
- 转换:通过神经网络(Encoder + Decoder)将图结构逐步解码为 树形表达式(如
(5 + 3) × 2
)。 - 输出:生成可计算的数学表达式树(前缀、中缀或后缀形式)。
1. 核心概念解析
(1) Quantity Graph(数值关系图)
数学问题中的数字(如 5
、2
)和实体(如 苹果
、小明
)被建模为图中的节点,它们之间的关系(如 多2个
、少了3倍
)是边。 示例:
“小明有5个苹果,吃了2个,还剩几个?”
- 节点:
小明
、5
、2
、剩余
- 边:
有(5)
、吃了(2)
、求(剩余)
显式建模问题中的数量逻辑,比纯文本(Seq2Seq)更易捕捉数学关系。
(2) Expression Tree(表达式树)
数学表达式的树形表示,其中:
- 叶子节点:数字或变量(如
5
、x
)。 - 非叶子节点:操作符(如
+
、×
)。 示例:(5 - 2)
的树形表示:
-
/ \
5 2
树结构能直接映射到可计算的数学表达式,避免生成不合法的表达式(如 5 + × 3
)。
2. Graph-to-Tree (GTS) 的工作流程
在 graph2tree.py
中,具体流程如下:
步骤 1:图编码(Graph Encoding)
- 输入:自然语言问题 → 转换为
Quantity Graph
。 - 编码器:
EncoderSeq
(LSTM/GNN)将图和文本编码为隐藏状态。 encoder = EncoderSeq(input_size, embedding_size, hidden_size, n_layers)
步骤 2:树解码(Tree Decoding)
解码器逐步生成表达式树,依赖三个关键模块:
- Prediction:预测当前步骤的操作符(如
+
、-
)。 predict = Prediction(hidden_size, op_nums, input_size) - GenerateNode:生成数字或子表达式节点。 generate = GenerateNode(hidden_size, op_nums, embedding_size)
- Merge:合并子树信息(如左子树
5
和右子树2
合并为5 - 2
)。 merge = Merge(hidden_size, embedding_size)
步骤 3:训练与优化
- 损失函数:交叉熵损失(Cross-Entropy)用于操作符和节点的预测。
- 优化器:
Adam
分模块优化(Encoder、Prediction、Generate、Merge)。 optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay)
3. 为什么用 GTS 而不是 Seq2Seq?
方面 | Graph-to-Tree (GTS) | Seq2Seq |
---|---|---|
输入表示 | 显式建模数值关系图(Quantity Graph) | 纯文本序列 |
输出结构 | 生成合法的表达式树 | 可能生成非法表达式(如 5 + × ) |
可解释性 | 树结构可直观分解逻辑 | 黑箱序列 |
适用任务 | 数学应用题、代码生成、语义解析 | 机器翻译、文本摘要 |
优势:
- 精准建模数学逻辑:通过图结构捕获问题中的数量关系。
- 保证输出合法性:树解码避免无效表达式。
- 可扩展性:可结合GNN(图神经网络)增强关系推理。
4. 典型应用场景
- 数学应用题求解(如
graph2tree.py
的任务):- 输入:
"甲有5本书,乙比甲多3本,求乙的书数"
- 输出:
(+ 5 3)
(前缀表达式树)。
- 输入:
- 程序代码生成:
- 输入:
"循环打印1到10的数字"
- 输出:抽象语法树(AST)。
- 输入:
- 语义解析(Semantic Parsing):
- 输入:
"查找价格低于100元的商品"
- 输出:SQL查询树。
- 输入:
汉语言处理包HanLP
“汉语言处理包”(Han Language Processing)简称HanLP。HanLP 是一个由中国开发者何晗(hankcs)于 2014 年开发的自然语言处理库。是一个功能丰富、同时适用于轻量级和海量级的场景的多语种自然语言处理工具包,具备高准确性、高效率、使用最新的语料库、清晰地架构和可定制性等特点。HanLP 具有丰富的功能,可以进行一系列文本分析任务,比如词法分析(分词、词性标注、命名实体识别)、句法分析、文本分类/聚类、信息抽取、语义分析等。
我们使用开源工具hanlp的依存句法分析工具,对每个问题文本生成其依存句法树。
数据输入文件:
- 训练集:official_data/train.csv
- 测试集:official_data/test.csv
生成文件:
- 训练集:./group_num_processed/train_corrected_processed_eda.json
- 测试集:./group_num_processed/test_processed.json
数据样例如下:
{
"id": "0",
"doc_index_originality": "original",
"original_text": "食堂运来105千克的萝卜,运来的青菜是萝卜的3倍,运来青菜多少千克?",
"cleaned_text": "食堂运来105千克的萝卜,运来的青菜是萝卜的3倍,运来青菜多少千克?",
"segmented_text_new": "食堂 运来 105 千克 的 萝卜 , 运来 的 青菜 是 萝卜 的 3 倍 , 运来 青菜 多 少 千克 ?",
"num_list": ["105", "3"],
"group_num": [0, 1, 2, 3, 4, 5, 10, 11, 13, 14],
"group_num_segs": [
["食堂", "运来", "105", "千克", "的", "萝卜"],
["是", "萝卜", "3", "倍"]
],
"equation": "x=105*3",
"equation_crawled": "x=105*3",
"equation_crawled_for_eval": "105*3",
"ans": "315",
"ans_gold_for_eval": "315"
}
各个字段的说明如下:
- id: “0” 样本的唯一标识符。
- doc_index_originality: “original” 表示文档的原创性,这里是原始文本(”original”),可能是相对于修改或增强后的版本而言。
- original_text: “食堂运来105千克的萝卜…” 原始的数学应用题文本。
- cleaned_text: “食堂运来105千克的萝卜…” 清理后的文本,可能与原始文本相同,表示无需清理。
- segmented_text_new: “食堂 运来 105 千克 的 萝卜…” 分词后的文本,使用空格分隔词语(如”食堂”、”运来”、”105″等)。
- num_list: [“105”, “3”] 数学题目中出现的数值列表(如”105″和”3″),后续将应用于数值计算。
- group_num: [0, 1, 2, 3, 4, 5, 10, 11, 13, 14] 表示在
segmented_text_new
中哪些位置的词语属于数值相关的词语。这些索引指向的词语可能与词语”105″或”3″的依存关系相关。例如:- 0: “食堂”
- 1: “运来”
- 2: “105”
- …
- group_num_segs: [
[“食堂”, “运来”, “105”, “千克”, “的”, “萝卜”],
[“是”, “萝卜”, “3”, “倍”]
]- 从
group_num
提取的具体词语分组,每个子列表是一个数值相关的短语。例如:- 第一组:与数字”105″相关的上下文(”食堂运来105千克的萝卜”)。
- 第二组:与数字”3″相关的上下文(”是萝卜的3倍”)。
- 从
- equation: “x=105*3” 解题的方程,这里表示青菜的重量是萝卜的3倍(105 × 3)。
- equation_crawled: “x=105*3” (与
equation
相同)。 - equation_crawled_for_eval: “105*3” 用于评估的简化方程(去除了”x=”部分)。
- ans: “315” 计算得到的答案(105 × 3 = 315)。标准答案,用于模型训练。
- ans_gold_for_eval: “315” 标准答案,用于模型评估(与
ans
一致)。
安装开发包
库名称 | 版本号 |
---|---|
python | 3.12.8 |
tqdm | 4.67.1 |
treelib | 1.7.1 |
torch | 2.5.1 |
treelib 是一个轻量级的 Python 库,用于创建、操作和可视化树形结构(Tree)。它提供了简单易用的 API,适合处理层次化数据。
GTS处理流程和执行结果
主要源码入口:graph2tree.py。这段代码实现了一个名为”Graph2Tree”的模型,主要用于将数学应用题文本转换为数学表达式树。输入数据文件为:
- 训练数据:group_num_processed/train_corrected_processed_eda.json
- 测试数据:group_num_processed/test_processed.json。
输出文件为预测结果:predicted_list.pkl
执行方法:python graph2tree.py
1. 初始化和配置
- 导入必要的库和模块
- 设置随机种子保证可重复性
- 创建必要的目录(models和prepared_data)
import json
import os
import pickle
import pprint
import random
import time
from tqdm import tqdm, trange
from graph_to_tree_utils import (TRAIN_ITERATION_OVER, USE_EDA, batch_size,
beam_size, embedding_size, hidden_size,
learning_rate, n_epochs, n_layers,
weight_decay)
from src.expressions_transfer import *
from src.models import *
from src.train_and_evaluate import *
if not os.path.exists('models'):
os.system('mkdir models')
if not os.path.exists('prepared_data'):
os.system('mkdir prepared_data')
从配置文件(graph_to_tree_utils.py)中加载各种超参数(如batch_size, hidden_size等)
配置数据:graph_to_tree_utils.py
# Graph2Tree_submit/competition/data_process_group_num.py
# 中的配置:
# 用于切换处理比赛的训练、测试集
# TARGET_FILE_GROUP_NUM = 'train_corrected'
TARGET_FILE_GROUP_NUM = 'test'
# Graph2Tree_submit/competition/graph2tree.py
# 中的配置:
# 人工干预模型是否训练结束,True则确定开始预测,False则继续训练(如果训练轮数已经完成,则开始预测)
TRAIN_ITERATION_OVER = True # True
# 每个问题至多使用的增强数据数量,<=0 表示不使用增强的数据,仅使用原始数据
USE_EDA = -1
# 模型训练超参数
batch_size = 64
embedding_size = 128
hidden_size = 512
n_epochs = 20 # 80
learning_rate = 1e-3
weight_decay = 1e-5
beam_size = 5
n_layers = 2
2. 数据准备阶段
- 加载训练数据(
train_corrected_processed_eda.json
) 到变量data_train
- 加载并预处理测试数据(处理分数格式,添加括号等)
- 使用
transfer_num
函数转换数据中的数字 - 将中缀表达式转换为前缀表达式
- 将数据分为训练集和测试集
- 准备输入和输出的语言模型(
input_lang
,output_lang
) - 保存预处理后的数据到pickle文件
if (not os.path.exists('prepared_data/generate_num_ids.pkl')) and (not os.path.exists('prepared_data/input_lang.pkl')):
data_train = load_raw_data(
"group_num_processed/train_corrected_processed_eda.json", raw=False, operator_no_gt_thresh=4, use_eda=USE_EDA)
test_file_name = 'test'
if (not os.path.exists('prepared_data/generate_num_ids.pkl')) and (not os.path.exists('prepared_data/input_lang.pkl')):
data_test = []
# 路径根据实际情况调整
target_file = f'./group_num_processed/{test_file_name}_processed.json'
for doc_index, data_d in tqdm(enumerate(open(target_file, encoding='utf-8'))):
data_d = json.loads(data_d)
# 初始化找到的分数集合
fraction_num_set = set()
# 将segmented_text_new转为列表
segmented_text_new_list = data_d["segmented_text_new"].split(
' ')
# 遍历segmented_text_new_list列表,发现分数的直接加括号,并将发现的分数添加到上述集合中
for i_, item in enumerate(segmented_text_new_list):
if ('/' in item) and (item[0] in digits) and (item[-1] in digits):
fraction_num_set.add(item)
segmented_text_new_list[i_] = f'({item})'
# 将segmented_text_new重新组合成字符串
data_d["segmented_text_new"] = (
' '.join(segmented_text_new_list)).strip()
# 将发现的分数集合转成列表,并排序(倒序)
fraction_num_list = list(fraction_num_set)
fraction_num_list.sort(key=lambda x: len(x), reverse=True)
# 这里之所以不直接对发现的分数进两端加括号,是为了防止出现,如果题中同时有1/5和11/5,11/5变为(11/5)之后又变为(1(1/5))
fraction_num_list_ext = []
for item in fraction_num_list:
numerator = item.split('/')[0]
denominator = item.split('/')[1]
fraction_num_list_ext.append(
(item, f'({numerator}圀/圀{denominator})'))
for fraction_num, fraction_num_ext in fraction_num_list_ext:
data_d["cleaned_text"] = data_d["cleaned_text"].replace(
fraction_num, fraction_num_ext)
# 再把多余的占位符 圀 去掉
data_d["cleaned_text"] = data_d["cleaned_text"].replace('圀', '')
# 进行覆盖
data_d["original_text"] = data_d["cleaned_text"]
data_d["segmented_text"] = data_d["segmented_text_new"]
data_d["equation"] = ''
# 因为测试数据都不知道答案,所以此处都给-1000
data_d["ans"] = -1000
data_test.append(data_d)
print('what the first 10 items of data_test is:')
pprint.pprint(data_test[:10])
print('and its original length is:')
pprint.pprint(len(data_test))
print()
if (not os.path.exists('prepared_data/generate_num_ids.pkl')) and (not os.path.exists('prepared_data/input_lang.pkl')):
pairs, generate_nums, copy_nums = transfer_num(
data_train + data_test, raw=False)
print('what the first 10 items of pairs are:')
pprint.pprint(pairs[:10])
print('and the last item of pairs is:')
pprint.pprint(pairs[-1])
print('and its original length is:')
pprint.pprint(len(pairs))
print()
pickle.dump(generate_nums, open('prepared_data/generate_nums.pkl', 'wb'))
pickle.dump(copy_nums, open('prepared_data/copy_nums.pkl', 'wb'))
else:
generate_nums = pickle.load(open('prepared_data/generate_nums.pkl', 'rb'))
copy_nums = pickle.load(open('prepared_data/copy_nums.pkl', 'rb'))
print(f'generate_nums loaded!\n')
print(f'copy_nums loaded!\n')
运行结果展示:
generate_nums loaded!
copy_nums loaded!
what the generate_nums are:
['1',
'20%',
'2',
'10',
'3.14',
'36',
'100',
'10000',
'0.5',
'3',
'80%',
'100%',
'5',
'85%',
'100000',
'1000',
'90%',
'15',
'17',
'60',
'7',
'4',
'6',
'32',
'8',
'23',
'20',
'70',
'9',
'12',
'14',
'48',
'28',
'47',
'24',
'80',
'0.2',
'50',
'35',
'30',
'2000',
'18',
'16',
'1.5',
'0.3',
'50%',
'200',
'25',
'31',
'95%',
'3000',
'2.5',
'250',
'22',
'240',
'360',
'400',
'0.8',
'11']
and its original length is:
59
what the copy_nums are:
14
prepared data loaded!
generated num ids loaded!
3. 初始化模型
模型的定义在源码:src\models.py
- 初始化四个主要组件:
EncoderSeq
: 序列编码器Prediction
: 预测模块GenerateNode
: 节点生成模块Merge
: 合并模块
- 为每个组件创建优化器和学习率调度器
- 如果可用,将模型移到GPU
# Initialize models
encoder = EncoderSeq(input_size=input_lang.n_words, embedding_size=embedding_size, hidden_size=hidden_size,
n_layers=n_layers)
predict = Prediction(hidden_size=hidden_size, op_nums=output_lang.n_words - copy_nums - 1 - len(generate_nums),
input_size=len(generate_nums))
generate = GenerateNode(hidden_size=hidden_size, op_nums=output_lang.n_words - copy_nums - 1 - len(generate_nums),
embedding_size=embedding_size)
merge = Merge(hidden_size=hidden_size, embedding_size=embedding_size)
# the embedding layer is only for generated number embeddings, operators, and paddings
encoder_optimizer = torch.optim.Adam(
encoder.parameters(), lr=learning_rate, weight_decay=weight_decay)
predict_optimizer = torch.optim.Adam(
predict.parameters(), lr=learning_rate, weight_decay=weight_decay)
generate_optimizer = torch.optim.Adam(
generate.parameters(), lr=learning_rate, weight_decay=weight_decay)
merge_optimizer = torch.optim.Adam(
merge.parameters(), lr=learning_rate, weight_decay=weight_decay)
encoder_scheduler = torch.optim.lr_scheduler.StepLR(
encoder_optimizer, step_size=20, gamma=0.5)
predict_scheduler = torch.optim.lr_scheduler.StepLR(
predict_optimizer, step_size=20, gamma=0.5)
generate_scheduler = torch.optim.lr_scheduler.StepLR(
generate_optimizer, step_size=20, gamma=0.5)
merge_scheduler = torch.optim.lr_scheduler.StepLR(
merge_optimizer, step_size=20, gamma=0.5)
# Move models to GPU
if USE_CUDA:
encoder.cuda()
predict.cuda()
generate.cuda()
merge.cuda()
4. 训练阶段
- 检查是否有已保存的模型和训练记录
- 按epoch进行训练循环:
- 准备训练批次数据
- 使用
train_tree
函数训练模型 - 更新学习率
- 定期保存模型(每10个epoch)
- 训练完成后保存最终模型
if not train_iteration_over:
if os.path.exists('training_memo.pkl'):
training_memo = pickle.load(open('training_memo.pkl', 'rb'))
trained_epoches = training_memo['trained_epoches']
else:
trained_epoches = -1
if os.path.exists('models/encoder'):
encoder.load_state_dict(torch.load("models/encoder"))
predict.load_state_dict(torch.load("models/predict"))
generate.load_state_dict(torch.load("models/generate"))
merge.load_state_dict(torch.load("models/merge"))
print('models loaded before entering the training iteration! Now the memo is:')
print('training_memo: ', training_memo)
print()
for epoch in range(n_epochs):
if epoch <= trained_epoches:
continue
loss_total = 0
input_batches, input_lengths, output_batches, output_lengths, nums_batches, \
num_stack_batches, num_pos_batches, num_size_batches, num_value_batches, graph_batches = prepare_train_batch(
train_pairs, batch_size)
print("epoch:", epoch + 1)
start = time.time()
for idx in trange(len(input_lengths)):
try:
loss = train_tree(
input_batches[idx], input_lengths[idx], output_batches[idx], output_lengths[idx],
num_stack_batches[idx], num_size_batches[idx], generate_num_ids, encoder, predict, generate, merge,
encoder_optimizer, predict_optimizer, generate_optimizer, merge_optimizer, output_lang, num_pos_batches[idx], graph_batches[idx])
except BaseException as e:
if not isinstance(e, KeyboardInterrupt):
print(f'idx: {idx}')
print('input_batches[idx]:')
pprint.pprint(input_batches[idx])
print()
input_batches_2_words = list(
map(lambda x: list(map(lambda y: input_lang.index2word[y], x)), input_batches[idx]))
print('input_batches_2_words:')
pprint.pprint(input_batches_2_words)
print()
print('input_lengths[idx]:')
pprint.pprint(input_lengths[idx])
print()
print('output_batches[idx]:')
pprint.pprint(output_batches[idx])
print()
output_batches_2_words = list(
map(lambda x: list(map(lambda y: output_lang.index2word[y], x)), output_batches[idx]))
print('output_batches_2_words:')
pprint.pprint(output_batches_2_words)
print()
print('output_lengths[idx]:')
pprint.pprint(output_lengths[idx])
print()
raise Exception('Some error happened!\n')
loss_total += loss
encoder_scheduler.step()
predict_scheduler.step()
generate_scheduler.step()
merge_scheduler.step()
print("loss:", loss_total / len(input_lengths))
print("training time", time_since(time.time() - start))
print("--------------------------------")
if (epoch+1) % 10 == 0:
torch.save(encoder.state_dict(), "models/encoder")
torch.save(predict.state_dict(), "models/predict")
torch.save(generate.state_dict(), "models/generate")
torch.save(merge.state_dict(), "models/merge")
print(f'models saved for epoch {epoch+1}!\n')
training_memo = {
'competiton_data_source': test_file_name,
'trained_epoches': epoch,
}
if os.path.exists('training_memo.pkl'):
os.system(f'rm -rf training_memo.pkl')
pickle.dump(training_memo, open('training_memo.pkl', 'wb'))
# 训练结束后最后一次保存
torch.save(encoder.state_dict(), "models/encoder")
torch.save(predict.state_dict(), "models/predict")
torch.save(generate.state_dict(), "models/generate")
torch.save(merge.state_dict(), "models/merge")
print(
f'models saved after the training iteration over, now the epoch is {epoch+1}!\n')
training_memo = {
'competiton_data_source': test_file_name,
'trained_epoches': epoch,
}
if os.path.exists('training_memo.pkl'):
os.system(f'rm -rf training_memo.pkl')
pickle.dump(training_memo, open('training_memo.pkl', 'wb'))
else:
encoder.load_state_dict(torch.load("models/encoder"))
predict.load_state_dict(torch.load("models/predict"))
generate.load_state_dict(torch.load("models/generate"))
merge.load_state_dict(torch.load("models/merge"))
print('models loaded when training iteration is over!\n')
运行结果展示:
epoch: 1
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [03:24<00:00, 1.71s/it]
loss: 1.9905092318852742
training time 0h 3m 24s
--------------------------------
epoch: 2
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [04:09<00:00, 2.08s/it]
loss: 1.594282736380895
training time 0h 4m 9s
--------------------------------
epoch: 3
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [11:30<00:00, 5.75s/it]
loss: 1.409391882022222
training time 0h 11m 30s
--------------------------------
epoch: 4
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [04:25<00:00, 2.21s/it]
loss: 1.2609346439441045
training time 0h 4m 25s
--------------------------------
epoch: 5
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [07:37<00:00, 3.81s/it]
loss: 1.1588241264224053
training time 0h 7m 37s
--------------------------------
epoch: 6
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [07:03<00:00, 3.53s/it]
loss: 1.0730938428392014
training time 0h 7m 3s
--------------------------------
epoch: 7
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [06:06<00:00, 3.05s/it]
loss: 1.020545799533526
training time 0h 6m 6s
--------------------------------
epoch: 8
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [05:24<00:00, 2.70s/it]
loss: 0.9611828183134397
training time 0h 5m 24s
--------------------------------
epoch: 9
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [06:17<00:00, 3.15s/it]
loss: 0.9045661106705666
training time 0h 6m 17s
--------------------------------
epoch: 10
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [05:28<00:00, 2.74s/it]
loss: 0.8810180435578029
training time 0h 5m 28s
--------------------------------
models saved for epoch 10!
epoch: 11
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [05:50<00:00, 2.92s/it]
loss: 0.8478761807084083
training time 0h 5m 50s
--------------------------------
epoch: 12
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [05:40<00:00, 2.83s/it]
loss: 0.8136498818794886
training time 0h 5m 40s
--------------------------------
epoch: 13
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [09:00<00:00, 4.50s/it]
loss: 0.8206553558508555
training time 0h 9m 0s
--------------------------------
epoch: 14
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [05:24<00:00, 2.70s/it]
loss: 0.7745375340183576
training time 0h 5m 24s
--------------------------------
epoch: 15
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [06:32<00:00, 3.27s/it]
loss: 0.7527274847030639
training time 0h 6m 32s
--------------------------------
epoch: 16
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [06:45<00:00, 3.38s/it]
loss: 0.753266458461682
training time 0h 6m 45s
--------------------------------
epoch: 17
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [06:17<00:00, 3.14s/it]
loss: 0.7025209320088227
training time 0h 6m 17s
--------------------------------
epoch: 18
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [07:50<00:00, 3.92s/it]
loss: 0.6968447413295508
training time 0h 7m 50s
--------------------------------
epoch: 19
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [08:48<00:00, 4.40s/it]
loss: 0.6249610838790735
training time 0h 8m 48s
--------------------------------
epoch: 20
100%|████████████████████████████████████████████████████████████████████████████████| 120/120 [05:41<00:00, 2.84s/it]
loss: 0.647699203590552
training time 0h 5m 41s
--------------------------------
models saved for epoch 20!
models loaded when training iteration is over!
5. 预测阶段
- 加载训练好的模型
- 将模型设置为评估模式
- 对测试数据进行预测:
- 为每个测试样本构建图结构
- 使用
evaluate_tree
函数生成预测 - 将预测结果转换为前缀表达式
- 记录预测结果
- 保存预测结果到
predicted_list.pkl
predict_total = 0
predicted_list = []
# 一个特例:由于分词的原因,而造成的num_list中有分母为0的成员
div_by_0_index_list = [4715]
start = time.time()
for test_batch_index, test_batch in enumerate(tqdm(test_pairs)):
if (len(div_by_0_index_list) > 0) and (test_batch_index in div_by_0_index_list):
predicted_list.append({
'question_index': test_batch_index,
'prefix_expr': [],
})
continue
print_cond = test_batch_index < 10
# print(test_batch)
batch_graph = get_single_example_graph(
test_batch[0], test_batch[1], test_batch[7], test_batch[4], test_batch[5])
test_res = evaluate_tree(test_batch[0], test_batch[1], generate_num_ids, encoder, predict, generate,
merge, output_lang, test_batch[5], batch_graph, beam_size=beam_size)
if print_cond:
print(
f'for test_batch_index: {test_batch_index}, what the test_res is:')
pprint.pprint(test_res)
print()
_, _, prefix_expr_predicted, _ = compute_prefix_tree_result(
test_res, test_batch[2], output_lang, test_batch[4], test_batch[6])
if print_cond:
print(
f'for test_batch_index: {test_batch_index}, what the prefix_expr_predicted is:')
pprint.pprint(prefix_expr_predicted)
print()
predicted_list.append({
'question_index': test_batch_index,
'prefix_expr': prefix_expr_predicted,
})
predict_total += 1
print("predict_total", predict_total)
print("predicting time", time_since(time.time() - start))
print("------------------------------------------------------")
print('what the first 10 items of predicted_list are at last:')
pprint.pprint(predicted_list[:10])
print('and its original length is:')
pprint.pprint(len(predicted_list))
print()
print('assert 1\n')
assert len(predicted_list) == len(test_pairs)
if os.path.exists('predicted_list.pkl'):
os.system('rm -f predicted_list.pkl')
if not os.path.exists('predicted_list.pkl'):
pickle.dump(predicted_list, open('predicted_list.pkl', 'wb'))
运行结果展示如下:
测试数据集包括8,000个数学题,测试数据样例如下:
{
"id": "0",
"original_text": "91.64与7.36的和乘43.6与3.6的差,积是多少?",
"cleaned_text": "91.64与7.36的和乘43.6与3.6的差,积是多少?",
"segmented_text_new": "91.64 与 7.36 的 和 乘 43.6 与 3.6 的 差 , 积 是 多少 ?",
"num_list": ["91.64", "7.36", "43.6", "3.6"],
"group_num": [0, 1, 2, 3, 10, 0, 1, 2, 3, 5, 10, 13, 4, 5, 6, 7, 8, 9, 4, 5, 6, 7, 8, 9, 10],
"group_num_segs": [
["91.64", "与", "7.36", "的", "差"],
["91.64", "与", "7.36", "的", "乘", "差", "是"],
["和", "乘", "43.6", "与", "3.6", "的"],
["和", "乘", "43.6", "与", "3.6", "的", "差"]
]
}
采用Graph-to-Tree (GTS)方法对小学数学应用题进行逻辑推理,预测当前步骤的操作符(如 +
、-
)。输出采用为前缀表示法prefix_expr。以 [‘-‘, ‘164’, ‘*’, ‘9’, ‘2’] 为例,解析步骤:最外层运算符是 -(减法)。左操作数是 164。内层右操作符是 (乘法), 的左操作数是 9,右操作数是 2。最终运算表达式为:164 – (9 * 2)。预测模块输出如下:
models set as eval()!
0%| | 0/8000 [00:00<?, ?it/s]for test_batch_index: 0, what the test_res is:
[0, 2, 2, 65, 66, 67, 2, 67, 68]
for test_batch_index: 0, what the prefix_expr_predicted is:
['/', '*', '*', '91.64', '7.36', '43.6', '*', '43.6', '3.6']
0%| | 1/8000 [00:00<37:24, 3.56it/s]for test_batch_index: 1, what the test_res is:
[3, 65, 2, 66, 8]
for test_batch_index: 1, what the prefix_expr_predicted is:
['-', '164', '*', '9', '2']
0%| | 2/8000 [00:00<24:12, 5.50it/s]for test_batch_index: 2, what the test_res is:
[0, 1, 65, 6, 66]
for test_batch_index: 2, what the prefix_expr_predicted is:
['/', '+', '30', '1', '3']
for test_batch_index: 3, what the test_res is:
[0, 1, 65, 66, 67]
for test_batch_index: 3, what the prefix_expr_predicted is:
['/', '+', '(1/5)', '(2/15)', '35']
0%| | 4/8000 [00:00<17:28, 7.63it/s]for test_batch_index: 4, what the test_res is:
[0, 66, 65]
for test_batch_index: 4, what the prefix_expr_predicted is:
['/', '50', '7.90']
for test_batch_index: 5, what the test_res is:
[2, 1, 65, 66, 67]
for test_batch_index: 5, what the prefix_expr_predicted is:
['*', '+', '24', '19', '30']
0%| | 6/8000 [00:00<14:08, 9.42it/s]for test_batch_index: 6, what the test_res is:
[2, 2, 66, 65, 8]
for test_batch_index: 6, what the prefix_expr_predicted is:
['*', '*', '100', '3', '2']
for test_batch_index: 7, what the test_res is:
[1, 0, 68, 67, 65]
for test_batch_index: 7, what the prefix_expr_predicted is:
['+', '/', '15', '3%', '3500']
0%| | 8/8000 [00:00<13:34, 9.82it/s]for test_batch_index: 8, what the test_res is:
[3, 65, 66]
for test_batch_index: 8, what the prefix_expr_predicted is:
['-', '14.7', '5.9']
for test_batch_index: 9, what the test_res is:
[0, 1, 65, 66, 3, 67, 6]
for test_batch_index: 9, what the prefix_expr_predicted is:
['/', '+', '27', '18', '-', '2', '1']
100%|██████████████████████████████████████████████████████████████████████████████| 8000/8000 [41:09<00:00, 3.24it/s]
predict_total 7999
predicting time 0h 41m 9s
------------------------------------------------------
what the first 10 items of predicted_list are at last:
[{'prefix_expr': ['/', '*', '*', '91.64', '7.36', '43.6', '*', '43.6', '3.6'],
'question_index': 0},
{'prefix_expr': ['-', '164', '*', '9', '2'], 'question_index': 1},
{'prefix_expr': ['/', '+', '30', '1', '3'], 'question_index': 2},
{'prefix_expr': ['/', '+', '(1/5)', '(2/15)', '35'], 'question_index': 3},
{'prefix_expr': ['/', '50', '7.90'], 'question_index': 4},
{'prefix_expr': ['*', '+', '24', '19', '30'], 'question_index': 5},
{'prefix_expr': ['*', '*', '100', '3', '2'], 'question_index': 6},
{'prefix_expr': ['+', '/', '15', '3%', '3500'], 'question_index': 7},
{'prefix_expr': ['-', '14.7', '5.9'], 'question_index': 8},
{'prefix_expr': ['/', '+', '27', '18', '-', '2', '1'], 'question_index': 9}]
and its original length is:
8000
将预测结果转换为csv文件
data_process_generate_ans.py
输入文件为预测结果:predicted_list.pkl
1. 初始化和数据加载
- 导入必要的库和工具函数
- 加载模型预测结果(
predicted_list.pkl
) - 打印预测结果的基本信息(前10项、最后一项和总长度)
- 准备输出文件(
submit.csv
)
generated_path = './generated/predicted_list.pkl'
predicted_list = pickle.load(open(generated_path, 'rb'))
print('what the first 10 items of predicted_list are:')
print('and the last item of predicted_list is:')
print('and its original length is:')
pprint.pprint(len(predicted_list))
print()
if os.path.exists('./submits/submit.csv'):
os.system('rm -rf "./submits/submit.csv"')
if not os.path.exists('./submits/submit.csv'):
test_submit_f = open(
'./submits/submit.csv', 'w', encoding='utf-8')
csv_writer_submit = csv.writer(test_submit_f)
2. 数据处理主循环
对于测试集中的每一个问题:
- 数据预处理:
- 删除多余空格和特殊字符(|,$,·)
- 统一字符格式(符号、数字1和字母l的转换)
- 删除拼音音节
- 统一单位表述
- 转换中文冒号为英文冒号
- 转换英文标点为中文标点
- 表达式解析和计算:
- 使用
parse_and_eval_prefix_expr
函数解析前缀表达式 - 处理各种异常情况:
- 表达式不以操作符开头
- 表达式计算结果为None
- 解析树有多个叶子节点
- 计算结果不一致
- 除零错误
- 对于无法处理的情况,默认返回-1000
- 使用
- 答案生成:
- 使用
generate_ans_and_post_process_for_competition_format
生成最终答案 - 将结果写入CSV文件
- 使用
with open('./official_data/test.csv', 'r', encoding='utf-8') as f:
reader = csv.reader(f)
counter_all \
= counter_model_cannot_predict \
= counter_prefix_expr_not_startswith_operator \
= counter_compute_prefix_expression_func_none \
= counter_parsed_tree_with_multi_leaves \
= counter_discrepency_between_evalution \
= counter_division_by_zero \
= 0
# for row in tqdm(reader):
for doc_index, row in tqdm(enumerate(reader)):
# print(row)
if len(predicted_list[doc_index]['prefix_expr']) > 0:
question = row[1]
# 预处理的第一步,先把问题句子中多余的空格去掉,避免影响判断
# 比赛数据集特有的,再把'|','$','·'去掉
question = del_spaces(row[1]).replace(
'|', '').replace('$', '').replace('·', '')
# 再把两端多余的空格去掉
question = question.strip()
# print(question)
# 符号字符统一
question = char_unify_convertor(question)
# 将一些写错的l替换为1
question = replace_l_with_1(question)
# 将一些写错的1替换为l
question = replace_1_with_l(question)
# 改用封装后的函数,删除拼音音节
question = rm_pinyin_yinjie(question)
# 改用封装后的函数,单位表述进行统一
question = units_mention_unify(question)
# 改用封装后的函数,超过1个字符的表述替换
question = convert_some_mentions(question)
# 将表示比例的中文冒号转为英文冒号
question = convert_cn_colon_to_en(question)
# 将英文标点转为中文
question = convert_en_punct_to_cn(question)
question_with_spaces_deleted = question
try:
eval_result_py, expr_infix, add_info = parse_and_eval_prefix_expr(
predicted_list[doc_index]['prefix_expr'])
if add_info.startswith('prefix_expr_list[0] is not an operator'):
counter_prefix_expr_not_startswith_operator += 1
ans = -1000
elif add_info.startswith('the compute_prefix_expression() result is now None'):
counter_compute_prefix_expression_func_none += 1
ans = -1000
else:
ans = generate_ans_and_post_process_for_competition_format(
question_with_spaces_deleted, expr_infix)
except Exception as e:
if 'has multiple leaves' in f'{e}':
counter_parsed_tree_with_multi_leaves += 1
elif f'{e}'.startswith('discrepency found between'):
counter_discrepency_between_evalution += 1
elif 'division by zero' in f'{e}':
counter_division_by_zero += 1
ans = -1000
else:
# print('由于模型无法生成该题的前置运算符表达式,故无法用前置运算符表达式解析器求解,答案暂给-1000\n')
counter_model_cannot_predict += 1
ans = -1000
csv_writer_submit.writerow([row[0], ans])
if 'doc_index' not in dir():
counter_all += 1
if 'doc_index' in dir():
counter_all = doc_index + 1
运行结果
python data_process_generate_ans.py
what the first 10 items of predicted_list are:
and the last item of predicted_list is:
and its original length is:
8000
8000it [00:06, 1272.01it/s]
test stats: counter_all: 8000
test stats: model cannot predict: 1, percentage: 0.000125
test stats: the predicted prefix expr not started with an operator: 11, percentage: 0.001375
test stats: the provided compute_prefix_expression() on the predicted prefix expr returns None: 50, percentage: 0.00625
test stats: the parsed tree with multi leaves: 0, percentage: 0.0
test stats: discrepency found between parse_and_eval_prefix_expr() and the provided compute_prefix_expression(): 0, percentage: 0.0
test stats: division by zero found: 49, percentage: 0.006125
本样例根据输入的小学数学应用题测试集文件生成应用题答案文件,输出文件./submits/submit.csv,数据示例如下:
| Index | Answer |
| ----- | ---------- |
| 0 | 2428.09344 |
| 1 | 146 |
| 2 | 11 |
| 3 | 1/105 |
| 4 | 6 |
| 5 | 1290 |
| ... | ... |
源码开源协议
GPL-v3
作者:https://github.com/jackli777