From 0e19d4570abc9f5f3ff364936ee33525cff413e2 Mon Sep 17 00:00:00 2001 From: guohelu <141622458+guohelu@users.noreply.github.com> Date: Fri, 13 Dec 2024 20:53:07 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=97=AE=E9=A2=98=20#7626=20(#7633)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: 修改模型问题 #7626 * fix: 修改模型方法问题 #7626 * fix: 修复语法错误 #7626 --- gcloud/contrib/template_market/admin.py | 4 +- gcloud/contrib/template_market/clients.py | 18 ++++- .../migrations/0001_initial.py | 6 +- gcloud/contrib/template_market/models.py | 77 +++++++++++-------- gcloud/contrib/template_market/viewsets.py | 44 +++++++++-- 5 files changed, 102 insertions(+), 47 deletions(-) diff --git a/gcloud/contrib/template_market/admin.py b/gcloud/contrib/template_market/admin.py index 8c415c540..aaa1a2979 100644 --- a/gcloud/contrib/template_market/admin.py +++ b/gcloud/contrib/template_market/admin.py @@ -18,6 +18,6 @@ @admin.register(models.TemplateSharedRecord) class TemplateSharedRecordAdmin(admin.ModelAdmin): - list_display = ["project_id", "template_id", "creator", "create_at", "update_at", "extra_info"] - list_filter = ["project_id", "template_id", "creator", "create_at", "update_at"] + list_display = ["project_id", "template_id", "creator", "extra_info"] + list_filter = ["project_id", "template_id", "creator"] search_fields = ["project_id", "creator"] diff --git a/gcloud/contrib/template_market/clients.py b/gcloud/contrib/template_market/clients.py index 8b448e327..fe1114806 100644 --- a/gcloud/contrib/template_market/clients.py +++ b/gcloud/contrib/template_market/clients.py @@ -23,22 +23,32 @@ def __init__(self): def _get_url(self, endpoint): return f"{self.base_url}{endpoint}" - def get_shared_detail(self, market_record_id): + def get_service_category(self): + url = self._get_url("/category/get_service_category/") + response = requests.get(url) + return response.json() + + def get_scene_label(self): + url = self._get_url("/sre_property/scene_label/") + response = requests.get(url) + return response.json() + + def get_template_scene_detail(self, market_record_id): url = self._get_url(f"/sre_scene/flow_template_scene/{market_record_id}/") response = requests.get(url) return response.json() - def get_shared_list(self): + def get_template_scene_list(self): url = self._get_url("/sre_scene/flow_template_scene/?is_all=true") response = requests.get(url) return response.json() - def create_shared_record(self, data): + def create_template_scene(self, data): url = self._get_url("/sre_scene/flow_template_scene/") response = requests.post(url, json=data) return response.json() - def patch_shared_record(self, data, market_record_id): + def patch_template_scene(self, data, market_record_id): url = self._get_url(f"/sre_scene/flow_template_scene/{market_record_id}/") response = requests.patch(url, json=data) return response.json() diff --git a/gcloud/contrib/template_market/migrations/0001_initial.py b/gcloud/contrib/template_market/migrations/0001_initial.py index 55f205cd3..edf2dcb1c 100644 --- a/gcloud/contrib/template_market/migrations/0001_initial.py +++ b/gcloud/contrib/template_market/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 3.2.15 on 2024-12-13 07:48 +# Generated by Django 3.2.15 on 2024-12-13 10:48 from django.db import migrations, models @@ -15,10 +15,8 @@ class Migration(migrations.Migration): fields=[ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ("project_id", models.IntegerField(default=-1, help_text="项目 ID", verbose_name="项目 ID")), - ("template_id", models.IntegerField(db_index=True, help_text="模板 ID", verbose_name="模板 ID ")), + ("template_id", models.IntegerField(db_index=True, help_text="模板 ID", verbose_name="模板 ID")), ("creator", models.CharField(default="", max_length=32, verbose_name="创建者")), - ("create_at", models.DateTimeField(auto_now_add=True, verbose_name="创建时间")), - ("update_at", models.DateTimeField(auto_now=True, verbose_name="更新时间")), ("extra_info", models.JSONField(blank=True, null=True, verbose_name="额外信息")), ], options={ diff --git a/gcloud/contrib/template_market/models.py b/gcloud/contrib/template_market/models.py index 1b761dbe7..f70c58439 100644 --- a/gcloud/contrib/template_market/models.py +++ b/gcloud/contrib/template_market/models.py @@ -16,49 +16,66 @@ from gcloud import err_code +TEMPLATE_SHARED_RECORD_BATCH_OPERATION_COUNT = 50 + class TemplateSharedManager(models.Manager): - def update_shared_record(self, new_template_ids, market_record_id, project_id, creator, existing_template_ids=None): + def update_shared_record( + self, new_template_ids, market_record_id, project_id, creator, existing_market_template_ids=None + ): market_record_id = int(market_record_id) - if existing_template_ids: - templates_to_remove = existing_template_ids - set(new_template_ids) - if templates_to_remove: - for template_id in templates_to_remove: - current_template_record = TemplateSharedRecord.objects.get(template_id=template_id) - current_market_ids = current_template_record.extra_info.get("market_record_ids", []) - if market_record_id in current_market_ids: - current_market_ids.remove(market_record_id) - current_template_record.extra_info["market_record_ids"] = current_market_ids - current_template_record.save() - if not current_template_record.extra_info["market_record_ids"]: - current_template_record.delete() - else: - return { - "result": False, - "message": "template {} is not in record {}".format(template_id, market_record_id), - "code": err_code.REQUEST_PARAM_INVALID.code, - } - - templates_to_add = set(new_template_ids) - existing_template_ids + if existing_market_template_ids: + ids_to_delete = [] + templates_to_remove = existing_market_template_ids - set(new_template_ids) + current_template_records = TemplateSharedRecord.objects.filter( + project_id=project_id, template_id__in=templates_to_remove + ) + for current_template_record in current_template_records: + market_record_ids = current_template_record.extra_info.get("market_record_ids") + market_record_ids.remove(market_record_id) + if not market_record_ids: + ids_to_delete.append(current_template_record.id) + current_template_record.save() + + if ids_to_delete: + TemplateSharedRecord.objects.filter(id__in=ids_to_delete).delete() + + templates_to_add = set(new_template_ids) - existing_market_template_ids if templates_to_add: new_template_ids = list(templates_to_add) new_records = [] + records_to_update = [] + existing_records = TemplateSharedRecord.objects.filter(project_id=project_id, template_id__in=new_template_ids) + + existing_template_ids = {record.template_id: record for record in existing_records} + for template_id in new_template_ids: - existing_record, created = TemplateSharedRecord.objects.get_or_create( - project_id=project_id, - template_id=template_id, - defaults={"creator": creator, "extra_info": {"market_record_ids": [market_record_id]}}, - ) - if not created: + if template_id in existing_template_ids: + existing_record = existing_template_ids[template_id] market_ids = existing_record.extra_info.setdefault("market_record_ids", []) if market_record_id not in market_ids: market_ids.append(market_record_id) - new_records.append(existing_record) + records_to_update.append(existing_record) + else: + new_record = TemplateSharedRecord( + project_id=project_id, + template_id=template_id, + creator=creator, + extra_info={"market_record_ids": [market_record_id]}, + ) + new_records.append(new_record) if new_records: - TemplateSharedRecord.objects.bulk_update(new_records, ["extra_info"]) + TemplateSharedRecord.objects.bulk_create( + new_records, batch_size=TEMPLATE_SHARED_RECORD_BATCH_OPERATION_COUNT + ) + + if records_to_update: + TemplateSharedRecord.objects.bulk_update( + records_to_update, ["extra_info"], batch_size=TEMPLATE_SHARED_RECORD_BATCH_OPERATION_COUNT + ) return {"result": True, "message": "update shared record successfully", "code": err_code.SUCCESS.code} @@ -67,8 +84,6 @@ class TemplateSharedRecord(models.Model): project_id = models.IntegerField(_("项目 ID"), default=-1, help_text="项目 ID") template_id = models.IntegerField(_("模板 ID"), help_text="模板 ID", db_index=True) creator = models.CharField(_("创建者"), max_length=32, default="") - create_at = models.DateTimeField(_("创建时间"), auto_now_add=True) - update_at = models.DateTimeField(verbose_name=_("更新时间"), auto_now=True) extra_info = models.JSONField(_("额外信息"), blank=True, null=True) objects = TemplateSharedManager() diff --git a/gcloud/contrib/template_market/viewsets.py b/gcloud/contrib/template_market/viewsets.py index b536ef555..3d07ff172 100644 --- a/gcloud/contrib/template_market/viewsets.py +++ b/gcloud/contrib/template_market/viewsets.py @@ -16,6 +16,7 @@ from rest_framework import viewsets from rest_framework.views import APIView from rest_framework.response import Response +from rest_framework.decorators import action from rest_framework import permissions from gcloud import err_code @@ -77,8 +78,37 @@ def _build_template_data(self, serializer, **kwargs): data["id"] = market_record_id return data + @action(detail=False, methods=["get"]) + def get_service_category(self, request, *args, **kwargs): + response_data = self.market_client.get_service_category() + if not response_data["result"]: + logging.warning("Failed to obtain the market service category") + return Response( + { + "result": False, + "message": "Failed to obtain the market service category", + "code": err_code.OPERATION_FAIL.code, + } + ) + return Response({"result": True, "data": response_data["data"], "code": err_code.SUCCESS.code}) + + @action(detail=False, methods=["get"]) + def get_scene_label(self, request, *args, **kwargs): + response_data = self.market_client.get_scene_label() + + if not response_data["result"]: + logging.exception("Failed to obtain scene tag list") + return Response( + { + "result": False, + "message": "Failed to obtain scene tag list", + "code": err_code.OPERATION_FAIL.code, + } + ) + return Response({"result": True, "data": response_data["data"], "code": err_code.SUCCESS.code}) + def list(self, request, *args, **kwargs): - response_data = self.market_client.get_shared_list() + response_data = self.market_client.get_template_scene_list() if not response_data["result"]: logging.exception("Failed to obtain the market template list") @@ -97,7 +127,7 @@ def create(self, request, *args, **kwargs): serializer.is_valid(raise_exception=True) data = self._build_template_data(serializer) - response_data = self.market_client.create_shared_record(data) + response_data = self.market_client.create_template_scene(data) if not response_data.get("result"): return Response( { @@ -119,10 +149,12 @@ def partial_update(self, request, *args, **kwargs): market_record_id = kwargs["pk"] serializer = self.serializer_class(data=request.data, partial=True) serializer.is_valid(raise_exception=True) - existing_records = self.market_client.get_shared_detail(market_record_id) - existing_template_ids = set([template["id"] for template in json.loads(existing_records["data"]["templates"])]) + existing_records = self.market_client.get_template_scene_detail(market_record_id) + existing_market_template_ids = set( + [template["id"] for template in json.loads(existing_records["data"]["templates"])] + ) data = self._build_template_data(serializer, market_record_id=market_record_id) - response_data = self.market_client.patch_shared_record(data, market_record_id) + response_data = self.market_client.patch_template_scene(data, market_record_id) if not response_data.get("result"): return Response( { @@ -136,6 +168,6 @@ def partial_update(self, request, *args, **kwargs): new_template_ids=serializer.validated_data["template_ids"], market_record_id=market_record_id, creator=serializer.validated_data["creator"], - existing_template_ids=existing_template_ids, + existing_market_template_ids=existing_market_template_ids, ) return Response({"result": True, "data": response_data, "code": err_code.SUCCESS.code})