研习社2019胸腔X光数据集和肺炎检测

摘要:

合集:AI案例-CV-医疗
赛题:AI研习社胸腔X光肺炎检测大赛
主办方:AI研习社
主页:https://god.yanxishe.com/13
AI问题:图像分类
数据集:研习社2019胸腔X光数据集
数据集价值:肺炎检测
解决方案:PyTorch框架、ResNet18模型

一、赛题描述

主办方提供的数据集中共包含5,857张胸腔X光图片,其中训练集4,099张(正常图片+肺炎图片),测试集 1,757张(训练模型正确识别肺炎X光图片,0=正常,1=肺炎)。

任务

训练模型能正确识别肺炎X光图片。

评价标准

Ture: 模型分类正确数量

Total: 测试集样本总数量

Score = (True/Total) * 100

结果文件上传

提交CSV文件。

第一个字段位:测试集图片ID;

第二个字段位:0=正常,1=肺炎;

二、数据集内容

数据结构

工作路径:xray_dataset

训练图片:

  • 正常:./xray_dataset/train/NORMAL
  • 肺炎:./xray_dataset/train/PNEUMONIA

测试图片:./xray_dataset/test/{0}.jpeg

图片样例:

数据集使用许可协议

GPL

三、解决方案样例

解决方案

利用PyTorch框架,使用ResNet18进行TTA(测试时增强)来进行计算机视觉分类。ResNet18 简介:

  • 提出时间:2015年(何恺明团队在CVPR最佳论文中提出)
  • 核心创新:残差学习框架(Residual Learning)
  • 网络深度:18层(包含卷积层、全连接层和跳跃连接)
  • 定位:轻量级ResNet变体,适合计算资源有限的场景

安装pytorch2.4.1

选择合适的CUDA版本进行安装,例如pytorch==2.4.1:

conda create -n pytorch241-gpu python=3.10
conda activate pytorch241-gpu
conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.1 -c pytorch -c nvidia

已安装python和pytorch版本信息:

python                    3.10.15              h4607a30_1    anaconda
pytorch                   2.4.1           py3.10_cuda12.1_cudnn9_0   pytorch
pytorch-cuda             12.1                 hde6ce7c_6   pytorch
pytorch-mutex             1.0                       cuda   pytorch
torchaudio               2.4.1                   pypi_0   pypi
torchvision               0.19.1                   pypi_0   pypi
...

导入开发包

import os, sys, glob, argparse
import pandas as pd
import numpy as np
from tqdm import tqdm

import time, datetime
import pdb, traceback

import cv2
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
from efficientnet_pytorch import EfficientNet

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

处理流程

1、训练函数

def train(train_loader, model, criterion, optimizer, epoch):
  batch_time = AverageMeter('Time', ':6.3f')
  # data_time = AverageMeter('Data', ':6.3f')
  losses = AverageMeter('Loss', ':.4e')
  top1 = AverageMeter('Acc@1', ':6.2f')
  # top5 = AverageMeter('Acc@5', ':6.2f')
  progress = ProgressMeter(len(train_loader), batch_time, losses, top1)

  # switch to train mode
  model.train()

  end = time.time()
  for i, (input, target) in enumerate(train_loader):
      input = input.cuda(non_blocking=True)
      target = target.cuda(non_blocking=True)

      # compute output
      output = model(input)
      loss = criterion(output, target)

      # measure accuracy and record loss
      acc1, acc5 = accuracy(output, target, topk=(1, 2))
      losses.update(loss.item(), input.size(0))
      top1.update(acc1[0], input.size(0))
      # top5.update(acc5[0], input.size(0))

      # compute gradient and do SGD step
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # measure elapsed time
      batch_time.update(time.time() - end)
      end = time.time()

      if i % 100 == 0:
          progress.pr2int(i)

2、训练

skf = KFold(n_splits=10, random_state=233, shuffle=True)
for flod_idx, (train_idx, val_idx) in enumerate(skf.split(train_jpg, train_jpg)):
  # print(flod_idx, train_idx, val_idx)
   
  train_loader = torch.utils.data.DataLoader(
      QRDataset(train_jpg[train_idx],
              transforms.Compose([
                          # transforms.RandomGrayscale(),
                          transforms.Resize((512, 512)),
                          # transforms.RandomAffine(10),
                          # transforms.ColorJitter(hue=.05, saturation=.05),
                          # transforms.RandomCrop((88, 88)),
                          transforms.RandomHorizontalFlip(),
                          transforms.RandomVerticalFlip(),
                          transforms.ToTensor(),
                          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
          ])
      ), batch_size=10, shuffle=True, num_workers=20, pin_memory=True
  )
   
  val_loader = torch.utils.data.DataLoader(
      QRDataset(train_jpg[val_idx],
              transforms.Compose([
                          transforms.Resize((512, 512)),
                          # transforms.Resize((124, 124)),
                          # transforms.RandomCrop((88, 88)),
                          transforms.ToTensor(),
                          transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
          ])
      ), batch_size=10, shuffle=False, num_workers=10, pin_memory=True
  )
       
   
  model = VisitNet().cuda()
  # model = nn.DataParallel(model).cuda()
  criterion = nn.CrossEntropyLoss().cuda()
  optimizer = torch.optim.SGD(model.parameters(), 0.01)
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
  best_acc = 0.0
  for epoch in range(10):
      scheduler.step()
      print('Epoch: ', epoch)

      train(train_loader, model, criterion, optimizer, epoch)
      val_acc = validate(val_loader, model, criterion)
       
      if val_acc.avg.item() > best_acc:
          best_acc = val_acc.avg.item()
          torch.save(model.state_dict(), './resnet18_fold{0}.pt'.format(flod_idx))

3、执行结果

python3 1_train.py
python3 2_predict.py

训练过程:

Epoch:  0
[ 0/369]       Time 149.469 (149.469) Loss 7.4950e-01 (7.4950e-01)   Acc@1 40.00 ( 40.00)
[100/369]       Time 1.039 ( 2.371)   Loss 7.4575e-02 (3.0477e-01)   Acc@1 100.00 ( 85.45)
[200/369]       Time 1.316 ( 1.944)   Loss 3.6670e-01 (2.5079e-01)   Acc@1 80.00 ( 88.86)
[300/369]       Time 1.483 ( 1.762)   Loss 9.2747e-02 (2.2640e-01)   Acc@1 90.00 ( 90.50)
* Acc@1 94.146 Acc@5 100.000
Epoch: 1
[ 0/369]       Time 125.449 (125.449) Loss 8.7460e-02 (8.7460e-02)   Acc@1 100.00 (100.00)
[100/369]       Time 1.158 ( 2.387)   Loss 7.5592e-02 (1.6483e-01)   Acc@1 100.00 ( 94.16)
[200/369]       Time 1.219 ( 1.855)   Loss 1.2777e-01 (1.4759e-01)   Acc@1 100.00 ( 94.93)
[300/369]       Time 1.218 ( 1.725)   Loss 3.1686e-02 (1.5421e-01)   Acc@1 100.00 ( 94.72)
* Acc@1 97.805 Acc@5 100.000
Epoch: 2
[ 0/369]       Time 191.227 (191.227) Loss 7.4547e-02 (7.4547e-02)   Acc@1 100.00 (100.00)
[100/369]       Time 1.611 ( 3.359)   Loss 3.0406e-01 (1.2452e-01)   Acc@1 80.00 ( 95.45)
[200/369]       Time 1.484 ( 2.373)   Loss 9.4277e-03 (1.1199e-01)   Acc@1 100.00 ( 96.07)
[300/369]       Time 1.201 ( 2.051)   Loss 3.0028e-01 (1.0985e-01)   Acc@1 90.00 ( 96.28)
* Acc@1 94.878 Acc@5 100.000
Epoch: 3
[ 0/369]       Time 195.062 (195.062) Loss 4.3558e-02 (4.3558e-02)   Acc@1 100.00 (100.00)
[100/369]       Time 1.026 ( 2.901)   Loss 1.4391e-02 (7.9261e-02)   Acc@1 100.00 ( 97.23)
[200/369]       Time 1.093 ( 1.999)   Loss 1.9612e-02 (9.3397e-02)   Acc@1 100.00 ( 96.67)
[300/369]       Time 1.128 ( 1.698)   Loss 3.7451e-02 (9.0811e-02)   Acc@1 100.00 ( 96.88)
* Acc@1 97.561 Acc@5 100.000
Epoch: 4
[ 0/369]       Time 136.914 (136.914) Loss 8.3409e-01 (8.3409e-01)   Acc@1 90.00 ( 90.00)
[100/369]       Time 1.046 ( 2.305)   Loss 2.9704e-01 (8.0157e-02)   Acc@1 90.00 ( 97.52)
[200/369]       Time 1.083 ( 1.701)   Loss 9.9393e-02 (9.8687e-02)   Acc@1 90.00 ( 96.57)
[300/369]       Time 1.098 ( 1.503)   Loss 1.6535e-01 (9.4286e-02)   Acc@1 90.00 ( 96.54)
* Acc@1 97.805 Acc@5 100.000
Epoch: 5
[ 0/369]       Time 147.025 (147.025) Loss 2.6786e-02 (2.6786e-02)   Acc@1 100.00 (100.00)
[100/369]       Time 0.974 ( 2.375)   Loss 5.0860e-02 (8.8977e-02)   Acc@1 100.00 ( 96.63)
[200/369]       Time 1.043 ( 1.713)   Loss 1.0801e-01 (7.2678e-02)   Acc@1 90.00 ( 97.51)
[300/369]       Time 1.097 ( 1.494)   Loss 6.3345e-02 (7.6149e-02)   Acc@1 100.00 ( 97.51)
* Acc@1 96.585 Acc@5 100.000
Epoch: 6
[ 0/369]       Time 166.994 (166.994) Loss 1.0245e-01 (1.0245e-01)   Acc@1 100.00 (100.00)
[100/369]       Time 1.102 ( 2.657)   Loss 1.4437e-02 (5.9280e-02)   Acc@1 100.00 ( 97.82)
[200/369]       Time 1.038 ( 1.889)   Loss 9.6385e-03 (6.2406e-02)   Acc@1 100.00 ( 97.76)
...

输出:resnet18_fold0.pt 。。。

源码开源协议

GPL-3.0 license

四、获取案例套装

文件包大小:1.2 GB

获取:医疗行业视觉案例套装

发表评论