数据准备 dicom_to_image.py


import cv2
import os
import pydicom

# dicom图像输入路径
inputdir = '/workspace/20210910/Bmodel/inputs'
# dicom图像输出路径
outdir = '/workspace/UnSupRAFT/datasets/heart/'

# 获取文件夹下文件列表
files_name_list=os.listdir(inputdir)
count=0
# 遍历所有文件
for file_name in files_name_list:
    path=os.path.join(inputdir,file_name)
    ds=pydicom.read_file(path)
    # 获取该文件的帧数
    num_frame=ds.pixel_array.shape[0]
    # 逐帧保存为PNG无损图像
    for i in range(num_frame):
        # if not os.path.exists(os.path.join(outdir,file_name)):
        #     os.makedirs(os.path.join(outdir,file_name))
        if i==0:
            image1 =ds.pixel_array[i, 70:550, 47:751]
        else:
            image2 = ds.pixel_array[i, 70:550, 47:751]
            cv2.imwrite(os.path.join(outdir, str(
                count)+"_1.png"), image1, [cv2.IMWRITE_PNG_COMPRESSION, 0])
            cv2.imwrite(os.path.join(outdir, str(
                count)+"_2.png"), image2, [cv2.IMWRITE_PNG_COMPRESSION, 0])
            image1=image2
            count+=1

设置数据集heart_dataset.py

# Data loading based on https://github.com/NVIDIA/flownet2-pytorch

from posixpath import join
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import tqdm
import os
import math
import random
from glob import glob
import os.path as osp

from core.utils import frame_utils
from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
import albumentations as albu
from albumentations.pytorch import ToTensor


class FlowDataset(data.Dataset):
    def __init__(self, aug_params=None, sparse=False):
        self.augmentor = None
        self.sparse = sparse
        if aug_params is not None:
            if sparse:
                self.augmentor = SparseFlowAugmentor(**aug_params)
            else:
                self.augmentor = FlowAugmentor(**aug_params)

        self.is_test = False
        self.init_seed = False
        self.flow_list = []
        self.image_list = []
        self.extra_info = []
        self.frames_transforms = albu.Compose([
            albu.Normalize((0., 0., 0.), (1., 1., 1.)),
            ToTensor()
        ])

    def __getitem__(self, index):

        if self.is_test:
            img1 = frame_utils.read_gen(self.image_list[index][0])
            img2 = frame_utils.read_gen(self.image_list[index][1])
            # self
            img1=np.repeat(np.array(img1)[:, :, np.newaxis], 3, axis=2)
            img2 = np.repeat(np.array(img2)[:, :, np.newaxis], 3, axis=2)
            img1 = np.array(img1).astype(np.uint8)[..., :3]
            img2 = np.array(img2).astype(np.uint8)[..., :3]

            img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
            img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
            return img1, img2, self.extra_info[index]

        if not self.init_seed:
            worker_info = torch.utils.data.get_worker_info()
            if worker_info is not None:
                torch.manual_seed(worker_info.id)
                np.random.seed(worker_info.id)
                random.seed(worker_info.id)
                self.init_seed = True

        index = index % len(self.image_list)
        valid = None
        if self.sparse:
            flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
        else:
            flow = frame_utils.read_gen(self.flow_list[index])

        img1 = frame_utils.read_gen(self.image_list[index][0])
        img2 = frame_utils.read_gen(self.image_list[index][1])

        flow = np.array(flow).astype(np.float32)
        img1 = np.array(img1).astype(np.uint8)
        img2 = np.array(img2).astype(np.uint8)

        frame1 = self.frames_transforms(image=img1)['image']
        frame2 = self.frames_transforms(image=img2)['image']

        # import cv2
        # cv2.imshow('image1', img1)
        # cv2.imshow('image2', img2)
        # cv2.waitKey(-1)

        # grayscale images
        if len(img1.shape) == 2:
            img1 = np.tile(img1[..., None], (1, 1, 3))
            img2 = np.tile(img2[..., None], (1, 1, 3))
        else:
            img1 = img1[..., :3]
            img2 = img2[..., :3]

        if self.augmentor is not None:
            if self.sparse:
                img1, img2, flow, valid = self.augmentor(
                    img1, img2, flow, valid)
            else:
                img1, img2, flow = self.augmentor(img1, img2, flow)

        img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
        img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
        flow = torch.from_numpy(flow).permute(2, 0, 1).float()

        if valid is not None:
            valid = torch.from_numpy(valid)
        else:
            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)

        return img1, img2, flow, valid.float(), frame1, frame2

    def __rmul__(self, v):
        self.flow_list = v * self.flow_list
        self.image_list = v * self.image_list
        return self

    def __len__(self):
        return len(self.image_list)


class HeartDataset(FlowDataset):
    def __init__(self, aug_params=None, split='testing', root='datasets/heart'):
        super(HeartDataset, self).__init__(aug_params, sparse=True)
        if split == 'testing':
            self.is_test = True
        root = os.path.join('/workspace/UnSupRAFT', root)
        images1 = sorted(glob(osp.join(root, '*_1.png')))
        images2 = sorted(glob(osp.join(root, '*_2.png')))
        import sys
        sys.path.append('UnSupRAFT')
        for img1, img2 in zip(images1, images2):
            frame_id = img1.split('/')[-1]
            self.extra_info += [[frame_id]]
            self.image_list += [[img1, img2]]


def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
    """ Create the data loader for the corresponding trainign set """

    train_dataset = HeartDataset()

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   pin_memory=False,
                                   shuffle=True,
                                   num_workers=4,
                                   drop_last=True)

    print('Training with %d image pairs' % len(train_dataset))
    return train_loader

# 测试数据集
if __name__=='__main__':
    aug_params = {
        'crop_size': [512, 512],
        'min_scale': -0.2,
        'max_scale': 0.4,
        'do_flip': False
    }
    train_dataset = HeartDataset(aug_params)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=1,
                                   pin_memory=True,
                                   shuffle=True,
                                   num_workers=1,
                                   drop_last=True)
    print('Training with %d image pairs' % len(train_dataset))
    for batch_ndx, sample in enumerate(train_loader):
        image1=sample[0].cuda()
        image2=sample[1].cuda()
        id=sample[2]
        for x in sample:
            print(x)
        print(sample.inp.is_pinned())
        print(sample.tgt.is_pinned())

运行文件hear_train.py

from __future__ import print_function, division
import sys
# sys.path.append('<core_path>')
'''

python -u train.py --name raft-chairs --stage chairs --validation chairs 
--gpus 0 --num_steps 100000 --batch_size 8 --lr 0.0004 --image_size 368 496 --wdecay 0.0001

'''

import argparse
import os
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
import sys

sys.path.append('core')
from raft import RAFT
import evaluate
from core import datasets
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
from core.utils.unsupervised_loss import unsup_loss
import heart_dataset

try:
    from torch.cuda.amp import GradScaler
except:
    # dummy GradScaler for PyTorch < 1.6
    class GradScaler:
        def __init__(self):
            pass

        def scale(self, loss):
            return loss

        def unscale_(self, optimizer):
            pass

        def step(self, optimizer):
            optimizer.step()

        def update(self):
            pass


# exclude extremly large displacements
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 200


def sequence_loss(flow_preds, warped_images, img1, total_steps, gamma=0.8, max_flow=MAX_FLOW):
    """ Loss function defined over sequence of flow predictions """

    n_predictions = len(flow_preds)
    flow_loss = 0.0

    # exlude invalid pixels and extremely large diplacements
    # mag = torch.sum(flow_gt**2, dim=1).sqrt()

    for i in range(n_predictions):
        i_weight = gamma**(n_predictions - i - 1)

        # if total_steps < 0:
        #     i_loss = (flow_preds[i] - flow_gt).abs()
        # else:
        i_loss = unsup_loss(flow_preds[i], warped_images[i], img1)

        flow_loss += i_weight * (i_loss).mean()

    # epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
    # epe = epe.view(-1)

    metrics = {
        'loss': flow_loss.item()
    }

    return flow_loss, metrics


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def fetch_optimizer(args, model):
    """ Create the optimizer and learning rate scheduler """
    optimizer = optim.AdamW(model.parameters(),
                            lr=args.lr,
                            weight_decay=args.wdecay,
                            eps=args.epsilon)

    scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
                                              args.lr,
                                              args.num_steps + 100,
                                              pct_start=0.05,
                                              cycle_momentum=False,
                                              anneal_strategy='linear')

    return optimizer, scheduler


class Logger:
    def __init__(self, model, scheduler):
        self.model = model
        self.scheduler = scheduler
        self.total_steps = 0
        self.running_loss = {}
        self.writer = None

    def _print_training_status(self):
        metrics_data = [
            self.running_loss[k] / SUM_FREQ
            for k in sorted(self.running_loss.keys())
        ]
        training_str = "[{:6d}, {:10.7f}] ".format(
            self.total_steps + 1,
            self.scheduler.get_last_lr()[0])
        metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)

        # print the training status
        print(training_str + metrics_str)

        if self.writer is None:
            self.writer = SummaryWriter()

        for k in self.running_loss:
            self.writer.add_scalar(k, self.running_loss[k] / SUM_FREQ,
                                   self.total_steps)
            self.running_loss[k] = 0.0

    def push(self, metrics):
        self.total_steps += 1

        for key in metrics:
            if key not in self.running_loss:
                self.running_loss[key] = 0.0

            self.running_loss[key] += metrics[key]

        if self.total_steps % SUM_FREQ == SUM_FREQ - 1:
            self._print_training_status()
            self.running_loss = {}

    def write_dict(self, results):
        if self.writer is None:
            self.writer = SummaryWriter()

        for key in results:
            self.writer.add_scalar(key, results[key], self.total_steps)

    def close(self):
        self.writer.close()


def train(args):

    model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
    print("Parameter Count: %d" % count_parameters(model))

    if args.restore_ckpt is not None:
        model.load_state_dict(torch.load(args.restore_ckpt), strict=False)

    model.cuda()
    model.train()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    train_loader = heart_dataset.fetch_dataloader(args)
    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 95000
    scaler = GradScaler(enabled=args.mixed_precision)
    logger = Logger(model, scheduler)

    VAL_FREQ = 5000
    add_noise = True

    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in tqdm(enumerate(train_loader),
                                       total=len(train_loader)):
            optimizer.zero_grad()
            image1 = data_blob[0].cuda()
            image2 = data_blob[1].cuda()
            id = data_blob[2]
            # image1, image2 = [
            #     x for x in data_blob
            # ]

            if args.add_noise:
                stdv = np.random.uniform(0.0, 5.0)
                image1 = (image1 +
                          stdv * torch.randn(*image1.shape).cuda()).clamp(
                              0.0, 255.0)
                image2 = (image2 +
                          stdv * torch.randn(*image2.shape).cuda()).clamp(
                              0.0, 255.0)
    #             def forward(self, image1, image2, flow_gt=None, frame1=None, frame2=None, \
    #  iters=12, flow_init=None, upsample=True, test_mode=False):
            flow_predictions, warped_images = model(image1, image2, iters=args.iters)

            loss, metrics = sequence_loss(flow_predictions, \
             warped_images,  image1, total_steps=total_steps, gamma=args.gamma)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            scaler.step(optimizer)
            scheduler.step()
            scaler.update()

            logger.push(metrics)
            # print(total_steps, total_steps % VAL_FREQ)
            if total_steps % VAL_FREQ == VAL_FREQ - 1:
                PATH = 'checkpoints/heart/%d_%s.pth' % (total_steps + 1,
                                                         args.name)
                torch.save(model.state_dict(), PATH)

                # results = {}
                # for val_dataset in args.validation:
                #     if val_dataset == 'chairs':
                #         results.update(evaluate.validate_chairs(model.module))
                #     elif val_dataset == 'sintel':
                #         results.update(evaluate.validate_sintel(model.module))
                #     elif val_dataset == 'kitti':
                #         results.update(evaluate.validate_kitti(model.module))

                # logger.write_dict(results)

                model.train()
                if args.stage != 'chairs':
                    model.module.freeze_bn()

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break

    logger.close()
    PATH = 'checkpoints/%s.pth' % args.name
    torch.save(model.state_dict(), PATH)

    return PATH


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', default='raft', help="name your experiment")
    parser.add_argument('--stage',
                        default='heart',
                        help="determines which dataset to use for training")
    parser.add_argument('--restore_ckpt', help="restore checkpoint")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--validation', type=str, nargs='+')

    parser.add_argument('--lr', type=float, default=0.00002)
    parser.add_argument('--num_steps', type=int, default=100000)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--image_size',
                        type=int,
                        nargs='+',
                        default=[512, 512])
    parser.add_argument('--gpus', type=int, nargs='+', default=[0])
    parser.add_argument('--mixed_precision',
                        action='store_true',
                        help='use mixed precision')

    parser.add_argument('--iters', type=int, default=12)
    parser.add_argument('--wdecay', type=float, default=.00005)
    parser.add_argument('--epsilon', type=float, default=1e-8)
    parser.add_argument('--clip', type=float, default=1.0)
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--gamma',
                        type=float,
                        default=0.8,
                        help='exponential weighting')
    parser.add_argument('--add_noise', action='store_true')
    args = parser.parse_args()

    torch.manual_seed(1234)
    np.random.seed(1234)

    if not os.path.isdir('checkpoints'):
        os.mkdir('checkpoints')

    train(args)

修改raft.py

def forward(self, image1, image2, flow_gt, frame1, frame2, \
	 	iters=12, flow_init=None, upsample=True, test_mode=False):

#修改为
def forward(self, image1, image2, flow_gt=None, frame1=None, frame2=None, \
     iters=12, flow_init=None, upsample=True, test_mode=False):

#预测代码修改evaluate.py

import sys
sys.path.append('core')

from PIL import Image
import argparse
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from core import datasets
from utils import flow_viz
from utils import frame_utils

from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
from tqdm import tqdm
from scipy.special import softmax
import cv2
import pickle


def viz(img, flows, val_id):
    img = img[0].permute(1,2,0).cpu().numpy()
    new_flows = []

    for flo in flows:
        try:
            flo = flo[0].permute(1,2,0).cpu().numpy()
        except:
            flo = flo.permute(1,2,0).cpu().numpy()
        flo = flow_viz.flow_to_image(flo)
        new_flows.append(flo)

    img_flo = np.concatenate([img, *new_flows], axis=0)

    img = img_flo[:, :, [2,1,0]]#/255.0
    cv2.imwrite(f'outputs/{val_id}.png', img)

    img /= 255.0
    # cv2.imshow('image', img)
    # cv2.waitKey()
    # cv2.destroyAllWindows()


@torch.no_grad()
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
    """ Create submission for the Sintel leaderboard """
    model.eval()
    for dstype in ['clean', 'final']:
        test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)

        flow_prev, sequence_prev = None, None
        for test_id in tqdm(range(len(test_dataset))):
            image1, image2, (sequence, frame) = test_dataset[test_id]
            if sequence != sequence_prev:
                flow_prev = None

            padder = InputPadder(image1.shape)
            image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())

            flow_low, flow_pr = model(image1, image2, \
             frame1=None, frame2=None, flow_gt=None, \
             iters=iters, flow_init=flow_prev, test_mode=True)
            flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()

            if warm_start:
                flow_prev = forward_interpolate(flow_low[0])[None].cuda()

            output_dir = os.path.join(output_path, dstype, sequence)
            output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))

            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            frame_utils.writeFlow(output_file, flow)
            sequence_prev = sequence


@torch.no_grad()
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
    """ Create submission for the Sintel leaderboard """
    model.eval()
    test_dataset = datasets.KITTI(split='testing', aug_params=None)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    for test_id in range(len(test_dataset)):
        image1, image2, (frame_id, ) = test_dataset[test_id]
        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())

        _, flow_pr = model(image1, image2, iters=iters, test_mode=True)
        flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()

        output_filename = os.path.join(output_path, frame_id)
        frame_utils.writeFlowKITTI(output_filename, flow)


@torch.no_grad()
def validate_chairs(model, iters=24):
    """ Perform evaluation on the FlyingChairs (test) split """
    model.eval()
    epe_list = []

    val_dataset = datasets.FlyingChairs(split='validation')
    # val_dataset = datasets.FlyingChairs(split='validation')
    for val_id in tqdm(range(len(val_dataset))):
        image1, image2, flow_gt, _, frame1, frame2 = val_dataset[val_id]

        image1 = image1[None].cuda()
        image2 = image2[None].cuda()

        _, flow_pr = model(image1, image2, flow_gt, \
         frame1=frame1, frame2=frame2, \
         iters=iters, test_mode=True)

        # viz(image1, [flow_pr, flow_gt], val_id)

        epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
        epe_list.append(epe.view(-1).numpy())

    epe = np.mean(np.concatenate(epe_list))
    print("Validation Chairs EPE: %f" % epe)
    return {'chairs': epe}


@torch.no_grad()
def validate_sintel(model, iters=32):
    """ Peform validation using the Sintel (train) split """
    model.eval()
    results = {}
    # for dstype in ['clean', 'final']:
    feature_maps = []
    for dstype in ['clean']:
        val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
        # val_dataset = datasets.MpiSintel(split='test', dstype=dstype)
        epe_list = []

        for val_id in tqdm(range(len(val_dataset))):
            frame1, frame2 = None, None

            image1, image2, flow_gt, _, frame1, frame2 = val_dataset[val_id]
            image1 = image1[None].cuda()
            image2 = image2[None].cuda()

            padder = InputPadder(image1.shape)
            image1, image2 = padder.pad(image1, image2)

            _, flow_pr, fmap1 = model(image1, image2, flow_gt, \
             frame1=frame1, frame2=frame2, \
             iters=iters, test_mode=True)

            feature_maps.append(fmap1.cpu().numpy())

            flow = padder.unpad(flow_pr[0]).cpu()

            epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
            epe_list.append(epe.view(-1).numpy())
            if val_id > 900:
                break

        pickle.dump(feature_maps, open( "umap_pickles/save_900.p", "wb" ))
        exit()

        epe_all = np.concatenate(epe_list)
        epe = np.mean(epe_all)
        px1 = np.mean(epe_all<1)
        px3 = np.mean(epe_all<3)
        px5 = np.mean(epe_all<5)

        print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
        results[dstype] = np.mean(epe_list)

    return results


@torch.no_grad()
def validate_kitti(model, iters=24):
    """ Peform validation using the KITTI-2015 (train) split """
    output_path = 'kitti_submission'
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    model.eval()
    val_dataset = datasets.KITTI(split='training')

    out_list, epe_list = [], []
    for val_id in range(len(val_dataset)):
        # image1, image2, flow_gt, valid_gt = val_dataset[val_id]
        image1, image2, flow_gt, valid_gt, frame1, frame2 = val_dataset[val_id]
        image1 = image1[None].cuda()
        image2 = image2[None].cuda()

        padder = InputPadder(image1.shape, mode='kitti')
        image1, image2 = padder.pad(image1, image2)

        flow_low, flow_pr = model(image1, image2, flow_gt, \
          frame1=frame1, frame2=frame2, \
          iters=iters, test_mode=True)

        flow = padder.unpad(flow_pr[0]).cpu()
        viz(image1[:, :, :-1, :-6], [flow], val_id)
        # output_filename = os.path.join(output_path, str(val_id)+'.flo')
        # frame_utils.writeFlow(output_filename, flow)
        epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
        mag = torch.sum(flow_gt**2, dim=0).sqrt()

        epe = epe.view(-1)
        mag = mag.view(-1)
        val = valid_gt.view(-1) >= 0.5

        out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
        epe_list.append(epe[val].mean().item())
        out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    epe = np.mean(epe_list)
    f1 = 100 * np.mean(out_list)

    print("Validation KITTI: %f, %f" % (epe, f1))
    return {'kitti-epe': epe, 'kitti-f1': f1}


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',default='/workspace/UnSupRAFT/models/raft-kitti.pth', help="restore checkpoint")
    parser.add_argument('--dataset', default='kitti' ,help="dataset for evaluation")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
    args = parser.parse_args()

    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model.cuda()
    model.eval()

    # create_sintel_submission(model.module, warm_start=True)
    # create_kitti_submission(model.module)

    with torch.no_grad():
        if args.dataset == 'chairs':
            validate_chairs(model.module)

        elif args.dataset == 'sintel':
            validate_sintel(model.module)

        elif args.dataset == 'kitti':
            validate_kitti(model.module)