实例
数据集
笔记本
笔记本

真实场景篡改图像检测挑战赛 —— 预测
暂无摘要
笔记本内容
import os
if not os.path.exists("/home/featurize/data/forgery_round1_test_20220217.zip"):
!featurize dataset download 966d7093-847d-4bd9-bdab-f03c7eb6eb77
import os
import cv2
import torch
import minetorch
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as albu
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from tqdm import tqdm
from albumentations.pytorch import ToTensorV2
from albumentations import (Normalize, Compose)
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
配置 #
CONFIG = {
'DIR': 'Experiments',
'img_size': 224,
'model_name': 'timm-efficientnet-b0',
'fold': 0,
'lr': 1e-4,
'batch_size': 1,
}
创建 PyTorch Dataset #
class AliDataset(Dataset):
def __init__(self, df, test_dir, transforms):
self.df = df
self.test_dir = test_dir
self.transforms = transforms
def __getitem__(self, idx):
row = self.df.iloc[idx]
fn = row.image
image = cv2.imread(os.path.join(self.test_dir, fn))
img_ori = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
shape = img_ori.shape
augmented = self.transforms(image=img_ori)
img = augmented['image']
return img.float(), fn, shape
def __len__(self):
return len(self.df)
def make_transforms(phase,mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), tta=False):
list_transforms = []
if tta:
list_transforms.extend(
[
albu.HorizontalFlip(p=1),
]
)
list_transforms.extend(
[
albu.Resize(CONFIG['img_size'], CONFIG['img_size']),
ToTensorV2(),
]
)
list_trfms = Compose(list_transforms)
return list_trfms
读取数据(TTA) #
TEST = '/home/featurize/data/test/img'
df = pd.DataFrame(os.listdir(TEST), columns=['image'])
testset = AliDataset(df, TEST, make_transforms('val'))
testset_tta = AliDataset(df, TEST, make_transforms('val', tta=True))
test_loader = DataLoader(
testset,
batch_size=CONFIG['batch_size'],
num_workers=8,
pin_memory=True
)
test_loader_tta = DataLoader(
testset_tta,
batch_size=CONFIG['batch_size'],
num_workers=8,
pin_memory=True
)
加载训练以后的模型 #
ckpts = [
'fold-0/models/latest.pth.tar'
]
models = []
for ckpt in ckpts:
model = smp.UnetPlusPlus(
CONFIG['model_name'],
classes=1,
encoder_weights=None,
activation=None,
).cuda()
stuff = torch.load(os.path.join(f'/home/featurize/{CONFIG["DIR"]}/{CONFIG["model_name"]}-{CONFIG["img_size"]}', ckpt))
model.load_state_dict(stuff['state_dict'])
model.eval();
models.append(model)
创建存放结果文件夹 #
!mkdir /home/featurize/images
!mkdir /home/featurize/images/images
计算结果 #
def sigmoid(x):
return 1/(1 + np.exp(-x))
for (image, fn, shape), (image_tta, _, _) in tqdm(zip(test_loader, test_loader_tta), total=len(testset)):
mask1 = 0
for model in models:
masks1 = model(image.cuda())
masks_tta1 = model(image_tta.cuda())
mask1 += albu.Resize(shape[0].item(), shape[1].item())(image=masks1[0].permute(1,2,0).detach().cpu().numpy())['image']
mask1 += albu.Resize(shape[0].item(), shape[1].item())(image=np.flip(masks_tta1[0].permute(1,2,0).detach().cpu().numpy(), axis=1))['image']
mask1 /= 2
mask1 /= len(models)
fake_mask = ((sigmoid(mask1) > 0.5)*255.).astype(np.uint8)
cv2.imwrite(f'/home/featurize/images/images/{fn[0].split(".")[0]}.png', fake_mask.astype(np.uint8))
100%|██████████| 4000/4000 [03:11<00:00, 20.83it/s]
查看结果是否完整(4000张) #
len(os.listdir('/home/featurize/images/images'))
4000
计算 Mask 分布 #
MASK = '/home/featurize/images/images'
ratios = []
for i in tqdm(range(len(df))):
mask_area = cv2.imread(os.path.join(MASK, os.listdir(MASK)[i]), cv2.IMREAD_GRAYSCALE)
ratio = np.sum(mask_area/255.) / (mask_area.shape[0] * mask_area.shape[1])
ratios.append(ratio)
100%|██████████| 4000/4000 [00:22<00:00, 179.51it/s]
查看 Mask 分布 #
sns.histplot(ratios);
查看结果随机样本 #
import random
index = random.randint(0, 4000)
f, ax = plt.subplots(1,2)
ax[0].imshow(cv2.cvtColor(cv2.imread(f'/home/featurize/data/test/img/{index}.jpg'), cv2.COLOR_BGR2RGB));
ax[1].imshow(cv2.imread(f'/home/featurize/images/images/{index}.png', cv2.IMREAD_UNCHANGED));
from IPython.display import clear_output
!cd /home/featurize;zip -r images.zip images
clear_output()
右键下载以后上传结果 #
评论(0条)