文章来源于网络收集而来,版权归原创者所有,如有侵权请及时联系!
6.6 配置文件
在模型定义、数据处理和训练等过程都有很多变量,这些变量应提供默认值,并统一放置在配置文件中,这样在后期调试、修改代码或迁移程序时会比较方便,在这里我们将所有可配置项放在 config.py
中。
class DefaultConfig(object):
env = 'default' # visdom 环境
model = 'AlexNet' # 使用的模型,名字必须与 models/__init__.py 中的名字一致
train_data_root = './data/train/' # 训练集存放路径
test_data_root = './data/test1' # 测试集存放路径
load_model_path = 'checkpoints/model.pth' # 加载预训练的模型的路径,为 None 代表不加载
batch_size = 128 # batch size
use_gpu = True # use GPU or not
num_workers = 4 # how many workers for loading data
print_freq = 20 # print info every N batch
debug_file = '/tmp/debug' # if os.path.exists(debug_file): enter ipdb
result_file = 'result.csv'
max_epoch = 10
lr = 0.1 # initial learning rate
lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay
weight_decay = 1e-4 # 损失函数
可配置的参数主要包括:
- 数据集参数(文件路径、batch_size 等)
- 训练参数(学习率、训练 epoch 等)
- 模型参数
这样我们在程序中就可以这样使用:
import models
from config import DefaultConfig
opt = DefaultConfig()
lr = opt.lr
model = getattr(models, opt.model)
dataset = DogCat(opt.train_data_root)
这些都只是默认参数,在这里还提供了更新函数,根据字典更新配置参数。
def parse(self, kwargs):
'''
根据字典 kwargs 更新 config 参数
'''
# 更新配置参数
for k, v in kwargs.items():
if not hasattr(self, k):
# 警告还是报错,取决于你个人的喜好
warnings.warn("Warning: opt has not attribut %s" %k)
setattr(self, k, v)
# 打印配置信息
print('user config:')
for k, v in self.__class__.__dict__.items():
if not k.startswith('__'):
print(k, getattr(self, k))
这样我们在实际使用时,并不需要每次都修改 config.py
,只需要通过命令行传入所需参数,覆盖默认配置即可。
例如:
opt = DefaultConfig()
new_config = {'lr':0.1,'use_gpu':False}
opt.parse(new_config)
opt.lr == 0.1
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论