TensorFlow自定义损耗功能中两个网络之间的相互作用
假设您有两个TensorFlow模型 model_a
和 model_b
,并且训练环看起来像这样,
with tf.GradientTape() as tape:
output_A = model_A(input)
output_B = model_B(input)
loss = loss_fn(output_A, output_B, true_output_A, true_output_B)
grads = tape.gradient(loss, model_A.trainable_variables)
optimizer.apply_gradients(zip(grads, model_A.trainable_variables))
并将损失函数定义为
def loss_fn(output_A, output_B, true_output_A, true_output_B)
loss = (output_A + output_B) - (true_output_A + true_output_B)
return loss
用于更新 model_a
具有另一个网络的输出( output_b
)。 TensorFlow如何处理这种情况?
计算梯度时,它是否使用 model_b
的权重?还是它将 output_b
作为常数而不是试图追踪其起源?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
它不会使用
model_b
权重,只有model_a
将更新权重。例如:
损失函数
自定义模型中发生的情况。FIT
定义Model1和Model2并保存它们,因此您可以在训练后检查权
重以下输出:
然后加载保存的型号并检查训练的型号:
Realake Modal1 Model 1 Trainable_weights训练前后(他们都应该更改):
输出:
比较训练前后的Model2 trainable_weights(它们都应该是相同的):
输出:
It won't use
model_B
weights, onlymodel_A
weights will be updated.For example:
Loss function
Customize what happens in Model.fit
Define model1 and model2 and save them, so you can check the weights after training
Gives the following output:
Then load saved models and check the trainable_weights:
Compare model1 trainable_weights before and after training (they should all change):
Outputs:
Compare model2 trainable_weights before and after training (they should all be the same):
Outputs: