Skip to content

Commit

Permalink
fix: 修改模型问题 #7626 (#7633)
Browse files Browse the repository at this point in the history
* fix: 修改模型问题 #7626

* fix: 修改模型方法问题 #7626

* fix: 修复语法错误 #7626
  • Loading branch information
guohelu authored Dec 13, 2024
1 parent b64a70b commit 0e19d45
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 47 deletions.
4 changes: 2 additions & 2 deletions gcloud/contrib/template_market/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

@admin.register(models.TemplateSharedRecord)
class TemplateSharedRecordAdmin(admin.ModelAdmin):
list_display = ["project_id", "template_id", "creator", "create_at", "update_at", "extra_info"]
list_filter = ["project_id", "template_id", "creator", "create_at", "update_at"]
list_display = ["project_id", "template_id", "creator", "extra_info"]
list_filter = ["project_id", "template_id", "creator"]
search_fields = ["project_id", "creator"]
18 changes: 14 additions & 4 deletions gcloud/contrib/template_market/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,32 @@ def __init__(self):
def _get_url(self, endpoint):
return f"{self.base_url}{endpoint}"

def get_shared_detail(self, market_record_id):
def get_service_category(self):
url = self._get_url("/category/get_service_category/")
response = requests.get(url)
return response.json()

def get_scene_label(self):
url = self._get_url("/sre_property/scene_label/")
response = requests.get(url)
return response.json()

def get_template_scene_detail(self, market_record_id):
url = self._get_url(f"/sre_scene/flow_template_scene/{market_record_id}/")
response = requests.get(url)
return response.json()

def get_shared_list(self):
def get_template_scene_list(self):
url = self._get_url("/sre_scene/flow_template_scene/?is_all=true")
response = requests.get(url)
return response.json()

def create_shared_record(self, data):
def create_template_scene(self, data):
url = self._get_url("/sre_scene/flow_template_scene/")
response = requests.post(url, json=data)
return response.json()

def patch_shared_record(self, data, market_record_id):
def patch_template_scene(self, data, market_record_id):
url = self._get_url(f"/sre_scene/flow_template_scene/{market_record_id}/")
response = requests.patch(url, json=data)
return response.json()
6 changes: 2 additions & 4 deletions gcloud/contrib/template_market/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 3.2.15 on 2024-12-13 07:48
# Generated by Django 3.2.15 on 2024-12-13 10:48

from django.db import migrations, models

Expand All @@ -15,10 +15,8 @@ class Migration(migrations.Migration):
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("project_id", models.IntegerField(default=-1, help_text="项目 ID", verbose_name="项目 ID")),
("template_id", models.IntegerField(db_index=True, help_text="模板 ID", verbose_name="模板 ID ")),
("template_id", models.IntegerField(db_index=True, help_text="模板 ID", verbose_name="模板 ID")),
("creator", models.CharField(default="", max_length=32, verbose_name="创建者")),
("create_at", models.DateTimeField(auto_now_add=True, verbose_name="创建时间")),
("update_at", models.DateTimeField(auto_now=True, verbose_name="更新时间")),
("extra_info", models.JSONField(blank=True, null=True, verbose_name="额外信息")),
],
options={
Expand Down
77 changes: 46 additions & 31 deletions gcloud/contrib/template_market/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,66 @@

from gcloud import err_code

TEMPLATE_SHARED_RECORD_BATCH_OPERATION_COUNT = 50


class TemplateSharedManager(models.Manager):
def update_shared_record(self, new_template_ids, market_record_id, project_id, creator, existing_template_ids=None):
def update_shared_record(
self, new_template_ids, market_record_id, project_id, creator, existing_market_template_ids=None
):
market_record_id = int(market_record_id)

if existing_template_ids:
templates_to_remove = existing_template_ids - set(new_template_ids)
if templates_to_remove:
for template_id in templates_to_remove:
current_template_record = TemplateSharedRecord.objects.get(template_id=template_id)
current_market_ids = current_template_record.extra_info.get("market_record_ids", [])
if market_record_id in current_market_ids:
current_market_ids.remove(market_record_id)
current_template_record.extra_info["market_record_ids"] = current_market_ids
current_template_record.save()
if not current_template_record.extra_info["market_record_ids"]:
current_template_record.delete()
else:
return {
"result": False,
"message": "template {} is not in record {}".format(template_id, market_record_id),
"code": err_code.REQUEST_PARAM_INVALID.code,
}

templates_to_add = set(new_template_ids) - existing_template_ids
if existing_market_template_ids:
ids_to_delete = []
templates_to_remove = existing_market_template_ids - set(new_template_ids)
current_template_records = TemplateSharedRecord.objects.filter(
project_id=project_id, template_id__in=templates_to_remove
)
for current_template_record in current_template_records:
market_record_ids = current_template_record.extra_info.get("market_record_ids")
market_record_ids.remove(market_record_id)
if not market_record_ids:
ids_to_delete.append(current_template_record.id)
current_template_record.save()

if ids_to_delete:
TemplateSharedRecord.objects.filter(id__in=ids_to_delete).delete()

templates_to_add = set(new_template_ids) - existing_market_template_ids
if templates_to_add:
new_template_ids = list(templates_to_add)

new_records = []
records_to_update = []
existing_records = TemplateSharedRecord.objects.filter(project_id=project_id, template_id__in=new_template_ids)

existing_template_ids = {record.template_id: record for record in existing_records}

for template_id in new_template_ids:
existing_record, created = TemplateSharedRecord.objects.get_or_create(
project_id=project_id,
template_id=template_id,
defaults={"creator": creator, "extra_info": {"market_record_ids": [market_record_id]}},
)
if not created:
if template_id in existing_template_ids:
existing_record = existing_template_ids[template_id]
market_ids = existing_record.extra_info.setdefault("market_record_ids", [])
if market_record_id not in market_ids:
market_ids.append(market_record_id)
new_records.append(existing_record)
records_to_update.append(existing_record)
else:
new_record = TemplateSharedRecord(
project_id=project_id,
template_id=template_id,
creator=creator,
extra_info={"market_record_ids": [market_record_id]},
)
new_records.append(new_record)

if new_records:
TemplateSharedRecord.objects.bulk_update(new_records, ["extra_info"])
TemplateSharedRecord.objects.bulk_create(
new_records, batch_size=TEMPLATE_SHARED_RECORD_BATCH_OPERATION_COUNT
)

if records_to_update:
TemplateSharedRecord.objects.bulk_update(
records_to_update, ["extra_info"], batch_size=TEMPLATE_SHARED_RECORD_BATCH_OPERATION_COUNT
)

return {"result": True, "message": "update shared record successfully", "code": err_code.SUCCESS.code}

Expand All @@ -67,8 +84,6 @@ class TemplateSharedRecord(models.Model):
project_id = models.IntegerField(_("项目 ID"), default=-1, help_text="项目 ID")
template_id = models.IntegerField(_("模板 ID"), help_text="模板 ID", db_index=True)
creator = models.CharField(_("创建者"), max_length=32, default="")
create_at = models.DateTimeField(_("创建时间"), auto_now_add=True)
update_at = models.DateTimeField(verbose_name=_("更新时间"), auto_now=True)
extra_info = models.JSONField(_("额外信息"), blank=True, null=True)

objects = TemplateSharedManager()
Expand Down
44 changes: 38 additions & 6 deletions gcloud/contrib/template_market/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rest_framework import viewsets
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.decorators import action
from rest_framework import permissions

from gcloud import err_code
Expand Down Expand Up @@ -77,8 +78,37 @@ def _build_template_data(self, serializer, **kwargs):
data["id"] = market_record_id
return data

@action(detail=False, methods=["get"])
def get_service_category(self, request, *args, **kwargs):
response_data = self.market_client.get_service_category()
if not response_data["result"]:
logging.warning("Failed to obtain the market service category")
return Response(
{
"result": False,
"message": "Failed to obtain the market service category",
"code": err_code.OPERATION_FAIL.code,
}
)
return Response({"result": True, "data": response_data["data"], "code": err_code.SUCCESS.code})

@action(detail=False, methods=["get"])
def get_scene_label(self, request, *args, **kwargs):
response_data = self.market_client.get_scene_label()

if not response_data["result"]:
logging.exception("Failed to obtain scene tag list")
return Response(
{
"result": False,
"message": "Failed to obtain scene tag list",
"code": err_code.OPERATION_FAIL.code,
}
)
return Response({"result": True, "data": response_data["data"], "code": err_code.SUCCESS.code})

def list(self, request, *args, **kwargs):
response_data = self.market_client.get_shared_list()
response_data = self.market_client.get_template_scene_list()

if not response_data["result"]:
logging.exception("Failed to obtain the market template list")
Expand All @@ -97,7 +127,7 @@ def create(self, request, *args, **kwargs):
serializer.is_valid(raise_exception=True)

data = self._build_template_data(serializer)
response_data = self.market_client.create_shared_record(data)
response_data = self.market_client.create_template_scene(data)
if not response_data.get("result"):
return Response(
{
Expand All @@ -119,10 +149,12 @@ def partial_update(self, request, *args, **kwargs):
market_record_id = kwargs["pk"]
serializer = self.serializer_class(data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
existing_records = self.market_client.get_shared_detail(market_record_id)
existing_template_ids = set([template["id"] for template in json.loads(existing_records["data"]["templates"])])
existing_records = self.market_client.get_template_scene_detail(market_record_id)
existing_market_template_ids = set(
[template["id"] for template in json.loads(existing_records["data"]["templates"])]
)
data = self._build_template_data(serializer, market_record_id=market_record_id)
response_data = self.market_client.patch_shared_record(data, market_record_id)
response_data = self.market_client.patch_template_scene(data, market_record_id)
if not response_data.get("result"):
return Response(
{
Expand All @@ -136,6 +168,6 @@ def partial_update(self, request, *args, **kwargs):
new_template_ids=serializer.validated_data["template_ids"],
market_record_id=market_record_id,
creator=serializer.validated_data["creator"],
existing_template_ids=existing_template_ids,
existing_market_template_ids=existing_market_template_ids,
)
return Response({"result": True, "data": response_data, "code": err_code.SUCCESS.code})

0 comments on commit 0e19d45

Please sign in to comment.