通过 Colab 上的回调节省准确性和损失

发布于 2025-01-09 14:34:05 字数 1050 浏览 1 评论 0原文

因此,我尝试在 Colab 上训练模型,这将花费我大约 70-72 小时的持续运行时间。我有一个免费帐户,所以我经常因为过度使用或不活动而被踢,这意味着我不能将历史记录转储到 pickle 文件中。

history = model.fit_generator(custom_generator(train_csv_list,batch_size), steps_per_epoch=len(train_csv_list[:13400])//(batch_size), epochs=1000, verbose=1,  callbacks=[stop_training], validation_data=(x_valid,y_valid))

我在回调方法中找到了 CSVLogger 并将其添加到我的回调中,如下所示。但由于某种原因,它不会创建 model_history_log.csv。我没有收到任何错误或警告。我做错了什么部分? 我的目标是在整个训练过程中仅保存准确性和损失。

class stop_(Callback): 
    def on_epoch_end(self, epoch, logs={}):
        model.save(Path("/content/drive/MyDrive/.../model" +str(int(epoch))))
        CSVLogger("/content/drive/MyDrive/.../model_history_log.csv", append=True)
        if(logs.get('accuracy') > ACCURACY_THRESHOLD):
                print("\nReached %2.2f%% accuracy, so stopping training!!" %(ACCURACY_THRESHOLD*100))   
                self.model.stop_training = True
stop_training = stop_()     

另外,由于我在每个时期都保存模型,因此模型是否保存此信息?到目前为止我还没有发现任何东西,我怀疑它是否能保存准确性、损失、验证准确性等

So im trying to train a model on colab, and it is going to take me roughly 70-72 hr of continues running. I have a free account, so i get kicked due to over-use or inactivity pretty frequently, which means I cant just dump history in a pickle file.

history = model.fit_generator(custom_generator(train_csv_list,batch_size), steps_per_epoch=len(train_csv_list[:13400])//(batch_size), epochs=1000, verbose=1,  callbacks=[stop_training], validation_data=(x_valid,y_valid))

I found the CSVLogger in callback method and added it to my callback as below. But it wont create model_history_log.csv for some reason. I don't get any error or warning. What part am i doing wrong ?
My goal is to only save accuracy and loss, throughout the training process

class stop_(Callback): 
    def on_epoch_end(self, epoch, logs={}):
        model.save(Path("/content/drive/MyDrive/.../model" +str(int(epoch))))
        CSVLogger("/content/drive/MyDrive/.../model_history_log.csv", append=True)
        if(logs.get('accuracy') > ACCURACY_THRESHOLD):
                print("\nReached %2.2f%% accuracy, so stopping training!!" %(ACCURACY_THRESHOLD*100))   
                self.model.stop_training = True
stop_training = stop_()     

Also since im saving the model at every epoch, does the model save this information ? so far i havent found anything, and i doubt it saves accuracy, loss, val accuracy,etc

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

夢归不見 2025-01-16 14:34:05

认为您想按如下方式编写回调,

class STOP(tf.keras.callbacks.Callback):
    def __init__ (self, model, csv_path, model_save_dir, epochs, acc_thld): # initialization of the callback
        # model is your compiled model
        # csv_path is path where csv file will be stored
        # model_save_dir is path to directory where model files will be saved
        # number of epochs you set in model.fit
        self.model=model
        self.csv_path=csv_path
        self.model_save_dir=model_save_dir
        self.epochs=epochs
        self.acc_thld=acc_thld
        self.acc_list=[] # create empty list to store accuracy
        self.loss_list=[] # create empty list to store loss
        self.epoch_list=[] # create empty list to store the epoch
        
    def on_epoch_end(self, epoch, logs=None):  # method runs on the end of each epoch  
        savestr='_' + str(epoch+1) + '.h5' # model will be save as an .h5 file with name _epoch.h5
        save_path=os.path.join(self.model_save_dir, savestr)
        acc= logs.get('accuracy') #get the accuracy for this epoch
        loss=logs.get('loss') # get the loss for this epoch       
        self.model.save (save_path)   # save the model     
        self.acc_list.append(logs.get('accuracy'))          
        self.loss_list.append(logs.get('loss'))
        self.epoch_list.append(epoch + 1)        
        if acc > self.acc_thld or epoch+1 ==epochs: # see of acc >thld or if this was the last epoch
            self.model.stop_training = True # stop training
            Eseries=pd.Series(self.epoch_list, name='Epoch')
            Accseries =pd.Series(self.acc_list, name='accuracy')
            Lseries=pd.Series(self.loss_list, name='loss')
            df=pd.concat([Eseries, Lseries, Accseries], axis=1) # create a dataframe with columns epoch loss accuracy
            df.to_csv(self.csv_path, index=False) # convert dataframe to a csv file and save it
            if acc  > self.acc_thld:
                print ('\nTraining halted on epoch ', epoch + 1, ' when accuracy exceeded the threshhold')

然后在运行 model.fit 使用代码之前

epochs=20 # set number of epoch for model.fit and the callback
sdir=r'C:\Temp\stooges' # set directory where save model files and the csv file will be stored
acc_thld=.98 # set accuracy threshold
csv_path=os.path.join(sdir, 'traindata.csv') # name your csv file to be saved in sdir
callbacks=STOP(model, csv_path, sdir, epochs, acc_thld) # instantiate the callback

记住在 model.fit 中设置回调 = 回调。我在一个简单的数据集上测试了这一点。它只运行了 3 个 epoch,准确率就超过了 0.98 的阈值。因此,由于它运行了 3 个纪元,因此它在 sdir 中创建了 3 个保存模型文件,标记为 ,

_1.h5
_2.h5
_3.h5

它还创建了标记为 traindata.csv 的 csv 文件。 csv 文件内容是

Epoch    loss        accuracy
   1     8.086007    .817778
   2     6.911876    .974444
   3     6.129871    .987778

Think you want to write your callback as follows

class STOP(tf.keras.callbacks.Callback):
    def __init__ (self, model, csv_path, model_save_dir, epochs, acc_thld): # initialization of the callback
        # model is your compiled model
        # csv_path is path where csv file will be stored
        # model_save_dir is path to directory where model files will be saved
        # number of epochs you set in model.fit
        self.model=model
        self.csv_path=csv_path
        self.model_save_dir=model_save_dir
        self.epochs=epochs
        self.acc_thld=acc_thld
        self.acc_list=[] # create empty list to store accuracy
        self.loss_list=[] # create empty list to store loss
        self.epoch_list=[] # create empty list to store the epoch
        
    def on_epoch_end(self, epoch, logs=None):  # method runs on the end of each epoch  
        savestr='_' + str(epoch+1) + '.h5' # model will be save as an .h5 file with name _epoch.h5
        save_path=os.path.join(self.model_save_dir, savestr)
        acc= logs.get('accuracy') #get the accuracy for this epoch
        loss=logs.get('loss') # get the loss for this epoch       
        self.model.save (save_path)   # save the model     
        self.acc_list.append(logs.get('accuracy'))          
        self.loss_list.append(logs.get('loss'))
        self.epoch_list.append(epoch + 1)        
        if acc > self.acc_thld or epoch+1 ==epochs: # see of acc >thld or if this was the last epoch
            self.model.stop_training = True # stop training
            Eseries=pd.Series(self.epoch_list, name='Epoch')
            Accseries =pd.Series(self.acc_list, name='accuracy')
            Lseries=pd.Series(self.loss_list, name='loss')
            df=pd.concat([Eseries, Lseries, Accseries], axis=1) # create a dataframe with columns epoch loss accuracy
            df.to_csv(self.csv_path, index=False) # convert dataframe to a csv file and save it
            if acc  > self.acc_thld:
                print ('\nTraining halted on epoch ', epoch + 1, ' when accuracy exceeded the threshhold')

then before you run model.fit use code

epochs=20 # set number of epoch for model.fit and the callback
sdir=r'C:\Temp\stooges' # set directory where save model files and the csv file will be stored
acc_thld=.98 # set accuracy threshold
csv_path=os.path.join(sdir, 'traindata.csv') # name your csv file to be saved in sdir
callbacks=STOP(model, csv_path, sdir, epochs, acc_thld) # instantiate the callback

Remember in model.fit set callbacks = callbacks. I tested this on a simple dataset. It ran for only 3 epochs before the accuracy exceeded the threshold of .98. So since it ran for 3 epoch it created 3 save model files in the sdir labeled as

_1.h5
_2.h5
_3.h5

It also created the csv file labelled as traindata.csv. The csv file content was

Epoch    loss        accuracy
   1     8.086007    .817778
   2     6.911876    .974444
   3     6.129871    .987778
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文