Transformers 保存并加载模型 | 八

做者|huggingface
编译|VK
来源|Githubhtml

本节说明如何保存和从新加载微调模型(BERT,GPT,GPT-2和Transformer-XL)。你须要保存三种文件类型才能从新加载通过微调的模型:python

  • 模型自己应该是PyTorch序列化保存的模型(https://pytorch.org/docs/stab...
  • 模型的配置文件是保存为JSON文件
  • 词汇表(以及基于GPT和GPT-2合并的BPE的模型)。

这些文件的默认文件名以下:json

  • 模型权重文件:pytorch_model.bin
  • 配置文件:config.json
  • 词汇文件:vocab.txt表明BERT和Transformer-XL,vocab.json表明GPT/GPT-2(BPE词汇),
  • 表明GPT/GPT-2(BPE词汇)额外的合并文件:merges.txt

若是使用这些默认文件名保存模型,则能够使用from_pretrained()方法从新加载模型和tokenizer。 分布式

这是保存模型,配置和配置文件的推荐方法。词汇到output_dir目录,而后从新加载模型和tokenizer:.net

from transformers import WEIGHTS_NAME, CONFIG_NAME

output_dir = "./models/"

# 步骤1:保存一个通过微调的模型、配置和词汇表

#若是咱们有一个分布式模型,只保存封装的模型
#它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model
#若是使用预约义的名称保存,则能够使用`from_pretrained`加载
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)

# 步骤2: 从新加载保存的模型

#Bert模型示例
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case)  # Add specific options if needed
#GPT模型示例
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)

若是要为每种类型的文件使用特定路径,则能够使用另外一种方法保存和从新加载模型:code

output_model_file = "./models/my_own_model_file.bin"
output_config_file = "./models/my_own_config_file.bin"
output_vocab_file = "./models/my_own_vocab_file.bin"

# 步骤1:保存一个通过微调的模型、配置和词汇表

#若是咱们有一个分布式模型,只保存封装的模型
#它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_vocab_file)

# 步骤2: 从新加载保存的模型

# 咱们没有使用预约义权重名称、配置名称进行保存,没法使用`from_pretrained`进行加载。
# 下面是在这种状况下的操做方法:

#Bert模型示例
config = BertConfig.from_json_file(output_config_file)
model = BertForQuestionAnswering(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)

#GPT模型示例
config = OpenAIGPTConfig.from_json_file(output_config_file)
model = OpenAIGPTDoubleHeadsModel(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = OpenAIGPTTokenizer(output_vocab_file)

原文连接:https://huggingface.co/transf...orm

欢迎关注磐创AI博客站:
http://panchuang.net/htm

OpenCV中文官方文档:
http://woshicver.com/token

欢迎关注磐创博客资源汇总站:
http://docs.panchuang.net/ci

相关文章
相关标签/搜索