实例
数据集
笔记本

笔记本

真实场景篡改图像检测挑战赛 —— 预测

暂无摘要
Dave上传于 3 years ago
标签
暂无标签
浏览2480
笔记本内容
import os
if not os.path.exists("/home/featurize/data/forgery_round1_test_20220217.zip"):
    !featurize dataset download 966d7093-847d-4bd9-bdab-f03c7eb6eb77
import os
import cv2
import torch
import minetorch

import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as albu
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

from tqdm import tqdm
from albumentations.pytorch import ToTensorV2
from albumentations import (Normalize, Compose)
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

配置 #

CONFIG = {
    'DIR': 'Experiments',
    'img_size': 224,
    'model_name': 'timm-efficientnet-b0',
    'fold': 0,
    'lr': 1e-4,
    'batch_size': 1,
}

创建 PyTorch Dataset #

class AliDataset(Dataset):
    def __init__(self, df, test_dir, transforms):
        self.df = df
        self.test_dir = test_dir
        self.transforms = transforms

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fn = row.image
        image = cv2.imread(os.path.join(self.test_dir, fn))
        img_ori = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        shape = img_ori.shape
        augmented = self.transforms(image=img_ori)
        img = augmented['image']

        return img.float(), fn, shape

    def __len__(self):
        return len(self.df)

def make_transforms(phase,mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), tta=False):
    list_transforms = []
    if tta:
        list_transforms.extend(
            [
                albu.HorizontalFlip(p=1),
            ]
        )
    list_transforms.extend(
        [
            albu.Resize(CONFIG['img_size'], CONFIG['img_size']),
            ToTensorV2(),
        ]
    )

    list_trfms = Compose(list_transforms)
    return list_trfms

读取数据(TTA) #

TEST = '/home/featurize/data/test/img'
df = pd.DataFrame(os.listdir(TEST), columns=['image'])

testset = AliDataset(df, TEST, make_transforms('val'))
testset_tta = AliDataset(df, TEST, make_transforms('val', tta=True))

test_loader = DataLoader(
    testset,
    batch_size=CONFIG['batch_size'],
    num_workers=8,
    pin_memory=True
)

test_loader_tta = DataLoader(
    testset_tta,
    batch_size=CONFIG['batch_size'],
    num_workers=8,
    pin_memory=True
)

加载训练以后的模型 #

ckpts = [
    'fold-0/models/latest.pth.tar'
]

models = []

for ckpt in ckpts:
    model = smp.UnetPlusPlus(
        CONFIG['model_name'],
        classes=1,
        encoder_weights=None,
        activation=None,
    ).cuda()
    stuff = torch.load(os.path.join(f'/home/featurize/{CONFIG["DIR"]}/{CONFIG["model_name"]}-{CONFIG["img_size"]}', ckpt))
    model.load_state_dict(stuff['state_dict'])
    model.eval();
    models.append(model)

创建存放结果文件夹 #

!mkdir /home/featurize/images
!mkdir /home/featurize/images/images

计算结果 #

def sigmoid(x):
    return 1/(1 + np.exp(-x))

for (image, fn, shape), (image_tta, _, _) in tqdm(zip(test_loader, test_loader_tta), total=len(testset)):
    mask1 = 0
    for model in models:
        masks1 = model(image.cuda())
        masks_tta1 = model(image_tta.cuda())
        mask1 += albu.Resize(shape[0].item(), shape[1].item())(image=masks1[0].permute(1,2,0).detach().cpu().numpy())['image']
        mask1 += albu.Resize(shape[0].item(), shape[1].item())(image=np.flip(masks_tta1[0].permute(1,2,0).detach().cpu().numpy(), axis=1))['image']
        mask1 /= 2
    mask1 /= len(models)
    fake_mask = ((sigmoid(mask1) > 0.5)*255.).astype(np.uint8)
    cv2.imwrite(f'/home/featurize/images/images/{fn[0].split(".")[0]}.png', fake_mask.astype(np.uint8))
100%|██████████| 4000/4000 [03:11<00:00, 20.83it/s]

查看结果是否完整(4000张) #

len(os.listdir('/home/featurize/images/images'))
4000

计算 Mask 分布 #

MASK = '/home/featurize/images/images'
ratios = []
for i in tqdm(range(len(df))):
    mask_area = cv2.imread(os.path.join(MASK, os.listdir(MASK)[i]), cv2.IMREAD_GRAYSCALE)
    ratio = np.sum(mask_area/255.) / (mask_area.shape[0] * mask_area.shape[1])
    ratios.append(ratio)
100%|██████████| 4000/4000 [00:22<00:00, 179.51it/s]

查看 Mask 分布 #

sns.histplot(ratios);

查看结果随机样本 #

import random
index = random.randint(0, 4000)
f, ax = plt.subplots(1,2)
ax[0].imshow(cv2.cvtColor(cv2.imread(f'/home/featurize/data/test/img/{index}.jpg'), cv2.COLOR_BGR2RGB));
ax[1].imshow(cv2.imread(f'/home/featurize/images/images/{index}.png', cv2.IMREAD_UNCHANGED));
from IPython.display import clear_output
!cd /home/featurize;zip -r images.zip images
clear_output()

右键下载以后上传结果 #

image.png

评论(0条)