From 34ce2bbc87a9f7c8b4d86eccb7912728fc6a6cb9 Mon Sep 17 00:00:00 2001
From: LIU42 <3528865430@qq.com>
Date: Sun, 29 Sep 2024 15:28:38 +0800
Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=9D=83=E9=87=8D=E5=AD=98?=
=?UTF-8?q?=E5=82=A8=E7=9B=AE=E5=BD=95=E3=80=82?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.gitignore | 4 ++--
README.md | 4 +++-
classify/predict.py | 2 +-
detect/predict.py | 2 +-
4 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/.gitignore b/.gitignore
index 7f8a2f6..9f7811b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -165,5 +165,5 @@ examples/outputs/*
*/datasets/*
*/runs/*
-*/weights/develop/*
-*/weights/deploy/*
+*/weights/*
+*/checkpoints/*
diff --git a/README.md b/README.md
index 16c2ec2..7e215ca 100644
--- a/README.md
+++ b/README.md
@@ -51,7 +51,7 @@ pip install -r requirements.txt
pip install onnxruntime-gpu
```
-待识别图像默认在 examples/sources/ 下 ,识别结果默认保存在 examples/outputs/ 下,如果不存在请先创建。将所有待识别的图像放入待识别图像目录下,要求图像尺寸为 640x480,可以在本项目 Releases 中下载我训练好的模型权重文件,解压到项目中相应的位置(位于 detect/weights/deploy/ 和 classify/weights/deploy/),运行 main.py 即可。
+待识别图像默认在 examples/sources/ 下 ,识别结果默认保存在 examples/outputs/ 下,如果不存在请先创建。将所有待识别的图像放入待识别图像目录下,要求图像尺寸为 640x480,可以在本项目 Releases 中下载我训练好的模型权重文件,解压到项目中相应的位置(位于 detect/weights/ 和 classify/weights/),运行 main.py 即可。
```bash
python main.py
@@ -79,3 +79,5 @@ strategy: "conservative" # 通行规则识别策略,“conservative”(
```bash
pip install ultralytics
```
+
+
diff --git a/classify/predict.py b/classify/predict.py
index 5e4164d..7c17e83 100644
--- a/classify/predict.py
+++ b/classify/predict.py
@@ -12,7 +12,7 @@ def __init__(self, configs):
providers = ['CPUExecutionProvider']
self.configs = configs
- self.session = ort.InferenceSession(f'classify/weights/deploy/classify-{self.precision}.onnx', providers=providers)
+ self.session = ort.InferenceSession(f'classify/weights/classify-{self.precision}.onnx', providers=providers)
def __call__(self, image, signals):
for signal in signals:
diff --git a/detect/predict.py b/detect/predict.py
index de019ab..8abb7a6 100644
--- a/detect/predict.py
+++ b/detect/predict.py
@@ -12,7 +12,7 @@ def __init__(self, configs):
providers = ['CPUExecutionProvider']
self.configs = configs
- self.session = ort.InferenceSession(f'detect/weights/deploy/detect-{self.precision}.onnx', providers=providers)
+ self.session = ort.InferenceSession(f'detect/weights/detect-{self.precision}.onnx', providers=providers)
def __call__(self, image):
inputs = process.preprocess(image, size=640, padding_color=127, precision=self.precision)