Coverage for pyodmongo/engines/utils.py: 100%
76 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 15:48 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 15:48 +0000
1from pydantic import BaseModel
2from bson import ObjectId
3from ..models.id_model import Id
4from ..models.db_model import MainBaseModel
5from ..models.db_field_info import DbField
6from ..models.db_decimal import DbDecimal
7from ..services.reference_pipeline import resolve_reference_pipeline
8from ..services.verify_subclasses import is_subclass
9from decimal import Decimal
10from bson import Decimal128
13def consolidate_dict(obj: MainBaseModel, dct: dict, populate: bool):
14 """
15 Recursively consolidates the attributes of a Pydantic model into a dictionary,
16 handling nested models, list fields, and references appropriately based on the model's
17 field specifications.
19 Args:
20 obj (MainBaseModel): The PyODMongo model instance from which to extract data.
21 dct (dict): The dictionary into which data should be consolidated. This dictionary
22 is modified in-place.
24 Returns:
25 dict: The modified dictionary with model data consolidated into it, including handling
26 of MongoDB ObjectId conversions and nested structures.
28 Description:
29 This function iterates over each field of the provided model instance, extracting the
30 value and associated metadata. Depending on whether the field contains model data,
31 is a list, or is a reference, the function appropriately processes the value to fit
32 MongoDB storage requirements, including converting to ObjectIds where necessary.
33 The function is recursive for nested models and lists of models.
34 """
35 for field, field_info in obj.__class__.model_fields.items():
36 value = getattr(obj, field)
37 try:
38 db_field_info: DbField = getattr(obj.__class__, field)
39 except AttributeError:
40 if is_subclass(class_to_verify=obj.__class__, subclass=BaseModel):
41 raise TypeError(
42 f"The {obj.__class__.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class"
43 )
44 alias = db_field_info.field_alias
45 has_model_fields = db_field_info.has_model_fields
46 is_list = db_field_info.is_list
47 by_reference = db_field_info.by_reference and not populate
48 if has_model_fields:
49 if value is None:
50 dct[alias] = None
51 continue
52 if not is_list:
53 if by_reference:
54 try:
55 dct[alias] = ObjectId(value.id)
56 except AttributeError:
57 dct[alias] = ObjectId(value)
58 else:
59 dct[alias] = {}
60 consolidate_dict(obj=value, dct=dct[alias], populate=populate)
61 else:
62 if by_reference:
63 try:
64 dct[alias] = [ObjectId(o.id) for o in value]
65 except AttributeError:
66 dct[alias] = [ObjectId(o) for o in value]
67 else:
68 dct[alias] = []
69 for v in value:
70 obj_lst_elem = {}
71 consolidate_dict(obj=v, dct=obj_lst_elem, populate=populate)
72 dct[alias].append(obj_lst_elem)
73 else:
74 if db_field_info.field_type == Id and value is not None:
75 if is_list:
76 dct[alias] = [ObjectId(o) for o in value if ObjectId.is_valid(o)]
77 else:
78 dct[alias] = ObjectId(value) if ObjectId.is_valid(value) else value
79 elif (
80 db_field_info.field_type == Decimal
81 or db_field_info.field_type == DbDecimal
82 ) and value is not None:
83 if is_list:
84 dct[alias] = [Decimal128(str(o)) for o in value]
85 else:
86 dct[alias] = Decimal128(str(value))
87 else:
88 dct[alias] = value
89 return dct
92def _skip_and_limit_stages(current_page: int, docs_per_page: int):
93 """
94 Add pagination stages to the aggregation pipeline.
96 Args:
97 current_page (int): The current page number.
98 docs_per_page (int): The number of documents per page.
99 """
100 max_docs_per_page = 1000
101 current_page = 1 if current_page <= 0 else current_page
102 docs_per_page = (
103 max_docs_per_page if docs_per_page > max_docs_per_page else docs_per_page
104 )
105 skip = (docs_per_page * current_page) - docs_per_page
106 skip_stage = [{"$skip": skip}]
107 limit_stage = [{"$limit": docs_per_page}]
108 return skip_stage, limit_stage
111def mount_base_pipeline(
112 Model,
113 query: dict,
114 sort: dict,
115 populate: bool,
116 pipeline: list | None,
117 populate_db_fields: list[DbField] | None,
118 paginate: int,
119 current_page: int,
120 docs_per_page: int,
121 no_paginate_limit: int | None,
122):
123 """
124 Mounts a base MongoDB aggregation pipeline for find operations.
126 This function constructs a pipeline by combining various stages, such as
127 document matching, sorting, reference population, and pagination,
128 based on the provided parameters. It serves as the core pipeline
129 builder for find queries.
131 Args:
132 Model: The PyODMongo model class for the query.
133 query: The MongoDB query dictionary for the `$match` stage.
134 sort: The dictionary defining the sort order for the `$sort` stage.
135 populate: If True, adds stages to populate referenced documents.
136 pipeline: An optional list of aggregation stages to prepend to the
137 pipeline. If None, the model's default `_pipeline` is used.
138 populate_db_fields: A list of specific `DbField`s to populate.
139 Used only if `populate` is True.
140 paginate: If True, enables pagination by adding `$skip` and `$limit`
141 stages.
142 current_page: The page number to retrieve when pagination is enabled.
143 docs_per_page: The number of documents per page for pagination.
144 no_paginate_limit: The maximum number of documents to return when
145 pagination is disabled.
146 is_find_one: If True, adds a `$limit: 1` stage to the pipeline,
147 optimizing the query for a single document.
149 Returns:
150 A list representing the complete MongoDB aggregation pipeline.
151 """
152 match_stage = [{"$match": query}]
153 sort_stage = [{"$sort": sort}] if sort != {} else []
154 skip_stage = []
155 limit_stage = [{"$limit": no_paginate_limit}] if no_paginate_limit else []
156 pipeline = pipeline or Model._pipeline
157 if paginate:
158 skip_stage, limit_stage = _skip_and_limit_stages(
159 current_page=current_page, docs_per_page=docs_per_page
160 )
161 if pipeline:
162 return match_stage + pipeline + sort_stage + skip_stage + limit_stage
163 if populate:
164 reference_stage = resolve_reference_pipeline(
165 cls=Model, pipeline=[], populate_db_fields=populate_db_fields
166 )
167 return match_stage + reference_stage + sort_stage + skip_stage + limit_stage
168 else:
169 return match_stage + sort_stage + skip_stage + limit_stage