TFF :修改state的值

发布于 2025-01-14 15:01:58 字数 1091 浏览 6 评论 0原文

iterative_process.initialize() 返回的状态对象通常是一个包含 numpy 数组的 Python 容器(tuple、collections.OrderedDict 等)。我希望状态的值不是随机的,而是从加载的模型开始。 作为开始,我这样写:

def create_keras_model():   
  Model = tf.keras.models.load_model(path)
  return Model
def model_fn(): 
    keras_model = create_keras_model()   
    return tff.learning.from_keras_model(keras_model..)
iterative_process = tff.learning.build_federated_averaging_process(model_fn=model_fn..)
state = iterative_process.initialize()

但是与正常情况相比(如果我不加载外部模型),测试准确性结果根本没有改变。

这就是为什么,我尝试这个解决方案:

# initialize_fn() function
@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER) 

iterative_process = tff.templates.IterativeProcess(initialize_fn, next_fn)
state = iterative_process.initialize()
state['model'] = create_keras_model()

但我发现这个错误:

NameError: name 'next_fn' is not defined

所以就我而言,如何定义 next_fn ? 谢谢

The state object returned by iterative_process.initialize() is typically a Python container (tuple, collections.OrderedDict, etc) that contains numpy arrays. I would like that the value of state is not random, instead it begin from loaded model.
As the beginning, I write this :

def create_keras_model():   
  Model = tf.keras.models.load_model(path)
  return Model
def model_fn(): 
    keras_model = create_keras_model()   
    return tff.learning.from_keras_model(keras_model..)
iterative_process = tff.learning.build_federated_averaging_process(model_fn=model_fn..)
state = iterative_process.initialize()

But test accuracy result does not change at all comparing by the normal case(if I don't load an external model).

That's why, I try this solution:

# initialize_fn() function
@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER) 

iterative_process = tff.templates.IterativeProcess(initialize_fn, next_fn)
state = iterative_process.initialize()
state['model'] = create_keras_model()

But I find this error:

NameError: name 'next_fn' is not defined

So in my case, how can I define next_fn ?
Thanks

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文