博客

  • 快速入门指南:使用timm进行模型训练的开端

    标题:快速入门指南:使用timm进行模型训练的开端

    导读:
    本文旨在帮助开发人员快速了解如何将timm集成到模型训练流程中。首先,您需要安装timm。接下来,我们将通过示例代码演示如何加载预训练模型、列出具有预训练权重的模型、微调预训练模型、以及如何使用预训练模型进行特征提取、图像增强和推理。让我们一起开始这个令人兴奋的旅程吧!

    正文:
    快速入门
    本快速入门指南旨在帮助开发人员快速了解如何将timm集成到他们的模型训练流程中。在开始之前,您需要先安装timm。有关安装的详细信息,请参阅安装指南。

    加载预训练模型
    通过create_model()函数可以加载预训练模型。下面的示例展示了如何加载预训练的mobilenetv3_large_100模型。

    import timm
    
    model = timm.create_model('mobilenetv3_large_100', pretrained=True)
    model.eval()

    需要注意的是,默认情况下返回的PyTorch模型处于训练模式,如果要进行推理,需要调用.eval()方法将其设置为评估模式。

    列出具有预训练权重的模型
    要列出timm中打包的具有预训练权重的模型,可以使用list_models()函数。如果指定pretrained=True,该函数将只返回具有预训练权重的模型名称。

    import timm
    from pprint import pprint
    
    model_names = timm.list_models(pretrained=True)
    pprint(model_names)

    您还可以使用特定模式来列出名称中包含特定字符串的模型。

    微调预训练模型
    要微调任何预训练模型,只需更改分类器(即最后一层)即可。

    model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_finetune_classes)

    如果要在自己的数据集上进行微调,请编写一个PyTorch训练循环或调整timm的训练脚本以使用自己的数据集。

    使用预训练模型进行特征提取
    在不修改网络结构的情况下,可以使用model.forward_features(input)来替代通常的model(input)方法,对任何模型进行特征提取。这将跳过头部分类器和全局池化操作。有关更详细的使用timm进行特征提取的指南,请参阅特征提取部分。

    图像增强
    为了将图像转换为模型接受的有效输入,可以使用timm.data.create_transform()函数,并提供模型期望的输入尺寸。这将返回一个通用的变换对象,其中包含了一些合理的默认设置。

    timm.data.create_transform((3, 224, 224))

    预训练模型在训练时应用了特定的数据转换。如果您在图像上使用错误的转换,模型将无法理解所见的图像!要了解给定预训练模型使用了哪些转换,可以查看其预训练配置。

    model.pretrained_cfg

    您可以使用timm.data.resolve_data_config()函数来解析与数据相关的配置。

    data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
    transform = timm.data.create_transform(**data_cfg)

    使用预训练模型进行推理
    下面,我们将结合前面的内容,使用预训练模型进行推理。首先,我们需要一张图片作为输入。我们从网络上加载一张叶子标题:快速入门指南:timm模型训练的开端

    导读:
    本文将带您快速了解如何将timm集成到模型训练流程中。从安装timm开始,到加载预训练模型、列出可用的预训练模型、微调模型、使用模型进行特征提取、图像增强和推理等,我们将逐步介绍这些内容。让我们开始这个令人兴奋的旅程吧!

    正文:
    快速入门
    本快速入门指南旨在帮助开发人员快速了解如何将timm集成到模型训练流程中。在开始之前,您需要先安装timm,具体安装方法请参考官方的安装指南。

    加载预训练模型
    要加载预训练模型,可以使用create_model()函数。以下示例展示了如何加载预训练的mobilenetv3_large_100模型。

    import timm
    
    model = timm.create_model('mobilenetv3_large_100', pretrained=True)
    model.eval()

    需要注意的是,默认情况下返回的PyTorch模型处于训练模式,如果要进行推理,需要调用.eval()方法将其设置为评估模式。

    列出具有预训练权重的模型
    要列出timm中已打包的具有预训练权重的模型,可以使用list_models()函数。如果指定pretrained=True,该函数将只返回具有预训练权重的模型名称。

    import timm
    from pprint import pprint
    
    model_names = timm.list_models(pretrained=True)
    pprint(model_names)

    您还可以使用特定模式来列出名称中包含特定字符串的模型。

    微调预训练模型
    要微调任何预训练模型,只需更改分类器(即最后一层)即可。

    model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=num_finetune_classes)

    如果要在自己的数据集上进行微调,请编写一个PyTorch训练循环或调整timm的训练脚本以使用自己的数据集。

    使用预训练模型进行特征提取
    在不修改网络结构的情况下,可以使用model.forward_features(input)来替代通常的model(input)方法,对任何模型进行特征提取。这将跳过头部分类器和全局池化操作。有关更详细的使用timm进行特征提取的指南,请参考官方文档中的特征提取部分。

    图像增强
    要将图像转换为模型可接受的有效输入,可以使用timm.data.create_transform()函数,并提供模型所期望的输入尺寸。这将返回一个通用的转换对象,其中包含了一些合理的默认设置。

    timm.data.create_transform((3, 224, 224))

    需要注意的是,预训练模型在训练时应用了特定的数据转换。如果您在图像上使用错误的转换,模型将无法理解所见的图像!要了解给定预训练模型使用了哪些转换,可以查看其预训练配置。

    model.pretrained_cfg

    您可以使用timm.data.resolve_data_config()函数来解析与数据相关的配置。

    data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
    transform = timm.data.create_transform(**data_cfg)

    使用预训练模型进行推理
    下面,我们将结合前面的内容,使用预训练模型进行推理。首先,我们需要准备一张图片作为输入。我们从网络上加载了一张叶子的图片。


    import requests
    from PIL import Image
    from io import BytesIO
    
    url = 'https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg'
    image = Image.open(requests.get(url, stream=True).raw)
    image

    接下来,我们再次创建模型和转换。这次,我们确保将模型设置为评估模式。

    model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
    transform = timm.data.create_transform(
        **timm.data.resolve_data_config(model.pretrained_cfg)
    )

    我们可以将图片通过转换函数进行处理,以便与模型进行推理。

    image_tensor = transform(image)
    image_tensor.shape

    现在,我们可以将处理后的图片传递给模型进行预测。这里我们使用unsqueeze(0)将其添加一个批次维度,因为模型需要一个批次的输入。

    output = model(image_tensor.unsqueeze(0))
    output.shape

    为了得到预测的概率,我们对输出应用softmax函数。这会得到一个形状为(num_classes,)的张量。

    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    probabilities.shape

    接下来,我们使用torch.topk找到前5个预测类别的索引和概率值。

    values, indices = torch.topk(probabilities, 5)
    indices

    通过检查对应于最高概率的索引值,我们可以查看模型的预测结果。

    IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
    IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
    [{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]

    至此,我们完成了使用预训练模型进行推理的过程。我们加载了一张图片,将其通过转换函数处理后输入模型,并得到了模型的预测结果。通过这个示例,我们展示了如何使用timm进行快速入门。

    结语:
    本文通过一个快速入门的示例,介绍了如何使用timm进行模型训练。我们展示了加载预训练模型、列出具有预训练权重的模型、微调预训练模型、使用预训练模型进行特征提取、图像增强和推理的方法。希望本文对您了解和使用timm有所帮助,让您能够更轻松地进行模型训练和推理!

  • ChatGLM3:大模型工具调用实现原理揭秘

    ChatGLM3终于迎来了与ChatGPT相似的工具调用能力。通过研究其源码和样本数据,我们可以了解到让大型语言模型学会使用工具的方法原理。本文将深入探讨ChatGLM3的工具调用实现原理,带您揭开这个令人兴奋的功能背后的奥秘。


    ChatGLM3是一种大型语言模型,具备了与ChatGPT相似的工具调用功能。通过这一功能,模型可以调用特定的工具来执行特定的任务。在本文中,我们将深入研究ChatGLM3的工具调用实现原理,揭示其背后的工作原理。

    首先,让我们来看一个官方例子,以帮助我们理解工具调用的过程。在这个例子中,我们准备了两个工具调用的描述信息:追踪指定股票的实时价格和将文本转换为语音。这些工具都具有特定的参数要求,比如需要提供股票代码或要转换的文本等。

    在进行工具调用之前,我们需要构建一个系统提示(System Prompt),告诉模型我们可以使用哪些工具。然后,我们将用户的查询传递给模型,模型将根据上下文和系统提示中提供的工具信息,确定需要调用的工具和相应的参数。模型的输出将告诉我们应该使用哪个工具以及传递给工具的参数。

    在实际的工具调用中,我们需要实现调用工具的逻辑。例如,在追踪股票价格的例子中,我们可以定义一个名为”track”的工具,并提供股票代码作为参数。模型将根据返回的结果生成回复。对于复杂的问题,模型可能需要进行多次工具调用。

    在ChatGLM3中,工具调用的原理是通过特殊处理模型输出来实现的。输出结果中的字典对象表示模型需要调用特定的工具,并传递相应的参数。这样的结果是由一个特殊的处理过程生成的,确保模型可以理解和使用。

    在工具调用的样本数据中,我们可以看到模型预测的输出中包含了工具调用的相关信息,如函数名称和代码。这些信息被解析和处理,最终生成用于调用工具的参数。这种设计的目的是将模型的预测结果转化为可执行的工具调用。

    在ChatGLM3中,工具调用只支持通过chat方法进行,不支持stream_chat方法。这是因为在stream_chat方法中,无法对模型输出进行相应的处理和转换。

    通过深入研究ChatGLM3的源码和样本数据,我们揭示了其工具调用的实现原理。我们了解到,工具调用是通过特殊处理模型输出和相关信息来实现的。这种功能为模型的应用提供了更大的灵活性和扩展性,使模型可以执行更多的任务。


    ChatGLM3的工具调用功能为大型语言模型带来了更多的实用性和功能扩展性。通过深入了解其实现原理,我们更好地理解了模型如何使用工具,并可以在实践中灵活应用。期待着ChatGLM3在未来的发展中带来更多令人惊喜的功能和应用!


人生梦想 - 关注前沿的计算机技术 acejoy.com 🐾 步子哥の博客 🐾 背多分论坛 🐾 借一步网 沪ICP备2024052574号-1