如何使用 TensorflowFederated 中的聚合客户端指标更新集中式服务器模型
我使用 TensorFlow Federated 框架设计了联邦学习模型。迭代过程定义如下,
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.9))
我有 2 个远程工作人员运行 tffruntime 远程执行程序服务,运行计算的上下文定义为 tff.backends.native.set_remote_python_execution_context(channels) 。当使用 iterative_process.next(state, train_data) 将模型广播到客户端时,我们如何识别客户端指标已聚合并应用于服务器模型。单个 API build_federated_averaging_process
是否足以从客户端获取指标、聚合然后更新服务器模型? if 意味着我们如何识别服务器模型已更新?谁能帮助我理解这一点。
I have designed the Federated Learning model with TensorFlow Federated framework. Defined the iterative process as below,
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.9))
I have 2 remote workers running the tffruntime remote executor service and the context for running computation is defined as tff.backends.native.set_remote_python_execution_context(channels)
. When the model is broadcasted to the client with iterative_process.next(state, train_data)
, how can we identify that the client metrics is aggregated and applied to the server model. Is the single api build_federated_averaging_process
is enough to get the metrics from clients, aggregate and then update the server model? If means how can we identify that the server model is updated? Can anyone please help me to understand this.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
build_federated_averaging_process
API 构建完整联邦学习步骤的迭代过程。如果您想验证服务器模型是否已更新,可以在每次iterative_process.next(state, train_data)
之后打印state.model
。The
build_federated_averaging_process
API builds an iterative process of the full federated learning steps. If you want to verify that the server model is updated, you can printstate.model
after eachiterative_process.next(state, train_data)
.