如何使用 TensorflowFederated 中的聚合客户端指标更新集中式服务器模型

发布于 2025-01-17 22:45:58 字数 627 浏览 1 评论 0原文

我使用 TensorFlow Federated 框架设计了联邦学习模型。迭代过程定义如下,

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.9))

我有 2 个远程工作人员运行 tffruntime 远程执行程序服务,运行计算的上下文定义为 tff.backends.native.set_remote_python_execution_context(channels) 。当使用 iterative_process.next(state, train_data) 将模型广播到客户端时,我们如何识别客户端指标已聚合并应用于服务器模型。单个 API build_federated_averaging_process 是否足以从客户端获取指标、聚合然后更新服务器模型? if 意味着我们如何识别服务器模型已更新?谁能帮助我理解这一点。

I have designed the Federated Learning model with TensorFlow Federated framework. Defined the iterative process as below,

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.9))

I have 2 remote workers running the tffruntime remote executor service and the context for running computation is defined as tff.backends.native.set_remote_python_execution_context(channels). When the model is broadcasted to the client with iterative_process.next(state, train_data), how can we identify that the client metrics is aggregated and applied to the server model. Is the single api build_federated_averaging_process is enough to get the metrics from clients, aggregate and then update the server model? If means how can we identify that the server model is updated? Can anyone please help me to understand this.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

小猫一只 2025-01-24 22:45:58

build_federated_averaging_process API 构建完整联邦学习步骤的迭代过程。如果您想验证服务器模型是否已更新,可以在每次 iterative_process.next(state, train_data) 之后打印 state.model

The build_federated_averaging_process API builds an iterative process of the full federated learning steps. If you want to verify that the server model is updated, you can print state.model after each iterative_process.next(state, train_data).

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文