diff --git a/backend/apps/core/management/commands/fetch_metabase.py b/backend/apps/core/management/commands/fetch_metabase.py index 8057a148..c687822d 100644 --- a/backend/apps/core/management/commands/fetch_metabase.py +++ b/backend/apps/core/management/commands/fetch_metabase.py @@ -76,16 +76,14 @@ def get_tables(self, token: str, database_id: int): return tables - def get_table_data(self, token: str, database_id: int, table: Table): - headers = self.get_headers(token) - fields = [f'"{field}"' for field in table.fields] - formated_field = ", ".join(fields) - query = f'SELECT {formated_field} FROM "{table.name}"' + def def_get_data_paginated(self, headers, database_id, query, page=0): + limit = 2000 + new_query = query + f" LIMIT {limit} OFFSET {page * limit}" payload = { "database": database_id, "native": { - "query": query, + "query": new_query, }, "type": "native", } @@ -93,12 +91,32 @@ def get_table_data(self, token: str, database_id: int, table: Table): response = requests.post(BASE_URL + "/api/dataset", headers=headers, json=payload) if response.status_code != 202: - return + self.stderr.write(f"Error fetching data: {response.text}") + return [] response_json = response.json() - rows = [] + return response_json["data"]["rows"] + + def get_table_data(self, token: str, database_id: int, table: Table): + headers = self.get_headers(token) + fields = [f'"{field}"' for field in table.fields] + formated_field = ", ".join(fields) + query = f'SELECT {formated_field} FROM "{table.name}"' - for row in response_json["data"]["rows"]: + raw_rows = [] + page = 0 + while True: + data = self.def_get_data_paginated(headers, database_id, query, page) + if len(data) == 0: + break + + raw_rows += data + page += 1 + + self.stdout.write(self.style.SUCCESS(f"Fetched {len(raw_rows)} rows from {str(table)}")) + + rows = [] + for row in raw_rows: instance = {} for i, field in enumerate(table.fields): instance[field] = row[i] @@ -109,7 +127,6 @@ def get_table_data(self, token: str, database_id: int, table: Table): self.save_data(table.name, json.dumps(rows, ensure_ascii=False, indent=4)) else: self.stdout.write(self.style.WARNING(f"No data found for {str(table)}")) - self.stdout.write(self.style.WARNING(query)) def clean_data(self): directory = os.path.join(os.getcwd(), "metabase_data") diff --git a/backend/apps/core/management/commands/populate.py b/backend/apps/core/management/commands/populate.py index 3fd1ea46..f0ad4247 100644 --- a/backend/apps/core/management/commands/populate.py +++ b/backend/apps/core/management/commands/populate.py @@ -67,7 +67,9 @@ def print(self, context): for model in self.models: context.stdout.write(context.style.SUCCESS(f"{'-' * self.depth * 2} {model.__name__}")) for field in model._meta.get_fields(): - if isinstance(field, models.ForeignKey): + if isinstance(field, models.ForeignKey) or isinstance( + field, models.ManyToManyField + ): name = f"{field.name} -> {field.related_model.__name__}" if field.null: @@ -93,6 +95,24 @@ def load_table_data(self, table_name): return data + def get_m2m_data(self, table_name, current_table_name, field_name, id): + cache_context = f"m2m_cache_{table_name}" + + if not hasattr(self, cache_context): + data = self.load_table_data(table_name) + cache = {} + + for item in data: + related_id = item[current_table_name] + if related_id not in cache: + cache[related_id] = [] + + cache[related_id].append(item[field_name]) + + setattr(self, cache_context, cache) + + return getattr(self, cache_context).get(id, []) + def model_has_data(self, model_name): if f"{model_name}.json" in self.files: return True @@ -139,7 +159,9 @@ def sort_models_by_depedencies(self, models_to_populate, other_models): has_all_dependencies = True for field in model._meta.get_fields(): - if isinstance(field, models.ForeignKey): + if isinstance(field, models.ForeignKey) or isinstance( + field, models.ManyToManyField + ): if ( field.related_model not in other_models and field.related_model not in sorted_models @@ -178,8 +200,9 @@ def create_instance(self, model, item): payload = {} retry = None table_name = model._meta.db_table + m2m_payload = {} - for field in model._meta.local_fields: + for field in model._meta.get_fields(): if isinstance(field, models.ForeignKey): field_name = f"{field.name}_id" current_value = item[field_name] @@ -188,6 +211,7 @@ def create_instance(self, model, item): continue reference = self.references.get(field.related_model._meta.db_table, current_value) + if reference: payload[field_name] = reference else: @@ -200,13 +224,48 @@ def create_instance(self, model, item): "table_name": field.related_model._meta.db_table, "field_name": field_name, } + elif isinstance(field, models.ManyToManyField): + field_name = field.name + m2m_table_name = field.m2m_db_table() + current_model_name = f"{model.__name__.lower()}_id" + field_model_name = field.related_model.__name__.lower() + "_id" + + m2m_related_data = self.get_m2m_data( + m2m_table_name, current_model_name, field_model_name, item["id"] + ) + + instances = [ + self.references.get(field.related_model._meta.db_table, current_value) + for current_value in m2m_related_data + ] + + if instances: + m2m_payload[field_name] = instances else: - payload[field.name] = item[field.name] + current_value = item.get(field.name) + + if current_value is None: + continue + + payload[field.name] = current_value instance = model(**payload) instance.save() + # Set many to many relationships + if m2m_payload: + for field_name, related_data in m2m_payload.items(): + field = getattr(instance, field_name) + + try: + field.set(related_data) + except Exception as e: + print(e) + print(field_name) + print(related_data) + raise e + if retry: retry["instance"] = instance self.retry_instances.append(retry)