您现在的位置是:群英 > 开发技术 > Python语言
Pytorch数据加载和处理怎样实现
Admin发表于 2022-05-17 17:08:08668 次浏览
相信很多人对“Pytorch数据加载和处理怎样实现”都不太了解,下面群英小编为你详细解释一下这个问题,希望对你有一定的帮助



目录
  • 一、下载安装包
  • 二、下载数据集
  • 三、读取数据集
  • 四、编写一个函数看看图像和landmark
  • 五、数据集类
  • 六、数据可视化
  • 七、数据变换
    • 1、Function_Rescale
    • 2、Function_RandomCrop
    • 3、Function_ToTensor
  • 八、组合转换
    • 九、迭代数据集
      • 总结

        一、下载安装包

        packages:

        • scikit-image:用于图像测IO和变换
        • pandas:方便进行csv解析

        二、下载数据集

        数据集说明:该数据集(我在这)是imagenet数据集标注为face的图片当中在dlib面部检测表现良好的图片――处理的是一个面部姿态的数据集,也就是按照入戏方式标注人脸


        数据集展示


        三、读取数据集

        #%%读取数据集
        landmarks_frame=pd.read_csv('D:/Python/Pytorch/data/faces/face_landmarks.csv')
        n=65
        img_name=landmarks_frame.iloc[n,0]
        landmarks=landmarks_frame.iloc[n,1:].values
        landmarks=landmarks.astype('float').reshape(-1,2)
        print('Image name :{}'.format(img_name))
        print('Landmarks shape :{}'.format(landmarks.shape))
        print('First 4 Landmarks:{}'.format(landmarks[:4]))
        

        运行结果

        四、编写一个函数看看图像和landmark

        #%%编写显示人脸函数
        def show_landmarks(image,landmarks):
            plt.imshow(image)
            plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker=".",c='r')
            plt.pause(0.001)
        plt.figure()
        show_landmarks(io.imread(os.path.join('D:/Python/Pytorch/data/faces/',img_name)),landmarks)
        plt.show()
        

        运行结果

        五、数据集类

        torch.utils.data.Dataset是表示数据集的抽象类,自定义数据类应继承Dataset并覆盖__len__实现len(dataset)返还数据集的尺寸。__getitem__用来获取一些索引数据:

        #%%数据集类――将数据集封装成一个类
        class FaceLandmarksDataset(Dataset):
            def __init__(self,csv_file,root_dir,transform=None):
                # csv_file(string):待注释的csv文件的路径
                # root_dir(string):包含所有图像的目录
                # transform(callabele,optional):一个样本上的可用的可选变换
                self.landmarks_frame=pd.read_csv(csv_file)
                self.root_dir=root_dir
                self.transform=transform
            def __len__(self):
                return len(self.landmarks_frame)
            def __getitem__(self, idx):
                img_name=os.path.join(self.root_dir,self.landmarks_frame.iloc[idx,0])
                image=io.imread(img_name)
                landmarks=self.landmarks_frame.iloc[idx,1:]
                landmarks=np.array([landmarks])
                landmarks=landmarks.astype('float').reshape(-1,2)
                sample={'image':image,'landmarks':landmarks}
                if self.transform:
                    sample=self.transform(sample)
                return sample    
        

        六、数据可视化

        #%%数据可视化
        # 将上面定义的类进行实例化并便利整个数据集
        face_dataset=FaceLandmarksDataset(csv_file='D:/Python/Pytorch/data/faces/face_landmarks.csv', 
                                          root_dir='D:/Python/Pytorch/data/faces/')
        fig=plt.figure()
        for i in range(len(face_dataset)) :
            sample=face_dataset[i]
            print(i,sample['image'].shape,sample['landmarks'].shape)
            ax=plt.subplot(1,4,i+1)
            plt.tight_layout()
            ax.set_title('Sample #{}'.format(i))
            ax.axis('off')
            show_landmarks(**sample)
            if i==3:
                plt.show()
                break
        

        运行结果


        七、数据变换

        由上图可以发现每张图像的尺寸大小是不同的。绝大多数神经网路都嘉定图像的尺寸相同。所以需要对图像先进行预处理。创建三个转换:

        Rescale:缩放图片

        RandomCrop:对图片进行随机裁剪

        ToTensor:把numpy格式图片转成torch格式图片(交换坐标轴)和上面同样的方式,将其写成一个类,这样就不需要在每次调用的时候川第一此参数,只需要实现__call__的方法,必要的时候使用__init__方法

        1、Function_Rescale

        # 将样本中的图像重新缩放到给定的大小
        class Rescale(object):    
            def __init__(self,output_size):
                assert isinstance(output_size,(int,tuple))
                self.output_size=output_size
            #output_size 为int或tuple,如果是元组输出与output_size匹配,
            #如果是int,匹配较小的图像边缘到output_size保持纵横比相同
            def __call__(self,sample):
                image,landmarks=sample['image'],sample['landmarks']
                h,w=image.shape[:2]
                if isinstance(self.output_size, int):#输入参数是int
                    if h>w:
                        new_h,new_w=self.output_size*h/w,self.output_size
                    else:
                        new_h,new_w=self.output_size,self.output_size*w/h
                else:#输入参数是元组
                    new_h,new_w=self.output_size
                new_h,new_w=int(new_h),int(new_w)
                img=transform.resize(image, (new_h,new_w))
                landmarks=landmarks*[new_w/w,new_h/h]
                return {'image':img,'landmarks':landmarks}
        

        2、Function_RandomCrop

        # 随机裁剪样本中的图像
        class RandomCrop(object):
            def __init__(self,output_size):
                assert isinstance(output_size, (int,tuple))
                if isinstance(output_size, int):
                    self.output_size=(output_size,output_size)
                else:
                    assert len(output_size)==2
                    self.output_size=output_size
            # 输入参数依旧表示想要裁剪后图像的尺寸,如果是元组其而包含两个元素直接复制长宽,如果是int,则裁剪为方形
            def __call__(self,sample):
                image,landmarks=sample['image'],sample['landmarks']
                h,w=image.shape[:2]
                new_h,new_w=self.output_size
                #确定图片裁剪位置
                top=np.random.randint(0,h-new_h)
                left=np.random.randint(0,w-new_w)
                image=image[top:top+new_h,left:left+new_w]
                landmarks=landmarks-[left,top]
                return {'image':image,'landmarks':landmarks}
        

        3、Function_ToTensor

        #%%
        # 将样本中的npdarray转换为Tensor
        class ToTensor(object):
            def __call__(self,sample):
                image,landmarks=sample['image'],sample['landmarks']
                image=image.transpose((2,0,1))#交换颜色轴
                #numpy的图片是:Height*Width*Color
                #torch的图片是:Color*Height*Width
                return {'image':torch.from_numpy(image),
                        'landmarks':torch.from_numpy(landmarks)}
        

        八、组合转换

        将上面编写的类应用到实例中

        Req: 把图像的短边调整为256,随机裁剪(randomcrop)为224大小的正方形。即:组合一个Rescale和RandomCrop的变换。

        #%%
        scale=Rescale(256)
        crop=RandomCrop(128)
        composed=transforms.Compose([Rescale(256),RandomCrop(224)])    
        # 在样本上应用上述变换
        fig=plt.figure() 
        sample=face_dataset[65]
        for i,tsfrm in enumerate([scale,crop,composed]):
            transformed_sample=tsfrm(sample)
            ax=plt.subplot(1,3,i+1)
            plt.tight_layout()
            ax.set_title(type(tsfrm).__name__)
            show_landmarks(**transformed_sample)
        plt.show()
        

        运行结果

        九、迭代数据集

        把这些整合起来以创建一个带有组合转换的数据集,总结一下没每次这个数据集被采样的时候:及时的从文件中读取图片,对读取的图片应用转换,由于其中一部是随机的randomcrop,数据被增强了。可以使用循环对创建的数据集执行同样的操作

        transformed_dataset=FaceLandmarksDataset(csv_file='D:/Python/Pytorch/data/faces/face_landmarks.csv',
                                                  root_dir='D:/Python/Pytorch/data/faces/',
                                                  transform=transforms.Compose([
                                                      Rescale(256),
                                                      RandomCrop(224),
                                                      ToTensor()
                                                      ]))
        for i in range(len(transformed_dataset)):
            sample=transformed_dataset[i]
            print(i,sample['image'].size(),sample['landmarks'].size())
            if i==3:
                break    
        

        运行结果


        对所有数据集简单使用for循环会牺牲很多功能――>麻烦,效率低!!改用多线程并行进行
        torch.utils.data.DataLoader可以提供上述功能的迭代器。collate_fn参数可以决定如何对数据进行批处理,绝大多数情况下默认值就OK

        transformed_dataset=FaceLandmarksDataset(csv_file='D:/Python/Pytorch/data/faces/face_landmarks.csv',
                                                  root_dir='D:/Python/Pytorch/data/faces/',
                                                  transform=transforms.Compose([
                                                      Rescale(256),
                                                      RandomCrop(224),
                                                      ToTensor()
                                                      ]))
        for i in range(len(transformed_dataset)):
            sample=transformed_dataset[i]
            print(i,sample['image'].size(),sample['landmarks'].size())
            if i==3:
                break    

        感谢各位的阅读,以上就是“Pytorch数据加载和处理怎样实现”的内容了,通过以上内容的阐述,相信大家对Pytorch数据加载和处理怎样实现已经有了进一步的了解,如果想要了解更多相关的内容,欢迎关注群英网络,群英网络将为大家推送更多相关知识点的文章。

        免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。

        相关信息推荐
        2022-05-06 18:09:17 
        摘要:本篇文章带大家了解一下微信小程序中的wxs模块,介绍一下wxs的使用,希望对大家有所帮助!
        2021-11-23 17:41:43 
        摘要:石头剪刀布的小游戏相信大家都有玩过吧,这篇文章我们来了解如何用python实现简单的石头剪刀布小游戏,实现效果及代码如下,感兴趣的朋友可以参考。
        2022-12-01 16:15:38 
        摘要:php导出mysql csv乱码问题的解决方法:1、打开相应的php文件;2、在文件头部写入BOM标识即可,代码如“fwrite($fp, chr(0xEF) . chr(0xBB) . chr(0xBF));”。
        云活动
        推荐内容
        热门关键词
        热门信息
        群英网络助力开启安全的云计算之旅
        立即注册,领取新人大礼包
        • 联系我们
        • 24小时售后:4006784567
        • 24小时TEL :0668-2555666
        • 售前咨询TEL:400-678-4567

        • 官方微信

          官方微信
        Copyright  ©  QY  Network  Company  Ltd. All  Rights  Reserved. 2003-2019  群英网络  版权所有   茂名市群英网络有限公司
        增值电信经营许可证 : B1.B2-20140078   粤ICP备09006778号
        免费拨打  400-678-4567
        免费拨打  400-678-4567 免费拨打 400-678-4567 或 0668-2555555
        微信公众号
        返回顶部
        返回顶部 返回顶部