KerasRL:值错误:张量必须与张量来自同一个图
我正在尝试构建一个 RL 模型来玩 Atari Pinball 游戏,同时遵循 Nicholas Renotte 的 视频。但是,当我尝试构建最终的 KerasRL 模型时,出现以下错误:
ValueError: Tensor("dense/kernel/Read/ReadVariableOp:0", shape=(256, 9), dtype=float32) must be from the same graph as Tensor("dense_4/Relu:0", shape=(None, 256), dtype=float32) (graphs are <tensorflow.python.framework.ops.Graph object at 0x000001DA9F3E0A90> and FuncGraph(name=keras_graph, id=2038356824176)).
代码:
def build_model(height, width, channels, actions):
model = Sequential()
model.add(Convolution2D(32, (8,8), strides=(4,4), activation='relu', input_shape=(3,height, width, channels)))
model.add(Convolution2D(64, (4,4), strides=(2,2), activation='relu'))
model.add(Convolution2D(64, (3,3), activation='relu'))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(actions, activation='linear'))
return model
height, width, channels = env.observation_space.shape
actions = env.action_space.n
model = build_model(height, width, channels, actions)
from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
def build_agent(model, actions):
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
memory = SequentialMemory(limit=1000, window_length=3)
dqn = DQNAgent(model=model, memory=memory, policy=policy,
enable_dueling_network=True, dueling_type='avg',
nb_actions=actions, nb_steps_warmup=1000
)
return dqn
dqn = build_agent(model, actions)
dqn.compile(Adam(lr=1e-4))
当我调用 build_agent
函数时,会弹出错误。
我尝试使用 tf.keras.backend.clear_session() ,但这没有帮助。
I am trying to build a RL model to play the Atari Pinball game while following Nicholas Renotte's video. However, when I try to build the final KerasRL model I get the following error :
ValueError: Tensor("dense/kernel/Read/ReadVariableOp:0", shape=(256, 9), dtype=float32) must be from the same graph as Tensor("dense_4/Relu:0", shape=(None, 256), dtype=float32) (graphs are <tensorflow.python.framework.ops.Graph object at 0x000001DA9F3E0A90> and FuncGraph(name=keras_graph, id=2038356824176)).
The code:
def build_model(height, width, channels, actions):
model = Sequential()
model.add(Convolution2D(32, (8,8), strides=(4,4), activation='relu', input_shape=(3,height, width, channels)))
model.add(Convolution2D(64, (4,4), strides=(2,2), activation='relu'))
model.add(Convolution2D(64, (3,3), activation='relu'))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(actions, activation='linear'))
return model
height, width, channels = env.observation_space.shape
actions = env.action_space.n
model = build_model(height, width, channels, actions)
from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
def build_agent(model, actions):
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
memory = SequentialMemory(limit=1000, window_length=3)
dqn = DQNAgent(model=model, memory=memory, policy=policy,
enable_dueling_network=True, dueling_type='avg',
nb_actions=actions, nb_steps_warmup=1000
)
return dqn
dqn = build_agent(model, actions)
dqn.compile(Adam(lr=1e-4))
The error pops up when I call the build_agent
function.
I tried using tf.keras.backend.clear_session()
but that didn't help.
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(2)
您也许可以尝试在
build_model
之前添加此内容:我也尝试关注 Nicholas Renotte 的视频。有关更深入的答案,请检查 此链接。会有警告,但至少它会运行。
You can perhaps try adding this before your
build_model
:I was also trying to follow Nicholas Renotte's video. For more in depth answer, check this link. There will be a warning, but at least it runs.
如果您愿意观看,他300万年后在同一个视频中确实解决了这个确切的错误。
基本上删除模型(
del model
)并重建它就可以了He literally adresses this exact error 3mn later in that same video if you care waching it.
Basically deleting the model (
del model
) and rebuildint it does the trick