-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-tfjs.html
162 lines (137 loc) · 5.74 KB
/
test-tfjs.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>TensorFlow.js 推論とバウンディングボックスの描画</title>
<style>
canvas {
border: 1px solid black;
}
</style>
</head>
<body>
<!-- ローカルファイルを選択するためのファイル入力 -->
<input type="file" id="file-input" accept="image/png">
<!-- 選択した画像を表示するための img 要素 -->
<img id="selected-image" src="" alt="Selected Image" style="display:none; max-width: 100%; height: auto;">
<!-- 推論結果を描画するための canvas -->
<canvas id="output-canvas"></canvas>
<!-- 推論を実行するボタン -->
<button id="run-inference">Run Inference</button>
<!-- 推論時間を表示するための要素 -->
<div id="inference-time" style="margin-top: 10px;"></div>
<!-- TensorFlow.js を CDN から読み込む -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script>
const MODEL_INPUT_WIDTH = 640; // モデルの入力幅
const MODEL_INPUT_HEIGHT = 480; // モデルの入力高
const classColors = [
'#FF6347', '#4682B4', '#32CD32', '#FFD700', '#8A2BE2', '#FF4500',
'#1E90FF', '#3CB371', '#DAA520', '#9400D3', '#FF1493', '#00CED1',
'#ADFF2F', '#FFDAB9', '#B22222'
];
// 画像をロードして表示する関数
function loadImage(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
const imgElement = document.getElementById('selected-image');
imgElement.src = e.target.result;
imgElement.style.display = 'block';
imgElement.onload = () => resolve(imgElement);
};
reader.onerror = reject;
reader.readAsDataURL(file);
});
}
// TensorFlow.js モデルのロード
async function loadModel() {
try {
const model = await tf.loadGraphModel('tfjs_model/model.json'); // 適切なパスに変更
return model;
} catch (err) {
console.error('モデルのロードに失敗しました:', err);
alert('モデルのロードに失敗しました。');
}
}
// 画像を TensorFlow.js 用に前処理する関数
function preprocessImage(imageElement) {
let tensor = tf.browser.fromPixels(imageElement).expandDims(0).toFloat();
// RGB を BGR に変換
tensor = tf.reverse(tensor, axis=[-1]); // 最後の次元 (チャネル) を反転
return tensor;
}
// 推論の実行と時間計測
async function runInference(model, inputTensor) {
try {
const startTime = Date.now();
const output = await model.executeAsync(inputTensor);
const endTime = Date.now();
const inferenceTime = endTime - startTime;
document.getElementById('inference-time').textContent = `推論時間: ${inferenceTime} ミリ秒`;
console.info(output);
return output;
} catch (err) {
console.error('推論に失敗しました:', err);
alert('推論に失敗しました。');
}
}
// 推論結果に基づいてバウンディングボックスを描画
function renderBoundingBoxes(output, imageElement) {
const canvas = document.getElementById('output-canvas');
const width = imageElement.width;
const height = imageElement.height;
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(imageElement, 0, 0, canvas.width, canvas.height);
const boxesData = output.arraySync(); // テンソルのデータを取得
const threshold = 0.5;
output.dispose()
for (let i = 0; i < boxesData.length; i++) {
const [batchno, classId, score, x1, y1, x2, y2] = boxesData[i];
// クラスIDが12または13の場合はスキップ
if (classId === 12 || classId === 13) {
continue;
}
if (score > threshold) {
// 画像のスケールに基づいてバウンディングボックスの座標を計算
const scaleX = canvas.width / MODEL_INPUT_WIDTH;
const scaleY = canvas.height / MODEL_INPUT_HEIGHT;
const boxColor = classColors[classId % classColors.length];
// バウンディングボックスを描画
ctx.strokeStyle = boxColor;
ctx.lineWidth = 2;
ctx.strokeRect(x1 * scaleX, y1 * scaleY, (x2 - x1) * scaleX, (y2 - y1) * scaleY);
// クラスIDとスコアを表示
ctx.font = '16px Arial';
ctx.fillStyle = boxColor;
ctx.fillText(`Class: ${classId} Score: ${score.toFixed(2)}`, x1 * scaleX, y1 * scaleY - 5);
}
}
}
// ファイルが選択されたときに発生するイベント
document.getElementById('file-input').addEventListener('change', async (event) => {
const file = event.target.files[0];
if (file) {
const imageElement = await loadImage(file);
const canvas = document.getElementById('output-canvas');
canvas.width = imageElement.width;
canvas.height = imageElement.height;
}
});
// 推論を実行するボタンに対するイベント
document.getElementById('run-inference').addEventListener('click', async () => {
const imageElement = document.getElementById('selected-image');
if (imageElement.src) {
const inputTensor = preprocessImage(imageElement);
const model = await loadModel();
const output = await runInference(model, inputTensor);
renderBoundingBoxes(output, imageElement);
}
});
</script>
</body>
</html>