Coverage for pyodmongo/services/reference_pipeline.py: 100%
49 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 15:08 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 15:08 +0000
1from pydantic import BaseModel
2from ..models.db_field_info import DbField
3from .verify_subclasses import is_subclass
4from .aggregate_stages import (
5 unwind,
6 group_set_replace_root,
7 unset,
8 lookup,
9 set_,
10)
11import copy
14def _paths_to_ref_ids(
15 cls: BaseModel,
16 paths: list,
17 db_field_path: list,
18 populate_db_fields: list[DbField] | None,
19):
20 for field in cls.model_fields:
21 try:
22 db_field: DbField = getattr(cls, field)
23 except AttributeError:
24 if is_subclass(class_to_verify=cls, subclass=BaseModel):
25 raise TypeError(
26 f"The {cls.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class"
27 )
28 db_field_path.append(db_field)
29 if db_field.by_reference:
30 try:
31 if db_field in populate_db_fields:
32 paths.append(copy.deepcopy(db_field_path))
33 except TypeError:
34 paths.append(copy.deepcopy(db_field_path))
35 elif db_field.has_model_fields:
36 _paths_to_ref_ids(
37 cls=db_field.field_type,
38 paths=paths,
39 db_field_path=db_field_path,
40 populate_db_fields=populate_db_fields,
41 )
42 db_field_path.pop(-1)
43 return paths
46def resolve_reference_pipeline(
47 cls: BaseModel, pipeline: list, populate_db_fields: list[DbField] | None
48):
49 """
50 Constructs a MongoDB aggregation pipeline that handles document references within a given class model.
52 Args:
53 cls (BaseModel): The model class for which to build the reference resolution pipeline.
54 pipeline (list): Initial pipeline stages to which reference resolution stages will be added.
56 Returns:
57 list: The modified MongoDB aggregation pipeline including reference resolution stages.
59 Description:
60 This function builds a pipeline that unwinds arrays, performs lookups for references, and resets
61 document structures to correctly embed referenced documents. It supports nested models and handles
62 complex reference structures.
63 """
64 paths = _paths_to_ref_ids(
65 cls=cls, paths=[], db_field_path=[], populate_db_fields=populate_db_fields
66 )
67 for db_field_path in paths:
68 path_str = ""
69 unwind_index_list = ["_id"]
70 paths_str_to_group = []
71 db_field: DbField = None
72 for index, db_field in enumerate(db_field_path):
73 db_field: DbField
74 path_str += (
75 "." + db_field.field_alias if path_str != "" else db_field.field_alias
76 )
78 if db_field.is_list and not db_field.by_reference:
79 unwind_index_list.append(f"__unwind_{index}")
80 paths_str_to_group.append(path_str)
81 pipeline += unwind(
82 path=path_str,
83 array_index=unwind_index_list[-1],
84 preserve_empty=True,
85 )
87 pipeline += lookup(
88 from_=db_field.field_type._collection,
89 local_field=path_str,
90 foreign_field="_id",
91 as_=path_str,
92 pipeline=resolve_reference_pipeline(
93 cls=db_field.field_type,
94 pipeline=[],
95 populate_db_fields=populate_db_fields,
96 ),
97 )
98 if not db_field.is_list:
99 pipeline += set_(local_field=path_str, as_=path_str)
101 index_to_unset = []
102 for index, path_str in enumerate(reversed(paths_str_to_group)):
103 id_ = [
104 f"${e}" for e in unwind_index_list[: len(unwind_index_list) - index - 1]
105 ]
106 current_index = unwind_index_list[-(index + 1)]
107 index_to_unset.append(current_index)
108 to_sort = {key: 1 for key in unwind_index_list[1:]}
109 pipeline += group_set_replace_root(
110 to_sort=to_sort,
111 id_=id_,
112 array_index=current_index,
113 field=path_str.split(".")[-1],
114 path_str=path_str,
115 )
116 if len(index_to_unset) > 0:
117 pipeline += unset(fields=index_to_unset)
119 return pipeline