Coverage for pyodmongo/engines/utils.py: 100%
78 statements
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-27 14:31 +0000
« prev ^ index » next coverage.py v7.7.1, created at 2025-03-27 14:31 +0000
1from pydantic import BaseModel
2from bson import ObjectId
3from ..models.id_model import Id
4from ..models.db_model import MainBaseModel
5from ..models.db_field_info import DbField
6from ..models.db_decimal import DbDecimal
7from ..services.reference_pipeline import resolve_reference_pipeline
8from ..services.verify_subclasses import is_subclass
9from decimal import Decimal
10from bson import Decimal128
13def consolidate_dict(obj: MainBaseModel, dct: dict, populate: bool):
14 """
15 Recursively consolidates the attributes of a Pydantic model into a dictionary,
16 handling nested models, list fields, and references appropriately based on the model's
17 field specifications.
19 Args:
20 obj (MainBaseModel): The PyODMongo model instance from which to extract data.
21 dct (dict): The dictionary into which data should be consolidated. This dictionary
22 is modified in-place.
24 Returns:
25 dict: The modified dictionary with model data consolidated into it, including handling
26 of MongoDB ObjectId conversions and nested structures.
28 Description:
29 This function iterates over each field of the provided model instance, extracting the
30 value and associated metadata. Depending on whether the field contains model data,
31 is a list, or is a reference, the function appropriately processes the value to fit
32 MongoDB storage requirements, including converting to ObjectIds where necessary.
33 The function is recursive for nested models and lists of models.
34 """
35 for field, field_info in obj.model_fields.items():
36 value = getattr(obj, field)
37 try:
38 db_field_info: DbField = getattr(obj.__class__, field)
39 except AttributeError:
40 if is_subclass(class_to_verify=obj.__class__, subclass=BaseModel):
41 raise TypeError(
42 f"The {obj.__class__.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class"
43 )
44 alias = db_field_info.field_alias
45 has_model_fields = db_field_info.has_model_fields
46 is_list = db_field_info.is_list
47 by_reference = db_field_info.by_reference and not populate
48 if has_model_fields:
49 if value is None:
50 dct[alias] = None
51 continue
52 if not is_list:
53 if by_reference:
54 try:
55 dct[alias] = ObjectId(value.id)
56 except AttributeError:
57 dct[alias] = ObjectId(value)
58 else:
59 dct[alias] = {}
60 consolidate_dict(obj=value, dct=dct[alias], populate=populate)
61 else:
62 if by_reference:
63 try:
64 dct[alias] = [ObjectId(o.id) for o in value]
65 except AttributeError:
66 dct[alias] = [ObjectId(o) for o in value]
67 else:
68 dct[alias] = []
69 for v in value:
70 obj_lst_elem = {}
71 consolidate_dict(obj=v, dct=obj_lst_elem, populate=populate)
72 dct[alias].append(obj_lst_elem)
73 else:
74 if db_field_info.field_type == Id and value is not None:
75 if is_list:
76 dct[alias] = [ObjectId(o) for o in value if ObjectId.is_valid(o)]
77 else:
78 dct[alias] = ObjectId(value) if ObjectId.is_valid(value) else value
79 elif (
80 db_field_info.field_type == Decimal
81 or db_field_info.field_type == DbDecimal
82 ) and value is not None:
83 if is_list:
84 dct[alias] = [Decimal128(str(o)) for o in value]
85 else:
86 dct[alias] = Decimal128(str(value))
87 else:
88 dct[alias] = value
89 return dct
92def _skip_and_limit_stages(current_page: int, docs_per_page: int):
93 """
94 Add pagination stages to the aggregation pipeline.
96 Args:
97 current_page (int): The current page number.
98 docs_per_page (int): The number of documents per page.
99 """
100 max_docs_per_page = 1000
101 current_page = 1 if current_page <= 0 else current_page
102 docs_per_page = (
103 max_docs_per_page if docs_per_page > max_docs_per_page else docs_per_page
104 )
105 skip = (docs_per_page * current_page) - docs_per_page
106 skip_stage = [{"$skip": skip}]
107 limit_stage = [{"$limit": docs_per_page}]
108 return skip_stage, limit_stage
111def mount_base_pipeline(
112 Model,
113 query: dict,
114 sort: dict,
115 populate: bool,
116 pipeline: list | None,
117 populate_db_fields: list[DbField] | None,
118 paginate: int,
119 current_page: int,
120 docs_per_page: int,
121 is_find_one: bool,
122):
123 """
124 Constructs a basic MongoDB aggregation pipeline based on the provided query, sort, and
125 model settings, optionally including reference population stages.
127 Args:
128 Model (type[Model]): The model class defining the MongoDB collection and its aggregation logic.
129 query (dict): MongoDB query dictionary to filter the documents.
130 sort (dict): Dictionary defining the sorting order of the documents.
131 populate (bool): Flag indicating whether to include reference population stages in the pipeline.
133 Returns:
134 list: The MongoDB aggregation pipeline configured with match, sort, and optionally population stages.
136 Description:
137 This function constructs a MongoDB aggregation pipeline using the provided model's
138 internal pipeline stages and the specified query and sort parameters. If 'populate' is
139 true, reference population stages defined in the model are included to resolve document
140 references as part of the aggregation process.
141 """
142 match_stage = [{"$match": query}]
143 sort_stage = [{"$sort": sort}] if sort != {} else []
144 skip_stage = []
145 limit_stage = [{"$limit": 1}] if is_find_one else []
146 model_stage = Model._pipeline
147 if pipeline:
148 return match_stage + pipeline + sort_stage
149 if model_stage != []:
150 return match_stage + model_stage + sort_stage
151 if paginate:
152 skip_stage, limit_stage = _skip_and_limit_stages(
153 current_page=current_page, docs_per_page=docs_per_page
154 )
155 if populate:
156 reference_stage = resolve_reference_pipeline(
157 cls=Model, pipeline=[], populate_db_fields=populate_db_fields
158 )
159 return match_stage + sort_stage + skip_stage + limit_stage + reference_stage
160 else:
161 return match_stage + sort_stage + skip_stage + limit_stage