Strawberry GraphQL and FastAPI. So it turns out that Pydantic is still not needed?
The essence of these methods is to generate SQL
request in such a way that we get from the database only what the client requested. Thus, we kind of save on the amount of data received from the database and kind of get the coveted optimization, unlike Rest API
. I added these functions in such a way that we could get the database objects and their relationships to an infinite depth thanks to recursion. Now you can see the same functions, only slightly enlarged:
def flatten(items):
if not items:
return items
if isinstance(items[0], list):
return flatten(items[0]) + flatten(items[1:])
return items[:1] + flatten(items[1:])
def get_relation_options(relation: dict, prev_sql=None):
key, val = next(iter(relation.items()))
fields = val['fields']
relations = val['relations']
if prev_sql:
sql = prev_sql.joinedload(key).load_only(*fields)
else:
sql = joinedload(key).load_only(*fields)
if len(relations) == 0:
return sql
if len(relations) == 1:
return get_relation_options(relations[0], sql)
result = []
for i in relations:
rels = get_relation_options(i, sql)
if hasattr(rels, '__iter__'):
for r in rels:
result.append(r)
else:
result.append(rels)
return result
def get_only_selected_fields(
db_baseclass_name, # это наша SqlAlchemy модель которая является основной, от которой будем отталкиваться.
info: Info,
):
def process_items(items: list[SelectedField], db_baseclass): # В этой функции мы разбиваем наши fields и relations для дальнейшей обработки
fields, relations = [], []
for item in items:
if item.name == '__typename': # item.name - имя нашего field из GraphQL Query
continue
try:
relation_name = getattr(db_baseclass, convert_camel_case(item.name))
except AttributeError:
continue
if not len(item.selections):
fields.append(relation_name)
continue
related_class = relation_name.property.mapper.class_
relations.append({relation_name: process_items(item.selections, related_class)})
return dict(fields=fields, relations=relations)
selections = info.selected_fields[0].selections
options = process_items(selections, db_baseclass_name)
fields = [load_only(*options['fields'])] if len(options['fields']) else []
query_options = [
*fields,
*flatten([get_relation_options(i) for i in options['relations']]) # Здесь мы имеем уже отсортированные отношения
]
return select(db_baseclass_name).options(*query_options)
The code is quite unreadable, but within 5 minutes you can figure it out. It is advisable for you to play around with this code yourself and look at the output SQL
to quickly understand what it is. It is important here to clarify what your SqlAlchemy
models and all relationship
should be described in these models. But if you are too lazy to do all this, I will write a simple example. We have a query like:
{
users: {
id
name
username
email
groups {
id
name
category {
id
name
}
}
}
}
And we automatically get SqlAclhemy
query like:
select(User).options(
load_only(User.id, User.name, User.username, User.email),
joinedload(User.groups).load_only(
Group.id, Group.name
).joinedload(Group.category).load_only(
Category.id, Category.name
)
)
Further, we can already wind various filtrations on it, and everything that we need. Next, we move on to the functions that make the models turn SqlAlchemy
into schemes Strawberry
.
The first function, which also looks scary enough, is what turns SqlAlchemy
models in full dict
objects.
def get_dict_object(model):
if isinstance(model, list):
return [get_dict_object(i) for i in model]
if isinstance(model, dict):
for k, v in model.items():
if isinstance(v, list):
return {
**model,
k: [get_dict_object(i) for i in v]
}
return model
mapper = class_mapper(model.__class__)
out = {
col.key: getattr(model, col.key)
for col in mapper.columns
if col.key in model.__dict__
}
for name, relation in mapper.relationships.items():
if name not in model.__dict__:
continue
try:
related_obj = getattr(model, name)
except AttributeError:
continue
if related_obj is not None:
if relation.uselist:
out[name] = [get_dict_object(child) for child in related_obj]
else:
out[name] = get_dict_object(related_obj)
else:
out[name] = None
return out
Next comes the scariest part of the article, because this code already looks really scary (God forgive me for posting this on Habr):
def orm_to_strawberry_step(item: dict, current_strawberry_type):
annots = current_strawberry_type.__annotations__
temp = {}
for k, v in item.items():
if k not in annots.keys():
continue
current_type = annots.get(k)
if isinstance(v, str) or isinstance(v, int) or isinstance(v, float) or isinstance(v, datetime):
temp[k] = v
continue
if isinstance(v, enum.Enum):
temp[k] = strawberry.enum(v.__class__)[v.value]
continue
if isinstance(current_type, StrawberryOptional):
current_type = current_type.of_type
if isinstance(current_type, UnionType):
current_type = current_type.__args__[0]
if isinstance(current_type, StrawberryList):
current_type = current_type.of_type
if isinstance(current_type, GenericAlias):
current_type = current_type.__args__[0]
if isinstance(v, list):
temp[k] = [orm_to_strawberry_step(i, current_type) for i in item[k]]
elif isinstance(v, dict):
temp[k] = orm_to_strawberry_step(item[k], current_type)
return current_strawberry_type(**temp)
def orm_to_strawberry(input_data, strawberry_type):
if isinstance(input_data, list):
return [orm_to_strawberry_step(get_dict_object(item), strawberry_type) for item in input_data]
return orm_to_strawberry_step(get_dict_object(input_data), strawberry_type)
What is worth saying about this code is that we are giving this big function our dict
and the scheme Strawberry
into which the transformation will take place. It receives all nested models (relationships) and substitutes our dicts in each. Again, to understand the code, you have to sit on it for a bit. Thus we have deep circuits with subcircuits that are fully validated. Strawberry
and ready to be issued to the user. Finally, I will leave a simple function that partially repeats the functionality Pydantic
when we use the method .dict()
:
def _to_dict(obj):
if isinstance(obj, list) or isinstance(obj, tuple):
return [_to_dict(i) for i in obj]
if not hasattr(obj, '__dict__'):
return obj
temp = obj.__dict__
for key, value in temp.items():
if hasattr(value, '_enum_definition') or isinstance(value, bytes):
continue
elif hasattr(value, '__dict__'):
temp[key] = _to_dict(value)
elif isinstance(value, list):
temp[key] = [_to_dict(i) for i in value]
return temp
def strawberry_to_dict(
strawberry_model,
exclude_none: bool = False,
exclude: set | None = None,
):
deep_copy = copy.deepcopy(strawberry_model)
dict_obj = _to_dict(deep_copy)
result_dict = {**dict_obj}
for k, v in dict_obj.items():
if exclude:
if k in exclude:
result_dict.pop(k, None)
if exclude_none and v is None:
result_dict.pop(k, None)
return result_dict
I hope someone will read the article to the end, because I tried to put maximum benefit into the article. There is no code on GitHub and probably will not be, there are literally 7 simple functions)))