类型规范和“to_representation_for_type”中的值之间的元素数量不匹配。类型规范有 2 个元素,值有 5 个元素

发布于 2025-01-17 06:00:47 字数 1748 浏览 5 评论 0原文

我使用tensorflow fedprox来实现联邦学习。(tff.learning.algorithms.build_unweighted_fed_prox)

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=preprocessed_example_dataset.element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

iterative_process = tff.learning.algorithms.build_unweighted_fed_prox(
    model_fn, 0.001,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.001),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)

import nest_asyncio
nest_asyncio.apply()

state = iterative_process.initialize()

for round in range(3, 11):
    state = iterative_process.next(state.state, federated_train_data)
    print('round {:2d}, metrics={}'.format(round, state.metrics))

,训练结果是:

第3轮,'sparse_categorical_accuracy'= 0.6435834

第4轮,'sparse_categorical_accuracy'= 0.6955319

第5轮, 'sparse_categorical_accuracy'= 0.74295634

第 6 轮,'sparse_categorical_accuracy'= 0.78176934

第 7 轮,'sparse_categorical_accuracy'= 0.80838746

第 8 轮,'sparse_categorical_accuracy'= 0.8300672

第 9 轮, 'sparse_categorical_accuracy'= 0.8486338

round 10, 'sparse_categorical_accuracy', 0.86639416


但是当我想根据测试数据评估我的模型时,出现错误:

evaluation = tff.learning.build_federated_evaluation(model_fn)
test_metrics = evaluation(state.state, federated_test_data)

TypeError: Mismatched number of elements between type spec and value in `to_representation_for_type`. Type spec has 2 elements, value has 5.

如何修复它?

I use tensorflow fedprox to implement federated learning.(tff.learning.algorithms.build_unweighted_fed_prox)

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=preprocessed_example_dataset.element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

iterative_process = tff.learning.algorithms.build_unweighted_fed_prox(
    model_fn, 0.001,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.001),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)

import nest_asyncio
nest_asyncio.apply()

state = iterative_process.initialize()

for round in range(3, 11):
    state = iterative_process.next(state.state, federated_train_data)
    print('round {:2d}, metrics={}'.format(round, state.metrics))

and the result of training is:

round 3, 'sparse_categorical_accuracy'= 0.6435834

round 4, 'sparse_categorical_accuracy'= 0.6955319

round 5, 'sparse_categorical_accuracy'= 0.74295634

round 6, 'sparse_categorical_accuracy'= 0.78176934

round 7, 'sparse_categorical_accuracy'= 0.80838746

round 8, 'sparse_categorical_accuracy'= 0.8300672

round 9, 'sparse_categorical_accuracy'= 0.8486338

round 10, 'sparse_categorical_accuracy', 0.86639416


but when I want to evaluate my model on test data I get error:

evaluation = tff.learning.build_federated_evaluation(model_fn)
test_metrics = evaluation(state.state, federated_test_data)

TypeError: Mismatched number of elements between type spec and value in `to_representation_for_type`. Type spec has 2 elements, value has 5.

How do I fix it?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

一袭白衣梦中忆 2025-01-24 06:00:47

您的evaluation方法需要tff.learning.ModelWeights,但您提供的是整个状态,这是一个更大的结构,包括global_model_weights下的模型权重> 属性。因此,这可以工作:

test_metrics = evaluation(state.state.global_model_weights, federated_test_data)


旁注,将 iterative_process.next 的返回值分配给 Python 变量 state 可能会变得非常混乱,因为它包含程序的状态和指标,这会导致您使用 state.state

Your evaluation method expects tff.learning.ModelWeights, but you are providing the entire state, which is a bigger structure, including the model weights under global_model_weights attribute. So, this could work:

test_metrics = evaluation(state.state.global_model_weights, federated_test_data)


Side note, assigning the return value of iterative_process.next to Python variable state can become very confusing, as it contains state of the program and metrics, which leads you to the use state.state

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文