您当前的位置:首页 > 电脑百科 > 程序开发 > 编程百科

在少样本学习中,用SetFit进行文本分类

时间:2023-11-28 12:07:09  来源:51CTO  作者:

译者 | 陈峻

在本文中,我将向您介绍“少样本(Few-shot)学习”的相关概念,并重点讨论被广泛应用于文本分类的SetFit方法。

在少样本学习中,用SetFit进行文本分类

传统的机器学习(ML)

在监督(Supervised)机器学习中,大量数据集被用于模型训练,以便磨练模型能够做出精确预测的能力。在完成训练过程之后,我们便可以利用测试数据,来获得模型的预测结果。然而,这种传统的监督学习方法存在着一个显著缺点:它需要大量无差错的训练数据集。但是并非所有领域都能够提供此类无差错数据集。因此,“少样本学习”的概念应运而生。

在深入研究Sentence Transformer fine-tuning(SetFit)之前,我们有必要简要地回顾一下自然语言处理(Natural Language Processing,NLP)的一个重要方面,也就是:“少样本学习”。

少样本学习

少样本学习是指:使用有限的训练数据集,来训练模型。模型可以从这些被称为支持集的小集合中获取知识。此类学习旨在教会少样本模型,辨别出训练数据中的相同与相异之处。例如,我们并非要指示模型将所给图像分类为猫或狗,而是指示它掌握各种动物之间的共性和区别。可见,这种方法侧重于理解输入数据中的相似点和不同点。因此,它通常也被称为元学习(meta-learning)、或是从学习到学习(learning-to-learn)。

值得一提的是,少样本学习的支持集,也被称为k向(k-way)n样本(n-shot)学习。其中“k”代表支持集里的类别数。例如,在二分类(binary classification)中,k 等于 2。而“n”表示支持集中每个类别的可用样本数。例如,如果正分类有10个数据点,而负分类也有10个数据点,那么 n就等于10。总之,这个支持集可以被描述为双向10样本学习。

既然我们已经对少样本学习有了基本的了解,下面让我们通过使用SetFit进行快速学习,并在实际应用中对电商数据集进行文本分类。

SetFit架构

由Hugging Face和英特尔实验室的团队联合开发的SetFit,是一款用于少样本照片分类的开源工具。你可以在项目库链接--https://Github.com/huggingface/setfit?ref=hackernoon.com中,找到关于SetFit的全面信息。

就输出而言,SetFit仅用到了客户评论(Customer Reviews,CR)情感分析数据集里、每个类别的八个标注示例。其结果就能够与由三千个示例组成的完整训练集上,经调优的RoBERTa Large的结果相同。值得强调的是,就体积而言,经微优的RoBERTa模型比SetFit模型大三倍。下图展示的是SetFit架构:

在少样本学习中,用SetFit进行文本分类

图片来源:https://www.sbert.NET/docs/trAIning/overview.html?ref=hackernoon.com

用SetFit实现快速学习

SetFit的训练速度非常快,效率也极高。与GPT-3和T-FEW等大模型相比,其性能极具竞争力。请参见下图:

SetFit与T-Few 3B模型的比较

在少样本学习中,用SetFit进行文本分类

如下图所示,SetFit在少样本学习方面的表现优于RoBERTa。

在少样本学习中,用SetFit进行文本分类

SetFit与RoBERT的比较,图片来源:https://huggingface.co/blog/setfit?ref=hackernoon.com

数据集

下面,我们将用到由四个不同类别组成的独特电商数据集,它们分别是:书籍、服装与配件、电子产品、以及家居用品。该数据集的主要目的是将来自电商网站的产品描述归类到指定的标签下。

为了便于采用少样本的训练方法,我们将从四个类别中各选择八个样本,从而得到总共32个训练样本。而其余样本则将留作测试之用。简言之,我们在此使用的支持集是4向8样本学习。下图展示的是自定义电商数据集的示例:

在少样本学习中,用SetFit进行文本分类 自定义电商数据集样本

我们采用名为“all-mpnet-base-v2”的Sentence Transformers预训练模型,将文本数据转换为各种向量嵌入。该模型可以为输入文本,生成维度为768的向量嵌入。

如下命令所示,我们将通过在conda环境(是一个开源的软件包管理系统和环境管理系统)中安装所需的软件包,来开始SetFit的实施。

复制

!pip3 install SetFit

!pip3 install sklearn

!pip3 install transformers

!pip3 install sentence-transformers

安装完软件包后,我们便可以通过如下代码加载数据集了。

复制

from datasets import load_dataset

dataset = load_dataset('csv', data_files={

"train": 'E_Commerce_Dataset_Train.csv',

"test": 'E_Commerce_Dataset_Test.csv'

})

我们来参照下图,看看训练样本和测试样本数。

在少样本学习中,用SetFit进行文本分类 训练和测试数据

我们使用sklearn软件包中的LabelEncoder,将文本标签转换为编码标签。

复制

from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()

通过LabelEncoder,我们将对训练和测试数据集进行编码,并将编码后的标签添加到数据集的“标签”列中。请参见如下代码:

复制

Encoded_Product = le.fit_transform(dataset["train"]['Label'])

dataset["train"] = dataset["train"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["train"].features)

Encoded_Product = le.fit_transform(dataset["test"]['Label'])

dataset["test"] = dataset["test"].remove_columns("Label").add_column("Label", Encoded_Product).cast(dataset["test"].features)

下面,我们将初始化SetFit模型和句子转换器(sentence-transformers)模型。

复制

from setfit import SetFitModel, SetFitTrainer

from sentence_transformers.losses import CosineSimilarityLoss

model_id = "sentence-transformers/all-mpnet-base-v2"

model = SetFitModel.from_pretrained(model_id)

trainer = SetFitTrainer(

model=model,

train_dataset=dataset["train"],

eval_dataset=dataset["test"],

loss_class=CosineSimilarityLoss,

metric="accuracy",

batch_size=64,

num_iteratinotallow=20,

num_epochs=2,

column_mApping={"Text": "text", "Label": "label"}

初始化完成两个模型后,我们现在便可以调用训练程序了。

复制

trainer.train()

在完成了2个训练轮数(epoch)后,我们将在eval_dataset上,对训练好的模型进行评估。

复制

trainer.evaluate()

经测试,我们的训练模型的最高准确率为87.5%。虽然87.5%的准确率并不算高,但是毕竟我们的模型只用了32个样本进行训练。也就是说,考虑到数据集规模的有限性,在测试数据集上取得87.5%的准确率,实际上是相当可观的。

此外,SetFit还能够将训练好的模型,保存到本地存储器中,以便后续从磁盘加载,用于将来的预测。

复制

trainer.model._save_pretrained(save_directory="SetFit_ECommerce_Output/")

model=SetFitModel.from_pretrained("SetFit_ECommerce_Output/", local_files_notallow=True)

如下代码展示了根据新的数据进行的预测结果:

复制

input = ["Campus Sutra Men's Sports Jersey T-Shirt Cool-Gear: Our Proprietary Moisture Management technology. Helps to absorb and evaporate sweat quickly. Keeps you Cool & Dry. Ultra-Fresh: Fabrics treated with Ultra-Fresh Antimicrobial Technology. Ultra-Fresh is a trademark of (TRA) Inc, Ontario, Canada. Keeps you odour free."]

output = model(input)

可见,其预测输出为1,而标签的LabelEncoded值为“服装与配件”。由于传统的AI模型需要大量的训练资源(包括时间和数据),才能有稳定水准的输出。而我们的模型与之相比,既准确又高效。

至此,相信您已经基本掌握了“少样本学习”的概念,以及如何使用SetFit来进行文本分类等应用。当然,为了获得更深刻的理解,我强烈建议您选择一个实际场景,创建一个数据集,编写对应的代码,并将该过程延展到零样本学习、以及单样本学习上。

译者介绍

陈峻(Julian Chen),51CTO社区编辑,具有十多年的IT项目实施经验,善于对内外部资源与风险实施管控,专注传播网络与信息安全知识与经验。

原文标题:Mastering Few-Shot Learning with SetFit for Text Classification,作者:Shyam Ganesh S)



Tags:SetFit   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,不构成投资建议。投资者据此操作,风险自担。如有任何标注错误或版权侵犯请与我们联系,我们将及时更正、删除。
▌相关推荐
在少样本学习中,用SetFit进行文本分类
译者 | 陈峻在本文中,我将向您介绍“少样本(Few-shot)学习”的相关概念,并重点讨论被广泛应用于文本分类的SetFit方法。传统的机器学习(ML)在监督(Supervised)机器学习中,大量数据集...【详细内容】
2023-11-28  Search: SetFit  点击:(175)  评论:(0)  加入收藏
▌简易百科推荐
即将过时的 5 种软件开发技能!
作者 | Eran Yahav编译 | 言征出品 | 51CTO技术栈(微信号:blog51cto) 时至今日,AI编码工具已经进化到足够强大了吗?这未必好回答,但从2023 年 Stack Overflow 上的调查数据来看,44%...【详细内容】
2024-04-03    51CTO  Tags:软件开发   点击:(5)  评论:(0)  加入收藏
跳转链接代码怎么写?
在网页开发中,跳转链接是一项常见的功能。然而,对于非技术人员来说,编写跳转链接代码可能会显得有些困难。不用担心!我们可以借助外链平台来简化操作,即使没有编程经验,也能轻松实...【详细内容】
2024-03-27  蓝色天纪    Tags:跳转链接   点击:(12)  评论:(0)  加入收藏
中台亡了,问题到底出在哪里?
曾几何时,中台一度被当做“变革灵药”,嫁接在“前台作战单元”和“后台资源部门”之间,实现企业各业务线的“打通”和全域业务能力集成,提高开发和服务效率。但在中台如火如荼之...【详细内容】
2024-03-27  dbaplus社群    Tags:中台   点击:(8)  评论:(0)  加入收藏
员工写了个比删库更可怕的Bug!
想必大家都听说过删库跑路吧,我之前一直把它当一个段子来看。可万万没想到,就在昨天,我们公司的某位员工,竟然写了一个比删库更可怕的 Bug!给大家分享一下(不是公开处刑),希望朋友们...【详细内容】
2024-03-26  dbaplus社群    Tags:Bug   点击:(5)  评论:(0)  加入收藏
我们一起聊聊什么是正向代理和反向代理
从字面意思上看,代理就是代替处理的意思,一个对象有能力代替另一个对象处理某一件事。代理,这个词在我们的日常生活中也不陌生,比如在购物、旅游等场景中,我们经常会委托别人代替...【详细内容】
2024-03-26  萤火架构  微信公众号  Tags:正向代理   点击:(10)  评论:(0)  加入收藏
看一遍就理解:IO模型详解
前言大家好,我是程序员田螺。今天我们一起来学习IO模型。在本文开始前呢,先问问大家几个问题哈~什么是IO呢?什么是阻塞非阻塞IO?什么是同步异步IO?什么是IO多路复用?select/epoll...【详细内容】
2024-03-26  捡田螺的小男孩  微信公众号  Tags:IO模型   点击:(8)  评论:(0)  加入收藏
为什么都说 HashMap 是线程不安全的?
做Java开发的人,应该都用过 HashMap 这种集合。今天就和大家来聊聊,为什么 HashMap 是线程不安全的。1.HashMap 数据结构简单来说,HashMap 基于哈希表实现。它使用键的哈希码来...【详细内容】
2024-03-22  Java技术指北  微信公众号  Tags:HashMap   点击:(11)  评论:(0)  加入收藏
如何从头开始编写LoRA代码,这有一份教程
选自 lightning.ai作者:Sebastian Raschka机器之心编译编辑:陈萍作者表示:在各种有效的 LLM 微调方法中,LoRA 仍然是他的首选。LoRA(Low-Rank Adaptation)作为一种用于微调 LLM(大...【详细内容】
2024-03-21  机器之心Pro    Tags:LoRA   点击:(12)  评论:(0)  加入收藏
这样搭建日志中心,传统的ELK就扔了吧!
最近客户有个新需求,就是想查看网站的访问情况。由于网站没有做google的统计和百度的统计,所以访问情况,只能通过日志查看,通过脚本的形式给客户导出也不太实际,给客户写个简单的...【详细内容】
2024-03-20  dbaplus社群    Tags:日志   点击:(4)  评论:(0)  加入收藏
Kubernetes 究竟有没有 LTS?
从一个有趣的问题引出很多人都在关注的 Kubernetes LTS 的问题。有趣的问题2019 年,一个名为 apiserver LoopbackClient Server cert expired after 1 year[1] 的 issue 中提...【详细内容】
2024-03-15  云原生散修  微信公众号  Tags:Kubernetes   点击:(5)  评论:(0)  加入收藏
相关文章
    无相关信息
站内最新
站内热门
站内头条