#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch.hub import load_state_dict_from_url
# __all__ 在python中的作用是什么?
# 是用来定义模块中的公共接口的,也就是说,当你使用from xxx import *时,只有__all__中的接口会被导入。
# 如果__all__为空,那么使用from xxx import *时,只有以单个下划线开头的接口不会被导入。
# 如果__all__不存在,那么使用from xxx import *时,所有接口都会被导入。
# 一般来说,__all__的作用是用来限制from xxx import *时导入的接口,以防止不必要的接口被导入。
# 但是,__all__只对from xxx import *有效,对from xxx import yyy无效。
# 也就是说,如果你想限制from xxx import yyy时导入的接口,那么你必须在yyy前面加上单个下划线。
# 但是,如果你想限制from xxx import *时导入的接口,那么你必须在__all__中定义。
__all__ = [
"create_yolox_model",
"yolox_nano",
"yolox_tiny",
"yolox_s",
"yolox_m",
"yolox_l",
"yolox_x",
"yolov3",
"yolox_custom"
]
# _CKPT_ROOT_URL表示的是 预训练模型的地址
_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
_CKPT_FULL_PATH = {
"yolox-nano": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_nano.pth",
"yolox-tiny": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_tiny.pth",
"yolox-s": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_s.pth",
"yolox-m": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_m.pth",
"yolox-l": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_l.pth",
"yolox-x": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_x.pth",
"yolov3": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_darknet.pth",
}
# yolox_nano()函数的作用是:创建YOLOX_Nano模型
# yolox_nano()函数的参数:
# pretrained:是否加载预训练模型,num_classes:类别数,device:设备,exp_path:实验路径,ckpt_path:检查点路径
def create_yolox_model(name: str, pretrained: bool = True, num_classes: int = 80, device=None,
exp_path: str = None, ckpt_path: str = None) -> nn.Module:
"""creates and loads a YOLOX model
Args:
name (str): name of model. for example, "yolox-s", "yolox-tiny" or "yolox_custom"
if you want to load your own model.
pretrained (bool): load pretrained weights into the model. Default to True.
device (str): default device to for model. Default to None.
num_classes (int): number of model classes. Default to 80.
exp_path (str): path to your own experiment file. Required if name="yolox_custom"
ckpt_path (str): path to your own ckpt. Required if name="yolox_custom" and you want to
load a pretrained model
Returns:
YOLOX model (nn.Module)
"""
from yolox.exp import get_exp, Exp # get_exp()函数在yolox/exp.py中定义
# get_exp()函数的作用是:根据exp_file参数的值,返回一个Exp类的实例
# Exp类包含了YOLOX模型的所有参数,包括:模型的backbone、neck、head、loss等
if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# cuda:0表示使用第0块GPU
device = torch.device(device)
assert name in _CKPT_FULL_PATH or name == "yolox_custom", \
f"user should use one of value in {_CKPT_FULL_PATH.keys()} or \"yolox_custom\""
# 如果name不在_CKPT_FULL_PATH中,那么就报错
if name in _CKPT_FULL_PATH:
exp: Exp = get_exp(exp_name=name)
# exp: Exp表示exp是Exp类的一个实例
exp.num_classes = num_classes
yolox_model = exp.get_model()
# Exp类中的get_model()函数的作用是:根据Exp类的参数,构建YOLOX模型
if pretrained and num_classes == 80:
# 如果pretrained=True,那么就加载预训练模型
weights_url = _CKPT_FULL_PATH[name]
# weights_url表示预训练模型的地址
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
# load_state_dict_from_url表示从weights_url地址下载预训练模型,map_location="cpu"表示将预训练模型加载到CPU上
# ckpt表示预训练模型的参数,全称是checkpoint,表示检查点,ckpt是一个字典
# ckpt中的key有:model、optimizer、lr_scheduler、epoch、best_ap50_95、best_ap50
# 他们的作用分别是:模型参数、优化器参数、学习率调整器参数、训练的epoch数、best_ap50_95、best_ap50
if "model" in ckpt:
ckpt = ckpt["model"]
yolox_model.load_state_dict(ckpt)
# yolox_model.load_state_dict()函数的作用是:将ckpt中的参数加载到yolox_model中
else:
assert exp_path is not None, "for a \"yolox_custom\" model exp_path must be provided"
# 如果name="yolox_custom",那么就必须提供exp_path参数
exp: Exp = get_exp(exp_file=exp_path)
yolox_model = exp.get_model()
if ckpt_path:
ckpt = torch.load(ckpt_path, map_location="cpu")
if "model" in ckpt:
ckpt = ckpt["model"]
yolox_model.load_state_dict(ckpt)
yolox_model.to(device)
# yolox_model这个类包含了YOLOX模型的所有参数,包括:模型的backbone、neck、head、loss等
# yolox_model.to(device)表示将YOLOX模型加载到device上
return yolox_model
# 总体来说,create_yolox_model()函数的作用是:根据name参数的值,创建YOLOX模型
# 先加载exp文件,然后根据exp文件中的参数,构建YOLOX模型
# 再根据ckpt_path参数的值,加载预训练模型,因为ckpt_path参数的值是预训练模型的地址
# -> nn.Module表示该函数返回一个nn.Module类的实例
def yolox_nano(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
def yolox_tiny(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
def yolox_s(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
return create_yolox_model("yolox-s", pretrained, num_classes, device)
def yolox_m(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
return create_yolox_model("yolox-m", pretrained, num_classes, device)
def yolox_l(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
return create_yolox_model("yolox-l", pretrained, num_classes, device)
def yolox_x(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
return create_yolox_model("yolox-x", pretrained, num_classes, device)
def yolov3(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
return create_yolox_model("yolov3", pretrained, num_classes, device)
# ckpt_path: str = None 表示ckpt_path是一个字符串,且默认值为None
def yolox_custom(ckpt_path: str = None, exp_path: str = None, device: str = None) -> nn.Module:
return create_yolox_model("yolox_custom", ckpt_path=ckpt_path, exp_path=exp_path, device=device)
YOLOX-build.py create_yolox_model yolox_nano
来自
标签:
发表回复