TFF :修改state的值
iterative_process.initialize() 返回的状态对象通常是一个包含 numpy 数组的 Python 容器(tuple、collections.OrderedDict 等)。我希望状态的值不是随机的,而是从加载的模型开始。 作为开始,我这样写:
def create_keras_model():
Model = tf.keras.models.load_model(path)
return Model
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(keras_model..)
iterative_process = tff.learning.build_federated_averaging_process(model_fn=model_fn..)
state = iterative_process.initialize()
但是与正常情况相比(如果我不加载外部模型),测试准确性结果根本没有改变。
这就是为什么,我尝试这个解决方案:
# initialize_fn() function
@tff.tf_computation
def server_init():
model = model_fn()
return model.trainable_variables
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER)
iterative_process = tff.templates.IterativeProcess(initialize_fn, next_fn)
state = iterative_process.initialize()
state['model'] = create_keras_model()
但我发现这个错误:
NameError: name 'next_fn' is not defined
所以就我而言,如何定义 next_fn ? 谢谢
The state object returned by iterative_process.initialize() is typically a Python container (tuple, collections.OrderedDict, etc) that contains numpy arrays. I would like that the value of state is not random, instead it begin from loaded model.
As the beginning, I write this :
def create_keras_model():
Model = tf.keras.models.load_model(path)
return Model
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(keras_model..)
iterative_process = tff.learning.build_federated_averaging_process(model_fn=model_fn..)
state = iterative_process.initialize()
But test accuracy result does not change at all comparing by the normal case(if I don't load an external model).
That's why, I try this solution:
# initialize_fn() function
@tff.tf_computation
def server_init():
model = model_fn()
return model.trainable_variables
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER)
iterative_process = tff.templates.IterativeProcess(initialize_fn, next_fn)
state = iterative_process.initialize()
state['model'] = create_keras_model()
But I find this error:
NameError: name 'next_fn' is not defined
So in my case, how can I define next_fn ?
Thanks
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论