摘要:
合集:AI案例-CV-农业
数据集:五种花卉数据集
数据集价值:支持花卉分类训练。
解决方案:PyTorch框架, Densenet、Efficientnet、ResNet模型
一、问题描述
五种花卉数据集(Five Flower Species Dataset)包含了五种花卉的图像,分别是洋甘菊(Chamomile)、郁金香(Tulip)、玫瑰(Rose)、向日葵(Sunflower)和蒲公英(Dandelion)。每个类别大约有800张照片。照片分辨率不高,大约为320×240像素。照片没有缩小到单一尺寸,它们的比例不同!该数据集通过搜索引擎收集。该数据集主要用于花卉分类任务,常用于训练和测试图像分类算法。该数据集由 Alexandru Măceșanu 提供,最早公开发布时,是在其学术工作中用于图像分类任务。数据集首次发布的时间为 2016年,在同年的一些机器学习和计算机视觉的研究工作中被广泛使用。
这个数据集是计算机视觉领域常用的图像分类基准数据集之一,尤其是在需要进行物体识别、花卉识别等领域的研究和应用中,起到了很大的作用。
二、数据集内容
数据结构
五种花卉分类数据集存储路径为data\flower_photos:
daisy
dandelion
rose
sunflower
tulip
数据集图片样例:

三、解决方案
- 使用VGGNet、GoogLeNet、ResNet、DenseNet、EfficientNet算法,基于五种花卉数据集中80%的数据训练识别模型,并对剩下20%的数据集进行测试 ;
- 使用不同的评价指标(如accuracy、precision、recall等)对各种算法进行评价;
- 鼓励基于现有的算法提出改进,进一步提高算法的性能。
源码文件
model.py: 是模型文件
train.py: 是调用模型训练的文件
predict.py: 是调用模型进行预测的文件
class_indices.json: 是训练数据集对应的标签文件
安装
conda create -n huggingface python=3.10
conda activate huggingface
pip3 install -r requestments.txt
requestments.txt
torch
torchvision
torchaudio
tqdm
matplotlib
tensorboard
VGGNet(2014)
- 核心思想:通过堆叠小卷积核(3×3)替代大卷积核(如5×5或7×7),减少参数量的同时增加网络深度,提升特征提取能力。
- 结构特点:
- 由多个卷积层+ReLU+最大池化的重复块组成,最后接全连接层。
- 常见变体:VGG16(13卷积层+3全连接)、VGG19(16卷积层+3全连接)。
- 优点:结构简单,特征提取能力强。
- 缺点:参数量大(全连接层占90%以上),计算成本高。
- 性能:ImageNet Top-1准确率约71.5%(VGG16)。
- 应用场景:早期图像分类、特征提取(现多被更高效模型替代)。
GoogLeNet(Inception v1, 2014)
- 核心思想:提出Inception模块,通过并行多尺度卷积(1×1、3×3、5×5)和池化,融合不同感受野的特征,减少计算量。
- 结构特点:
- Inception模块:使用1×1卷积降维(减少通道数),再拼接不同分支的输出。
- 引入辅助分类器(训练时辅助梯度回传,测试时移除)。
- 无全连接层,改用全局平均池化降低参数量。
- 优点:参数效率高,适合深层网络。
- 缺点:结构复杂,手动设计模块难度大。
- 性能:ImageNet Top-1准确率约69.8%(Inception v1)。
- 演进版本:Inception v2/v3(BN优化)、Inception v4(结合ResNet)。
ResNet(2016)
- 核心思想:提出残差连接(Residual Connection),解决深层网络梯度消失问题,允许训练数百层的网络。
- 结构特点:
- 残差块:输入通过跨层连接(Shortcut)与卷积层输出相加(
F(x) + x
)。 - 分阶段堆叠残差块,每阶段降采样(步长=2的卷积)。
- 变体:ResNet18/34(基础残差块)、ResNet50/101/152(瓶颈结构:1×1→3×3→1×1)。
- 残差块:输入通过跨层连接(Shortcut)与卷积层输出相加(
- 优点:训练极深网络(如ResNet152),性能显著提升。
- 性能:ImageNet Top-1准确率约76.15%(ResNet50)。
- 应用场景:图像分类、检测(如Faster R-CNN)、分割等。
DenseNet(2017)
DenseNet121(Dense Convolutional Network with 121 layers)是一种经典的密集连接卷积神经网络,由Gao Huang等人在2017年的论文《Densely Connected Convolutional Networks》中提出。它是DenseNet系列中的一个代表性模型,以其独特的“密集连接”(Dense Connection)机制和高效的参数利用著称。DenseNet121在资源受限的场景下(如边缘设备)表现优异,适合需要平衡精度和计算成本的场景。如需更高精度,可考虑DenseNet169或DenseNet201。
- 核心思想:密集连接(Dense Connection),每层接收前面所有层的输出作为输入(通道拼接),最大化特征复用。
- 结构特点:
- Dense Block:层间直接连接(
concat
),每层输出通道数固定(增长率k)。 - 过渡层:1×1卷积压缩通道 + 平均池化降采样。
- 变体:DenseNet121/169/201(层数不同)。
- Dense Block:层间直接连接(
- 优点:参数高效,缓解梯度消失,特征重用能力强。
- 缺点:密集拼接导致显存占用较高。
- 性能:ImageNet Top-1准确率约74.65%(DenseNet121)。
- 应用场景:医学图像分析、小样本任务。
EfficientNet(2019)
- 核心思想:通过复合缩放(Compound Scaling)统一调整深度、宽度和分辨率,实现精度与效率的平衡。
- 结构特点:
- 基础模型EfficientNet-B0使用MobileNetV2的逆残差块(MBConv) + SE(Squeeze-Excitation)注意力机制。
- 缩放系数(φ)按规则扩展(如B0→B7)。
- 优点:同等计算资源下精度最高,适合移动端部署。
- 性能:
- EfficientNet-B0:77.1% Top-1(仅5.3M参数)。
- EfficientNet-B7:84.3% Top-1(66M参数)。
- 应用场景:移动设备、实时推理(如手机拍照分类)。
对比总结
模型 | 核心创新 | 参数量 | ImageNet Top-1 | 特点 |
---|---|---|---|---|
VGG16 | 小卷积核堆叠 | 138M | ~71.5% | 简单但参数量大 |
GoogLeNet | Inception多尺度并行 | 6.8M | ~69.8% | 参数高效,结构复杂 |
ResNet50 | 残差连接 | 25.6M | ~76.15% | 解决梯度消失,支持极深网络 |
DenseNet121 | 密集连接 | 7.98M | ~74.65% | 特征复用强,显存占用高 |
EfficientNet-B0 | 复合缩放 + MBConv | 5.3M | ~77.1% | 精度/效率最优,适合移动端 |
演进趋势
- 深度增加:VGG(19层)→ ResNet(152层)→ EfficientNet(B7)。
- 结构优化:从手工设计(Inception)到自动化搜索(EfficientNet)。
- 效率提升:参数量减少(VGG→EfficientNet),更适合边缘计算。
根据任务需求选择模型:
- 高精度:ResNet/DenseNet/EfficientNet大模型。
- 轻量化:EfficientNet/MobileNet。
- 平衡型:ResNet50/DenseNet121。
四、处理过程
以下以DenseNet121模型为例子介绍训练和预测过程。
数据预处理
- 训练数据:随机裁剪、水平翻转、归一化(ImageNet均值/方差)
- 验证数据:中心裁剪、相同归一化
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(args)
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
tb_writer = SummaryWriter()
if os.path.exists("./weights") is False:
os.makedirs("./weights")
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)
# 如果存在预训练权重则载入
model = densenet121(num_classes=args.num_classes).to(device)
if args.weights != "":
if os.path.exists(args.weights):
load_state_dict(model, args.weights)
else:
raise FileNotFoundError("not found weights file: {}".format(args.weights))
# 是否冻结权重
if args.freeze_layers:
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结
if "classifier" not in name:
para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4, nesterov=True)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
模型初始化
- 使用DenseNet121架构
- 支持加载预训练权重
- 可冻结特征提取层(只训练最后的分类层)
model = densenet121(num_classes=args.num_classes).to(device)
if args.weights != "":
load_state_dict(model, args.weights)
训练配置
- 使用带动量的SGD优化器
- 采用余弦退火学习率调度
- 包含L2正则化(weight_decay)
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4, nesterov=True)
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
训练流程
for epoch in range(args.epochs):
# train
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], acc, epoch)
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)
torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
源码开源协议
MIT License
Copyright (c) 2024 Dollarkillerx
五、运行
预权重文件
预权重文件存储在工程子目录weights下:
./weights/densenet121.pth
训练
在进入子目录后,修改数据集路径,输入以下命令。
cd Densenet #以Densenet为例
python train.py
输出:
Namespace(num_classes=5, epochs=30, batch_size=4, lr=0.001, lrf=0.1, data_path='../data/flower_photos', weights='../weights/densenet121.pth', freeze_layers=False, device='cpu')
Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/
4317 images were found in the dataset.
3457 images for training.
860 images for validation.
Using 4 dataloader workers every process
D:\Dev\dev-Python\DataSet\Classification-of-flowers\Densenet\model.py:244: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(weights_path)
successfully load pretrain-weights.
[epoch 0] mean loss 1.12: 21%|███████████▎ | 181/865 [04:00<14:20, 1.26s/it]
预测
python predict.py