Coverage for src / mesh / models / orm / journal_models.py: 88%
90 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-04 12:41 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-04 12:41 +0000
1from __future__ import annotations
3from itertools import chain
4from typing import TypeVar
6from django.db import models
7from django.db.models import QuerySet
8from django.utils.translation import gettext_lazy as _
10from mesh.models.orm.base_models import BaseChangeTrackingModel
11from mesh.models.orm.submission_models import Submission
13_T = TypeVar("_T", bound=models.Model)
16class JournalSectionManager(models.Manager["JournalSection"]):
17 """
18 Custom manager for JournalSection used to cache data at the manager level.
19 """
21 # Variables used to cache data at the class level. The cache must be cleared
22 # on every change on a JournalSection (creation, deletion, update).
23 _all_journal_sections: QuerySet[JournalSection] | None = None
24 _all_journal_sections_children: dict[int | None, list[JournalSection]] | None = None
25 _all_journal_sections_parents: dict[int, JournalSection | None] | None = None
27 def get_queryset(self):
28 return (
29 super()
30 .get_queryset()
31 .select_related(
32 "created_by",
33 "last_modified_by",
34 "parent",
35 )
36 )
38 def all_journal_sections(self) -> QuerySet[JournalSection]:
39 """
40 Returns all registered `JournalSection`.
41 """
42 if self._all_journal_sections is None:
43 self._all_journal_sections = self.get_queryset().all()
45 return self._all_journal_sections
47 def all_journal_sections_parents(self):
48 """
49 Returns the mapping: `{journal_section.pk: parent}` for all registered
50 `JournalSection`.
51 """
52 if self._all_journal_sections_parents is None:
53 self._all_journal_sections_parents = {
54 c.pk: c.parent for c in self.all_journal_sections()
55 }
57 return self._all_journal_sections_parents
59 def all_journal_sections_children(self) -> dict[int | None, list[JournalSection]]:
60 """
61 Return the mapping: `{journal_section.pk : list[children]}` for all registered
62 JournalSection.
64 There's an additional entry to the mapping, `None`, listing all the journal_sections
65 without a parent (top-level journal_sections).
66 """
67 if self._all_journal_sections_children is None:
68 journal_sections = self.all_journal_sections()
69 processed_journal_sections = {c.pk: [] for c in journal_sections}
70 processed_journal_sections[None] = []
72 for journal_section in journal_sections:
73 parent: JournalSection | None = journal_section.parent
74 key: int | None = None if parent is None else parent.pk
75 processed_journal_sections[key].append(journal_section)
77 self._all_journal_sections_children = processed_journal_sections
79 return self._all_journal_sections_children
81 # To optimize this we'll need to do recursive queries
82 # native Django cannot do recursive CTEs
83 # https://github.com/dimagi/django-cte
84 def get_children_recursive(
85 self, journal_section: JournalSection | None
86 ) -> list[JournalSection]:
87 """
88 Return the flattened list of all children nodes of the given journal_section.
89 """
90 children_dict = self.all_journal_sections_children()
91 key = None if journal_section is None else journal_section.pk
92 children = children_dict[key]
93 return list(
94 # Flatten the input sequences into a single sequence.
95 chain.from_iterable(
96 ([c, *self.get_children_recursive(c)] for c in children),
97 )
98 )
100 def get_parents_recursive(
101 self, journal_section: JournalSection | None
102 ) -> list[JournalSection]:
103 """
104 Return the flattened list of all parent nodes of the given journal_section.
105 """
106 if journal_section is None:
107 return []
108 parents_dict = self.all_journal_sections_parents()
109 parent = parents_dict.get(journal_section.pk)
110 if parent is None:
111 return []
112 return [parent, *self.get_parents_recursive(parent)]
114 def clean_cache(self):
115 self._all_journal_sections = None
116 self._all_journal_sections_children = None
117 self._all_journal_sections_parents = None
120class JournalSection(BaseChangeTrackingModel):
121 """
122 Represents a journal section. Sections can be nested infinitely.
124 Sections are mainly used to give editor rights over a whole section.
125 """
127 name = models.CharField(verbose_name=_("Name"), max_length=128, unique=True)
128 parent = models.ForeignKey["JournalSection"](
129 "self", on_delete=models.SET_NULL, null=True, related_name="children"
130 )
131 children: models.manager.RelatedManager[JournalSection]
133 objects: JournalSectionManager = JournalSectionManager() # type: ignore
135 def __str__(self) -> str:
136 return self.name
138 def save(self, *args, **kwargs) -> None:
139 """
140 Need to check that the selected parent journal_section is valid.
141 """
142 if self.parent and (self.parent == self or self.parent in self.all_children()):
143 raise ValueError("The selected parent is invalid (self or child).")
144 super().save(*args, **kwargs)
145 self.__class__.objects.clean_cache()
147 def delete(self, *args, **kwargs):
148 """
149 Delete the section and additionally move all submissions and sub-sections
150 to the parent section.
151 """
152 parent = self.parent
153 children = list(self.children.all().values_list("pk", flat=True))
154 submissions_to_update = list(
155 Submission.objects.filter(journal_section=self).values_list("pk", flat=True)
156 )
158 res = super().delete(*args, **kwargs)
160 if children:
161 JournalSection.objects.filter(pk__in=children).update(parent=parent)
162 if submissions_to_update:
163 Submission.objects.filter(pk__in=submissions_to_update).update(journal_section=parent)
165 self.__class__.objects.clean_cache()
166 return res
168 def top_level_journal_section(self) -> JournalSection:
169 """
170 Return the top level parent journal_section (journal).
171 """
172 # The journal_section is a journal
173 if self.parent is None:
174 return self
176 # Browse up the journal_section arborescence until finding the top level journal_section
177 parent_journal_sections = self.__class__.objects.all_journal_sections_parents()
179 top_level_journal_section = self.parent
181 parent_journal = top_level_journal_section.parent
182 while parent_journal is not None:
183 top_level_journal_section = parent_journal_sections[parent_journal.pk]
184 if top_level_journal_section is None:
185 raise ValueError(
186 f"Invalid cache: {parent_journal.pk} pk not found in parent_journal_sections"
187 )
188 parent_journal = top_level_journal_section.parent
190 return top_level_journal_section
192 def all_children(self) -> list[JournalSection]:
193 """
194 Get all the `JournalSection` children.
195 """
196 if self._state.adding:
197 return []
199 return self.__class__.objects.get_children_recursive(self)