电商多模态图文检索 - Chinese-CLIP baseline 全流程示例
Chinese-CLIP 是 OpenAI CLIP 模型的中文版本,基于 OpenAI CLIP 的视觉侧参数,继续使用大规模中文原生图文数据,进行如下图所示的两阶段预训练(~2 亿中文图文对)。旨在帮助用户快速实现中文领域的图文特征&相似度计算、跨模态检索、零样本图片分类等任务,在 电商多模态图文检索挑战赛 (即 MUGE 检索数据集)效果显著。我们这里提供一个基于 base 规模的 baseline 全流程实现,方便同学们更简单上手,可以在验证集取得~75 左右的 Mean Recall 分数。如果大家想要深入了解和改进模型代码,提高比赛分数,欢迎在 Github 开源 repo 了解更多细节~
代码 Github 开源链接:https://github.com/OFA-Sys/Chinese-CLIP
技术报告:https://arxiv.org/abs/2211.01335
电商多模态图文检索挑战赛主页:https://tianchi.aliyun.com/competition/entrance/532031/introduction
说明
- 基于此 notebook,可以轻松实现利用 Chinese-CLIP 图文预训练双塔表征模型,在电商多模态图文检索挑战赛完成 finetune 及验证集/测试集预测
- 由于对比学习对于训练 batch 大小有一定的基本要求,此 notebook 运行至少需要 8GB 显存,请大家在 GPU 显存满足此要求的运行环境下下载并运行此 notebook( 下载链接 ),尝试 baseline 流程。目前天池实验室 Notebook 环境暂不支持使用这个规模的显存,达摩院 ModelScope 平台 提供了一定的 GPU 机时(进入链接后右上角"在 Notebook 打开"),资源紧张的同学可以考虑利用
- 请保证环境 Pytorch 版本至少在 1.8.0 及以上
- Notebook 中的流程较为简单,主要方便参赛选手上手跑通一个不错的 base 规模(视觉侧 ViT-B/16,文本侧 Roberta-base)baseline。如果要继续调优或改进模型,建议进一步详细参考 github 开源源码(尤其是 readme 跨模态检索部分的文档 )和 Chinese-CLIP 技术报告
准备工作
1. Clone 代码库并准备数据目录
# clone Chinese-CLIP 代码库到用户目录(如果因网络问题卡在 git clone 这步,请 interrupt 该步并重试几次):
!git clone https://github.com/OFA-Sys/Chinese-CLIP.git
# 创建一个 datapath 文件夹,用于存放数据集和 Chinese-CLIP 参数
!mkdir -p datapath
!mkdir -p datapath/pretrained_weights # 存放预训练参数
!mkdir -p datapath/datasets # 存放数据集
!mkdir -p datapath/experiments # 存放 finetune 参数和日志
!tree .
Cloning into 'Chinese-CLIP'...
remote: Enumerating objects: 990, done.[K
remote: Counting objects: 100% (364/364), done.[K
remote: Compressing objects: 100% (131/131), done.[K
remote: Total 990 (delta 328), reused 233 (delta 233), pack-reused 626[K
Receiving objects: 100% (990/990), 402.11 KiB | 593.00 KiB/s, done.
Resolving deltas: 100% (633/633), done.
Updating files: 100% (54/54), done.
[01;34m.[0m
├── [01;34mChinese-CLIP[0m
│ ├── [01;34massets[0m
│ │ └── [00mChinese_CLIP_logo_tp_path.svg[0m
│ ├── [01;34mcn_clip[0m
│ │ ├── [01;34mclip[0m
│ │ │ ├── [00mbert_tokenizer.py[0m
│ │ │ ├── [00mconfiguration_bert.py[0m
│ │ │ ├── [00m__init__.py[0m
│ │ │ ├── [01;34mmodel_configs[0m
│ │ │ │ ├── [00mRBT3-chinese.json[0m
│ │ │ │ ├── [00mRN50.json[0m
│ │ │ │ ├── [00mRoBERTa-wwm-ext-base-chinese.json[0m
│ │ │ │ ├── [00mRoBERTa-wwm-ext-large-chinese.json[0m
│ │ │ │ ├── [00mViT-B-16.json[0m
│ │ │ │ ├── [00mViT-B-32.json[0m
│ │ │ │ ├── [00mViT-H-14.json[0m
│ │ │ │ ├── [00mViT-L-14-336.json[0m
│ │ │ │ └── [00mViT-L-14.json[0m
│ │ │ ├── [00mmodeling_bert.py[0m
│ │ │ ├── [00mmodel.py[0m
│ │ │ ├── [00mutils.py[0m
│ │ │ └── [00mvocab.txt[0m
│ │ ├── [01;34meval[0m
│ │ │ ├── [00mcvinw_zeroshot_templates.py[0m
│ │ │ ├── [00mdata.py[0m
│ │ │ ├── [00mevaluation.py[0m
│ │ │ ├── [00mevaluation_tr.py[0m
│ │ │ ├── [00mextract_features.py[0m
│ │ │ ├── [00mimagenet_zeroshot_templates.py[0m
│ │ │ ├── [00m__init__.py[0m
│ │ │ ├── [00mmake_topk_predictions.py[0m
│ │ │ ├── [00mmake_topk_predictions_tr.py[0m
│ │ │ ├── [00mtransform_ir_annotation_to_tr.py[0m
│ │ │ └── [00mzeroshot_evaluation.py[0m
│ │ ├── [00m__init__.py[0m
│ │ ├── [01;34mpreprocess[0m
│ │ │ ├── [00mbuild_lmdb_dataset.py[0m
│ │ │ ├── [00m__init__.py[0m
│ │ │ └── [00mtransform_openai_pretrain_weights.py[0m
│ │ └── [01;34mtraining[0m
│ │ ├── [00mdata.py[0m
│ │ ├── [00m__init__.py[0m
│ │ ├── [00mlogger.py[0m
│ │ ├── [00mmain.py[0m
│ │ ├── [00mparams.py[0m
│ │ ├── [00mscheduler.py[0m
│ │ └── [00mtrain.py[0m
│ ├── [01;34mexamples[0m
│ │ └── [01;35mpokemon.jpeg[0m
│ ├── [00mMIT-LICENSE.txt[0m
│ ├── [00mREADME_En.md[0m
│ ├── [00mREADME.md[0m
│ ├── [00mrequirements.txt[0m
│ ├── [00mResults.md[0m
│ ├── [01;34mrun_scripts[0m
│ │ ├── [00mflickr30k_finetune_vit-b-16_rbt-base_flip.sh[0m
│ │ ├── [00mflickr30k_finetune_vit-b-16_rbt-base.sh[0m
│ │ ├── [00mmuge_finetune_vit-b-16_rbt-base_flip.sh[0m
│ │ ├── [00mmuge_finetune_vit-b-16_rbt-base.sh[0m
│ │ └── [00mzeroshot_eval.sh[0m
│ ├── [00msetup.py[0m
│ ├── [00mzeroshot_dataset_en.md[0m
│ └── [00mzeroshot_dataset.md[0m
├── [00mChinese-CLIP-on-Retrieval.ipynb[0m
└── [01;34mdatapath[0m
├── [01;34mdatasets[0m
├── [01;34mexperiments[0m
└── [01;34mpretrained_weights[0m
14 directories, 54 files
2. 安装 Chinese-CLIP 相关依赖
# 安装 Chinese-CLIP 依赖包
!pip install -r Chinese-CLIP/requirements.txt
Looking in indexes: https://mirrors.aliyun.com/pypi/simple/
Requirement already satisfied: numpy in /home/pai/lib/python3.6/site-packages (from -r Chinese-CLIP/requirements.txt (line 1)) (1.19.5)
Requirement already satisfied: tqdm in /home/pai/lib/python3.6/site-packages (from -r Chinese-CLIP/requirements.txt (line 2)) (4.64.0)
Requirement already satisfied: six in /home/pai/lib/python3.6/site-packages (from -r Chinese-CLIP/requirements.txt (line 3)) (1.16.0)
Requirement already satisfied: timm in /home/pai/lib/python3.6/site-packages (from -r Chinese-CLIP/requirements.txt (line 4)) (0.6.12)
Requirement already satisfied: lmdb==1.3.0 in /home/pai/lib/python3.6/site-packages (from -r Chinese-CLIP/requirements.txt (line 5)) (1.3.0)
Requirement already satisfied: torch>=1.7.1 in /home/pai/lib/python3.6/site-packages (from -r Chinese-CLIP/requirements.txt (line 6)) (1.10.2+cu113)
Requirement already satisfied: torchvision in /home/pai/lib/python3.6/site-packages (from -r Chinese-CLIP/requirements.txt (line 7)) (0.11.3+cu113)
Requirement already satisfied: importlib-resources in /home/pai/lib/python3.6/site-packages (from tqdm->-r Chinese-CLIP/requirements.txt (line 2)) (5.4.0)
Requirement already satisfied: huggingface-hub in /home/pai/lib/python3.6/site-packages (from timm->-r Chinese-CLIP/requirements.txt (line 4)) (0.4.0)
Requirement already satisfied: pyyaml in /home/pai/lib/python3.6/site-packages (from timm->-r Chinese-CLIP/requirements.txt (line 4)) (5.4.1)
Requirement already satisfied: dataclasses in /home/pai/lib/python3.6/site-packages (from torch>=1.7.1->-r Chinese-CLIP/requirements.txt (line 6)) (0.8)
Requirement already satisfied: typing-extensions in /home/pai/lib/python3.6/site-packages (from torch>=1.7.1->-r Chinese-CLIP/requirements.txt (line 6)) (3.7.4.3)
Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /home/pai/lib/python3.6/site-packages (from torchvision->-r Chinese-CLIP/requirements.txt (line 7)) (8.4.0)
Requirement already satisfied: requests in /home/pai/lib/python3.6/site-packages (from huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (2.18.4)
Requirement already satisfied: importlib-metadata in /home/pai/lib/python3.6/site-packages (from huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (4.8.1)
Requirement already satisfied: packaging>=20.9 in /home/pai/lib/python3.6/site-packages (from huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (21.3)
Requirement already satisfied: filelock in /home/pai/lib/python3.6/site-packages (from huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (3.4.1)
Requirement already satisfied: zipp>=3.1.0 in /home/pai/lib/python3.6/site-packages (from importlib-resources->tqdm->-r Chinese-CLIP/requirements.txt (line 2)) (3.6.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/pai/lib/python3.6/site-packages (from packaging>=20.9->huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (3.0.9)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /home/pai/lib/python3.6/site-packages (from requests->huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (3.0.4)
Requirement already satisfied: idna<2.7,>=2.5 in /home/pai/lib/python3.6/site-packages (from requests->huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (2.6)
Requirement already satisfied: urllib3<1.23,>=1.21.1 in /home/pai/lib/python3.6/site-packages (from requests->huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (1.22)
Requirement already satisfied: certifi>=2017.4.17 in /home/pai/lib/python3.6/site-packages (from requests->huggingface-hub->timm->-r Chinese-CLIP/requirements.txt (line 4)) (2021.5.30)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
# 查看当前 kernel 下已安装的包 list packages
!pip list --format=columns
Package Version
---------------------------- ------------
absl-py 0.15.0
aiohttp 3.8.1
aiosignal 1.2.0
alibabacloud-credentials 0.2.0
alibabacloud-endpoint-util 0.0.3
alibabacloud-gateway-spi 0.0.1
alibabacloud-openapi-util 0.1.6
alibabacloud-pai-dlc20201203 1.0.0
alibabacloud-tea 0.2.9
alibabacloud-tea-openapi 0.3.3
alibabacloud-tea-util 0.3.5
alibabacloud-tea-xml 0.0.2
alipai 0.1.7
aliyun-log-python-sdk 0.7.9
aliyun-python-sdk-core 2.13.36
aliyun-python-sdk-kms 2.15.0
aliyun-python-sdk-sts 3.1.0
apache-beam 2.15.0
apache-flink 1.10.1
argcomplete 2.0.0
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
asn1crypto 0.24.0
astor 0.8.1
astroid 2.11.5
astunparse 1.6.3
async-generator 1.10
async-timeout 4.0.2
asynctest 0.13.0
attrs 21.4.0
autopep8 1.6.0
avro-python3 1.9.1
backcall 0.2.0
backports.zoneinfo 0.2.1
bleach 4.1.0
brotlipy 0.7.0
cached-property 1.5.2
cachetools 4.2.4
certifi 2021.5.30
cffi 1.15.0
chardet 3.0.4
charset-normalizer 2.0.12
clang 5.0
cloudpickle 1.2.2
colorama 0.4.4
common-io 0.3.0
conda 4.10.3
conda-package-handling 1.7.3
configparser 5.2.0
contextlib2 21.6.0
crcmod 1.7
cryptography 37.0.2
cvxopt 1.3.0
cycler 0.11.0
Cython 0.29.24
dataclasses 0.8
dateparser 1.1.1
decorator 4.4.2
defusedxml 0.7.1
deprecation 2.1.0
dill 0.2.9
docopt 0.6.2
eas-prediction 0.13
easy-rec 0.1.6
elastic-transport 8.1.2
elasticsearch 8.2.0
entrypoints 0.4
fastavro 0.21.24
filelock 3.4.1
flake8 4.0.1
flatbuffers 1.12
frozenlist 1.2.0
future 0.18.2
gast 0.4.0
google-auth 1.35.0
google-auth-oauthlib 0.4.6
google-pasta 0.2.0
graphviz 0.19.1
grpcio 1.46.3
h5py 3.1.0
hdfs 2.7.0
htmlmin 0.1.12
httplib2 0.12.0
huggingface-hub 0.4.0
hyperopt 0.1.2
idna 2.6
idna-ssl 1.1.0
ImageHash 4.2.1
importlib-metadata 4.8.1
importlib-resources 5.4.0
ipykernel 5.5.6
ipython 7.16.3
ipython-genutils 0.2.0
ipywidgets 7.7.0
isort 5.10.1
jedi 0.17.2
Jinja2 3.0.3
jmespath 0.10.0
joblib 1.1.0
json-tricks 3.15.5
jsonschema 4.0.0
jupyter-client 7.1.2
jupyter-core 4.9.2
jupyterlab-pygments 0.1.2
jupyterlab-widgets 1.1.0
keras 2.6.0
Keras-Preprocessing 1.1.2
kiwisolver 1.3.1
lazy-object-proxy 1.7.1
llvmlite 0.36.0
lmdb 1.3.0
mamba 0.15.3
Markdown 3.3.7
MarkupSafe 2.0.1
matplotlib 3.3.4
mccabe 0.6.1
minio 7.1.8
missingno 0.4.2
mistune 0.8.4
mock 2.0.0
multidict 5.2.0
multimethod 1.4
nbclient 0.5.9
nbconvert 6.0.7
nbformat 5.1.3
nest-asyncio 1.5.5
networkx 2.7.1
notebook 6.4.10
numba 0.53.1
numpy 1.19.5
oauth2client 3.0.0
oauthlib 3.2.0
olefile 0.46
opencv-contrib-python 4.5.5.64
opencv-python 4.5.5.64
opt-einsum 3.3.0
oss2 2.15.0
packaging 21.3
pai-nni 2.6
pandas 1.1.5
pandas-profiling 3.1.0
pandocfilters 1.5.0
parso 0.7.1
patsy 0.5.2
pbr 5.9.0
pexpect 4.8.0
phik 0.11.2
pickleshare 0.7.5
Pillow 8.4.0
pip 21.3.1
platformdirs 2.4.0
plotly 5.8.0
prettytable 2.5.0
prometheus-client 0.14.1
prompt-toolkit 3.0.29
protobuf 3.19.4
psutil 5.9.1
ptyprocess 0.7.0
py4j 0.10.8.1
pyalink-public 1.5.1
pyarrow 0.13.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pybind11 2.7.0
pycodestyle 2.8.0
pycosat 0.6.3
pycparser 2.21
pycryptodome 3.14.1
pydantic 1.8.2
pydot 1.4.2
pyflakes 2.4.0
Pygments 2.12.0
pylint 2.13.9
pymongo 3.12.3
pyodps 0.11.0
pyOpenSSL 18.0.0
pyparsing 3.0.9
pyrsistent 0.18.0
PySocks 1.6.8
python-dateutil 2.8.2
PythonWebHDFS 0.2.3
pytz 2022.1
pytz-deprecation-shim 0.1.0.post0
PyWavelets 1.1.1
PyYAML 5.4.1
pyzmq 23.0.0
regex 2022.3.2
requests 2.18.4
requests-oauthlib 1.3.1
responses 0.17.0
rsa 4.8
ruamel_yaml 0.15.37
ruamel-yaml-conda 0.15.100
Rx 3.2.0
schema 0.7.5
scikit-learn 0.24.2
scipy 1.5.4
seaborn 0.11.2
Send2Trash 1.8.0
setuptools 58.0.4
simplejson 3.17.6
six 1.16.0
statsmodels 0.12.2
tangled-up-in-unicode 0.1.0
tenacity 8.0.1
tensorboard 2.6.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow 2.6.2
tensorflow-estimator 2.6.0
tensorflow-gpu 2.6.0
tensorflow-io 0.21.0
tensorflow-io-gcs-filesystem 0.21.0
termcolor 1.1.0
terminado 0.13.0
testpath 0.6.0
threadpoolctl 3.1.0
timm 0.6.12
toml 0.10.2
tomli 1.2.3
torch 1.10.2+cu113
torchaudio 0.10.2+cu113
torchvision 0.11.3+cu113
tornado 6.1
tqdm 4.64.0
traitlets 4.3.3
typed-ast 1.5.4
typing-extensions 3.7.4.3
tzdata 2022.1
tzlocal 4.2
urllib3 1.22
visions 0.7.4
wcwidth 0.2.5
webencodings 0.5.1
websockets 9.1
Werkzeug 2.0.3
wheel 0.37.1
widgetsnbextension 3.6.0
wrapt 1.12.1
xgboost 1.5.2
xlrd 2.0.1
xmltodict 0.13.0
yarl 1.7.2
yq 2.13.0
zipp 3.6.0
3. 下载 Chinese-CLIP 预训练参数
# 下载 base 规模 Chinese-CLIP 预训练参数:
!wget https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-b-16.pt
!mv clip_cn_vit-b-16.pt datapath/pretrained_weights/
--2023-01-13 15:38:37-- https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-b-16.pt
Resolving clip-cn-beijing.oss-cn-beijing.aliyuncs.com... 59.110.190.128
Connecting to clip-cn-beijing.oss-cn-beijing.aliyuncs.com|59.110.190.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 753196934 (718M) [application/octet-stream]
Saving to: ‘clip_cn_vit-b-16.pt’
clip_cn_vit-b-16.pt 100%[===================>] 718.30M 20.7MB/s in 40s
2023-01-13 15:39:17 (18.1 MB/s) - ‘clip_cn_vit-b-16.pt’ saved [753196934/753196934]
4. 预处理数据
# 详细的预处理数据流程,请参见 github readme: https://github.com/OFA-Sys/Chinese-CLIP#数据集格式预处理
# 我们这里方便起见,直接下载已经预处理好(处理成 LMDB 格式)的电商检索数据集
!wget https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip
!mv MUGE.zip datapath/datasets/
%cd datapath/datasets/
!unzip MUGE.zip
%cd ../..
!tree datapath/
--2023-01-13 15:41:28-- https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/datasets/MUGE.zip
Resolving clip-cn-beijing.oss-cn-beijing.aliyuncs.com... 59.110.190.128
Connecting to clip-cn-beijing.oss-cn-beijing.aliyuncs.com|59.110.190.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2166554702 (2.0G) [application/zip]
Saving to: ‘MUGE.zip’
MUGE.zip 100%[===================>] 2.02G 26.7MB/s in 89s
2023-01-13 15:42:58 (23.1 MB/s) - ‘MUGE.zip’ saved [2166554702/2166554702]
/mnt/workspace/tianchi/datapath/datasets
Archive: MUGE.zip
creating: MUGE/
creating: MUGE/lmdb/
creating: MUGE/lmdb/train/
creating: MUGE/lmdb/train/imgs/
inflating: MUGE/lmdb/train/imgs/data.mdb
inflating: MUGE/lmdb/train/imgs/lock.mdb
creating: MUGE/lmdb/train/pairs/
inflating: MUGE/lmdb/train/pairs/data.mdb
inflating: MUGE/lmdb/train/pairs/lock.mdb
creating: MUGE/lmdb/valid/
creating: MUGE/lmdb/valid/imgs/
inflating: MUGE/lmdb/valid/imgs/data.mdb
inflating: MUGE/lmdb/valid/imgs/lock.mdb
creating: MUGE/lmdb/valid/pairs/
inflating: MUGE/lmdb/valid/pairs/data.mdb
inflating: MUGE/lmdb/valid/pairs/lock.mdb
inflating: MUGE/train_imgs.tsv
inflating: MUGE/train_texts.jsonl
inflating: MUGE/valid_imgs.tsv
inflating: MUGE/valid_texts.jsonl
/mnt/workspace/tianchi
[01;34mdatapath/[0m
├── [01;34mdatasets[0m
│ ├── [01;34mMUGE[0m
│ │ ├── [01;34mlmdb[0m
│ │ │ ├── [01;34mtrain[0m
│ │ │ │ ├── [01;34mimgs[0m
│ │ │ │ │ ├── [00mdata.mdb[0m
│ │ │ │ │ └── [00mlock.mdb[0m
│ │ │ │ └── [01;34mpairs[0m
│ │ │ │ ├── [00mdata.mdb[0m
│ │ │ │ └── [00mlock.mdb[0m
│ │ │ └── [01;34mvalid[0m
│ │ │ ├── [01;34mimgs[0m
│ │ │ │ ├── [00mdata.mdb[0m
│ │ │ │ └── [00mlock.mdb[0m
│ │ │ └── [01;34mpairs[0m
│ │ │ ├── [00mdata.mdb[0m
│ │ │ └── [00mlock.mdb[0m
│ │ ├── [00mtrain_imgs.tsv[0m
│ │ ├── [00mtrain_texts.jsonl[0m
│ │ ├── [00mvalid_imgs.tsv[0m
│ │ └── [00mvalid_texts.jsonl[0m
│ └── [01;31mMUGE.zip[0m
├── [01;34mexperiments[0m
└── [01;34mpretrained_weights[0m
└── [00mclip_cn_vit-b-16.pt[0m
11 directories, 14 files
# 查看数据样例
!head -n 1 datapath/datasets/MUGE/valid_texts.jsonl # 验证集第一条样本
import lmdb
import base64
from io import BytesIO
from PIL import Image
image_ids = [286314, 141999, 183846]
lmdb_imgs = "datapath/datasets/MUGE/lmdb/valid/imgs"
env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False)
txn_imgs = env_imgs.begin(buffers=True)
for image_id in image_ids:
image_b64 = txn_imgs.get("{}".format(image_id).encode('utf-8')).tobytes()
img = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64)))
img.show()
{"text_id": 248816, "text": "圣诞 抱枕", "image_ids": [1006938, 561749, 936929, 286314, 141999, 183846]}
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD2B7E261D0>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD29B384198>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD29B384208>
运行 finetune 训练
# 进入代码工作区目录
%cd Chinese-CLIP/
/mnt/workspace/tianchi/Chinese-CLIP
下面我们开始运行 finetune,说明如下:
- 下面的代码改编自代码库中的 shell 运行脚本: Chinese-CLIP/runscripts/mugefinetunevit-b-16rbt-base.sh
- 详细的 finetune 各配置项,请参见: https://github.com/OFA-Sys/Chinese-CLIP#模型 finetune
- 执行训练的 python 代码,请参见: Chinese-CLIP/cn_clip/training/main.py
- 模型实现的 python 代码,请参见: Chinese-CLIP/cn_clip/clip/model.py 以及比赛官方学习视频对于模型实现的介绍
- 按照如下的超参进行 finetune,要求机器需要至少达到 8G 显存,finetune 时间大约需要 40min 左右(V100 上测试)
- 请注意 finetune 中间给出的验证集准确率,是验证集 batch 内部,图文对之间的 Recall@1 正确率,与比赛指标要求的从整个验证集/测试集图片池召回的 Recall@1 并不是同一个指标。这里仅用于观察训练趋势,如果要评测模型效果,请完整运行下文的特征提取、KNN 召回和计算 Recall 流程
- 对比学习的训练收敛和稳定性和总 batch size 相关。如您使用较小的 batch size(比如我们这里的 48 per-GPU * 1 GPU),建议使用较小的学习率(如我们这里的 3e-6)。我们推荐使用更多的 GPU 和更大的 batch size 以取得更好的效果。
- 训练完成后如果输出打印 EOFError,无视即可
# 准备 finetune 相关配置,详见 https://github.com/OFA-Sys/Chinese-CLIP#模型 finetune
# 指定机器数 & 卡数
GPUS_PER_NODE=1 # 卡数
WORKER_CNT=1 # 机器数
MASTER_ADDR="localhost"
MASTER_PORT=8514 # 同台机器同时起多个任务,请分别分配不同的端口号
RANK=0
# 刚刚创建过的目录,存放了预训练参数和预处理好的数据集
DATAPATH="../datapath/"
# 指定 LMDB 格式的训练集和验证集路径(存放了 LMDB 格式的图片和图文对数据)
train_data=f"{DATAPATH}/datasets/MUGE/lmdb/train"
val_data=f"{DATAPATH}/datasets/MUGE/lmdb/valid"
num_workers=4 # 训练集 pytorch dataloader 的进程数,设置为>0,以减小训练时读取数据的时间开销
valid_num_workers=4 # 验证集 pytorch dataloader 的进程数,设置为>0,以减小验证时读取数据的时间开销
# 指定刚刚下载好的 Chinese-CLIP 预训练权重的路径
resume=f"{DATAPATH}/pretrained_weights/clip_cn_vit-b-16.pt"
reset_data_offset="--reset-data-offset" # 从头读取训练数据
reset_optimizer="--reset-optimizer" # 重新初始化 AdamW 优化器
# 指定输出相关配置
output_base_dir=f"{DATAPATH}/experiments/"
name="muge_finetune_vit-b-16_roberta-base_bs48_1gpu" # finetune 超参、日志、ckpt 将保存在../datapath/experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/
save_step_frequency=999999 # disable it
save_epoch_frequency=1 # 每轮保存一个 finetune ckpt
log_interval=10 # 日志打印间隔步数
report_training_batch_acc="--report-training-batch-acc" # 训练中,报告训练 batch 的 in-batch 准确率
# 指定训练超参数
context_length=52 # 序列长度,这里指定为 Chinese-CLIP 默认的 52
warmup=100 # warmup 步数
batch_size=48 # 训练单卡 batch size
valid_batch_size=48 # 验证单卡 batch size
lr=3e-6 # 学习率,因为这里我们使用的对比学习 batch size 很小,所以对应的学习率也调低一些
wd=0.001 # weight decay
max_epochs=1 # 训练轮数,也可通过--max-steps 指定训练步数
valid_step_interval=1000 # 验证步数间隔
valid_epoch_interval=1 # 验证轮数间隔
vision_model="ViT-B-16" # 指定视觉侧结构为 ViT-B/16
text_model="RoBERTa-wwm-ext-base-chinese" # 指定文本侧结构为 RoBERTa-base
use_augment="--use-augment" # 对图像使用数据增强
grad_checkpointing="--grad-checkpointing" # 激活重计算策略,用更多训练时间换取更小的显存开销
run_command = "export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;" + \
f"""
python3 -m torch.distributed.launch --nproc_per_node={GPUS_PER_NODE} --nnodes={WORKER_CNT} --node_rank={RANK} \
--master_addr={MASTER_ADDR} --master_port={MASTER_PORT} cn_clip/training/main.py \
--train-data={train_data} \
--val-data={val_data} \
--num-workers={num_workers} \
--valid-num-workers={valid_num_workers} \
--resume={resume} \
{reset_data_offset} \
{reset_optimizer} \
--logs={output_base_dir} \
--name={name} \
--save-step-frequency={save_step_frequency} \
--save-epoch-frequency={save_epoch_frequency} \
--log-interval={log_interval} \
{report_training_batch_acc} \
--context-length={context_length} \
--warmup={warmup} \
--batch-size={batch_size} \
--valid-batch-size={valid_batch_size} \
--valid-step-interval={valid_step_interval} \
--valid-epoch-interval={valid_epoch_interval} \
--lr={lr} \
--wd={wd} \
--max-epochs={max_epochs} \
--vision-model={vision_model} \
{use_augment} \
{grad_checkpointing} \
--text-model={text_model}
""".lstrip()
print(run_command)
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;python3 -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=8514 cn_clip/training/main.py --train-data=../datapath//datasets/MUGE/lmdb/train --val-data=../datapath//datasets/MUGE/lmdb/valid --num-workers=4 --valid-num-workers=4 --resume=../datapath//pretrained_weights/clip_cn_vit-b-16.pt --reset-data-offset --reset-optimizer --logs=../datapath//experiments/ --name=muge_finetune_vit-b-16_roberta-base_bs48_1gpu --save-step-frequency=999999 --save-epoch-frequency=1 --log-interval=10 --report-training-batch-acc --context-length=52 --warmup=100 --batch-size=48 --valid-batch-size=48 --valid-step-interval=1000 --valid-epoch-interval=1 --lr=3e-06 --wd=0.001 --max-epochs=1 --vision-model=ViT-B-16 --use-augment --grad-checkpointing --text-model=RoBERTa-wwm-ext-base-chinese
# 执行 finetune 流程
!{run_command}
/home/pai/lib/python3.6/site-packages/torch/distributed/launch.py:186: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use_env is set by default in torchrun.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See
https://pytorch.org/docs/stable/distributed.html#launch-utility for
further instructions
FutureWarning,
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
from cryptography import x509
Loading vision model config from cn_clip/clip/model_configs/ViT-B-16.json
Loading text model config from cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json
…
2023-01-14,20:32:11 | INFO | Rank 0 | Grad-checkpointing activated.
2023-01-14,21:16:01 | INFO | Rank 0 | Global Steps: 5210/5215 | Train Epoch: 1 [250080/250320 (100%)] | Loss: 0.286732 | Image2Text Acc: 93.75 | Text2Image Acc: 89.58 | Data Time: 0.021s | Batch Time: 0.367s | LR: 0.000000 | logit_scale: 4.605 | Global Batch Size: 48
2023-01-14,21:16:03 | INFO | Rank 0 | Begin to eval on validation set (epoch 1 @ 5215 steps)...
2023-01-14,21:16:27 | INFO | Rank 0 | Evaluated 100/638 batches...
2023-01-14,21:16:49 | INFO | Rank 0 | Evaluated 200/638 batches...
2023-01-14,21:17:11 | INFO | Rank 0 | Evaluated 300/638 batches...
2023-01-14,21:17:34 | INFO | Rank 0 | Evaluated 400/638 batches...
2023-01-14,21:17:56 | INFO | Rank 0 | Evaluated 500/638 batches...
2023-01-14,21:18:19 | INFO | Rank 0 | Evaluated 600/638 batches...
2023-01-14,21:18:27 | INFO | Rank 0 | Validation Result (epoch 1 @ 5215 steps) | Valid Loss: 0.238827 | Image2Text Acc: 92.72 | Text2Image Acc: 92.00 | logit_scale: 4.605 | Valid Batch Size: 48
2023-01-14,21:18:33 | INFO | Rank 0 | Saved checkpoint ../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch1.pt (epoch 1 @ 5215 steps) (writing took 5.584172010421753 seconds)
2023-01-14,21:18:39 | INFO | Rank 0 | Saved checkpoint ../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt (epoch 1 @ 5215 steps) (writing took 5.848175048828125 seconds)
Exception in thread Thread-1:
Traceback (most recent call last):
File "/home/pai/lib/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/home/pai/lib/python3.6/threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "/home/pai/lib/python3.6/logging/handlers.py", line 1476, in _monitor
record = self.dequeue(True)
File "/home/pai/lib/python3.6/logging/handlers.py", line 1425, in dequeue
return self.queue.get(block)
File "/home/pai/lib/python3.6/multiprocessing/queues.py", line 94, in get
res = self._recv_bytes()
File "/home/pai/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/home/pai/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
buf = self._recv(4)
File "/home/pai/lib/python3.6/multiprocessing/connection.py", line 383, in _recv
raise EOFError
EOFError
验证集效果验证
为了验证模型的效果,我们提供特征提取、以及图文检索任务评估的流程。更加详尽的描述,也可参考 Github readme 的相关部分说明:https://github.com/OFA-Sys/Chinese-CLIP/blob/master/README.md#预测及评估。
1. 计算图文特征
# 为验证集图片池和 query 文本计算特征
dataset_name="MUGE"
split="valid" # 指定计算 valid 或 test 集特征
# 指定我们刚刚 finetune 的 ckpt 路径,也可以指定预训练 ckpt 路径测试 zero-shot 效果
resume=f"{DATAPATH}/experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt"
run_command = "export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;" + \
f"""
python -u cn_clip/eval/extract_features.py \
--extract-image-feats \
--extract-text-feats \
--image-data="{DATAPATH}/datasets/{dataset_name}/lmdb/{split}/imgs" \
--text-data="{DATAPATH}/datasets/{dataset_name}/{split}_texts.jsonl" \
--img-batch-size=32 \
--text-batch-size=32 \
--context-length=52 \
--resume={resume} \
--vision-model=ViT-B-16 \
--text-model=RoBERTa-wwm-ext-base-chinese
"""
print(run_command.lstrip())
!{run_command}
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;
python -u cn_clip/eval/extract_features.py --extract-image-feats --extract-text-feats --image-data="../datapath//datasets/MUGE/lmdb/valid/imgs" --text-data="../datapath//datasets/MUGE/valid_texts.jsonl" --img-batch-size=32 --text-batch-size=32 --context-length=52 --resume=../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt --vision-model=ViT-B-16 --text-model=RoBERTa-wwm-ext-base-chinese
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
from cryptography import x509
Params:
context_length: 52
debug: False
extract_image_feats: True
extract_text_feats: True
image_data: ../datapath//datasets/MUGE/lmdb/valid/imgs
image_feat_output_path: None
img_batch_size: 32
precision: amp
resume: ../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt
text_batch_size: 32
text_data: ../datapath//datasets/MUGE/valid_texts.jsonl
text_feat_output_path: None
text_model: RoBERTa-wwm-ext-base-chinese
vision_model: ViT-B-16
Loading vision model config from cn_clip/clip/model_configs/ViT-B-16.json
Loading text model config from cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json
Preparing image inference dataset.
Preparing text inference dataset.
Begin to load model checkpoint from ../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt.
=> loaded checkpoint '../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt' (epoch 1 @ 5215 steps)
Make inference for texts...
100%|█████████████████████████████████████████| 157/157 [00:07<00:00, 21.66it/s]
5008 text features are stored in ../datapath//datasets/MUGE/valid_texts.txt_feat.jsonl
Make inference for images...
100%|█████████████████████████████████████████| 932/932 [02:56<00:00, 5.28it/s]
29806 image features are stored in ../datapath//datasets/MUGE/valid_imgs.img_feat.jsonl
Done!
产出图文特征将保存于 ../datapath/datasets/MUGE
目录下,图片特征保存于 valid_imgs.img_feat.jsonl
文件,每行以 json 存储一张图片的特征,格式如下:
{"image_id": 1000002, "feature": [0.0198, …, -0.017, 0.0248]}
文本特征则保存于 valid_texts.txt_feat.jsonl
,格式如下:
{"text_id": 248816, "feature": [0.1314, …, 0.0018, -0.0002]}
2. KNN 检索
对于小规模的学术检索数据集,我们提供一个简单的 KNN 检索实现,便于计算文到图检索,在验证集 3w 图片池的 top-k 召回结果。
# 进行 KNN 检索,为验证集每个 query,匹配特征余弦相似度最高的 top-10 商品图片
split="valid" # 指定计算 valid 或 test 集特征
run_command = "export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;" + \
f"""
python -u cn_clip/eval/make_topk_predictions.py \
--image-feats="{DATAPATH}/datasets/{dataset_name}/{split}_imgs.img_feat.jsonl" \
--text-feats="{DATAPATH}/datasets/{dataset_name}/{split}_texts.txt_feat.jsonl" \
--top-k=10 \
--eval-batch-size=32768 \
--output="{DATAPATH}/datasets/{dataset_name}/{split}_predictions.jsonl"
"""
print(run_command.lstrip())
!{run_command}
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;
python -u cn_clip/eval/make_topk_predictions.py --image-feats="../datapath//datasets/MUGE/valid_imgs.img_feat.jsonl" --text-feats="../datapath//datasets/MUGE/valid_texts.txt_feat.jsonl" --top-k=10 --eval-batch-size=32768 --output="../datapath//datasets/MUGE/valid_predictions.jsonl"
Params:
eval_batch_size: 32768
image_feats: ../datapath//datasets/MUGE/valid_imgs.img_feat.jsonl
output: ../datapath//datasets/MUGE/valid_predictions.jsonl
text_feats: ../datapath//datasets/MUGE/valid_texts.txt_feat.jsonl
top_k: 10
Begin to load image features...
29806it [00:07, 3768.32it/s]
Finished loading image features.
Begin to compute top-10 predictions for texts...
5008it [02:40, 31.17it/s]
Top-10 predictions are saved in ../datapath//datasets/MUGE/valid_predictions.jsonl
Done!
产出的结果保存在指定的 jsonl 文件中,每行表示一个文本召回的 top-k 图片 id(按模型相关性预测分数由大到小排好了序),格式如下,这个格式和我们比赛要求的提交格式是一样的:
!head -n 5 ../datapath/datasets/MUGE/valid_predictions.jsonl
{"text_id": 248816, "image_ids": [286314, 183846, 141999, 936929, 1006938, 162268, 412925, 384823, 877108, 103269]}
{"text_id": 248859, "image_ids": [574241, 175548, 678269, 1059854, 708768, 1003303, 854191, 913653, 495295, 153682]}
{"text_id": 248871, "image_ids": [60015, 807177, 160459, 706996, 161417, 666622, 637011, 77996, 690344, 783255]}
{"text_id": 248898, "image_ids": [946372, 397154, 642386, 777624, 27450, 829271, 222468, 420283, 323919, 937205]}
{"text_id": 248931, "image_ids": [225106, 139500, 349934, 941660, 515959, 818516, 646440, 979324, 94740, 1024375]}
我们具体观察其中一组 case 的预测结果:
# 查看数据样例
!sed -n "5,5p" ../datapath/datasets/MUGE/valid_texts.jsonl # 验证集第三条样本,对应上面第 5 行
import lmdb
import base64
from io import BytesIO
from PIL import Image
image_ids = [225106, 139500, 349934, 941660, 515959] # 模型预测的 top-5 相关图片
lmdb_imgs = "../datapath/datasets/MUGE/lmdb/valid/imgs"
env_imgs = lmdb.open(lmdb_imgs, readonly=True, create=False, lock=False, readahead=False, meminit=False)
txn_imgs = env_imgs.begin(buffers=True)
for image_id in image_ids:
image_b64 = txn_imgs.get("{}".format(image_id).encode('utf-8')).tobytes()
img = Image.open(BytesIO(base64.urlsafe_b64decode(image_b64)))
img.show()
{"text_id": 248931, "text": "32 不粘锅", "image_ids": [227953, 349934, 646440, 204288, 941660, 425873]}
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD2B298B240>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD29B3C36D8>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD29B3C3E80>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD29B3C3FD0>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7FD29B3C36D8>
3. Recall 计算
我们提供了评测脚本计算检索任务的 Recall@1/5/10,同时给出 mean recall(Recall@1/5/10 的平均数)。运行如下命令即可获取分数:
# 根据 top-10 预测结果,计算验证集的 Recall@1/5/10,同时给出 mean recall(Recall@1/5/10 的平均数)
split="valid" # 指定计算 valid 或 test 集特征
run_command = "export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;" + \
f"""
python cn_clip/eval/evaluation.py \
{DATAPATH}/datasets/{dataset_name}/{split}_texts.jsonl \
{DATAPATH}/datasets/{dataset_name}/{split}_predictions.jsonl \
output.json;
cat output.json
"""
print(run_command.lstrip())
!{run_command}
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;
python cn_clip/eval/evaluation.py ../datapath//datasets/MUGE/valid_texts.jsonl ../datapath//datasets/MUGE/valid_predictions.jsonl output.json;
cat output.json
Read standard from ../datapath//datasets/MUGE/valid_texts.jsonl
Read user submit file from ../datapath//datasets/MUGE/valid_predictions.jsonl
The evaluation finished successfully.
{"success": true, "score": 75.33280085197018, "scoreJson": {"score": 75.33280085197018, "mean_recall": 75.33280085197018, "r1": 56.86900958466453, "r5": 81.21006389776358, "r10": 87.91932907348243}}
可以看到验证集上,我们取得了~75 的 Mean Recall 值,同时也给出了详细的 Recall@1/5/10. 这个分数,已经比直接使用同规模预训练 ckpt 进行预测(zero-shot)的 71.1 更高,说明我们的 finetune 让模型更加适配到电商检索任务的领域数据上。通过调节不同的超参,上面 finetune 的分数还有很大的提升空间。
在我们的实验中,Chinese-CLIP 各规模 zero-shot 以及 finetune,在验证集上可以得到的分数大体如下(更多 Chinese-CLIP 在其他任务的实验结果,请详见 Github 代码库 https://github.com/OFA-Sys/Chinese-CLIP/blob/master/Results.md ):
MUGE Text-to-Image Retrieval (Official Validation Set):
Setup | Zero-shot | Finetune | ||||||
---|---|---|---|---|---|---|---|---|
Metric | R@1 | R@5 | R@10 | MR | R@1 | R@5 | R@10 | MR |
CN-CLIPRN50 | 42.6 | 68.6 | 77.9 | 63.0 | 48.6 | 75.1 | 84.0 | 69.2 |
CN-CLIPViT-B/16 | 52.1 | 76.7 | 84.4 | 71.1 | 58.4 | 83.6 | 90.0 | 77.4 |
CN-CLIPViT-L/14 | 56.3 | 79.8 | 86.2 | 74.1 | 63.3 | 85.6 | 91.3 | 80.1 |
CN-CLIPViT-L/14@336px | 59.0 | 81.4 | 87.8 | 76.1 | 65.3 | 86.7 | 92.1 | 81.3 |
CN-CLIPViT-H/14 | 63.0 | 84.1 | 89.2 | 78.8 | 68.9 | 88.7 | 93.1 | 83.6 |
准备测试集结果
下面我们在测试集上,再运行一次和验证集同样的特征计算和 KNN 召回流程,只是数据换成测试集,从而准备一份测试集预测结果,用于提交官方
# 为测试集图片池和 query 文本计算特征
dataset_name="MUGE"
split="test" # 指定计算 valid 或 test 集特征
resume=f"{DATAPATH}/experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt"
run_command = "export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;" + \
f"""
python -u cn_clip/eval/extract_features.py \
--extract-image-feats \
--extract-text-feats \
--image-data="{DATAPATH}/datasets/{dataset_name}/lmdb/{split}/imgs" \
--text-data="{DATAPATH}/datasets/{dataset_name}/{split}_texts.jsonl" \
--img-batch-size=32 \
--text-batch-size=32 \
--context-length=52 \
--resume={resume} \
--vision-model=ViT-B-16 \
--text-model=RoBERTa-wwm-ext-base-chinese
"""
print(run_command.lstrip())
!{run_command}
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;
python -u cn_clip/eval/extract_features.py --extract-image-feats --extract-text-feats --image-data="../datapath//datasets/MUGE/lmdb/test/imgs" --text-data="../datapath//datasets/MUGE/test_texts.jsonl" --img-batch-size=32 --text-batch-size=32 --context-length=52 --resume=../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt --vision-model=ViT-B-16 --text-model=RoBERTa-wwm-ext-base-chinese
/home/pai/lib/python3.6/site-packages/OpenSSL/crypto.py:12: CryptographyDeprecationWarning: Python 3.6 is no longer supported by the Python core team. Therefore, support for it is deprecated in cryptography and will be removed in a future release.
from cryptography import x509
Params:
context_length: 52
debug: False
extract_image_feats: True
extract_text_feats: True
image_data: ../datapath//datasets/MUGE/lmdb/test/imgs
image_feat_output_path: None
img_batch_size: 32
precision: amp
resume: ../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt
text_batch_size: 32
text_data: ../datapath//datasets/MUGE/test_texts.jsonl
text_feat_output_path: None
text_model: RoBERTa-wwm-ext-base-chinese
vision_model: ViT-B-16
Loading vision model config from cn_clip/clip/model_configs/ViT-B-16.json
Loading text model config from cn_clip/clip/model_configs/RoBERTa-wwm-ext-base-chinese.json
Preparing image inference dataset.
Preparing text inference dataset.
Begin to load model checkpoint from ../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt.
=> loaded checkpoint '../datapath//experiments/muge_finetune_vit-b-16_roberta-base_bs48_1gpu/checkpoints/epoch_latest.pt' (epoch 1 @ 5215 steps)
Make inference for texts...
100%|█████████████████████████████████████████| 157/157 [00:07<00:00, 21.72it/s]
5004 text features are stored in ../datapath//datasets/MUGE/test_texts.txt_feat.jsonl
Make inference for images...
100%|█████████████████████████████████████████| 950/950 [06:13<00:00, 2.54it/s]
30399 image features are stored in ../datapath//datasets/MUGE/test_imgs.img_feat.jsonl
Done!
# 进行 KNN 检索,为测试集每个 query,匹配特征余弦相似度最高的 top-10 商品图片
split="test" # 指定计算 valid 或 test 集特征
run_command = "export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;" + \
f"""
python -u cn_clip/eval/make_topk_predictions.py \
--image-feats="{DATAPATH}/datasets/{dataset_name}/{split}_imgs.img_feat.jsonl" \
--text-feats="{DATAPATH}/datasets/{dataset_name}/{split}_texts.txt_feat.jsonl" \
--top-k=10 \
--eval-batch-size=32768 \
--output="{DATAPATH}/datasets/{dataset_name}/{split}_predictions.jsonl"
"""
print(run_command.lstrip())
!{run_command}
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip;
python -u cn_clip/eval/make_topk_predictions.py --image-feats="../datapath//datasets/MUGE/test_imgs.img_feat.jsonl" --text-feats="../datapath//datasets/MUGE/test_texts.txt_feat.jsonl" --top-k=10 --eval-batch-size=32768 --output="../datapath//datasets/MUGE/test_predictions.jsonl"
Params:
eval_batch_size: 32768
image_feats: ../datapath//datasets/MUGE/test_imgs.img_feat.jsonl
output: ../datapath//datasets/MUGE/test_predictions.jsonl
text_feats: ../datapath//datasets/MUGE/test_texts.txt_feat.jsonl
top_k: 10
Begin to load image features...
30399it [00:07, 4257.35it/s]
Finished loading image features.
Begin to compute top-10 predictions for texts...
5004it [02:40, 31.19it/s]
Top-10 predictions are saved in ../datapath//datasets/MUGE/test_predictions.jsonl
Done!
!head -n 5 ../datapath/datasets/MUGE/test_predictions.jsonl
{"text_id": 342160, "image_ids": [599057, 224560, 282239, 908966, 155774, 511583, 368883, 611578, 808493, 598944]}
{"text_id": 342169, "image_ids": [686662, 191142, 344670, 1122644, 986797, 406083, 969455, 424464, 2971, 251105]}
{"text_id": 342177, "image_ids": [75011, 173661, 108900, 454071, 814416, 1114918, 331928, 474673, 98571, 1069979]}
{"text_id": 342198, "image_ids": [526032, 277415, 1015004, 443782, 612288, 862027, 944337, 780085, 1122730, 442797]}
{"text_id": 342202, "image_ids": [258830, 270450, 426942, 807698, 529722, 465381, 777309, 937790, 650457, 23176]}
该结果下载并提交到天池后,预期将取得 56.5/80.6/87.5 左右的 Recall@1/5/10,即 74.9 左右的 Mean Recall,可能由于 finetune 的随机性有一些小幅度的浮动。
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
上一篇: 金融数据分析及欺诈检测
下一篇: 彻底找到 Tomcat 启动速度慢的元凶
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论