NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - AI模型在线查看 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 - 3D道路快速建模
在这篇博文中,我将展示如何制作 Doodle Dash,这是一款完全在浏览器中运行的实时 ML 驱动的网页游戏(感谢 Transformers.js)。 本教程的目标是向你展示制作自己的 ML 支持的网页游戏是多么容易!
在开始之前,让我们先讨论一下我们将要创建的内容。 该游戏的灵感来自于 Google 的 Quick, Draw! 游戏中,你会得到一个单词,神经网络有 20 秒的时间来猜测你在画什么(重复 6 次)。 事实上,我们将使用他们的训练数据来训练我们自己的草图检测模型! 你不是喜欢开源吗? 😍
在我们的版本中,你将有一分钟的时间绘制尽可能多的项目,一次一个提示。 如果模型预测正确的标签,画布将被清除,并且你将获得一个新单词。 继续这样做,直到计时器用完! 由于游戏在你的浏览器本地运行,因此我们根本不必担心服务器延迟。 该模型能够在你绘画时进行实时预测,每秒可预测超过 60 个......🤯 哇!
1、训练神经网络
本部分介绍游戏使用的神经网络模型的微调。
1.1 训练数据
我们将使用 Google Quick, Draw! 的子集来训练我们的模型。 数据集,包含 345 个类别的超过 500 万张图画。 以下是数据集中的一些示例:
1.2 模型架构
我们将微调 apple/mobilevit-small,这是一个轻量级且适合移动设备的 Vision Transformer,已在 ImageNet-1k 上进行了预训练。 它只有 5.6M 参数(~20 MB 文件大小),非常适合在浏览器中运行! 有关更多信息,请查看 MobileViT 论文和下面的模型架构。
1.3 微调
为了使博客文章(相对)简短,我们准备了一个 Colab 笔记本,它将向你展示我们在数据集上微调 apple/mobilevit-small 所采取的确切步骤。 在较高层面上,这涉及:
- 加载 Quick Draw! 数据集。
- 使用 MobileViTImageProcessor 转换数据集。
- 定义我们的整理函数和评估指标。
- 使用 MobileViTForImageClassification.from_pretrained 加载预训练的 MobileVIT 模型。
- 使用 Trainer 和 TrainingArguments 帮助程序类训练模型。
- 使用 🤗 Evaluate 评估模型。
注意:你可以在 Hugging Face Hub 上找到我们经过微调的模型。
2、在浏览器中执行推理
Transformers.js 是一个 JavaScript 库,可让你直接在浏览器中运行 🤗 Transformers(无需服务器)! 它的设计在功能上与 Python 库相同,这意味着你可以使用非常相似的 API 运行相同的预训练模型。
在幕后,Transformers.js 使用 ONNX 运行时,因此我们需要将微调后的 PyTorch 模型转换为 ONNX。
2.1 将模型转换为 ONNX
幸运的是,🤗 Optimum 库使将微调模型转换为 ONNX 变得超级简单! 最简单(也是推荐的方法)是:
首先克隆 Transformers.js 存储库并安装必要的依赖项:
git clone https://github.com/xenova/transformers.js.git
cd transformers.js
pip install -r scripts/requirements.txt
运行转换脚本(它在底层使用 Optimum):
python -m scripts.convert --model_id <model_id>
其中 <model_id>
是要转换的模型的名称(例如 Xenova/quickdraw-mobilevit-small
)。
2.2 设置我们的项目
让我们首先使用 Vite 搭建一个简单的 React 应用程序:
npm create vite@latest doodle-dash -- --template react
接下来,进入项目目录并安装必要的依赖项:
cd doodle-dash
npm install
npm install @xenova/transformers
然后,你可以通过运行以下命令来启动开发服务器:
npm run dev
2.3 在浏览器中运行模型
运行机器学习模型需要大量计算,因此在单独的线程中执行推理非常重要。 这样我们就不会阻塞主线程,该线程用于渲染 UI 并对您的绘图手势做出反应。 Web Workers API 使这一切变得超级简单!
在 src 目录中创建一个新文件(例如,worker.js)并添加以下代码:
import { pipeline, RawImage } from "@xenova/transformers";
const classifier = await pipeline("image-classification", 'Xenova/quickdraw-mobilevit-small', { quantized: false });
const image = await RawImage.read('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png');
const output = await classifier(image.grayscale());
console.log(output);
现在,我们可以通过将以下代码添加到 App 组件来在 App.jsx 文件中使用此worker程序:
import { useState, useEffect, useRef } from 'react'
// ... rest of the imports
function App() {
// Create a reference to the worker object.
const worker = useRef(null);
// We use the `useEffect` hook to set up the worker as soon as the `App` component is mounted.
useEffect(() => {
if (!worker.current) {
// Create the worker if it does not yet exist.
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
type: 'module'
});
}
// Create a callback function for messages from the worker thread.
const onMessageReceived = (e) => { /* See code */ };
// Attach the callback function as an event listener.
worker.current.addEventListener('message', onMessageReceived);
// Define a cleanup function for when the component is unmounted.
return () => worker.current.removeEventListener('message', onMessageReceived);
});
// ... rest of the component
}
你可以通过运行开发服务器(使用 npm run dev
)、访问本地网站(通常为 http://localhost:5173/
)并打开浏览器控制台来测试一切是否正常。 应该可以看到模型的输出被记录到控制台。
[{ label: "skateboard", score: 0.9980043172836304 }]
WOW! 🥳 虽然上面的代码只是最终产品的一小部分,但它显示了它的机器学习方面是多么简单! 剩下的只是让它看起来漂亮并添加一些游戏逻辑。
3、游戏设计
在本节中,我将简要讨论游戏设计过程。 提醒一下,你可以在 GitHub 上找到该项目的完整源代码,因此我不会详细介绍代码本身。
3.1 利用实时性能
执行浏览器内推理的主要优点之一是我们可以实时进行预测(每秒超过 60 次)。 在原版《快速画画!》中 游戏中,模型每隔几秒才会做出新的预测。 我们可以在游戏中做同样的事情,但这样我们就无法利用它的实时性能! 所以,我决定重新设计主游戏循环:
- 我们的版本不是六轮 20 秒的回合(每一轮对应一个新单词),而是要求玩家在 60 秒内正确绘制尽可能多的涂鸦(一次一个提示)。
- 如果遇到无法画出的单词,你可以跳过它(但这会花费 3 秒的剩余时间)。
- 在最初的游戏中,由于模型每隔几秒就会进行一次猜测,因此它可以慢慢地从列表中划掉标签,直到最终猜测正确。 在我们的版本中,我们会降低前 n 个不正确标签的模型分数,随着用户继续绘图,n 会随着时间的推移而增加。
3.2 质量改善
原始数据集包含 345 个不同的类,由于我们的模型相对较小(约 20MB),因此有时无法正确猜测某些类。 为了解决这个问题,我们删除了一些单词:
- 与其他标签太相似(例如“谷仓”与“房屋”)
- 太难理解(例如“动物迁徙”)
- 太难绘制足够的细节(例如“大脑”)
- 不明确(例如“蝙蝠”)
经过筛选后,我们仍然剩下 300 多个不同的类别!
4、彩蛋:起个好名字
本着开源开发的精神,我决定向 Hugging Chat 询问一些游戏名称的想法……不用说,它并没有让人失望!
我喜欢“Doodle Dash”的头韵(建议#4),所以我决定采用它!
原文链接:Making ML-powered web games with Transformers.js
BimAnt翻译整理,转载请标明出处