返回介绍

读取数据

发布于 2025-02-17 22:41:42 字数 1244 浏览 0 评论 0 收藏 0

数据读取函数↓

def data_load(train_list_path,batch_size):
    '''
    train_list_path:标注文件 txt 所在 path
    '''
    train_dir_list=[]
    train_label=[]
    with open(train_list_path,'r') as train_dirs:
        #train_dir_list.append(train_dirs.readline())
        lines=[line.strip() for line in train_dirs]
        for line in lines:
            img_path,label=line.split()
            train_dir_list.append(img_path)
            train_label.append(label)
    def reader():
        imgs=[]
        labels=[]
        img_mask=np.arange(len(train_dir_list)) #生成索引
        np.random.shuffle(img_mask) #随机打乱索引
        count=0
        for i in img_mask:
            img=cv2.imread(train_dir_list[i])
            img=cv2.resize(img,(224,224),interpolation=cv2.INTER_CUBIC)/255
            img=np.transpose(img,(2,0,1))
            imgs.append(img)
            labels.append(train_label[i])
            count+=1
            if(count%train_paramters['batch_size']==0):
                yield np.asarray(imgs).astype('float32'),np.asarray(labels).astype('int64').reshape((train_paramters['batch_size'],1))
                imgs=[]
                labels=[]
    return reader

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文