Coverage for pyodmongo/engines/utils.py: 100%
57 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 15:08 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 15:08 +0000
1from pydantic import BaseModel
2from bson import ObjectId
3from ..models.id_model import Id
4from ..models.db_model import MainBaseModel
5from ..models.db_field_info import DbField
6from ..services.reference_pipeline import resolve_reference_pipeline
7from ..services.verify_subclasses import is_subclass
10def consolidate_dict(obj: MainBaseModel, dct: dict):
11 """
12 Recursively consolidates the attributes of a Pydantic model into a dictionary,
13 handling nested models, list fields, and references appropriately based on the model's
14 field specifications.
16 Args:
17 obj (MainBaseModel): The PyODMongo model instance from which to extract data.
18 dct (dict): The dictionary into which data should be consolidated. This dictionary
19 is modified in-place.
21 Returns:
22 dict: The modified dictionary with model data consolidated into it, including handling
23 of MongoDB ObjectId conversions and nested structures.
25 Description:
26 This function iterates over each field of the provided model instance, extracting the
27 value and associated metadata. Depending on whether the field contains model data,
28 is a list, or is a reference, the function appropriately processes the value to fit
29 MongoDB storage requirements, including converting to ObjectIds where necessary.
30 The function is recursive for nested models and lists of models.
31 """
32 for field, field_info in obj.model_fields.items():
33 value = getattr(obj, field)
34 try:
35 db_field_info: DbField = getattr(obj.__class__, field)
36 except AttributeError:
37 if is_subclass(class_to_verify=obj.__class__, subclass=BaseModel):
38 raise TypeError(
39 f"The {obj.__class__.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class"
40 )
41 alias = db_field_info.field_alias
42 has_model_fields = db_field_info.has_model_fields
43 is_list = db_field_info.is_list
44 by_reference = db_field_info.by_reference
45 if has_model_fields:
46 if value is None:
47 dct[alias] = None
48 continue
49 if not is_list:
50 if by_reference:
51 try:
52 dct[alias] = ObjectId(value.id)
53 except AttributeError:
54 dct[alias] = ObjectId(value)
55 else:
56 dct[alias] = {}
57 consolidate_dict(obj=value, dct=dct[alias])
58 else:
59 if by_reference:
60 try:
61 dct[alias] = [ObjectId(o.id) for o in value]
62 except AttributeError:
63 dct[alias] = [ObjectId(o) for o in value]
64 else:
65 dct[alias] = []
66 for v in value:
67 obj_lst_elem = {}
68 consolidate_dict(obj=v, dct=obj_lst_elem)
69 dct[alias].append(obj_lst_elem)
70 else:
71 if db_field_info.field_type == Id and value is not None:
72 if is_list:
73 dct[alias] = [ObjectId(o) for o in value if ObjectId.is_valid(o)]
74 else:
75 dct[alias] = ObjectId(value) if ObjectId.is_valid(value) else value
76 else:
77 dct[alias] = value
78 return dct
81def mount_base_pipeline(
82 Model,
83 query: dict,
84 sort: dict,
85 populate: bool,
86 populate_db_fields: list[DbField] | None,
87):
88 """
89 Constructs a basic MongoDB aggregation pipeline based on the provided query, sort, and
90 model settings, optionally including reference population stages.
92 Args:
93 Model (type[Model]): The model class defining the MongoDB collection and its aggregation logic.
94 query (dict): MongoDB query dictionary to filter the documents.
95 sort (dict): Dictionary defining the sorting order of the documents.
96 populate (bool): Flag indicating whether to include reference population stages in the pipeline.
98 Returns:
99 list: The MongoDB aggregation pipeline configured with match, sort, and optionally population stages.
101 Description:
102 This function constructs a MongoDB aggregation pipeline using the provided model's
103 internal pipeline stages and the specified query and sort parameters. If 'populate' is
104 true, reference population stages defined in the model are included to resolve document
105 references as part of the aggregation process.
106 """
107 match_stage = [{"$match": query}]
108 sort_stage = [{"$sort": sort}] if sort != {} else []
109 model_stage = Model._pipeline
110 if model_stage != []:
111 return match_stage + model_stage + sort_stage
112 if populate:
113 reference_stage = resolve_reference_pipeline(
114 cls=Model, pipeline=[], populate_db_fields=populate_db_fields
115 )
116 return match_stage + reference_stage + sort_stage
117 else:
118 return match_stage + sort_stage