Strawberry GraphQL and FastAPI. So it turns out that Pydantic is still not needed?

one interesting file on Githubwhich prompted me to try to go further.

The essence of these methods is to generate SQL request in such a way that we get from the database only what the client requested. Thus, we kind of save on the amount of data received from the database and kind of get the coveted optimization, unlike Rest API. I added these functions in such a way that we could get the database objects and their relationships to an infinite depth thanks to recursion. Now you can see the same functions, only slightly enlarged:

def flatten(items):
    if not items:
        return items
    if isinstance(items[0], list):
        return flatten(items[0]) + flatten(items[1:])
    return items[:1] + flatten(items[1:])


def get_relation_options(relation: dict, prev_sql=None):
    key, val = next(iter(relation.items()))
    fields = val['fields']
    relations = val['relations']
    if prev_sql:
        sql = prev_sql.joinedload(key).load_only(*fields)
    else:
        sql = joinedload(key).load_only(*fields)
    if len(relations) == 0:
        return sql
    if len(relations) == 1:
        return get_relation_options(relations[0], sql)
    result = []
    for i in relations:
        rels = get_relation_options(i, sql)
        if hasattr(rels, '__iter__'):
            for r in rels:
                result.append(r)
        else:
            result.append(rels)
    return result


def get_only_selected_fields(
    db_baseclass_name,  # это наша SqlAlchemy модель которая является основной, от которой будем отталкиваться.
    info: Info, 
):
    def process_items(items: list[SelectedField], db_baseclass): # В этой функции мы разбиваем наши fields и relations для дальнейшей обработки
        fields, relations = [], []
        for item in items:
            if item.name == '__typename': # item.name - имя нашего field из GraphQL Query
                continue
            try:
                relation_name = getattr(db_baseclass, convert_camel_case(item.name))
            except AttributeError:
                continue
            if not len(item.selections):
                fields.append(relation_name)
                continue
            related_class = relation_name.property.mapper.class_
            relations.append({relation_name: process_items(item.selections, related_class)})
        return dict(fields=fields, relations=relations)

    selections = info.selected_fields[0].selections
    options = process_items(selections, db_baseclass_name)

    fields = [load_only(*options['fields'])] if len(options['fields']) else []

    query_options = [
        *fields,
        *flatten([get_relation_options(i) for i in options['relations']]) # Здесь мы имеем уже отсортированные отношения
    ]

    return select(db_baseclass_name).options(*query_options)

The code is quite unreadable, but within 5 minutes you can figure it out. It is advisable for you to play around with this code yourself and look at the output SQLto quickly understand what it is. It is important here to clarify what your SqlAlchemy models and all relationship should be described in these models. But if you are too lazy to do all this, I will write a simple example. We have a query like:

{
  users: {
    id
    name
    username
    email
    groups {
      id
      name
      category {
        id
        name
      }
    }
  }
}

And we automatically get SqlAclhemy query like:

select(User).options(
  load_only(User.id, User.name, User.username, User.email),
  joinedload(User.groups).load_only(
    Group.id, Group.name
  ).joinedload(Group.category).load_only(
    Category.id, Category.name
  )
)

Further, we can already wind various filtrations on it, and everything that we need. Next, we move on to the functions that make the models turn SqlAlchemy into schemes Strawberry.

The first function, which also looks scary enough, is what turns SqlAlchemy models in full dict objects.

def get_dict_object(model):
    if isinstance(model, list):
        return [get_dict_object(i) for i in model]
    if isinstance(model, dict):
        for k, v in model.items():
            if isinstance(v, list):
                return {
                    **model,
                    k: [get_dict_object(i) for i in v]
                }
        return model
    mapper = class_mapper(model.__class__)
    out = {
        col.key: getattr(model, col.key)
        for col in mapper.columns
        if col.key in model.__dict__
    }
    for name, relation in mapper.relationships.items():
        if name not in model.__dict__:
            continue
        try:
            related_obj = getattr(model, name)
        except AttributeError:
            continue
        if related_obj is not None:
            if relation.uselist:
                out[name] = [get_dict_object(child) for child in related_obj]
            else:
                out[name] = get_dict_object(related_obj)
        else:
            out[name] = None
    return out

Next comes the scariest part of the article, because this code already looks really scary (God forgive me for posting this on Habr):

def orm_to_strawberry_step(item: dict, current_strawberry_type):
    annots = current_strawberry_type.__annotations__
    temp = {}
    for k, v in item.items():
        if k not in annots.keys():
            continue
        current_type = annots.get(k)
        if isinstance(v, str) or isinstance(v, int) or isinstance(v, float) or isinstance(v, datetime):
            temp[k] = v
            continue
        if isinstance(v, enum.Enum):
            temp[k] = strawberry.enum(v.__class__)[v.value]
            continue
        if isinstance(current_type, StrawberryOptional):
            current_type = current_type.of_type
        if isinstance(current_type, UnionType):
            current_type = current_type.__args__[0]
        if isinstance(current_type, StrawberryList):
            current_type = current_type.of_type
        if isinstance(current_type, GenericAlias):
            current_type = current_type.__args__[0]
        if isinstance(v, list):
            temp[k] = [orm_to_strawberry_step(i, current_type) for i in item[k]]
        elif isinstance(v, dict):
            temp[k] = orm_to_strawberry_step(item[k], current_type)
    return current_strawberry_type(**temp)


def orm_to_strawberry(input_data, strawberry_type):
    if isinstance(input_data, list):
        return [orm_to_strawberry_step(get_dict_object(item), strawberry_type) for item in input_data]
    return orm_to_strawberry_step(get_dict_object(input_data), strawberry_type)

What is worth saying about this code is that we are giving this big function our dictand the scheme Strawberry into which the transformation will take place. It receives all nested models (relationships) and substitutes our dicts in each. Again, to understand the code, you have to sit on it for a bit. Thus we have deep circuits with subcircuits that are fully validated. Strawberry and ready to be issued to the user. Finally, I will leave a simple function that partially repeats the functionality Pydanticwhen we use the method .dict():

def _to_dict(obj):
    if isinstance(obj, list) or isinstance(obj, tuple):
        return [_to_dict(i) for i in obj]
    if not hasattr(obj, '__dict__'):
        return obj
    temp = obj.__dict__
    for key, value in temp.items():
        if hasattr(value, '_enum_definition') or isinstance(value, bytes):
            continue
        elif hasattr(value, '__dict__'):
            temp[key] = _to_dict(value)
        elif isinstance(value, list):
            temp[key] = [_to_dict(i) for i in value]
    return temp

def strawberry_to_dict(
    strawberry_model,
    exclude_none: bool = False,
    exclude: set | None = None,
):
    deep_copy = copy.deepcopy(strawberry_model)
    dict_obj = _to_dict(deep_copy)
    result_dict = {**dict_obj}
    for k, v in dict_obj.items():
        if exclude:
            if k in exclude:
                result_dict.pop(k, None)
        if exclude_none and v is None:
            result_dict.pop(k, None)
    return result_dict

I hope someone will read the article to the end, because I tried to put maximum benefit into the article. There is no code on GitHub and probably will not be, there are literally 7 simple functions)))

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *