From 34ce2bbc87a9f7c8b4d86eccb7912728fc6a6cb9 Mon Sep 17 00:00:00 2001 From: LIU42 <3528865430@qq.com> Date: Sun, 29 Sep 2024 15:28:38 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=9D=83=E9=87=8D=E5=AD=98?= =?UTF-8?q?=E5=82=A8=E7=9B=AE=E5=BD=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 ++-- README.md | 4 +++- classify/predict.py | 2 +- detect/predict.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 7f8a2f6..9f7811b 100644 --- a/.gitignore +++ b/.gitignore @@ -165,5 +165,5 @@ examples/outputs/* */datasets/* */runs/* -*/weights/develop/* -*/weights/deploy/* +*/weights/* +*/checkpoints/* diff --git a/README.md b/README.md index 16c2ec2..7e215ca 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ pip install -r requirements.txt pip install onnxruntime-gpu ``` -待识别图像默认在 examples/sources/ 下 ,识别结果默认保存在 examples/outputs/ 下,如果不存在请先创建。将所有待识别的图像放入待识别图像目录下,要求图像尺寸为 640x480,可以在本项目 Releases 中下载我训练好的模型权重文件,解压到项目中相应的位置(位于 detect/weights/deploy/classify/weights/deploy/),运行 main.py 即可。 +待识别图像默认在 examples/sources/ 下 ,识别结果默认保存在 examples/outputs/ 下,如果不存在请先创建。将所有待识别的图像放入待识别图像目录下,要求图像尺寸为 640x480,可以在本项目 Releases 中下载我训练好的模型权重文件,解压到项目中相应的位置(位于 detect/weights/classify/weights/),运行 main.py 即可。 ```bash python main.py @@ -79,3 +79,5 @@ strategy: "conservative" # 通行规则识别策略,“conservative”( ```bash pip install ultralytics ``` + + diff --git a/classify/predict.py b/classify/predict.py index 5e4164d..7c17e83 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -12,7 +12,7 @@ def __init__(self, configs): providers = ['CPUExecutionProvider'] self.configs = configs - self.session = ort.InferenceSession(f'classify/weights/deploy/classify-{self.precision}.onnx', providers=providers) + self.session = ort.InferenceSession(f'classify/weights/classify-{self.precision}.onnx', providers=providers) def __call__(self, image, signals): for signal in signals: diff --git a/detect/predict.py b/detect/predict.py index de019ab..8abb7a6 100644 --- a/detect/predict.py +++ b/detect/predict.py @@ -12,7 +12,7 @@ def __init__(self, configs): providers = ['CPUExecutionProvider'] self.configs = configs - self.session = ort.InferenceSession(f'detect/weights/deploy/detect-{self.precision}.onnx', providers=providers) + self.session = ort.InferenceSession(f'detect/weights/detect-{self.precision}.onnx', providers=providers) def __call__(self, image): inputs = process.preprocess(image, size=640, padding_color=127, precision=self.precision)