深度学习,孪生网络 quick start 求助 - V2EX
请不要在回答技术问题时复制粘贴 AI 生成的内容
yangyuhan12138

深度学习,孪生网络 quick start 求助

  •  
  •   yangyuhan12138 Jun 19, 2023 1848 views
    This topic created in 1067 days ago, the information mentioned may be changed or developed.

    书接上回:背景介绍,我是想做一个验证码识别,验证码的类型是点选图标,我打算分两步走,第一步使用 yolo 进行目标检测,识别出背景中的图标(big),以及要求我们点选的图标(small);第二步是使用孪生网络,比较 small 和 big 的相似性,选出最匹配的 big

    现在第一步已经走完了,达到了预期期望,能正常把验证码中的图标都扣出来,并且准确率还行, 现在第二部我也是在网上找了一些孪生网络的代码,尝试在电脑上自己训练,现在的损失值已经极低了,但是识别新图片的准确率不行,我知道需要增加训练数据,当时想请大佬们帮我看看我现在的网络写的有没有问题...这个加数据真的特别难加...而且我们已经有不少数据了.. 原图长这样

    321687153222_.pic.jpg

    用于训练孪生的长这样 现在大概 30 多个文件夹,每个文件夹里长这样

    311687153180_.pic.jpg

    import matplotlib.pyplot as plt import numpy as np import random from PIL import Image import PIL.ImageOps import torchvision import torchvision.datasets as datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset import torchvision.utils import torch from torch.autograd import Variable import torch.nn as nn from torch import optim import torch.nn.functional as F # Showing images def imshow(img, text=None): npimg = img.numpy() plt.axis("off") if text: plt.text(75, 8, text, style='italic', fOntweight='bold', bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10}) plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # Plotting data def show_plot(iteration, loss): plt.plot(iteration, loss) plt.show() class SiameseNetworkDataset(Dataset): def __init__(self, imageFolderDataset, transform=None): self.imageFolderDataset = imageFolderDataset self.transform = transform def __getitem__(self, index): img0_tuple = random.choice(self.imageFolderDataset.imgs) # We need to approximately 50% of images to be in the same class should_get_same_class = random.randint(0, 1) if should_get_same_class: while True: # Look untill the same class image is found img1_tuple = random.choice(self.imageFolderDataset.imgs) if img0_tuple[1] == img1_tuple[1]: break else: while True: # Look untill a different class image is found img1_tuple = random.choice(self.imageFolderDataset.imgs) if img0_tuple[1] != img1_tuple[1]: break img0 = Image.open(img0_tuple[0]) img1 = Image.open(img1_tuple[0]) img0 = img0.convert("L") img1 = img1.convert("L") if self.transform is not None: img0 = self.transform(img0) img1 = self.transform(img1) return img0, img1, torch.from_numpy(np.array([int(img1_tuple[1] != img0_tuple[1])], dtype=np.float32)) def __len__(self): return len(self.imageFolderDataset.imgs) # create the Siamese Neural Network class SiameseNetwork(nn.Module): def __init__(self): super(SiameseNetwork, self).__init__() # Setting up the Sequential of CNN Layers self.cnn1 = nn.Sequential( nn.Conv2d(1, 96, kernel_size=11, stride=4), nn.ReLU(inplace=True), nn.MaxPool2d(3, stride=2), nn.Dropout(0.5), # 添加 dropout nn.Conv2d(96, 256, kernel_size=5, stride=1), nn.ReLU(inplace=True), nn.MaxPool2d(2, stride=2), nn.Conv2d(256, 384, kernel_size=3, stride=1), nn.ReLU(inplace=True) ) # Setting up the Fully Connected Layers self.fc1 = nn.Sequential( nn.Linear(384, 1024), nn.ReLU(inplace=True), nn.Linear(1024, 256), nn.ReLU(inplace=True), nn.Linear(256, 2) ) def forward_once(self, x): # This function will be called for both images # Its output is used to determine the similiarity output = self.cnn1(x) output = output.view(x.size()[0], 384) output = self.fc1(output) return output def forward(self, input1, input2): # In this function we pass in both images and obtain both vectors # which are returned output1 = self.forward_once(input1) output2 = self.forward_once(input2) return output1, output2 # Define the Contrastive Loss Function class ContrastiveLoss(torch.nn.Module): def __init__(self, margin=2.0): super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, output1, output2, label): # Calculate the euclidean distance and calculate the contrastive loss euclidean_distance = F.pairwise_distance( output1, output2, keepdim=True) loss_cOntrastive= torch.mean((1-label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) return loss_contrastive def build(epochs, model_path, old_model_path): # Load the training dataset # folder_dataset = datasets.ImageFolder(root="./data/faces/training/") folder_dataset = datasets.ImageFolder(root="/Users/yangyuhan/temp/") # # Resize the images and transform to tensors transformation = transfors.Compose([transforms.Resize((100, 100)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ColorJitter( brightness=0.1, cOntrast=0.1, saturation=0.1), transforms.ToTensor() ]) # Initialize the network siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset, transform=transformation) # Load the training dataset train_dataloader = DataLoader(siamese_dataset, shuffle=True, num_workers=6, batch_size=64) # net = SiameseNetwork().cuda() net = SiameseNetwork() net.load_state_dict(torch.load(old_model_path)) # Set the model in training mode net.train() criterion = ContrastiveLoss() optimizer = optim.Adam(net.parameters(), lr=0.0005) counter = [] loss_history = [] iteration_number = 0 # Iterate throught the epochs for epoch in range(epochs): # Iterate over batches for i, (img0, img1, label) in enumerate(train_dataloader, 0): # Send the images and labels to CUDA # img0, img1, label = img0.cuda(), img1.cuda(), label.cuda() # Zero the gradients optimizer.zero_grad() # Pass in the two images into the network and obtain two outputs output1, output2 = net(img0, img1) # Pass the outputs of the networks and label into the loss function loss_cOntrastive= criterion(output1, output2, label) # Calculate the backpropagation loss_contrastive.backward() # Optimize optimizer.step() # Every 10 batches print out the loss if i % 10 == 0: print( f"Epoch number {epoch}\n Current loss {loss_contrastive.item()}\n") iteration_number += 10 counter.append(iteration_number) loss_history.append(loss_contrastive.item()) show_plot(counter, loss_history) torch.save(net.state_dict(), model_path) if __name__ == '__main__': old_model_path = '/Users/yangyuhan/model/captcha_618_200.pth' model_path = '/Users/yangyuhan/model/captcha_618_500.pth' # test(model_path) build(200, model_path, old_model_path) 

    我的代码合源码的区别是 output = output.view(x.size()[0], 384) 这行 他原本应该是 output = output.view(output.size()[0], -1) 我不知道这个会不会对训练有啥影响,因为我直接用它的代码这里是报错的,后来在 gpt 上找的答案,告诉我这里改成 384 就可以了.

    请大佬们帮我看看网络结构有没有什么问题 现在我训练的数据是统一类别的图片放在同一个文件夹下边的 Current loss 已经是 0.004438498988747597 是不是我训练的 太狠了?会不会让他的 loss 保持在 0.1 附近效果比较好? 之类的... 可以给点建议吗

    由于我的网络返回的是相似性,越接近 0 越相似,所以我最后测试的时候使用 3 长小图每张都去和每张大图比较,取各自最相似的一张作为结果.

    6 replies    2023-06-23 15:52:51 +08:00
    yangyuhan12138
        1
    yangyuhan12138  
    OP
       Jun 19, 2023
    我现在大概理解网络的结构,只是很大概的大概,所以麻烦大佬们 说的详细点...
    r6cb
        2
    r6cb  
       Jun 19, 2023
    loss 过小可能是数据集太小的原因
    yangyuhan12138
        3
    yangyuhan12138  
    OP
       Jun 19, 2023
    @r6cb 最开始是 2.几 还是挺大的,我是训练了很多 epoch 之后变成这个样子了,他对训练集的数据基本都没啥问题了,主要是对新数据的预测能力不太行
    opeth
        4
    opeth  
       Jun 19, 2023   1
    数据比较少,加点防过拟合方式,比如 DropOut ,增强 Data Augmentation
    再就是换个 loss 试试
    另外我看你数据预处理的方式,是带着背景抠的图,大小位置还都不一样,这可能会导致网络拟合到背景图片内容上面去
    有条件的话把图标 mask 出来,最最起码要检测框紧贴那个前景图标的边缘吧,这样你的检测器也能训练的更好
    做分类之前最好再做个对齐,把输入图片里图标的尺寸和位置都对齐到一致的参数
    hubahuba
        5
    hubahuba  
       Jun 20, 2023
    @opeth 感谢大佬,学到一些
    yangyuhan12138
        6
    yangyuhan12138  
    OP
       Jun 23, 2023
    @opeth 感谢大佬的耐心回复,但是有些地方我还是不知道怎么弄,因为本来从来也没涉及到过这个领域,我决定先暂时不搞了,我本来是期望再找到一个像 yolo 这样的开箱即用的东西,但是找了找好像没有,所以就先算了,以后我系统学习了 深度学习再来看看这块的内容
    About     Help     Advertise     Blog     API     FAQ     Solana     5195 Online   Highest 6679       Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 69ms UTC 08:53 PVG 16:53 LAX 01:53 JFK 04:53
    Do have faith in what you're doing.
    ubao msn snddm index pchome yahoo rakuten mypaper meadowduck bidyahoo youbao zxmzxm asda bnvcg cvbfg dfscv mmhjk xxddc yybgb zznbn ccubao uaitu acv GXCV ET GDG YH FG BCVB FJFH CBRE CBC GDG ET54 WRWR RWER WREW WRWER RWER SDG EW SF DSFSF fbbs ubao fhd dfg ewr dg df ewwr ewwr et ruyut utut dfg fgd gdfgt etg dfgt dfgd ert4 gd fgg wr 235 wer3 we vsdf sdf gdf ert xcv sdf rwer hfd dfg cvb rwf afb dfh jgh bmn lgh rty gfds cxv xcv xcs vdas fdf fgd cv sdf tert sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf shasha9178 shasha9178 shasha9178 shasha9178 shasha9178 liflif2 liflif2 liflif2 liflif2 liflif2 liblib3 liblib3 liblib3 liblib3 liblib3 zhazha444 zhazha444 zhazha444 zhazha444 zhazha444 dende5 dende denden denden2 denden21 fenfen9 fenf619 fen619 fenfe9 fe619 sdf sdf sdf sdf sdf zhazh90 zhazh0 zhaa50 zha90 zh590 zho zhoz zhozh zhozho zhozho2 lislis lls95 lili95 lils5 liss9 sdf0ty987 sdft876 sdft9876 sdf09876 sd0t9876 sdf0ty98 sdf0976 sdf0ty986 sdf0ty96 sdf0t76 sdf0876 df0ty98 sf0t876 sd0ty76 sdy76 sdf76 sdf0t76 sdf0ty9 sdf0ty98 sdf0ty987 sdf0ty98 sdf6676 sdf876 sd876 sd876 sdf6 sdf6 sdf9876 sdf0t sdf06 sdf0ty9776 sdf0ty9776 sdf0ty76 sdf8876 sdf0t sd6 sdf06 s688876 sd688 sdf86