Coverage for pyodmongo/services/model_init.py: 100%
112 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 15:08 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 15:08 +0000
1from pydantic import BaseModel
2from pymongo import IndexModel, ASCENDING, TEXT
3from typing import Any, Union, get_origin, get_args
4from types import UnionType
5from ..models.id_model import Id
6from ..models.db_field_info import DbField
7from ..services.verify_subclasses import is_subclass
10def __ordinary_index_and_text_keys(
11 cls: BaseModel,
12 indexes: list,
13 text_keys: list,
14 indexes_path: list,
15 text_indexes_path: list,
16):
17 for key in cls.model_fields.keys():
18 try:
19 db_field_info: DbField = getattr(cls, key)
20 except AttributeError:
21 if is_subclass(class_to_verify=cls, subclass=BaseModel):
22 raise TypeError(
23 f"The {cls.__name__} class inherits from Pydantic's BaseModel class. Try switching to PyODMongo's MainBaseModel class"
24 )
26 if (
27 db_field_info.has_model_fields
28 and not db_field_info.by_reference
29 and not db_field_info.is_list
30 ):
31 indexes_path.append(db_field_info.field_alias)
32 text_indexes_path.append(db_field_info.field_alias)
33 __ordinary_index_and_text_keys(
34 cls=db_field_info.field_type,
35 indexes=indexes,
36 text_keys=text_keys,
37 indexes_path=indexes_path,
38 text_indexes_path=text_indexes_path,
39 )
40 indexes_path.pop()
41 text_indexes_path.pop()
42 if not cls.model_fields[key].json_schema_extra:
43 continue
44 is_index = cls.model_fields[key].json_schema_extra.get("index") or False
45 is_unique = cls.model_fields[key].json_schema_extra.get("unique") or False
46 is_text_index = (
47 cls.model_fields[key].json_schema_extra.get("text_index") or False
48 )
49 if is_index:
50 indexes_path.append(db_field_info.field_alias)
51 index_name = ".".join(indexes_path)
52 from ..models.db_model import DbModel
54 # TODO: 🤮🤮🤮 This section requires a more efficient and robust solution 🤮🤮🤮
55 # The current implementation needs improvement to optimize performance and readability.
56 if (
57 is_unique
58 and len(indexes_path) > 1
59 and is_subclass(class_to_verify=cls, subclass=DbModel)
60 ):
61 is_unique = False
63 indexes.append(
64 IndexModel([(index_name, ASCENDING)], name=index_name, unique=is_unique)
65 )
66 indexes_path.pop()
67 if is_text_index:
68 text_indexes_path.append(db_field_info.field_alias)
69 text_index_name = ".".join(text_indexes_path)
70 text_keys.append((text_index_name, TEXT))
71 text_indexes_path.pop()
72 return indexes, text_keys
75def resolve_indexes(cls: BaseModel):
76 """
77 Resolves and constructs MongoDB index models for the fields of a given class based on field attributes.
79 Args:
80 cls (BaseModel): The model class whose indexes are to be resolved.
82 Returns:
83 list[IndexModel]: A list of MongoDB index models constructed for the class fields.
85 Description:
86 This method iterates over the model fields and checks for index-related attributes (index, unique,
87 text_index). It creates appropriate MongoDB index structures (IndexModel) for these fields,
88 including handling of text indexes and unique constraints.
89 """
90 indexes = []
91 text_keys = []
93 indexes, text_keys = __ordinary_index_and_text_keys(
94 cls=cls,
95 indexes=indexes,
96 text_keys=text_keys,
97 indexes_path=[],
98 text_indexes_path=[],
99 )
101 if len(text_keys) > 0:
102 default_language = cls._default_language
103 if default_language:
104 indexes.append(
105 IndexModel(text_keys, name="texts", default_language=default_language)
106 )
107 else:
108 indexes.append(IndexModel(text_keys, name="texts"))
110 return indexes
113def _is_union(field_type: Any):
114 return get_origin(field_type) is UnionType or get_origin(field_type) is Union
117def _has_a_list_in_union(field_type: Any):
118 for ft in get_args(field_type):
119 if get_origin(ft) is list:
120 return ft
123def _union_collector_info(field, args):
124 args = get_args(args)
125 has_any_model_field = False
126 for arg in args:
127 if hasattr(arg, "model_fields"):
128 has_any_model_field = True
129 by_reference = (Id in args) and (field != "id") and has_any_model_field
130 if by_reference:
131 field_type_index = 0
132 for arg in args:
133 if hasattr(arg, "model_fields"):
134 break
135 field_type_index += 1
136 field_type = args[field_type_index]
137 else:
138 field_type = args[0]
139 return field_type, by_reference
142def field_annotation_infos(field, field_info) -> DbField:
143 """
144 Extracts and constructs database field metadata from model field annotations.
146 Args:
147 field (str): The name of the field in the model.
148 field_info (Field): Pydantic model field information containing metadata.
150 Returns:
151 DbField: A DbField instance encapsulating the database-related metadata of the model field.
153 Description:
154 This method processes a model field's annotation to determine its database-related properties,
155 such as whether it's a list, reference, or contains sub-model fields. It constructs and returns
156 a DbField object that represents this metadata.
157 """
158 field_annotation = field_info.annotation
159 by_reference = False
160 field_type = field_annotation
161 if _is_union(field_type=field_annotation):
162 has_a_list_in_union = _has_a_list_in_union(field_type=field_annotation)
163 if has_a_list_in_union is not None:
164 field_annotation = has_a_list_in_union
165 is_list = get_origin(field_annotation) is list
166 if is_list:
167 args = get_args(field_annotation)[0]
168 is_union = _is_union(args)
169 if is_union:
170 field_type, by_reference = _union_collector_info(field=field, args=args)
171 else:
172 field_type = args
173 elif _is_union(field_annotation):
174 field_type, by_reference = _union_collector_info(
175 field=field, args=field_annotation
176 )
177 has_model_fields = hasattr(field_type, "model_fields")
178 field_name = field
179 field_alias = field_info.alias or field
180 if field_name == "id":
181 field_alias = "_id"
182 return DbField(
183 field_name=field_name,
184 field_alias=field_alias,
185 field_type=field_type,
186 by_reference=by_reference,
187 is_list=is_list,
188 has_model_fields=has_model_fields,
189 )
192def _recursice_db_fields_info(db_field_info: DbField, path: list) -> DbField:
193 if db_field_info.has_model_fields:
194 for field, field_info in db_field_info.field_type.model_fields.items():
195 rec_db_field_info = field_annotation_infos(
196 field=field, field_info=field_info
197 )
198 path.append(rec_db_field_info.field_alias)
199 path_str = ".".join(path)
200 rec_db_field_info.path_str = path_str
201 rec_db_field_info = _recursice_db_fields_info(
202 db_field_info=rec_db_field_info, path=path
203 )
204 setattr(db_field_info, field, rec_db_field_info)
205 path.pop(-1)
206 return db_field_info
209def resolve_class_fields_db_info(cls: BaseModel):
210 """
211 Resolves and sets database field information for all fields of the specified class.
213 Args:
214 cls (BaseModel): The class whose fields are to be resolved.
216 Description:
217 This method iterates over each field of the given class, resolves its database field information,
218 and sets additional properties directly on the class to facilitate database operations. This includes
219 constructing nested DbField structures for complex models.
220 """
221 for field, field_info in cls.model_fields.items():
222 db_field_info = field_annotation_infos(field=field, field_info=field_info)
223 path = db_field_info.field_alias
224 db_field_info.path_str = path
225 field_to_set = _recursice_db_fields_info(
226 db_field_info=db_field_info, path=[path]
227 )
228 setattr(cls, field + "__pyodmongo", field_to_set)
231# TODO finish resolve_single_db_field
232# def resolve_single_db_field(value: Any) -> DbField:
233# by_reference = False
234# return value
237# def resolve_db_fields(bases: tuple[Any], db_fields: dict):
238# for base in bases:
239# if base is object:
240# continue
241# base_annotations = base.__dict__.get("__annotations__")
242# for base_field, value in base_annotations.items():
243# if base_field not in db_fields.keys() and not base_field.startswith("_"):
244# db_fields[base_field] = resolve_single_db_field(value=value)
245# resolve_db_fields(bases=base.__bases__, db_fields=db_fields)
246# return db_fields