数据准备 dicom_to_image.py
:
import cv2
import os
import pydicom
# dicom图像输入路径
inputdir = '/workspace/20210910/Bmodel/inputs'
# dicom图像输出路径
outdir = '/workspace/UnSupRAFT/datasets/heart/'
# 获取文件夹下文件列表
files_name_list=os.listdir(inputdir)
count=0
# 遍历所有文件
for file_name in files_name_list:
path=os.path.join(inputdir,file_name)
ds=pydicom.read_file(path)
# 获取该文件的帧数
num_frame=ds.pixel_array.shape[0]
# 逐帧保存为PNG无损图像
for i in range(num_frame):
# if not os.path.exists(os.path.join(outdir,file_name)):
# os.makedirs(os.path.join(outdir,file_name))
if i==0:
image1 =ds.pixel_array[i, 70:550, 47:751]
else:
image2 = ds.pixel_array[i, 70:550, 47:751]
cv2.imwrite(os.path.join(outdir, str(
count)+"_1.png"), image1, [cv2.IMWRITE_PNG_COMPRESSION, 0])
cv2.imwrite(os.path.join(outdir, str(
count)+"_2.png"), image2, [cv2.IMWRITE_PNG_COMPRESSION, 0])
image1=image2
count+=1
设置数据集heart_dataset.py
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
from posixpath import join
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import tqdm
import os
import math
import random
from glob import glob
import os.path as osp
from core.utils import frame_utils
from core.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
import albumentations as albu
from albumentations.pytorch import ToTensor
class FlowDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.extra_info = []
self.frames_transforms = albu.Compose([
albu.Normalize((0., 0., 0.), (1., 1., 1.)),
ToTensor()
])
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
# self
img1=np.repeat(np.array(img1)[:, :, np.newaxis], 3, axis=2)
img2 = np.repeat(np.array(img2)[:, :, np.newaxis], 3, axis=2)
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
frame1 = self.frames_transforms(image=img1)['image']
frame2 = self.frames_transforms(image=img2)['image']
# import cv2
# cv2.imshow('image1', img1)
# cv2.imshow('image2', img2)
# cv2.waitKey(-1)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[..., None], (1, 1, 3))
img2 = np.tile(img2[..., None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(
img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float(), frame1, frame2
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
class HeartDataset(FlowDataset):
def __init__(self, aug_params=None, split='testing', root='datasets/heart'):
super(HeartDataset, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
root = os.path.join('/workspace/UnSupRAFT', root)
images1 = sorted(glob(osp.join(root, '*_1.png')))
images2 = sorted(glob(osp.join(root, '*_2.png')))
import sys
sys.path.append('UnSupRAFT')
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [[frame_id]]
self.image_list += [[img1, img2]]
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
train_dataset = HeartDataset()
train_loader = data.DataLoader(train_dataset,
batch_size=args.batch_size,
pin_memory=False,
shuffle=True,
num_workers=4,
drop_last=True)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
# 测试数据集
if __name__=='__main__':
aug_params = {
'crop_size': [512, 512],
'min_scale': -0.2,
'max_scale': 0.4,
'do_flip': False
}
train_dataset = HeartDataset(aug_params)
train_loader = data.DataLoader(train_dataset,
batch_size=1,
pin_memory=True,
shuffle=True,
num_workers=1,
drop_last=True)
print('Training with %d image pairs' % len(train_dataset))
for batch_ndx, sample in enumerate(train_loader):
image1=sample[0].cuda()
image2=sample[1].cuda()
id=sample[2]
for x in sample:
print(x)
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
运行文件hear_train.py
from __future__ import print_function, division
import sys
# sys.path.append('<core_path>')
'''
python -u train.py --name raft-chairs --stage chairs --validation chairs
--gpus 0 --num_steps 100000 --batch_size 8 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
'''
import argparse
import os
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import sys
sys.path.append('core')
from raft import RAFT
import evaluate
from core import datasets
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from core.utils.unsupervised_loss import unsup_loss
import heart_dataset
try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass
def scale(self, loss):
return loss
def unscale_(self, optimizer):
pass
def step(self, optimizer):
optimizer.step()
def update(self):
pass
# exclude extremly large displacements
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 200
def sequence_loss(flow_preds, warped_images, img1, total_steps, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds)
flow_loss = 0.0
# exlude invalid pixels and extremely large diplacements
# mag = torch.sum(flow_gt**2, dim=1).sqrt()
for i in range(n_predictions):
i_weight = gamma**(n_predictions - i - 1)
# if total_steps < 0:
# i_loss = (flow_preds[i] - flow_gt).abs()
# else:
i_loss = unsup_loss(flow_preds[i], warped_images[i], img1)
flow_loss += i_weight * (i_loss).mean()
# epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
# epe = epe.view(-1)
metrics = {
'loss': flow_loss.item()
}
return flow_loss, metrics
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(),
lr=args.lr,
weight_decay=args.wdecay,
eps=args.epsilon)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
args.lr,
args.num_steps + 100,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy='linear')
return optimizer, scheduler
class Logger:
def __init__(self, model, scheduler):
self.model = model
self.scheduler = scheduler
self.total_steps = 0
self.running_loss = {}
self.writer = None
def _print_training_status(self):
metrics_data = [
self.running_loss[k] / SUM_FREQ
for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}, {:10.7f}] ".format(
self.total_steps + 1,
self.scheduler.get_last_lr()[0])
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
print(training_str + metrics_str)
if self.writer is None:
self.writer = SummaryWriter()
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k] / SUM_FREQ,
self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics):
self.total_steps += 1
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % SUM_FREQ == SUM_FREQ - 1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter()
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
def train(args):
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt), strict=False)
model.cuda()
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
train_loader = heart_dataset.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 95000
scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(model, scheduler)
VAL_FREQ = 5000
add_noise = True
should_keep_training = True
while should_keep_training:
for i_batch, data_blob in tqdm(enumerate(train_loader),
total=len(train_loader)):
optimizer.zero_grad()
image1 = data_blob[0].cuda()
image2 = data_blob[1].cuda()
id = data_blob[2]
# image1, image2 = [
# x for x in data_blob
# ]
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 +
stdv * torch.randn(*image1.shape).cuda()).clamp(
0.0, 255.0)
image2 = (image2 +
stdv * torch.randn(*image2.shape).cuda()).clamp(
0.0, 255.0)
# def forward(self, image1, image2, flow_gt=None, frame1=None, frame2=None, \
# iters=12, flow_init=None, upsample=True, test_mode=False):
flow_predictions, warped_images = model(image1, image2, iters=args.iters)
loss, metrics = sequence_loss(flow_predictions, \
warped_images, image1, total_steps=total_steps, gamma=args.gamma)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
scaler.step(optimizer)
scheduler.step()
scaler.update()
logger.push(metrics)
# print(total_steps, total_steps % VAL_FREQ)
if total_steps % VAL_FREQ == VAL_FREQ - 1:
PATH = 'checkpoints/heart/%d_%s.pth' % (total_steps + 1,
args.name)
torch.save(model.state_dict(), PATH)
# results = {}
# for val_dataset in args.validation:
# if val_dataset == 'chairs':
# results.update(evaluate.validate_chairs(model.module))
# elif val_dataset == 'sintel':
# results.update(evaluate.validate_sintel(model.module))
# elif val_dataset == 'kitti':
# results.update(evaluate.validate_kitti(model.module))
# logger.write_dict(results)
model.train()
if args.stage != 'chairs':
model.module.freeze_bn()
total_steps += 1
if total_steps > args.num_steps:
should_keep_training = False
break
logger.close()
PATH = 'checkpoints/%s.pth' % args.name
torch.save(model.state_dict(), PATH)
return PATH
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--stage',
default='heart',
help="determines which dataset to use for training")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_steps', type=int, default=100000)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--image_size',
type=int,
nargs='+',
default=[512, 512])
parser.add_argument('--gpus', type=int, nargs='+', default=[0])
parser.add_argument('--mixed_precision',
action='store_true',
help='use mixed precision')
parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma',
type=float,
default=0.8,
help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()
torch.manual_seed(1234)
np.random.seed(1234)
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
train(args)
修改raft.py
def forward(self, image1, image2, flow_gt, frame1, frame2, \
iters=12, flow_init=None, upsample=True, test_mode=False):
#修改为
def forward(self, image1, image2, flow_gt=None, frame1=None, frame2=None, \
iters=12, flow_init=None, upsample=True, test_mode=False):
#预测代码修改evaluate.py
import sys
sys.path.append('core')
from PIL import Image
import argparse
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from core import datasets
from utils import flow_viz
from utils import frame_utils
from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
from tqdm import tqdm
from scipy.special import softmax
import cv2
import pickle
def viz(img, flows, val_id):
img = img[0].permute(1,2,0).cpu().numpy()
new_flows = []
for flo in flows:
try:
flo = flo[0].permute(1,2,0).cpu().numpy()
except:
flo = flo.permute(1,2,0).cpu().numpy()
flo = flow_viz.flow_to_image(flo)
new_flows.append(flo)
img_flo = np.concatenate([img, *new_flows], axis=0)
img = img_flo[:, :, [2,1,0]]#/255.0
cv2.imwrite(f'outputs/{val_id}.png', img)
img /= 255.0
# cv2.imshow('image', img)
# cv2.waitKey()
# cv2.destroyAllWindows()
@torch.no_grad()
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
for dstype in ['clean', 'final']:
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)
flow_prev, sequence_prev = None, None
for test_id in tqdm(range(len(test_dataset))):
image1, image2, (sequence, frame) = test_dataset[test_id]
if sequence != sequence_prev:
flow_prev = None
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
flow_low, flow_pr = model(image1, image2, \
frame1=None, frame2=None, flow_gt=None, \
iters=iters, flow_init=flow_prev, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
if warm_start:
flow_prev = forward_interpolate(flow_low[0])[None].cuda()
output_dir = os.path.join(output_path, dstype, sequence)
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
frame_utils.writeFlow(output_file, flow)
sequence_prev = sequence
@torch.no_grad()
def create_kitti_submission(model, iters=24, output_path='kitti_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
test_dataset = datasets.KITTI(split='testing', aug_params=None)
if not os.path.exists(output_path):
os.makedirs(output_path)
for test_id in range(len(test_dataset)):
image1, image2, (frame_id, ) = test_dataset[test_id]
padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
_, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()
output_filename = os.path.join(output_path, frame_id)
frame_utils.writeFlowKITTI(output_filename, flow)
@torch.no_grad()
def validate_chairs(model, iters=24):
""" Perform evaluation on the FlyingChairs (test) split """
model.eval()
epe_list = []
val_dataset = datasets.FlyingChairs(split='validation')
# val_dataset = datasets.FlyingChairs(split='validation')
for val_id in tqdm(range(len(val_dataset))):
image1, image2, flow_gt, _, frame1, frame2 = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
_, flow_pr = model(image1, image2, flow_gt, \
frame1=frame1, frame2=frame2, \
iters=iters, test_mode=True)
# viz(image1, [flow_pr, flow_gt], val_id)
epe = torch.sum((flow_pr[0].cpu() - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
epe = np.mean(np.concatenate(epe_list))
print("Validation Chairs EPE: %f" % epe)
return {'chairs': epe}
@torch.no_grad()
def validate_sintel(model, iters=32):
""" Peform validation using the Sintel (train) split """
model.eval()
results = {}
# for dstype in ['clean', 'final']:
feature_maps = []
for dstype in ['clean']:
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
# val_dataset = datasets.MpiSintel(split='test', dstype=dstype)
epe_list = []
for val_id in tqdm(range(len(val_dataset))):
frame1, frame2 = None, None
image1, image2, flow_gt, _, frame1, frame2 = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
_, flow_pr, fmap1 = model(image1, image2, flow_gt, \
frame1=frame1, frame2=frame2, \
iters=iters, test_mode=True)
feature_maps.append(fmap1.cpu().numpy())
flow = padder.unpad(flow_pr[0]).cpu()
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
if val_id > 900:
break
pickle.dump(feature_maps, open( "umap_pickles/save_900.p", "wb" ))
exit()
epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)
return results
@torch.no_grad()
def validate_kitti(model, iters=24):
""" Peform validation using the KITTI-2015 (train) split """
output_path = 'kitti_submission'
if not os.path.exists(output_path):
os.makedirs(output_path)
model.eval()
val_dataset = datasets.KITTI(split='training')
out_list, epe_list = [], []
for val_id in range(len(val_dataset)):
# image1, image2, flow_gt, valid_gt = val_dataset[val_id]
image1, image2, flow_gt, valid_gt, frame1, frame2 = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape, mode='kitti')
image1, image2 = padder.pad(image1, image2)
flow_low, flow_pr = model(image1, image2, flow_gt, \
frame1=frame1, frame2=frame2, \
iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
viz(image1[:, :, :-1, :-6], [flow], val_id)
# output_filename = os.path.join(output_path, str(val_id)+'.flo')
# frame_utils.writeFlow(output_filename, flow)
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
mag = torch.sum(flow_gt**2, dim=0).sqrt()
epe = epe.view(-1)
mag = mag.view(-1)
val = valid_gt.view(-1) >= 0.5
out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
epe_list.append(epe[val].mean().item())
out_list.append(out[val].cpu().numpy())
epe_list = np.array(epe_list)
out_list = np.concatenate(out_list)
epe = np.mean(epe_list)
f1 = 100 * np.mean(out_list)
print("Validation KITTI: %f, %f" % (epe, f1))
return {'kitti-epe': epe, 'kitti-f1': f1}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model',default='/workspace/UnSupRAFT/models/raft-kitti.pth', help="restore checkpoint")
parser.add_argument('--dataset', default='kitti' ,help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args()
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model.cuda()
model.eval()
# create_sintel_submission(model.module, warm_start=True)
# create_kitti_submission(model.module)
with torch.no_grad():
if args.dataset == 'chairs':
validate_chairs(model.module)
elif args.dataset == 'sintel':
validate_sintel(model.module)
elif args.dataset == 'kitti':
validate_kitti(model.module)