摘要
合集:AI案例-计算机视觉-动物
赛题:科大讯飞2024-濒危大型动物种类识别挑战赛
AI问题:图像分类和识别
数据集:濒危大型动物图像数据集,包括9种动物(每种各45张图像)的训练图像和135张测试图像。
数据集发布方:北京林业大学
数据集价值:挖掘濒危大型动物图像数据,采用合适的算法预测动物种类。
解决方案:EfficientNet-B1模型、timm开发框架
一、赛题描述
全球生物多样性持续下降,尤其濒危大型动物的数量锐减,成为全球环保的紧急议题。濒危动物如大象、犀牛及大型猫科动物面临众多威胁,如栖息地丧失、非法狩猎等,这不仅威胁到生物多样性,也影响生态平衡。为此,精确及时的监测和保护措施显得至关重要。然而,传统监测方法耗时且效率低,难以满足迫切需求。借助人工智能技术,开发自动化的濒危大型动物种类识别系统,通过分析图像数据自动识别特定动物,将极大提升保护效率和准确性。
本次濒危大型动物种类识别挑战赛旨在应用人工智能技术,提高对濒危大型动物的监测与保护效率。赛事提供了包括照相陷阱、无人机等多源采集的动物图像数据及其所在保护区的地理位置信息。参赛者需基于这些数据,开发出能够准确识别濒危大型动物种类的模型。挑战的关键在于处理和识别在不同光照、角度和背景条件下相似物种。
二、数据集描述
数据简介
数据集包括9种动物,每种动物各45张图像的训练图像,总共405张图片。测试集包括135张图像。
训练图像文件分别放在如下9个目录下:
badger 45
bear 45
chimpanzee 45
hare 45
leopard 45
moth 45
panda 45
shark 45
tiger 45
数据集版权许可协议
BY-NC-SA 4.0
https://creativecommons.org/licenses/by-nc-sa/4.0/deed.zh-hans
三、解决方案样例
1、深度学习模型EfficientNet
源码:iFLYTEK-2024-Endangered-Animals.ipynb
源码实现了一个濒危大型动物种类识别系统,使用深度学习模型EfficientNet 对9种濒危动物进行分类。
EfficientNet 的核心创新在于复合缩放方法(Compound Scaling),该方法系统地平衡了网络的深度(depth)、宽度(width)和分辨率(resolution)三个维度:
- 深度(depth):网络层数
- 宽度(width):每层的通道数
- 分辨率(resolution):输入图像尺寸
传统方法通常只缩放其中一个维度,而 EfficientNet 同时缩放这三个维度。
2、B1模型参数
参数 | 值 |
---|---|
输入分辨率 | 240×240 |
参数量 | 7.8M |
FLOPs | 0.7B |
ImageNet Top-1 准确率 | 79.1% |
深度系数 | 1.1 |
宽度系数 | 1.0 |
3、运行环境
参考文章《安装和验证深度学习框架PyTorch》。
外部库名称 | 版本号 |
---|---|
python | 3.12.8 |
pandas | 1.5.3 |
numpy | 1.26.4 |
torch | 2.6.0+cu126 |
torchvision | 0.21.0+cu126 |
timm | 1.0.15 |
timm
(PyTorch Image Models)是一个专注于图像模型的 PyTorch 库,由 Ross Wightman 维护。它集成了大量预训练的计算机视觉模型(包括 EfficientNet、Vision Transformer 等),并提供统一的接口、训练脚本和实用工具。安装方法如下:
conda install timm
# 或
pip install timm
# 或从源码安装(最新版)
pip install git+https://github.com/rwightman/pytorch-image-models.git
四、源码结构
1、加载开发包
import torch
torch.manual_seed(0) # 固定随机种子
# 导入必要的库 (torchvision, numpy, PIL 等)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import time
import glob
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import notebook
import os
import timm
2、数据准备
- 标签定义:定义了9种濒危动物类别:
['badger', 'bear', 'chimpanzee', 'hare', 'leopard', 'moth', 'panda', 'shark', 'tiger']
- 数据加载:
- 训练数据:从
./data/train/
目录加载,按子目录分类 - 测试数据:从
./data/test/
目录加载
- 训练数据:从
- 数据打乱:使用
np.random.shuffle
对训练路径进行随机打乱
# 9类濒危动物
label_list = ['badger', 'bear', 'chimpanzee', 'hare', 'leopard', 'moth', 'panda', 'shark', 'tiger']
# 加载训练集路径
train_path = glob.glob('./data/train/*/*.jpg')
np.random.shuffle(train_path)
train_label = [label_list.index(os.path.basename(os.path.dirname(x))) for x in train_path] # Windows OS
# 加载测试集路径
test_path = glob.glob('./data/test/*.jpg')
3、工具类定义
AverageMeter
:跟踪训练指标(损失/准确率)的滑动平均值
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
ProgressMeter
:训练进度可视化工具
class ProgressMeter(object):
def __init__(self, num_batches, *meters):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = ""
def pr2int(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
4、数据集类 (XFDataset
)
- 实现图像加载和数据增强
- 支持训练/验证/测试三种模式
class XFDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, torch.from_numpy(np.array(self.img_label[index]))
def __len__(self):
return len(self.img_path)
5、模型构建
- 使用
timm
库创建EfficientNet-B1模型 - 自定义输出层适配9分类任务
model = timm.create_model('efficientnet_b1',
pretrained=False, # 不使用预训练权重
num_classes=9) # 9分类任务
model.cuda() # GPU加速
6、训练流程
训练数据增强:
- 调整大小为 256×256
- 随机水平翻转
- 随机垂直翻转
- 颜色抖动(亮度和色调调整)
- 归一化(使用ImageNet的均值和标准差)
# 数据加载器
train_loader = torch.utils.data.DataLoader(
XFDataset(train_path[:-50], train_label[:-50],
transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=.5, hue=.3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=32, shuffle=True, num_workers=0, pin_memory=False
)
val_loader = torch.utils.data.DataLoader(
XFDataset(train_path[-50:], train_label[-50:],
transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=32, shuffle=False, num_workers=0, pin_memory=False
)
训练:
- 损失函数:交叉熵损失
nn.CrossEntropyLoss()
- 优化器:AdamW优化器,学习率1e-4
- 学习率调度:StepLR,每4个epoch衰减15%
- 训练循环:20个epoch
- 每个epoch记录时间、损失和准确率
- 每100个batch打印进度
- 每个epoch结束后在验证集上评估
- 保存最佳模型(基于验证集准确率)
# 训练配置
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.AdamW(model.parameters(), lr=1e-4) # 更小的学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
# 训练循环
for epoch in range(20):
print('Epoch: ', epoch)
train(train_loader, model, criterion, optimizer, epoch)
scheduler.step()
val_acc = validate(val_loader, model, criterion)
if val_acc.avg.item() > best_acc:
best_acc = round(val_acc.avg.item(), 2)
torch.save(model.state_dict(), f'./model_{best_acc}.pt')
训练:
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(len(train_loader), batch_time, losses, top1)
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
target = target.long() # 确保目标为长整型 (int64)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), input.size(0))
acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
top1.update(acc, input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 100 == 0:
progress.pr2int(i)
验证函数validate()
:
- 计算验证集上的损失和准确率
- 使用
model.eval()
和torch.no_grad()
确保不更新模型参数
def validate(val_loader, model, criterion):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(len(val_loader), batch_time, losses, top1)
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (input, target) in notebook.tqdm(enumerate(val_loader), total=len(val_loader)):
input = input.cuda()
target = target.cuda()
# compute output
output = model(input)
target = target.long() # 确保目标为长整型 (int64)
loss = criterion(output, target)
# measure accuracy and record loss
acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
losses.update(loss.item(), input.size(0))
top1.update(acc, input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {top1.avg:.3f}'
.format(top1=top1))
return top1
7、预测与提交
预测过程:
# 数据加载器
test_loader = torch.utils.data.DataLoader(
XFDataset(test_path, [0] * len(test_path),
transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=32, shuffle=False, num_workers=0, pin_memory=False
)
# 生成预测结果
val_label = pd.DataFrame()
val_label['uuid'] = [x.split('/')[-1] for x in test_path]
val_label['y_pred'] = predict(test_loader, model, 1).argmax(1)
# 保存提交文件
val_label['label'] = val_label['y_pred'].apply(lambda x: label_list[x])
val_label[['uuid', 'label']].to_csv('submit.csv', index=None)
预测函数predict()
:
- 可选的测试时增强(TTA),默认10次
- 对测试集进行预测并返回softmax概率
def predict(test_loader, model, tta=10):
# switch to evaluate mode
model.eval()
test_pred_tta = None
for _ in range(tta):
test_pred = []
with torch.no_grad():
end = time.time()
for i, (input, target) in notebook.tqdm(enumerate(test_loader), total=len(test_loader)):
input = input.cuda()
target = target.cuda()
# compute output
output = model(input)
output = F.softmax(output, dim=1)
output = output.data.cpu().numpy()
test_pred.append(output)
test_pred = np.vstack(test_pred)
if test_pred_tta is None:
test_pred_tta = test_pred
else:
test_pred_tta += test_pred
return test_pred_tta
8、运行展示
运行输出如下:
Epoch: 0
[ 0/12] Time 64.895 (64.895) Loss 3.0313e+00 (3.0313e+00) Acc@1 6.25 ( 6.25)
100%
2/2 [00:02<00:00, 1.15s/it]
* Acc@1 10.000
Epoch: 1
[ 0/12] Time 2.303 ( 2.303) Loss 2.3945e+00 (2.3945e+00) Acc@1 12.50 ( 12.50)
100%
2/2 [00:03<00:00, 1.59s/it]
* Acc@1 12.000
Epoch: 2
[ 0/12] Time 5.129 ( 5.129) Loss 2.4899e+00 (2.4899e+00) Acc@1 9.38 ( 9.38)
100%
2/2 [00:01<00:00, 1.13it/s]
* Acc@1 14.000
Epoch: 3
[ 0/12] Time 3.143 ( 3.143) Loss 2.6533e+00 (2.6533e+00) Acc@1 18.75 ( 18.75)
100%
2/2 [00:02<00:00, 1.04it/s]
* Acc@1 14.000
Epoch: 4
[ 0/12] Time 3.462 ( 3.462) Loss 2.4884e+00 (2.4884e+00) Acc@1 15.62 ( 15.62)
100%
2/2 [00:01<00:00, 1.11it/s]
* Acc@1 14.000
Epoch: 5
[ 0/12] Time 3.107 ( 3.107) Loss 2.4229e+00 (2.4229e+00) Acc@1 9.38 ( 9.38)
100%
2/2 [00:01<00:00, 1.15it/s]
* Acc@1 16.000
Epoch: 6
[ 0/12] Time 3.001 ( 3.001) Loss 2.3651e+00 (2.3651e+00) Acc@1 6.25 ( 6.25)
100%
2/2 [00:02<00:00, 1.40s/it]
* Acc@1 18.000
Epoch: 7
[ 0/12] Time 3.505 ( 3.505) Loss 2.0442e+00 (2.0442e+00) Acc@1 15.62 ( 15.62)
100%
2/2 [00:02<00:00, 1.06it/s]
* Acc@1 14.000
Epoch: 8
[ 0/12] Time 2.827 ( 2.827) Loss 1.9845e+00 (1.9845e+00) Acc@1 15.62 ( 15.62)
100%
2/2 [00:02<00:00, 1.06it/s]
* Acc@1 20.000
Epoch: 9
[ 0/12] Time 3.837 ( 3.837) Loss 2.2321e+00 (2.2321e+00) Acc@1 15.62 ( 15.62)
100%
2/2 [00:01<00:00, 1.15it/s]
* Acc@1 20.000
Epoch: 10
[ 0/12] Time 2.900 ( 2.900) Loss 2.1750e+00 (2.1750e+00) Acc@1 18.75 ( 18.75)
100%
2/2 [00:02<00:00, 1.07s/it]
* Acc@1 20.000
Epoch: 11
[ 0/12] Time 3.021 ( 3.021) Loss 1.9955e+00 (1.9955e+00) Acc@1 28.12 ( 28.12)
100%
2/2 [00:03<00:00, 1.56s/it]
* Acc@1 14.000
Epoch: 12
[ 0/12] Time 2.698 ( 2.698) Loss 2.0443e+00 (2.0443e+00) Acc@1 18.75 ( 18.75)
100%
2/2 [00:02<00:00, 1.21s/it]
* Acc@1 18.000
Epoch: 13
[ 0/12] Time 4.560 ( 4.560) Loss 2.0945e+00 (2.0945e+00) Acc@1 15.62 ( 15.62)
100%
2/2 [00:02<00:00, 1.13s/it]
* Acc@1 22.000
Epoch: 14
[ 0/12] Time 3.307 ( 3.307) Loss 2.1820e+00 (2.1820e+00) Acc@1 15.62 ( 15.62)
100%
2/2 [00:02<00:00, 1.00it/s]
* Acc@1 26.000
Epoch: 15
[ 0/12] Time 2.934 ( 2.934) Loss 2.1960e+00 (2.1960e+00) Acc@1 18.75 ( 18.75)
100%
2/2 [00:02<00:00, 1.20s/it]
* Acc@1 20.000
Epoch: 16
[ 0/12] Time 3.537 ( 3.537) Loss 2.5063e+00 (2.5063e+00) Acc@1 3.12 ( 3.12)
100%
2/2 [00:02<00:00, 1.02s/it]
* Acc@1 28.000
Epoch: 17
[ 0/12] Time 3.439 ( 3.439) Loss 2.4042e+00 (2.4042e+00) Acc@1 6.25 ( 6.25)
100%
2/2 [00:02<00:00, 1.04s/it]
* Acc@1 26.000
Epoch: 18
[ 0/12] Time 3.201 ( 3.201) Loss 2.0144e+00 (2.0144e+00) Acc@1 18.75 ( 18.75)
100%
2/2 [00:03<00:00, 1.49s/it]
* Acc@1 12.000
Epoch: 19
[ 0/12] Time 3.895 ( 3.895) Loss 2.0476e+00 (2.0476e+00) Acc@1 18.75 ( 18.75)
100%
2/2 [00:02<00:00, 1.06s/it]
* Acc@1 28.000
输出结果 submit.csv 的部分数据如下:
uuid (图片名称) | label (预测动物种类) |
---|---|
test\00981a3b70.jpg | panda |
test\033c8c8787.jpg | moth |
test\0373ac21ad.jpg | panda |
test\056fd9df3e.jpg | bear |
test\0798f04ce0.jpg | badger |
test\10c84dc56f.jpg | hare |
test\11dfaa6649.jpg | tiger |
test\148c987a2a.jpg | leopard |
test\15df975139.jpg | tiger |
test\173fedc818.jpg | chimpanzee |
… | … |
源码开源协议
GPL-v3