Simple task queue in Django, Kandinsky 2.1 connection

Most developers sooner or later face the need to implement run queues, complex computational processes. An abundance of ready-made solutions allows you to choose exactly what you need in your current task. My task was quite simple and did not require a complex process control algorithm: one GPU performs some resource-intensive function sequentially (in this article, an example will be Kandinsky 2.1), without error handling with process restart, the function is run by the user in django. Python developer, with such a statement of the question, the Internet leads to Celery in the first place. After briefly reading the documentation, you subconsciously close the task, everything is ready, I will connect it in 15 minutes. But in the process of implementation, I encountered a problem that significantly affects the speed of the process. Every time a function is executed in Celery, the process loads the weights of the model, in case Kandinsky 2.1 weights are large. A similar issue has been discussed in stackoverflow. Attempts to get around this barrier in “celery” led to new errors. One solution leads to this article Scaling AllenNLP/PyTorch in Production, instead of Celery, the author suggests using ZeroMQ, I will work with this implementation to solve the problem. In the article, the calculations are performed in the CPU, in general I recommend reading it. ZeroMQ – will be an exchange protocol between a resource-intensive function and django, I will use Django Channels to start a background process in the queue, worker and background tasks.

Sample algorithm:

  1. The user, from the web interface, sends a request to Django.

  2. From the main Django channel, forward the received user request to the “other” fixed Django channel.

  3. Send the received request from the main channel via TCP ZeroMQ to the resource-intensive function.

  4. After executing the function, return the response over TCP ZeroMQ to the “other” Django channel.

  5. Return response to user

The javascript client side is very simple, no additional comments needed. In my example, the client-server interaction is done using WebSocket, thanks to the flexibility of Django Channels, you are not limited in the communication protocol for running background tasks.

var ws_wall = new WebSocket("ws://"+ IP_ADDR +":"+PORT+"/");

ws_wall.onmessage = function(event) {
  // что то делаем с полученным ответом
}

// отправить данные в Django
function send_wall() {
    var body = document.getElementById('id_body');
    if (body.value == "") {
        return false;
    }
    if (ws_wall.readyState != WebSocket.OPEN) {
        return false;
    }
    var data = JSON.stringify({body: body.value,
                               event: "wallpost"});
    ws_wall.send(data);
}

The server-side code snippet of the Django Channels main thread. Most practicing developers interacting with Django will not see anything new, standard code snippets from the documentation. It makes no sense to talk in detail about the structure of the orm django model, it will be clear from the code that everything is very simple.

import json
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from myapp.models import User, Post
from asgiref.sync import sync_to_async
from django.utils import dateformat

class WallHandler(AsyncJsonWebsocketConsumer):
    async def connect(self):
        """
        инициализация подключения
        """
        self.room_group_name = "wall"
        self.sender_id = self.scope['user'].id
        self.sender_name = self.scope['user']
        if str(self.scope['user']) != 'AnonymousUser':
            self.path_data = self.scope['user'].path_data
        await self.channel_layer.group_add(
            self.room_group_name,
            self.channel_name
        )
        await self.accept()
        
    async def disconnect(self, close_code):
        """
        обработка отключения
        покинуть группу
        """
        print("error code: ", close_code)
        await self.channel_layer.group_discard(
            self.room_group_name,
            self.channel_name
        )      
        
    async def receive(self, text_data):
        """
        обработка полученных данных от килиента (WebSocket)
        получить событие и отправиь соответствующее событие
        """
        response = json.loads(text_data)
        event = response.get("event", None)
        if self.scope['user'].is_authenticated: 
            if event == "wallpost":
                post = Post()
                post.body = response["body"]
                post.path_data = self.path_data
                post.user_post = self.scope['user']
                post_async = sync_to_async(post.save) # взаимодействие с синхронным Django
                await post_async()

                """
                из основного канала Django, 
                отправить полученный запрос от пользователя 
                в «другой» фиксированный канал Django
                """
                _temp_dict = {}
                _temp_dict["body"] = response["body"]
                _temp_dict["path_data"] = self.path_data
                _temp_dict["post"] = str(post.id)
                _temp_dict["type"] = "triggerWorker"
                _temp_dict["room_group_name"] = self.room_group_name
                await self.channel_layer.send('nnapp', _temp_dict) # синтаксис await не блокирует работу основного потока
                """
                сервер продолжает работу,
                отправить текущие готовые данные клиенту, 
                без окончания выполнения ресурсоёмкой функции
                """
                _data = {"type": "wallpost",
                         "timestamp": dateformat.format(post.date_post, 'U'),
                         "text":response["body"],
                         "user_post": str(self.sender_name),
                         "user_id": str(self.sender_id),
                         "id": str(post.id),
                         "status" : "wallpost"
                        }
                await self.channel_layer.group_send(self.room_group_name, _data)
    async def wallpost(self, res):
        """  
        отправить сообщение клиенту(WebSocket)
        """
        await self.send(text_data=json.dumps(res))                

IMPORTANT!!! Additional settings in the configuration file asgi.py Django project.

"""
ASGI config for app project.

It exposes the ASGI callable as a module-level variable named ``application``.

For more information on this file, see
https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
"""

import os
import django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'app.settings')
django.setup()

from django.core.asgi import get_asgi_application
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter, ChannelNameRouter
import app.routing
from wall import nnapp

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'app.settings')

application = ProtocolTypeRouter({
    "http": get_asgi_application(),
    "websocket": AuthMiddlewareStack(
        URLRouter(
            app.routing.websocket_urlpatterns,
        )
    ),
    # добавил «другой» фиксированный канал Django
    "channel": ChannelNameRouter({
        "nnapp": nnapp.NNHandler.as_asgi(),
    }),
})

I’ll create a “different” Django channel to handle the background task. To a new file nnapp.py I copy the client part from the article Scaling AllenNLP/PyTorch in Production and modify to suit my needs.

from myapp.models import User, Post
from asgiref.sync import async_to_sync
from channels.consumer import SyncConsumer
from channels.layers import get_channel_layer
import uuid
import os
import json
import zlib
import pickle
import zmq

# глобальные переменные из статьи
work_publisher = None
result_subscriber = None
TOPIC = 'snaptravel'

RECEIVE_PORT = 5555
SEND_PORT = 5556 

# получить доступ к channel layer
channel_layer = get_channel_layer()

# низкоуровневый протокол подразумивает работу с байтами
def compress(obj):
    p = pickle.dumps(obj)
    return zlib.compress(p)

def decompress(pickled):
    p = zlib.decompress(pickled)
    return pickle.loads(p)
    
def start():
    global work_publisher, result_subscriber
    context = zmq.Context()
    work_publisher = context.socket(zmq.PUB)
    work_publisher.connect(f'tcp://127.0.0.1:{SEND_PORT}') 

def _parse_recv_for_json(result, topic=TOPIC):
    compressed_json = result[len(topic) + 1:]
    return decompress(compressed_json)

def send(args, model=None, topic=TOPIC):
    id = str(uuid.uuid4())
    message = {'body': args["title"], 'model': model, 'id': id}
    compressed_message = compress(message)
    work_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
    return id

def get(id, topic=TOPIC):
    context = zmq.Context()
    result_subscriber = context.socket(zmq.SUB)
    result_subscriber.setsockopt(zmq.SUBSCRIBE, topic.encode('utf8'))
    result_subscriber.connect(f'tcp://127.0.0.1:{RECEIVE_PORT}')
    result = _parse_recv_for_json(result_subscriber.recv())
    while result['id'] != id:
        result = _parse_recv_for_json(result_subscriber.recv())
    result_subscriber.close()
    if result.get('error'):
        raise Exception(result['error_msg'])
    return result
  
# эта функция немного отличаеться от оригинала из статьи
def send_and_get(args, model=None):
    id = send(args, model=model)
    res = get(id)
    namefile = f'{id}.jpg'
    res['prediction'][0].save(f'media/data_image/{args["path_data"]}/{namefile}', format="JPEG")
    post = Post.objects.get(id=args["post"]) 
    post.image = namefile
    post.save()
    _data = {"type": "wallpost", "status":"Kandinsky-2.1", "path_data": args["path_data"],
             "data": f'{namefile}', "post":args["post"]}
    # возвращаем результат в основной канал Django Channels
    async_to_sync(channel_layer.group_send)(args["room_group_name"], _data)

# запуск фиксированого канала с клиентом ZeroMQ
class NNHandler(SyncConsumer):
    start()
    def triggerWorker(self, message):
        print ("data for a background task: ", message)
        send_and_get(message, model="Kandinsky-2.1")

ZeroMQ server with a resource-intensive function, in my case Kadinsky 2.1, for the most part completely coincides with the code of the article, with minor additions. To work with the Russian language, I use a trained translator model Helsinki-NLP/opus-mt-ru-en.

import os, time
from types import SimpleNamespace
import zmq
import zlib
import pickle
import torch.multiprocessing as mp
import threading
import cv2
from kandinsky2 import get_kandinsky2
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import uuid

QUEUE_SIZE = mp.Value('i', 0)

def compress(obj):
    p = pickle.dumps(obj)
    return zlib.compress(p)

def decompress(pickled):
    p = zlib.decompress(pickled)
    return pickle.loads(p)

TOPIC = 'snaptravel'
prediction_functions = {}

RECEIVE_PORT = 5556
SEND_PORT = 5555

# «модель» генерация картинки
model = get_kandinsky2('cuda', task_type="text2img", model_version='2.1', use_flash_attention=False)

# «модель» переводчик
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
model_translater = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")


def _parse_recv_for_json(result, topic=TOPIC):
    compressed_json = result[len(topic) + 1:]
    return decompress(compressed_json)

def _decrease_queue():
    with QUEUE_SIZE.get_lock():
        QUEUE_SIZE.value -= 1

def _increase_queue():
    with QUEUE_SIZE.get_lock():
        QUEUE_SIZE.value += 1
    
def send_prediction(message, result_publisher, topic=TOPIC):
    _increase_queue()
    model_name = message['model']
    body = message['body']
    id = message['id']
    
    # подготовка входных данных для обученной «модели» переводчик
    tokenized_text = tokenizer([str(body).lower()], return_tensors="pt")

    # перевод
    translation = model_translater.generate(**tokenized_text)
    body = tokenizer.batch_decode(translation, skip_special_tokens=True)[0]
    print(body)

    # генерация изображения
    images = model.generate_text2img(
        str(body).lower(), 
        num_steps=70,
        batch_size=1, 
        guidance_scale=4,
        h=768, w=768,
        sampler="p_sampler", 
        prior_cf_scale=4,
        prior_steps="5"
    )
    result = {"result": images}

    if result.get('result') is None:
        time.sleep(1)
        compressed_message = compress({'error': True, 'error_msg': 'No result was given: ' + str(result), 'id': id})
        result_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
        _decrease_queue()
        return
      
    prediction = result['result']
    compressed_message = compress({'prediction': prediction, 'id': id})
    result_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
    _decrease_queue()
    print ("SERVER", message, f'{topic} '.encode('utf8'))

def queue_size():
    return QUEUE_SIZE.value

def load_models():
    models = SimpleNamespace()
    return models

def start():
    global prediction_functions

    models = load_models()
    prediction_functions = {
    'queue': queue_size
    }

    print(f'Connecting to {RECEIVE_PORT} in server', TOPIC.encode('utf8'))
    context = zmq.Context()
    work_subscriber = context.socket(zmq.SUB)
    work_subscriber.setsockopt(zmq.SUBSCRIBE, TOPIC.encode('utf8'))
    work_subscriber.bind(f'tcp://127.0.0.1:{RECEIVE_PORT}')

    # send work
    print(f'Connecting to {SEND_PORT} in server')
    result_publisher = context.socket(zmq.PUB)
    result_publisher.bind(f'tcp://127.0.0.1:{SEND_PORT}')

    print('Server started')
    while True:
        message = _parse_recv_for_json(work_subscriber.recv())
        threading.Thread(target=send_prediction, args=(message, result_publisher), kwargs={'topic': TOPIC}).start()

if __name__ == '__main__':
  start()

Run django channels for background tasks, in a separate process:
python manage.py runworker nnapp

These parts of the code are from a working project of my previous article, so popular and commented that it doesn’t even need to be mentioned. Good example Build a Pytorch Server with celery and RabbitMQ, but poorly scalable and slightly more difficult to implement than discussed in this article. There were a few more ideas, but I decided to take the path of least resistance … To work with a gpu cluster, we need a more complex algorithm for tracking the load of each individual gpu, so this is a simple example based on someone else’s work. Working code https://github.com/naturalkind/social-network.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *