train.py

import argparse
import os
from collections import OrderedDict
from glob import glob

import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import yaml
import albumentations as A
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
from torch.optim import lr_scheduler
from tqdm import tqdm

import archs
import losses
from dataset import Dataset
from metrics import iou_score, dice_coef, precision_score, recall_score
from utils import AverageMeter, str2bool

ARCH_NAMES = archs.__all__
LOSS_NAMES = losses.__all__
LOSS_NAMES.append('BCEWithLogitsLoss')


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default=None,
                        help='model name: (default: arch+timestamp)')
    parser.add_argument('--epochs', default=150, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        metavar='N', help='mini-batch size (default: 16)')

    # model
    parser.add_argument('--arch', '-a', metavar='ARCH', default='UKAN_NestedUNet',
                        choices=ARCH_NAMES,
                        help='model architecture: ' +
                             ' | '.join(ARCH_NAMES) +
                             ' (default: NestedUNet)')
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=512, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=512, type=int,
                        help='image height')

    # loss
    parser.add_argument('--loss', default='BCEDiceLoss',
                        choices=LOSS_NAMES,
                        help='loss: ' +
                             ' | '.join(LOSS_NAMES) +
                             ' (default: BCEDiceLoss)')

    # dataset
    parser.add_argument('--dataset', default='dsb2018_96_1000',
                        help='dataset name')
    parser.add_argument('--img_ext', default='.jpg',
                        help='image file extension')
    parser.add_argument('--mask_ext', default='.png',
                        help='mask file extension')

    # optimizer
    parser.add_argument('--optimizer', default='SGD',
                        choices=['Adam', 'SGD'],
                        help='loss: ' +
                             ' | '.join(['Adam', 'SGD']) +
                             ' (default: Adam)')
    parser.add_argument('--lr', '--learning_rate', default=1e-3, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='momentum')
    parser.add_argument('--weight_decay', default=1e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', default=False, type=str2bool,
                        help='nesterov')

    # scheduler
    parser.add_argument('--scheduler', default='CosineAnnealingLR',
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
    parser.add_argument('--min_lr', default=1e-5, type=float,
                        help='minimum learning rate')
    parser.add_argument('--factor', default=0.1, type=float)
    parser.add_argument('--patience', default=2, type=int)
    parser.add_argument('--milestones', default='1,2', type=str)
    parser.add_argument('--gamma', default=2 / 3, type=float)
    parser.add_argument('--early_stopping', default=-1, type=int,
                        metavar='N', help='early stopping (default: -1)')

    parser.add_argument('--num_workers', default=4, type=int)

    config = parser.parse_args()

    return config


def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter(),
                  'dice': AverageMeter(),
                  'precision': AverageMeter(),
                  'recall': AverageMeter()}

    model.train()

    pbar = tqdm(total=len(train_loader))
    for input, target, _ in train_loader:
        input = input.cuda()
        target = target.cuda()

        # compute output
        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            output = outputs[-1]
        else:
            output = model(input)
            loss = criterion(output, target)

        # compute metrics
        iou = iou_score(output, target)
        dice = dice_coef(output, target)
        precision = precision_score(output, target)
        recall = recall_score(output, target)

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))
        avg_meters['dice'].update(dice, input.size(0))
        avg_meters['precision'].update(precision, input.size(0))
        avg_meters['recall'].update(recall, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
            ('dice', avg_meters['dice'].avg),
            ('precision', avg_meters['precision'].avg),
            ('recall', avg_meters['recall'].avg),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg),
                        ('precision', avg_meters['precision'].avg),
                        ('recall', avg_meters['recall'].avg)])


def validate(config, val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter(),
                  'dice': AverageMeter(),
                  'precision': AverageMeter(),
                  'recall': AverageMeter()}

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for input, target, _ in val_loader:
            input = input.cuda()
            target = target.cuda()

            # compute output
            if config['deep_supervision']:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                output = outputs[-1]
            else:
                output = model(input)
                loss = criterion(output, target)

            # compute metrics
            iou = iou_score(output, target)
            dice = dice_coef(output, target)
            precision = precision_score(output, target)
            recall = recall_score(output, target)

            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))
            avg_meters['precision'].update(precision, input.size(0))
            avg_meters['recall'].update(recall, input.size(0))

            postfix = OrderedDict([
                ('loss', avg_meters['loss'].avg),
                ('iou', avg_meters['iou'].avg),
                ('dice', avg_meters['dice'].avg),
                ('precision', avg_meters['precision'].avg),
                ('recall', avg_meters['recall'].avg),
            ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg),
                        ('precision', avg_meters['precision'].avg),
                        ('recall', avg_meters['recall'].avg)])


def main():
    config = vars(parse_args())

    if config['name'] is None:
        if config['deep_supervision']:
            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
        else:
            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
    os.makedirs('models/%s' % config['name'], exist_ok=True)

    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    with open('models/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    # Enable multi-GPU support with DataParallel
    if torch.cuda.device_count() > 1:
        print(f"=> Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    model = model.cuda()

    params = filter(lambda p: p.requires_grad, model.parameters())
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(
            params, lr=config['lr'], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
                              nesterov=config['nesterov'], weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],
                                                   verbose=1, min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')],
                                             gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError

    # Data loading code
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

    train_transform = Compose([
        A.RandomRotate90(),
        A.HorizontalFlip(p=0.5),  # 50% 概率水平翻转
        A.VerticalFlip(p=0.5),  # 50% 概率垂直翻转
        OneOf([
            A.HueSaturationValue(),
            A.RandomBrightnessContrast(),
        ], p=1),
        A.Resize(config['input_h'], config['input_w']),
        A.Normalize(),
    ])

    val_transform = Compose([
        A.Resize(config['input_h'], config['input_w']),
        A.Normalize(),
    ])

    train_dataset = Dataset(
        img_ids=train_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=train_transform)
    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    # 创建CSV文件
    df = pd.DataFrame(columns=['epoch', 'loss', 'iou', 'dice', 'precision', 'recall',
                              'val_loss', 'val_iou', 'val_dice', 'val_precision', 'val_recall'])
    df.to_csv('models/%s/log.csv' % config['name'], index=False)

    best_iou = 0
    best_epoch = 0
    for epoch in range(config['epochs']):
        print('\nEpoch [%d/%d]' % (epoch, config['epochs']))

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        val_log = validate(config, val_loader, model, criterion)

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        print('loss %.4f - iou %.4f - dice %.4f - precision %.4f - recall %.4f - val_loss %.4f - val_iou %.4f - val_dice %.4f - val_precision %.4f - val_recall %.4f'
              % (train_log['loss'], train_log['iou'], train_log['dice'], train_log['precision'], train_log['recall'],
                 val_log['loss'], val_log['iou'], val_log['dice'], val_log['precision'], val_log['recall']))

        df = pd.DataFrame([[epoch, train_log['loss'], train_log['iou'], train_log['dice'], train_log['precision'], train_log['recall'],
                          val_log['loss'], val_log['iou'], val_log['dice'], val_log['precision'], val_log['recall']]],
                         columns=['epoch', 'loss', 'iou', 'dice', 'precision', 'recall',
                                 'val_loss', 'val_iou', 'val_dice', 'val_precision', 'val_recall'])
        df.to_csv('models/%s/log.csv' % config['name'], mode='a', header=False, index=False)

        if val_log['iou'] > best_iou:
            print("=> saved best model")
            best_iou = val_log['iou']
            best_epoch = epoch
            torch.save(model.state_dict(), 'models/%s/model.pth' % config['name'])

    print("=> Best IoU: %.4f at epoch %d" % (best_iou, best_epoch))


if __name__ == '__main__':
    main()

archs.py


import torch
from torch import nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from kan import KANLinear
from model.attention.CBAM import CBAMBlock  # 导入CBAM模块

__all__ = ['UKAN_NestedUNet']

class KANLayer(nn.Module):
    """
    KAN层实现,基于Kolmogorov-Arnold网络
    这是一个特殊的神经网络层,使用样条函数来增强网络的表达能力
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # 网格大小,控制样条函数的复杂度
        grid_size = 3
        # 样条函数的阶数
        spline_order = 2
        # 噪声缩放因子,用于初始化
        scale_noise = 0.1
        # 基础缩放因子
        scale_base = 1.0
        # 样条缩放因子
        scale_spline = 1.0
        # 基础激活函数
        base_activation = nn.SiLU
        # 网格epsilon值,防止网格点过于接近
        grid_eps = 0.02
        # 网格范围
        grid_range = [-1, 1]

        # 第一个KAN线性层,将输入特征映射到隐藏特征
        self.fc1 = KANLinear(
            in_features,
            hidden_features,
            grid_size=grid_size,
            spline_order=spline_order,
            scale_noise=scale_noise,
            scale_base=scale_base,
            scale_spline=scale_spline,
            base_activation=base_activation,
            grid_eps=grid_eps,
            grid_range=grid_range,
        )
        # 第二个KAN线性层,将隐藏特征映射到输出特征
        self.fc2 = KANLinear(
            hidden_features,
            out_features,
            grid_size=grid_size,
            spline_order=spline_order,
            scale_noise=scale_noise,
            scale_base=scale_base,
            scale_spline=scale_spline,
            base_activation=base_activation,
            grid_eps=grid_eps,
            grid_range=grid_range,
        )
        # 深度可分离卷积,用于捕获空间信息
        self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features)
        # 批归一化层
        self.bn = nn.BatchNorm2d(hidden_features)
        # ReLU激活函数
        self.relu = nn.ReLU()
        # Dropout层,用于正则化
        self.drop = nn.Dropout(drop)

    def forward(self, x, H, W):
        """
        前向传播函数
        x: 输入特征 [B, N, C]
        H, W: 特征图的高和宽
        """
        B, N, C = x.shape
        # 应用第一个KAN线性层
        x = self.fc1(x.reshape(B * N, C)).reshape(B, N, -1)
        # 重塑为图像格式以应用卷积
        x = x.transpose(1, 2).view(B, -1, H, W)
        # 应用深度可分离卷积、批归一化和ReLU
        x = self.relu(self.bn(self.dwconv(x)))
        # 重塑回序列格式
        x = x.flatten(2).transpose(1, 2)
        # 应用第二个KAN线性层
        x = self.fc2(x.reshape(B * N, -1)).reshape(B, N, -1)
        # 应用dropout并返回
        return self.drop(x)

class KANBlock(nn.Module):
    """
    KAN块,包含一个LayerNorm和一个KANLayer,并使用残差连接
    """
    def __init__(self, dim, drop=0., drop_path=0., norm_layer=nn.LayerNorm):
        super().__init__()
        # DropPath用于随机丢弃残差连接,增强正则化
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        # 层归一化
        self.norm = norm_layer(dim)
        # KAN层
        self.layer = KANLayer(in_features=dim, hidden_features=dim, drop=drop)

    def forward(self, x, H, W):
        """
        前向传播函数,实现残差连接
        """
        return x + self.drop_path(self.layer(self.norm(x), H, W))

class PatchEmbed(nn.Module):
    """
    图像块嵌入层,将图像转换为序列表示
    """
    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        # 计算输出特征图的高和宽
        self.H, self.W = img_size[0] // stride, img_size[1] // stride
        # 计算图像块的数量
        self.num_patches = self.H * self.W
        # 卷积层,用于提取图像块特征
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        # 层归一化
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        前向传播函数
        将图像转换为序列表示
        """
        # 应用卷积
        x = self.proj(x)
        # 获取输出特征图的尺寸
        _, _, H, W = x.shape
        # 将特征图展平为序列
        x = x.flatten(2).transpose(1, 2)
        # 应用层归一化
        x = self.norm(x)
        return x, H, W

class VGGBlock(nn.Module):
    """
    VGG块,包含两个卷积层,每个卷积层后跟批归一化和ReLU激活
    """
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        # ReLU激活函数
        self.relu = nn.ReLU(inplace=True)
        # 第一个卷积层
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        # 第一个批归一化层
        self.bn1 = nn.BatchNorm2d(middle_channels)
        # 第二个卷积层
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        # 第二个批归一化层
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        """
        前向传播函数
        """
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class UKAN_NestedUNet(nn.Module):
    """
    UKAN嵌套UNet模型,将KAN模块集成到UNet++架构中
    KAN模块主要在深层特征(x2_0, x2_2, x3_0, x4_0, x3_1)中发挥作用
    """
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, img_size=224, **kwargs):
        super().__init__()
        # 定义每层的滤波器数量
        self.nb_filter = [32, 64, 128, 256, 512]
        # 是否使用深度监督
        self.deep_supervision = deep_supervision
        # 最大池化层,用于下采样
        self.pool = nn.MaxPool2d(2, 2)
        # 上采样层,用于上采样
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # 编码器部分
        # 第一层编码器
        self.conv0_0 = VGGBlock(input_channels, self.nb_filter[0], self.nb_filter[0])
        # 第二层编码器
        self.conv1_0 = VGGBlock(self.nb_filter[0], self.nb_filter[1], self.nb_filter[1])
        # 第三层编码器
        self.conv2_0 = VGGBlock(self.nb_filter[1], self.nb_filter[2], self.nb_filter[2])
        
        # 新增:为x2_0添加PatchEmbed和KANBlock
        self.patch_embed2 = PatchEmbed(img_size=img_size//4, patch_size=3, stride=1, in_chans=self.nb_filter[2], embed_dim=self.nb_filter[2])
        self.block2 = KANBlock(dim=self.nb_filter[2], norm_layer=nn.LayerNorm)
        
        # 第四层编码器
        self.conv3_0 = VGGBlock(self.nb_filter[2], self.nb_filter[3], self.nb_filter[3])
        # 第四层的图像块嵌入,将特征转换为序列表示以应用KAN
        self.patch_embed3 = PatchEmbed(img_size=img_size//8, patch_size=3, stride=2, in_chans=self.nb_filter[3], embed_dim=self.nb_filter[3])
        # 第四层的KAN块,增强特征表达能力
        self.block3 = KANBlock(dim=self.nb_filter[3], norm_layer=nn.LayerNorm)
        # 第五层的图像块嵌入
        self.patch_embed4 = PatchEmbed(img_size=img_size//16, patch_size=3, stride=2, in_chans=self.nb_filter[3], embed_dim=self.nb_filter[4])
        # 第五层的KAN块
        self.block4 = KANBlock(dim=self.nb_filter[4], norm_layer=nn.LayerNorm)

        # 解码器部分
        # 第一层解码器
        self.conv0_1 = VGGBlock(self.nb_filter[0]+self.nb_filter[1], self.nb_filter[0], self.nb_filter[0])
        # 第二层解码器
        self.conv1_1 = VGGBlock(self.nb_filter[1]+self.nb_filter[2], self.nb_filter[1], self.nb_filter[1])
        # 第三层解码器
        self.conv2_1 = VGGBlock(self.nb_filter[2]+self.nb_filter[3], self.nb_filter[2], self.nb_filter[2])
        # 第四层解码器
        self.conv3_1 = VGGBlock(self.nb_filter[3]+self.nb_filter[4], self.nb_filter[3], self.nb_filter[3])
        # 第四层解码器的KAN块
        self.dblock3 = KANBlock(dim=self.nb_filter[3], norm_layer=nn.LayerNorm)
        # 嵌套连接的解码器
        self.conv0_2 = VGGBlock(self.nb_filter[0]*2+self.nb_filter[1], self.nb_filter[0], self.nb_filter[0])
        self.conv1_2 = VGGBlock(self.nb_filter[1]*2+self.nb_filter[2], self.nb_filter[1], self.nb_filter[1])
        self.conv2_2 = VGGBlock(self.nb_filter[2]*2+self.nb_filter[3], self.nb_filter[2], self.nb_filter[2])
        # 第三层解码器的KAN块
        self.dblock2 = KANBlock(dim=self.nb_filter[2], norm_layer=nn.LayerNorm)
        # 更深层的嵌套连接
        self.conv0_3 = VGGBlock(self.nb_filter[0]*3+self.nb_filter[1], self.nb_filter[0], self.nb_filter[0])
        self.conv1_3 = VGGBlock(self.nb_filter[1]*3+self.nb_filter[2], self.nb_filter[1], self.nb_filter[1])
        self.conv0_4 = VGGBlock(self.nb_filter[0]*4+self.nb_filter[1], self.nb_filter[0], self.nb_filter[0])

        # 为x0_1, x0_2, x0_3, x0_4特征图创建CBAM模块
        self.cbam_x0_1 = CBAMBlock(channel=self.nb_filter[0], reduction=16, kernel_size=3)
        self.cbam_x0_2 = CBAMBlock(channel=self.nb_filter[0], reduction=16, kernel_size=3)
        self.cbam_x0_3 = CBAMBlock(channel=self.nb_filter[0], reduction=16, kernel_size=3)
        self.cbam_x0_4 = CBAMBlock(channel=self.nb_filter[0], reduction=16, kernel_size=3)

        # 输出层
        if self.deep_supervision:
            # 如果使用深度监督,为每个解码器输出创建一个卷积层
            self.final1 = nn.Conv2d(self.nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(self.nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(self.nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(self.nb_filter[0], num_classes, kernel_size=1)
        else:
            # 否则只为最终输出创建一个卷积层
            self.final = nn.Conv2d(self.nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        """
        前向传播函数
        实现UNet++的前向传播,并在深层特征中应用KAN模块
        """
        # 编码器路径
        # 第一层特征
        x0_0 = self.conv0_0(input)
        # 第二层特征
        x1_0 = self.conv1_0(self.pool(x0_0))
        # 第一层解码器特征
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
        # 应用CBAM到x0_1
        x0_1 = self.cbam_x0_1(x0_1)

        # 第三层特征(经过KAN处理)
        x2_0_raw = self.conv2_0(self.pool(x1_0))
        # 新增:将x2_0转换为序列表示,应用KAN块
        out, H, W = self.patch_embed2(x2_0_raw)
        out = self.block2(out, H, W)
        # 将序列表示转换回空间表示
        x2_0 = out.reshape(-1, H, W, self.nb_filter[2]).permute(0, 3, 1, 2).contiguous()
        
        # 第二层解码器特征
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        # 第一层深度解码器特征
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
        # 应用CBAM到x0_2
        x0_2 = self.cbam_x0_2(x0_2)

        # 第四层特征
        x3_0 = self.conv3_0(self.pool(x2_0))
        # 将x3_0转换为序列表示,应用KAN块
        out, H, W = self.patch_embed3(x3_0)
        out = self.block3(out, H, W)
        # 将序列表示转换回空间表示
        x3_0 = out.reshape(-1, H, W, self.nb_filter[3]).permute(0, 3, 1, 2).contiguous()

        # 第五层特征
        # 将x3_0转换为序列表示,应用KAN块
        out, H, W = self.patch_embed4(x3_0)
        out = self.block4(out, H, W)
        # 将序列表示转换回空间表示
        x4_0 = out.reshape(-1, H, W, self.nb_filter[4]).permute(0, 3, 1, 2).contiguous()

        # 第四层解码器特征
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        # 将x3_1转换为序列表示,应用KAN块
        _, _, H, W = x3_1.shape
        out = x3_1.flatten(2).transpose(1, 2)
        out = self.dblock3(out, H, W)
        # 将序列表示转换回空间表示
        x3_1 = out.reshape(-1, H, W, self.nb_filter[3]).permute(0, 3, 1, 2).contiguous()

        # 修复尺寸不匹配问题:对x3_1进行两次上采样
        # 从14x14上采样到28x28,再到56x56
        x3_1_up = self.up(self.up(x3_1))
        # 第三层解码器特征
        x2_1 = self.conv2_1(torch.cat([x2_0, x3_1_up], 1))
        
        # 第二层深度解码器特征
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        # 第一层更深解码器特征
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
        # 应用CBAM到x0_3
        x0_3 = self.cbam_x0_3(x0_3)

        # 第三层深度解码器特征
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, x3_1_up], 1))  # 使用相同的上采样结果
        # 将x2_2转换为序列表示,应用KAN块
        _, _, H, W = x2_2.shape
        out = x2_2.flatten(2).transpose(1, 2)
        out = self.dblock2(out, H, W)
        # 将序列表示转换回空间表示
        x2_2 = out.reshape(-1, H, W, self.nb_filter[2]).permute(0, 3, 1, 2).contiguous()

        # 第二层更深解码器特征
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        # 最终解码器特征
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
        # 应用CBAM到x0_4
        x0_4 = self.cbam_x0_4(x0_4)

        # 输出处理
        if self.deep_supervision:
            # 如果使用深度监督,返回所有解码器的输出
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            # 否则只返回最终输出
            output = self.final(x0_4)
            return output 

kan.py

import torch
import torch.nn.functional as F
import math


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

dataset.py

import os

import cv2
import numpy as np
import torch
import torch.utils.data


class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
        """
        Args:
            img_ids (list): Image ids.
            img_dir: Image file directory.
            mask_dir: Mask file directory.
            img_ext (str): Image file extension.
            mask_ext (str): Mask file extension.
            num_classes (int): Number of classes.
            transform (Compose, optional): Compose transforms of albumentations. Defaults to None.
        
        Note:
            Make sure to put the files as the following structure:
            <dataset name>
            ├── images
            |   ├── 0a7e06.jpg
            │   ├── 0aab0a.jpg
            │   ├── 0b1761.jpg
            │   ├── ...
            |
            └── masks
                ├── 0
                |   ├── 0a7e06.png
                |   ├── 0aab0a.png
                |   ├── 0b1761.png
                |   ├── ...
                |
                ├── 1
                |   ├── 0a7e06.png
                |   ├── 0aab0a.png
                |   ├── 0b1761.png
                |   ├── ...
                ...
        """
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_ext = img_ext
        self.mask_ext = mask_ext
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        
        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))

        mask = []
        for i in range(self.num_classes):
            mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
                        img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
        #数组沿深度方向进行拼接。
        mask = np.dstack(mask)

        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)#这个包比较方便,能把mask也一并做掉
            img = augmented['image']#参考https://github.com/albumentations-team/albumentations
            mask = augmented['mask']
        
        img = img.astype('float32') / 255
        img = img.transpose(2, 0, 1)
        mask = mask.astype('float32') / 255
        mask = mask.transpose(2, 0, 1)
        
        return img, mask, {'img_id': img_id}


losses.py
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge
except ImportError:
    pass

__all__ = ['BCEDiceLoss', 'LovaszHingeLoss']


class BCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        bce = F.binary_cross_entropy_with_logits(input, target)
        smooth = 1e-5
        input = torch.sigmoid(input)
        num = target.size(0)
        input = input.view(num, -1)
        target = target.view(num, -1)
        intersection = (input * target)
        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
        dice = 1 - dice.sum() / num
        return 0.5 * bce + dice


class LovaszHingeLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        input = input.squeeze(1)
        target = target.squeeze(1)
        loss = lovasz_hinge(input, target, per_image=True)

        return loss

metrics.py

import numpy as np
import torch
import torch.nn.functional as F


def iou_score(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output_ = output > 0.5
    target_ = target > 0.5
    intersection = (output_ & target_).sum()
    union = (output_ | target_).sum()

    return (intersection + smooth) / (union + smooth)


def dice_coef(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output = output > 0.5
    target = target > 0.5
    intersection = (output & target).sum()

    return (2. * intersection + smooth) / \
        (output.sum() + target.sum() + smooth)


def precision_score(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output = output > 0.5
    target = target > 0.5
    
    true_positives = (output & target).sum()
    predicted_positives = output.sum()

    return (true_positives + smooth) / (predicted_positives + smooth)


def recall_score(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    output = output > 0.5
    target = target > 0.5
    
    true_positives = (output & target).sum()
    actual_positives = target.sum()

    return (true_positives + smooth) / (actual_positives + smooth)

utils.py

import argparse


def str2bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

CBAM.py

import numpy as np
import torch
from torch import nn
from torch.nn import init



class ChannelAttention(nn.Module):
    def __init__(self,channel,reduction=16):
        super().__init__()
        self.maxpool=nn.AdaptiveMaxPool2d(1)
        self.avgpool=nn.AdaptiveAvgPool2d(1)
        self.se=nn.Sequential(
            nn.Conv2d(channel,channel//reduction,1,bias=False),
            nn.ReLU(),
            nn.Conv2d(channel//reduction,channel,1,bias=False)
        )
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        max_result=self.maxpool(x)
        avg_result=self.avgpool(x)
        max_out=self.se(max_result)
        avg_out=self.se(avg_result)
        output=self.sigmoid(max_out+avg_out)
        return output

class SpatialAttention(nn.Module):
    def __init__(self,kernel_size=7):
        super().__init__()
        self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        max_result,_=torch.max(x,dim=1,keepdim=True)
        avg_result=torch.mean(x,dim=1,keepdim=True)
        result=torch.cat([max_result,avg_result],1)
        output=self.conv(result)
        output=self.sigmoid(output)
        return output



class CBAMBlock(nn.Module):

    def __init__(self, channel=512,reduction=16,kernel_size=49):
        super().__init__()
        self.ca=ChannelAttention(channel=channel,reduction=reduction)
        self.sa=SpatialAttention(kernel_size=kernel_size)


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        residual=x
        out=x*self.ca(x)
        out=out*self.sa(out)
        return out+residual


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    kernel_size=input.shape[2]
    cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)
    output=cbam(input)
    print(output.shape)

val.py

import argparse
import os
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
import albumentations as A  # 直接导入 albumentations 主模块
from albumentations.core.composition import Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import archs
from dataset import Dataset
from metrics import iou_score, dice_coef, precision_score, recall_score
from utils import AverageMeter

"""
需要指定参数:--name dsb2018_96_NestedUNet_woDS
"""


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default="dsb2018_96_1000_UKAN_NestedUNet_woDS",
                        help='model name')

    args = parser.parse_args()

    return args


def main():
    args = parse_args()

    with open('models/%s/config.yml' % args.name, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    print('-' * 20)
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
    print('-' * 20)

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.cuda()

    # Data loading code
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

    model.load_state_dict(torch.load('models/%s/model.pth' % args.name))
    model.eval()

    val_transform = Compose([
        A.Resize(config['input_h'], config['input_w']),  # 替换为 A.Resize
        A.Normalize(),
    ])

    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    # 创建四个指标的平均值计算器
    avg_iou = AverageMeter()
    avg_dice = AverageMeter()
    avg_precision = AverageMeter()
    avg_recall = AverageMeter()

    for c in range(config['num_classes']):
        os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            if config['deep_supervision']:
                output = model(input)[-1]
            else:
                output = model(input)

            # 计算所有指标
            iou = iou_score(output, target)
            dice = dice_coef(output, target)
            precision = precision_score(output, target)
            recall = recall_score(output, target)

            # 更新平均值
            avg_iou.update(iou, input.size(0))
            avg_dice.update(dice, input.size(0))
            avg_precision.update(precision, input.size(0))
            avg_recall.update(recall, input.size(0))

            output = torch.sigmoid(output).cpu().numpy()

            for i in range(len(output)):
                for c in range(config['num_classes']):
                    cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'),
                                (output[i, c] * 255).astype('uint8'))

    print('IoU: %.4f' % avg_iou.avg)
    print('Dice: %.4f' % avg_dice.avg)
    print('Precision: %.4f' % avg_precision.avg)
    print('Recall: %.4f' % avg_recall.avg)

    plot_examples(input, target, model, num_examples=3)

    torch.cuda.empty_cache()


def plot_examples(datax, datay, model, num_examples=6):
    fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18, 4 * num_examples))
    m = datax.shape[0]
    for row_num in range(num_examples):
        image_indx = np.random.randint(m)
        image_arr = model(datax[image_indx:image_indx + 1]).squeeze(0).detach().cpu().numpy()
        ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])
        ax[row_num][0].set_title("Original Image")
        ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0, :, :].astype(int)))
        ax[row_num][1].set_title("Segmented Image Localization")
        ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])
        ax[row_num][2].set_title("Target Image")
    plt.show()


if __name__ == '__main__':
    main()