0

0

科大讯飞-学术论文分类挑战赛:ERNIE 准确率0.79

P粉084495128

P粉084495128

发布时间:2025-07-25 10:25:11

|

908人浏览过

|

来源于php中文网

原创

随着人工智能技术不断发展,每周都有非常多的论文公开发布。现如今对论文进行分类逐渐成为非常现实的问题,这也是研究人员和研究机构每天都面临的问题。现在希望选手能构建一个论文分类模型。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

科大讯飞-学术论文分类挑战赛:ernie 准确率0.79 - php中文网

赛事任务

本次赛题希望参赛选手利用论文信息:论文id、标题、摘要,划分论文具体类别。

赛题样例(使用\t分隔):

paperid:9821title:Calculation of prompt diphoton production cross sections at Tevatron and LHC energies

abstract:A fully differential calculation in perturbative quantum chromodynamics is presented for the production of massive photon pairs at hadron colliders. All next-to-leading order perturbative contributions from quark-antiquark, gluon-(anti)quark, and gluon-gluon subprocesses are included, as well as all-orders resummation of initial-state gluon radiation valid at next-to-next-to-leading logarithmic accuracy.

categories:hep-ph

数据说明

训练数据和测试集以csv文件给出,其中:

  • 训练集5W篇论文。其中每篇论文都包含论文id、标题、摘要和类别四个字段。

    PictoGraphic
    PictoGraphic

    AI驱动的矢量插图库和插图生成平台

    下载
  • 测试集1W篇论文。其中每篇论文都包含论文id、标题、摘要,不包含论文类别字段。

评估指标

本次竞赛的评价标准采用准确率指标,最高分为1。

计算方法参考https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html, 评估代码参考:

from sklearn.metrics import accuracy_scorey_pred = [0, 2, 1, 3]y_true = [0, 1, 2, 3]
In [1]
!pip install paddle-ernie > log.log
In [2]
import numpy as npimport paddle as P# 导入ernie模型from ernie.tokenizing_ernie import ErnieTokenizerfrom ernie.modeling_ernie import ErnieModel

model = ErnieModel.from_pretrained('ernie-1.0')    # Try to get pretrained model from server, make sure you have network connectionmodel.eval()
tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')

ids, _ = tokenizer.encode('hello world')
ids = P.to_tensor(np.expand_dims(ids, 0))  # insert extra `batch` dimensionpooled, encoded = model(ids)                 # eager executionprint(pooled.numpy())
In [9]
import sysimport numpy as npimport pandas as pdfrom sklearn.metrics import f1_scoreimport paddle as Pfrom ernie.tokenizing_ernie import ErnieTokenizerfrom ernie.modeling_ernie import ErnieModelForSequenceClassification
In [10]
train_df = pd.read_csv('train.csv', sep='\t')
train_df['title'] = train_df['title'] + ' ' + train_df['abstract']

train_df = train_df.sample(frac=1.0)
train_df.head()
In [11]
train_df.shape
In [12]
train_df['categories'].nunique()
In [13]
train_df['categories'], lbl_list = pd.factorize(train_df['categories'])
In [14]
# 模型超参数BATCH=32MAX_SEQLEN=300LR=5e-5EPOCH=10# 定义ernie分类模型ernie = ErnieModelForSequenceClassification.from_pretrained('ernie-2.0-en', num_labels=39)
optimizer = P.optimizer.Adam(LR,parameters=ernie.parameters())
tokenizer = ErnieTokenizer.from_pretrained('ernie-2.0-en')
In [15]
train_df.iterrows()
In [16]
# 对数据集进行转换,主要操作为文本编码def make_data(df):
    data = []    for i, row in enumerate(df.iterrows()):
        text, label = row[1].title, row[1].categories
        text_id, _ = tokenizer.encode(text) # ErnieTokenizer 会自动添加ERNIE所需要的特殊token,如[CLS], [SEP]
        text_id = text_id[:MAX_SEQLEN]
        text_id = np.pad(text_id, [0, MAX_SEQLEN-len(text_id)], mode='constant')
        data.append((text_id, label))    return data

train_data = make_data(train_df.iloc[:-5000])
val_data = make_data(train_df.iloc[-5000:])
In [ ]
# 获取batch数据def get_batch_data(data, i):
    d = data[i*BATCH: (i + 1) * BATCH]
    feature, label = zip(*d)
    feature = np.stack(feature)  # 将BATCH行样本整合在一个numpy.array中
    label = np.stack(list(label))
    feature = P.to_tensor(feature) # 使用to_variable将numpy.array转换为paddle tensor
    label = P.to_tensor(label)    return feature, label
In [12]
EPOCH=1# 模型训练for i in range(EPOCH):
    np.random.shuffle(train_data) # 每个epoch都shuffle数据以获得最佳训练效果;
    ernie.train()    for j in range(len(train_data) // BATCH):
        feature, label = get_batch_data(train_data, j)
        loss, _ = ernie(feature, labels=label) 
        loss.backward()
        optimizer.minimize(loss)
        ernie.clear_gradients()        if j % 50 == 0:            print('Train %d: loss %.5f' % (j, loss.numpy()))        
        # 模型验证
        if j % 100 == 0:
            all_pred, all_label = [], []            with P.no_grad():
                ernie.eval()                for j in range(len(val_data) // BATCH):
                    feature, label = get_batch_data(val_data, j)
                    loss, logits = ernie(feature, labels=label)

                    all_pred.extend(logits.argmax(-1).numpy())
                    all_label.extend(label.numpy())
                ernie.train()
            acc = (np.array(all_label) == np.array(all_pred)).astype(np.float32).mean()            print('Val acc %.5f' % acc)
In [13]
test_df = pd.read_csv('test.csv', sep='\t')
test_df['title'] = test_df['title'] + ' ' + test_df['abstract']
test_df['categories'] = 0test_data = make_data(test_df.iloc[:])
In [20]
all_pred, all_label = [], []# 模型预测with P.no_grad():
    ernie.eval()    for j in range(len(test_data) // BATCH+1):
        feature, label = get_batch_data(test_data, j)
        loss, logits = ernie(feature, labels=label)

        all_pred.extend(logits.argmax(-1).numpy())
        all_label.extend(label.numpy())
In [21]
pd.DataFrame({    'paperid': test_df['paperid'],    'categories': lbl_list[all_pred]
}).to_csv('submit.csv', index=None)

相关专题

更多
html版权符号
html版权符号

html版权符号是“©”,可以在html源文件中直接输入或者从word中复制粘贴过来,php中文网还为大家带来html的相关下载资源、相关课程以及相关文章等内容,供大家免费下载使用。

598

2023.06.14

html在线编辑器
html在线编辑器

html在线编辑器是用于在线编辑的工具,编辑的内容是基于HTML的文档。它经常被应用于留言板留言、论坛发贴、Blog编写日志或等需要用户输入普通HTML的地方,是Web应用的常用模块之一。php中文网为大家带来了html在线编辑器的相关教程、以及相关文章等内容,供大家免费下载使用。

641

2023.06.21

html网页制作
html网页制作

html网页制作是指使用超文本标记语言来设计和创建网页的过程,html是一种标记语言,它使用标记来描述文档结构和语义,并定义了网页中的各种元素和内容的呈现方式。本专题为大家提供html网页制作的相关的文章、下载、课程内容,供大家免费下载体验。

462

2023.07.31

html空格
html空格

html空格是一种用于在网页中添加间隔和对齐文本的特殊字符,被用于在网页中插入额外的空间,以改变元素之间的排列和对齐方式。本专题为大家提供html空格的相关的文章、下载、课程内容,供大家免费下载体验。

243

2023.08.01

html是什么
html是什么

HTML是一种标准标记语言,用于创建和呈现网页的结构和内容,是互联网发展的基石,为网页开发提供了丰富的功能和灵活性。本专题为大家提供html相关的各种文章、以及下载和课程。

2865

2023.08.11

html字体大小怎么设置
html字体大小怎么设置

在网页设计中,字体大小的选择是至关重要的。合理的字体大小不仅可以提升网页的可读性,还能够影响用户对网页整体布局的感知。php中文网将介绍一些常用的方法和技巧,帮助您在HTML中设置合适的字体大小。

501

2023.08.11

html转txt
html转txt

html转txt的方法有使用文本编辑器、使用在线转换工具和使用Python编程。本专题为大家提供html转txt相关的文章、下载、课程内容,供大家免费下载体验。

307

2023.08.31

html文本框代码怎么写
html文本框代码怎么写

html文本框代码:1、单行文本框【<input type="text" style="height:..;width:..;" />】;2、多行文本框【textarea style=";height:;"></textare】。

419

2023.09.01

php源码安装教程大全
php源码安装教程大全

本专题整合了php源码安装教程,阅读专题下面的文章了解更多详细内容。

150

2025.12.31

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 2.4万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.5万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 0人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号