tensorflow embedding可视化的问题

发布于 2022-09-07 21:35:53 字数 4124 浏览 21 评论 0

我现在完成了利用Alexnet网络实现对自己图片集的分类,现在看到了tensorBoard里面embedding的功能感觉很炫酷,特别希望可以通过那种可视化的界面看我的网络分类时的效果。但是关于embedding实在是做不到,希望大神指点!!

在我的网络里面,先是把图片集存到了一个.npy文件里,然后在后面的训练过程中随机抓取,不是很清楚embedding应该在代码的什么地方加载,我想把我的原始数据,也就是图片存入embedding看效果,但是加载的时候一直出问题,希望大神救命!!!

代码如下:

def alexnet_main():
    loopNum = 5
    # 加载使用的训练集文件名和标签。
    files = np.load("label.npy", encoding='bytes')[()]

    #日志路径-用于embedding
    log_dir = 'model'
    metadata = os.path.join(log_dir,'metadata.tsv')
    j = 0
    with open(metadata, 'w') as metadata_file:
        for i in files:
            metadata_file.write('%d\n' % j)
            j = j+1


    # 提取文件名。
    keys = [i for i in files]


    myinput = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='input')
    mylabel = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='label')
    
    # 建立网络,keepprob为0.6。
    myoutput = alexnet(myinput, 0.6)

    # 定义训练的loss函数。
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=myoutput, labels=mylabel))

    # 定义优化器,学习率设置为0.09,学习率可以设置为其他的数值。
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.09).minimize(loss)

    # 定义准确率
    valaccuracy = tf.reduce_mean(
        tf.cast(
            tf.equal(
                tf.argmax(myoutput, 1),
                tf.argmax(mylabel, 1)),
            tf.float32))

    # tensorflow的saver,可以用于保存模型。
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    all_vars = tf.global_variables()

    #会话过程
    with tf.Session() as sess:
        sess.run(init)
        saver = tf.train.Saver(all_vars)
        saver.restore(sess, r"model/model.ckpt")#用于加载模型的函数

        # 100个epoch
        totalAcc = 0
        for loop in range(loopNum):

            # 生成并打乱训练集的顺序。
            indices = np.arange(1100)#modify
            random.shuffle(indices)

            # batch size此处定义为50。
            # 训练集一共1100张图片,前1000张用于训练,后100张用于验证集。
            for i in range(0, 0+1000, 50):
                photo = []
                label = []
                #print("1:",label)
                for j in range(0, 20):
                    photo.append(cv2.resize(cv2.imread(keys[indices[i + j]]), (224, 224))/225)
                    #print(i+j)
                    label.append(files[keys[indices[i + j]]])
                    #print("2:",label)

                #加载embedding,每次训练加载一次
                target = tf.convert_to_tensor(photo)
                embedding_var = tf.Variable(photo,'data_embedding')
                config = projector.ProjectorConfig()
                embedding = config.embeddings.add()
                embedding.tensor_name = embedding_var.name
                embedding.metadata_path = metadata
                embedding.sprite.single_image_dim.extend([28,28])
                projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config)


                m = getOneHotLabel(label, depth=10)
                a, b = sess.run([optimizer, loss], feed_dict={myinput: photo, mylabel: m})

            acc = 0
            # 每次取验证集的20张图片进行验证,返回这200张图片的正确率。
            for i in range(1000, 1000+100, 20):
                photo = []
                label = []
                for j in range(i, i + 5):
                    photo.append(cv2.resize(cv2.imread(keys[indices[j]]), (224, 224))/225)
                    label.append(files[keys[indices[j]]])
                m = getOneHotLabel(label, depth=10)
                acc += sess.run(valaccuracy, feed_dict={myinput: photo, mylabel: m})
            # 输出,一共有5次验证集数据相加,所以需要除以50。
            print("Epoch ", loop, ': validation rate: ', acc/5)
            totalAcc += acc/5
        print("final ",totalAcc/loopNum)
        # 保存模型。
        saver.save(sess, "model/model.ckpt")
        to_visualise = myinput
        to_visualise = vector_to_mnist(to_visualise)
        to_visualise = invert_grayscale(to_visualise)
        sprite_image = create_sprite_image(to_visualise)
        plt.imsave(metadata,sprite_image)
        plt.imshow(sprite_image)

if __name__ == '__main__':
    alexnet_main()

其中的

alexnet是我前文定义的网络

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

那些过往 2022-09-14 21:35:53

您好,暂时回答不了您的问题,我想请问程序运行过程中报错IndexError: list index out of range该如何解决呢?十分期望您的指导,感谢

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文