有没有办法腌制自定义TensorFlow.keras度量?

发布于 2025-02-08 12:25:19 字数 1493 浏览 3 评论 0原文

我定义了以下自定义度量标准以在tensorflow中训练我的模型:

import tensorflow as tf
from tensorflow import keras as ks
N_CLASSES = 15

class MulticlassMeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self,
                 y_true = None,
                 y_pred = None,
                 num_classes = None,
                 name = "Multi_MeanIoU",
                 dtype = None):
        super(MulticlassMeanIoU, self).__init__(num_classes = num_classes,
                                             name = name, dtype = dtype)
        self.__name__ = name

    def get_config(self):
        base_config = super().get_config()
        return {**base_config, "num_classes": self.num_classes}

    def update_state(self, y_true, y_pred, sample_weight = None):
        y_pred = tf.math.argmax(y_pred, axis = -1)
        return super().update_state(y_true, y_pred, sample_weight)

met = MulticlassMeanIoU(num_classes = N_CLASSES)

训练模型后,我保存模型,我也尝试将自定义对象保存如下:

with open("/some/path/custom_metrics.pkl", "wb") as f:
    pickle.dump(met, f)

但是,当我尝试加载指标时像这样:

with open(path_custom_metrics, "rb") as f:
    met = pickle.load(f)

我总是会遇到一些错误,例如attributeError:'MulticlassMeaniou'对象没有属性'update_state_fn'

现在,我想知道是否可以腌制一个自定义指标,如果是的话,如何?如果我可以通过模型保存自定义指标,那么它将派上用场,因此,当我将模型加载到另一个Python会话中时,我总是拥有首先加载模型所需的指标。可以通过在加载模型之前将完整代码插入其他脚本来重新定义度量标准,但是,我认为这是不好的样式,如果我会在培训脚本中更改有关指标的问题,并且可能会引起问题。忘记将代码复制到其他脚本。

I defined the following custom metric to train my model in tensorflow:

import tensorflow as tf
from tensorflow import keras as ks
N_CLASSES = 15

class MulticlassMeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self,
                 y_true = None,
                 y_pred = None,
                 num_classes = None,
                 name = "Multi_MeanIoU",
                 dtype = None):
        super(MulticlassMeanIoU, self).__init__(num_classes = num_classes,
                                             name = name, dtype = dtype)
        self.__name__ = name

    def get_config(self):
        base_config = super().get_config()
        return {**base_config, "num_classes": self.num_classes}

    def update_state(self, y_true, y_pred, sample_weight = None):
        y_pred = tf.math.argmax(y_pred, axis = -1)
        return super().update_state(y_true, y_pred, sample_weight)

met = MulticlassMeanIoU(num_classes = N_CLASSES)

After training the model, I save the model and I also tried to save the custom object as follows:

with open("/some/path/custom_metrics.pkl", "wb") as f:
    pickle.dump(met, f)

However, when I try to load the metric like this:

with open(path_custom_metrics, "rb") as f:
    met = pickle.load(f)

I always get some errors, e.g. AttributeError: 'MulticlassMeanIoU' object has no attribute 'update_state_fn'.

Now I wonder whether it is possible to pickle a custom metric at all and if so, how? It would come in handy if I could save custom metrics with the model, so when I load the model in another Python session, I always have the metric which is required to load the model in the first place. It would be possible to define the metric anew through inserting the full code to the other script before loading the model, however, I think this would be bad style and could cause problems in case I would change something about the metric in the training script and forget to copy the code to the other script.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

黒涩兲箜 2025-02-15 12:25:20

如果您需要泡菜度量标准,则可能的解决方案是使用__ getState __()__ setState __()方法。在(DE)序列化过程中,如果可用,则调用这两种方法。将这些方法添加到您的代码中,您将拥有所需的东西。我试图使其尽可能一般,以便它适用于任何termric

    def __getstate__(self):
        variables = {v.name: v.numpy() for v in self.variables}
        state = {
            name: variables[var.name]
            for name, var in self._unconditional_dependency_names.items()
            if isinstance(var, tf.Variable)}
        state['name'] = self.name
        state['num_classes'] = self.num_classes
        return state

    def __setstate__(self, state: Dict[str, Any]):
        self.__init__(name=state.pop('name'), num_classes=state.pop('num_classes'))
        for name, value in state.items():
            self._unconditional_dependency_names[name].assign(value)

If you need to pickle a metric, one possible solution is to use __getstate__() and __setstate__() methods. During the (de)serialization process, these two methods are called, if they are available. Add these methods to your code and you will have what you need. I tried to make it as general as possible, so that it works for any Metric:

    def __getstate__(self):
        variables = {v.name: v.numpy() for v in self.variables}
        state = {
            name: variables[var.name]
            for name, var in self._unconditional_dependency_names.items()
            if isinstance(var, tf.Variable)}
        state['name'] = self.name
        state['num_classes'] = self.num_classes
        return state

    def __setstate__(self, state: Dict[str, Any]):
        self.__init__(name=state.pop('name'), num_classes=state.pop('num_classes'))
        for name, value in state.items():
            self._unconditional_dependency_names[name].assign(value)
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文