Keras_Segmentation VGG-Unet 导致 AttributeError: “Functional”对象没有属性“output_width”;
我使用 image-segmentation-keras 库 创建了一个模型,如下初始化:
import keras_segmentation
from keras_segmentation.models.unet import vgg_unet
from tensorflow.keras.layers import Input
model = vgg_unet(n_classes=21 , input_height=256, input_width=448)
然后我这样训练它:
model.train(
train_images = "/content/drive/MyDrive/imgs_train/",
train_annotations = "/content/drive/MyDrive/masks_train/",
val_images = "/content/drive/MyDrive/mgs_validation/",
val_annotations = "/content/drive/MyDrive/masks_validation/",
checkpoints_path = "/content/drive/MyDrive/tmp/vgg_unet_1" ,
epochs=28,validate=True,callbacks = [myCallback])
model.load_weights('checkpoint_filepath')
并像这样保存它:
model.save('/content/drive/MyDrive/vgg_unet_segmentation.h5')
然后像这样加载它:
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
但是,当我尝试通过执行 out = model.predict_segmentation(inp=inp, out_fname="/tmp/out.png")
,我收到以下错误:
AttributeError: 'Functional' object has no attribute 'predict_segmentation'
因此,为了解决此问题,我执行了以下操作:
from types import MethodType
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)
但是,这导致了另一个我无法解决的问题:
[<ipython-input-7-a4b7d02cd9a2>](https://localhost:8080/#) in <module>()
4 out = model.predict_segmentation(
5 inp=inp,
----> 6 out_fname="/tmp/out.png")
[/content/image-segmentation-keras/keras_segmentation/predict.py](https://localhost:8080/#) in predict(model, inp, out_fname, checkpoints_path, overlay_img, class_names, show_legends, colors, prediction_width, prediction_height, read_image_type)
148 assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 "
149
--> 150 output_width = model.output_width
151 output_height = model.output_height
152 input_width = model.input_width
AttributeError: 'Functional' object has no attribute 'output_width'
任何知道为什么会发生这种情况,如果是这样,如何解决?
任何帮助表示赞赏!
谢谢!
I created a model with the image-segmentation-keras library by initializing it as such:
import keras_segmentation
from keras_segmentation.models.unet import vgg_unet
from tensorflow.keras.layers import Input
model = vgg_unet(n_classes=21 , input_height=256, input_width=448)
I then train it as such:
model.train(
train_images = "/content/drive/MyDrive/imgs_train/",
train_annotations = "/content/drive/MyDrive/masks_train/",
val_images = "/content/drive/MyDrive/mgs_validation/",
val_annotations = "/content/drive/MyDrive/masks_validation/",
checkpoints_path = "/content/drive/MyDrive/tmp/vgg_unet_1" ,
epochs=28,validate=True,callbacks = [myCallback])
model.load_weights('checkpoint_filepath')
And save it like so:
model.save('/content/drive/MyDrive/vgg_unet_segmentation.h5')
Then load it like so:
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
However, when I try to make a prediction by doing out = model.predict_segmentation(inp=inp, out_fname="/tmp/out.png")
, I get the following error:
AttributeError: 'Functional' object has no attribute 'predict_segmentation'
So to solve this issue I did the following:
from types import MethodType
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)
However, this lead to another issue which I haven't been able to resolve:
[<ipython-input-7-a4b7d02cd9a2>](https://localhost:8080/#) in <module>()
4 out = model.predict_segmentation(
5 inp=inp,
----> 6 out_fname="/tmp/out.png")
[/content/image-segmentation-keras/keras_segmentation/predict.py](https://localhost:8080/#) in predict(model, inp, out_fname, checkpoints_path, overlay_img, class_names, show_legends, colors, prediction_width, prediction_height, read_image_type)
148 assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 "
149
--> 150 output_width = model.output_width
151 output_height = model.output_height
152 input_width = model.input_width
AttributeError: 'Functional' object has no attribute 'output_width'
Any idea why this might be happening, and if so, how it can be resolved?
Any help is appreciated!
Thanks!
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论
评论(1)
尝试像以下代码那样的输出尝试model.predict():
Try model.predict() for the output like the following code :