Skip to content

Commit

Permalink
ml_baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamer121121 committed Jan 7, 2019
0 parents commit 6d3de29
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 0 deletions.
Empty file added db.sqlite3
Empty file.
15 changes: 15 additions & 0 deletions manage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python
import os
import sys

if __name__ == "__main__":
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "server.settings")
try:
from django.core.management import execute_from_command_line
except ImportError as exc:
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?"
) from exc
execute_from_command_line(sys.argv)
Empty file added server/__init__.py
Empty file.
120 changes: 120 additions & 0 deletions server/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Django settings for server project.
Generated by 'django-admin startproject' using Django 2.0.5.
For more information on this file, see
https://docs.djangoproject.com/en/2.0/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/2.0/ref/settings/
"""

import os

# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/

# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'q#i^3*pua-g9r^_7w0hs(dy25+x^ijsgau4%i$k*$3=t9x(#1a'

# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True

ALLOWED_HOSTS = []


# Application definition

INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
]

MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

ROOT_URLCONF = 'server.urls'

TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]

WSGI_APPLICATION = 'server.wsgi.application'


# Database
# https://docs.djangoproject.com/en/2.0/ref/settings/#databases

DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
}
}


# Password validation
# https://docs.djangoproject.com/en/2.0/ref/settings/#auth-password-validators

AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]


# Internationalization
# https://docs.djangoproject.com/en/2.0/topics/i18n/

LANGUAGE_CODE = 'en-us'

TIME_ZONE = 'UTC'

USE_I18N = True

USE_L10N = True

USE_TZ = True


# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/2.0/howto/static-files/

STATIC_URL = '/static/'
23 changes: 23 additions & 0 deletions server/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""server URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/2.0/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.conf.urls import include,url
from . import views
urlpatterns = [
url('admin/', admin.site.urls),
url('^api/data/Perter$', views.Perter)

]
117 changes: 117 additions & 0 deletions server/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from django.http import JsonResponse


def Perter(request):
import argparse
import scipy.io
import torch
import numpy as np
import os
from torchvision import datasets
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
#######################################################################
# Evaluate
parser = argparse.ArgumentParser(description='Demo')
parser.add_argument('--query_index', default=777, type=int, help='test_image_index')
parser.add_argument('--test_dir', default='../Market/pytorc', type=str, help='./test_data')
opts = parser.parse_args()

data_dir = opts.test_dir
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x)) for x in ['gallery', 'query']}

#####################################################################
# Show result
def imshow(path, title=None):
"""Imshow for Tensor."""
im = plt.imread(path)
plt.imshow(im)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated

######################################################################
result = scipy.io.loadmat('pytorch_result.mat')
query_feature = torch.FloatTensor(result['query_f'])
query_cam = result['query_cam'][0]
query_label = result['query_label'][0]
gallery_feature = torch.FloatTensor(result['gallery_f'])
gallery_cam = result['gallery_cam'][0]
gallery_label = result['gallery_label'][0]

multi = os.path.isfile('multi_query.mat')

if multi:
m_result = scipy.io.loadmat('multi_query.mat')
mquery_feature = torch.FloatTensor(m_result['mquery_f'])
mquery_cam = m_result['mquery_cam'][0]
mquery_label = m_result['mquery_label'][0]
mquery_feature = mquery_feature.cuda()

query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()

#######################################################################
# sort the images
def sort_img(qf, ql, qc, gf, gl, gc):
query = qf.view(-1, 1)
# print(query.shape)
score = torch.mm(gf, query)
score = score.squeeze(1).cpu()
score = score.numpy()
# predict index
index = np.argsort(score) # from small to large
index = index[::-1]
# index = index[0:2000]
# good index
query_index = np.argwhere(gl == ql)
# same camera
camera_index = np.argwhere(gc == qc)

# good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index1 = np.argwhere(gl == -1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1)

mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
return index

i = opts.query_index
index = sort_img(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label, gallery_cam)

########################################################################
# Visualize the rank result

query_path, _ = image_datasets['query'].imgs[i]
query_label = query_label[i]
print(query_path)
print('Top 10 images are as follow:')
try: # Visualize Ranking Result
# Graphical User Interface is needed
fig = plt.figure(figsize=(16, 4))
ax = plt.subplot(1, 11, 1)
ax.axis('off')
imshow(query_path, 'query')
for i in range(10):
ax = plt.subplot(1, 11, i + 2)
ax.axis('off')
img_path, _ = image_datasets['gallery'].imgs[index[i]]
label = gallery_label[index[i]]
imshow(img_path)
if label == query_label:
ax.set_title('%d' % (i + 1), color='green')
else:
ax.set_title('%d' % (i + 1), color='red')
print(img_path)
except RuntimeError:
for i in range(10):
img_path = image_datasets.imgs[index[i]]
print(img_path[0])
print('If you want to see the visualization of the ranking result, graphical user interface is needed.')

fig.savefig("show.png")

print("1")
return JsonResponse({'result': 'success'})
16 changes: 16 additions & 0 deletions server/wsgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
WSGI config for server project.
It exposes the WSGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/2.0/howto/deployment/wsgi/
"""

import os

from django.core.wsgi import get_wsgi_application

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "server.settings")

application = get_wsgi_application()

0 comments on commit 6d3de29

Please sign in to comment.