Using Kotlin and WebFlux to Run ML Tasks in Apache Spark on GPU

This is the third article on implementing a scalable system for running distributed machine learning tasks on the GPU using Java, Kotlin, Spring, and Spark. List of all articles:

  1. Options for using Java ML libraries in conjunction with Spring, Docker, Spark, Rapids, CUDA

  2. Scalable Big Data system in Kubernetes using Spark and Cassandra

  3. Using Kotlin and WebFlux to Run ML Tasks in Apache Spark on GPU

What is this article about

In the previous article, we used the Spring servlet stack (Boot 2.7.11) and JDK 8 to create a Spark Driver application.

It’s the second half of 2023, many people already use Boot 3+ (or even 3.1+) in production, a new LTS version of Java should be released very soon, and, to put it mildly, Boot 2+ and JDK8 are outdated. They were used intentionally, since for the tasks of training machine learning models on the GPU in the Spark environment, a GPU computing accelerator is part of the system. NVIDIA Rapids. Support for JDK 17 appeared only in release v23.06.0 dated 06/27/23, with its release it became possible to switch to the current LTS version of Java, and with it – to Spring Boot 3+.

This article describes the migration from Boot 2 and JDK 8 to Boot 3 and JDK 17, from Spring Web to Spring WebFlux, and finally compares the Web and WebFlux versions in terms of hardware consumption and execution speed.

JDK8, Spring boot 2.7.11 → JDK17, Spring Boot 3.1.1

To migrate, it is enough to upgrade Rapids to 23.06.0, JDK to 17, Spring Boot to 3.1.1. There are not so many nuances:

  1. Slf4j and Log4j logger conflict when using Spark: exclude spring boot starter logging from the spring boot starter web dependency:

pom.xml
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
    <version>${spring.boot.version}</version>
    <exclusions>
        <exclusion>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-logging</artifactId>
        </exclusion>
        <exclusion>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-tomcat</artifactId>
        </exclusion>
    </exclusions>
</dependency>
  1. You need to run Spark Driver on JDK 17 with the following parameters (provided for Dockerfile):

Application Dockerfile
ENV JAVA_OPTS='--add-opens=java.base/java.lang=ALL-UNNAMED \
               --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \
               --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \
               --add-opens=java.base/java.io=ALL-UNNAMED \
               --add-opens=java.base/java.net=ALL-UNNAMED \
               --add-opens=java.base/java.nio=ALL-UNNAMED \
               --add-opens=java.base/java.util=ALL-UNNAMED \
               --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \
               --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \
               --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \
               --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \
               --add-opens=java.base/sun.security.action=ALL-UNNAMED \
               --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \
               --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED'
  1. Due to the transition to Hibernate 6, when using JSOB and BYTEA fields in Postgres entities, you will have to slightly refactor Entity:

At the same time, the previously used CustomPostgresDialect is no longer needed and can be removed by replacing it with org.hibernate.dialect.PostgreSQLDialect:

application.yml
spring:
  ...
  jpa:
    database-platform: com.mlwebservice.config.CustomPostgresDialect  # <== delete
    database-platform: org.hibernate.dialect.PostgreSQLDialect        # <== add
Previously used CustomPostgresDialect
package com.mlwebservice.config

import com.vladmihalcea.hibernate.type.array.IntArrayType
import com.vladmihalcea.hibernate.type.array.StringArrayType
import com.vladmihalcea.hibernate.type.json.JsonBinaryType
import com.vladmihalcea.hibernate.type.json.JsonNodeBinaryType
import com.vladmihalcea.hibernate.type.json.JsonNodeStringType
import com.vladmihalcea.hibernate.type.json.JsonStringType
import org.hibernate.dialect.PostgreSQL10Dialect
import java.sql.Types

class CustomPostgresDialect : PostgreSQL10Dialect() {
    init {
        registerHibernateType(Types.OTHER, StringArrayType::class.qualifiedName)
        registerHibernateType(Types.OTHER, IntArrayType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonStringType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonBinaryType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonNodeBinaryType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonNodeStringType::class.qualifiedName)
    }
}

Apart from the dockerfiles and the steps to add a new version of the Rapids jar to the jar directory to send to the Spark executors and to the executor’s image, this is all that needs to be done. You can get the current version in the corresponding branch of the repository.

This could have ended, but curiosity takes its toll, and the question arose – will it work on the reactive stack, and will there be an effect?

Let’s make ML reactive: Spring Web → Spring WebFlux

Dependencies

There should initially be more changes during such a transition, but there are also nuances in the form of dependency management. So, Netty, required for Project Reactor (WebFlux), is used by Spark itself and the Cassandra driver, so they initially conflicted. Solved by specifying three dependencies at the very beginning of the list of dependencies:

pom.xml: Netty dependencies
<dependencies>
    <!-- Netty -->
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-all</artifactId>
        <version>4.1.74.Final</version>
    </dependency>
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-codec-http</artifactId>
        <version>4.1.74.Final</version>
    </dependency>
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-resolver-dns</artifactId>
        <version>4.1.74.Final</version>
    </dependency>

    <!-- Spring -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-webflux</artifactId>
        <version>${spring.boot.version}</version>
        <exclusions>
            <exclusion>
                <artifactId>log4j-to-slf4j</artifactId>
                <groupId>org.apache.logging.log4j</groupId>
            </exclusion>
        </exclusions>
    </dependency>
    ...
</dependencies>

Spring Data is also being replaced with a reactive version:

pom.xml: R2DBC and Spring Data Cassandra Reactive
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-cassandra-reactive</artifactId>
    <version>${spring.boot.version}</version>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-r2dbc</artifactId>
    <version>${spring.boot.version}</version>
</dependency>
<dependency>
    <groupId>io.r2dbc</groupId>
    <artifactId>r2dbc-postgresql</artifactId>
    <version>0.8.13.RELEASE</version>
</dependency>

And several libraries are added for Kotlin to work in the WebFlux environment:

pom.xml: Kotlin dependencies
<dependency>
    <groupId>org.jetbrains.kotlin</groupId>
    <artifactId>kotlin-stdlib</artifactId>
    <version>${kotlin.version}</version>
</dependency>
<dependency>
    <groupId>org.jetbrains.kotlin</groupId>
    <artifactId>kotlin-reflect</artifactId>
    <version>${kotlin.version}</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>org.jetbrains.kotlinx</groupId>
    <artifactId>kotlinx-coroutines-reactor</artifactId>
    <version>1.7.2</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>io.projectreactor.kotlin</groupId>
    <artifactId>reactor-kotlin-extensions</artifactId>
    <version>1.2.2</version>
    <scope>runtime</scope>
</dependency>

By the way, Kotlin itself also upgraded from version 1.8.21 to 1.9.0.

To log HTTP requests-responses, add Zalando Logbook:

pom.xml: Zalando Logbook
<dependency>
    <groupId>org.zalando</groupId>
    <artifactId>logbook-spring-boot-autoconfigure</artifactId>
    <version>3.2.0</version>
</dependency>
<dependency>
    <groupId>org.zalando</groupId>
    <artifactId>logbook-netty</artifactId>
    <version>3.2.0</version>
</dependency>
pom.xml (full version for WebFlux)
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.mlwebservice</groupId>
    <artifactId>MLWebService</artifactId>
    <version>1.0.0-SNAPSHOT</version>

    <properties>
        <java.version>17</java.version>
        <spring.boot.version>3.1.1</spring.boot.version>
        <scala.version>2.12</scala.version>
        <spark.version>3.3.2</spark.version>
        <lombok.version>1.18.24</lombok.version>
        <org.mapstruct.version>1.4.2.Final</org.mapstruct.version>
        <kotlin.version>1.9.0</kotlin.version>
        <jackson.version>2.13.5</jackson.version>
    </properties>

    <distributionManagement>
        <repository>
            <id>XGBoost4J Snapshot Repo</id>
            <name>XGBoost4J Snapshot Repo</name>
            <url>https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/</url>
        </repository>
    </distributionManagement>

    <dependencies>
        <!-- Netty -->
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.74.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-codec-http</artifactId>
            <version>4.1.74.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-resolver-dns</artifactId>
            <version>4.1.74.Final</version>
        </dependency>

        <!-- Spring -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-webflux</artifactId>
            <version>${spring.boot.version}</version>
            <exclusions>
                <exclusion>
                    <groupId>org.springframework.boot</groupId>
                    <artifactId>spring-boot-starter-logging</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-core</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.module</groupId>
            <artifactId>jackson-module-kotlin</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-annotations</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>${jackson.version}</version>
        </dependency>

        <!-- Spring Data -->
        <dependency>
            <groupId>org.springframework.data</groupId>
            <artifactId>spring-data-commons</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-cassandra-reactive</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-r2dbc</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-jpa</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <scope>runtime</scope>
            <version>42.6.0</version>
        </dependency>
        <dependency>
            <groupId>io.r2dbc</groupId>
            <artifactId>r2dbc-postgresql</artifactId>
            <version>0.8.13.RELEASE</version>
        </dependency>
        <dependency>
            <groupId>com.vladmihalcea</groupId>
            <artifactId>hibernate-types-60</artifactId>
            <version>2.21.1</version>
        </dependency>

        <!-- Cassandra -->
        <dependency>
            <groupId>com.datastax.oss</groupId>
            <artifactId>java-driver-core</artifactId>
            <version>4.13.0</version>
        </dependency>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.12.15</version>
        </dependency>
        <dependency>
            <groupId>com.datastax.spark</groupId>
            <artifactId>spark-cassandra-connector_2.12</artifactId>
            <version>3.3.0</version>
        </dependency>

        <dependency>
            <groupId>com.typesafe</groupId>
            <artifactId>config</artifactId>
            <version>1.4.2</version>
        </dependency>

        <!-- Spark -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.antlr</groupId>
            <artifactId>antlr4-runtime</artifactId>
            <version>4.8</version>
            <scope>runtime</scope>
        </dependency>

        <!-- GXBoost -->
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-spark-gpu_${scala.version}</artifactId>
            <version>1.7.5</version>
        </dependency>
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-gpu_${scala.version}</artifactId>
            <version>1.7.5</version>
        </dependency>

        <!-- Kubernetes -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-kubernetes_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.codehaus.janino</groupId>
            <artifactId>commons-compiler</artifactId>
            <version>3.0.16</version>
        </dependency>
        <dependency>
            <groupId>org.codehaus.janino</groupId>
            <artifactId>janino</artifactId>
            <version>3.0.16</version>
        </dependency>

        <!-- Rapids -->
        <dependency>
            <groupId>com.nvidia</groupId>
            <artifactId>rapids-4-spark_${scala.version}</artifactId>
            <version>23.06.0</version>
        </dependency>

        <!-- Lombok -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>${lombok.version}</version>
        </dependency>

        <!-- Logging -->
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-spring-webflux</artifactId>
            <version>3.1.0</version>
        </dependency>
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-spring-boot-autoconfigure</artifactId>
            <version>3.2.0</version>
        </dependency>
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-netty</artifactId>
            <version>3.2.0</version>
        </dependency>

        <!-- Utils -->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
            <version>3.12.0</version>
        </dependency>

        <!-- Kotlin -->
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-stdlib</artifactId>
            <version>${kotlin.version}</version>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-reflect</artifactId>
            <version>${kotlin.version}</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlinx</groupId>
            <artifactId>kotlinx-coroutines-reactor</artifactId>
            <version>1.7.2</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>io.projectreactor.kotlin</groupId>
            <artifactId>reactor-kotlin-extensions</artifactId>
            <version>1.2.2</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlinx.spark</groupId>
            <artifactId>kotlin-spark-api_3.3.1_${scala.version}</artifactId>
            <version>1.2.3</version>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-test</artifactId>
            <version>${kotlin.version}</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <finalName>service</finalName>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <version>3.0.6</version>
                <configuration>
                    <mainClass>com.mlwebservice.MLWebServiceApplication</mainClass>
                </configuration>
                <executions>
                    <execution>
                        <goals>
                            <goal>repackage</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>

            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.11.0</version>
                <executions>
                    <execution>
                        <id>compile</id>
                        <phase>compile</phase>
                        <goals>
                            <goal>compile</goal>
                        </goals>
                    </execution>
                    <execution>
                        <id>testCompile</id>
                        <phase>test-compile</phase>
                        <goals>
                            <goal>testCompile</goal>
                        </goals>
                    </execution>
                </executions>
                <configuration>
                    <source>${java.version}</source>
                    <target>${java.version}</target>
                    <annotationProcessorPaths>
                        <path>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                            <version>${lombok.version}</version>
                        </path>
                    </annotationProcessorPaths>
                </configuration>
            </plugin>

            <plugin>
                <groupId>org.jetbrains.kotlin</groupId>
                <artifactId>kotlin-maven-plugin</artifactId>
                <version>${kotlin.version}</version>
                <executions>
                    <execution>
                        <id>compile</id>
                        <phase>process-sources</phase>
                        <goals>
                            <goal>compile</goal>
                        </goals>
                        <configuration>
                            <jvmTarget>${java.version}</jvmTarget>
                            <sourceDirs>
                                <source>src/main/java</source>
                                <source>src/main/kotlin</source>
                                <source>target/generated-sources/annotations</source>
                            </sourceDirs>
                        </configuration>
                    </execution>
                    <execution>
                        <id>test-compile</id>
                        <phase>test-compile</phase>
                        <goals>
                            <goal>test-compile</goal>
                        </goals>
                        <configuration>
                            <jvmTarget>${java.version}</jvmTarget>
                            <sourceDirs>
                                <source>src/main/java</source>
                                <source>src/main/kotlin</source>
                                <source>target/generated-sources/annotations</source>
                            </sourceDirs>
                        </configuration>
                    </execution>
                </executions>
                <configuration>
                    <jvmTarget>${java.version}</jvmTarget>
                    <sourceDirs>
                        <source>src/main/java</source>
                        <source>src/main/kotlin</source>
                        <source>target/generated-sources/annotations</source>
                    </sourceDirs>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

main class

We modify the Main class of the application, you need to add annotations @EnableWebFlux And @EnableR2dbcRepositoriesspecify application type REACTIVE

main class
package com.mlwebservice;

import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
import org.springframework.boot.autoconfigure.gson.GsonAutoConfiguration;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories;
import org.springframework.web.reactive.config.EnableWebFlux;

import java.net.InetAddress;
import java.net.UnknownHostException;

@EnableWebFlux
@EnableR2dbcRepositories
@SpringBootApplication(exclude = {
        GsonAutoConfiguration.class,
        CassandraAutoConfiguration.class
})
public class MLWebServiceApplication {
    public static void main(String[] args) {
        new SpringApplicationBuilder(MLWebServiceApplication.class)
                .web(WebApplicationType.REACTIVE)
                .run(args);
        );
    }
}

Spring Data → R2DBC

Since the database entity uses a JSONB field (with its display in the application as a JsonNode), an R2DBC configuration with custom converters is required:

jsonb converters
package com.mlwebservice.config

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import io.r2dbc.postgresql.codec.Json
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.core.convert.converter.Converter
import org.springframework.data.convert.ReadingConverter
import org.springframework.data.convert.WritingConverter
import org.springframework.data.r2dbc.convert.R2dbcCustomConversions
import org.springframework.data.r2dbc.dialect.PostgresDialect

@Configuration
open class R2dbcConfiguration(private val objectMapper: ObjectMapper) {

    @Bean
    open fun customConversions() : R2dbcCustomConversions {
        val converters = listOf<Converter<*, *>>(
            JsonNodeWritingConverter(objectMapper),
            JsonNodeReadingConverter(objectMapper)
        )
        return R2dbcCustomConversions.of(PostgresDialect.INSTANCE, converters);
    }
}

@WritingConverter
class JsonNodeWritingConverter(private val objectMapper: ObjectMapper) : Converter<JsonNode, Json> {
    override fun convert(source: JsonNode): Json {
        return Json.of(objectMapper.writeValueAsString(source));
    }
}

@ReadingConverter
class JsonNodeReadingConverter(private val objectMapper: ObjectMapper) : Converter<Json, JsonNode> {
    override fun convert(source: Json): JsonNode? {
        return objectMapper.readTree(source.asString());
    }
}

Next should be removed from the entity mentioned above ModelEntity extra annotations, the result should be:

ModelEntity
package com.mlwebservice.persist.entity

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.ObjectNode
import org.springframework.data.annotation.CreatedDate
import org.springframework.data.annotation.Id
import org.springframework.data.annotation.LastModifiedDate
import org.springframework.data.relational.core.mapping.Column
import org.springframework.data.relational.core.mapping.Table
import java.time.LocalDateTime
import java.util.*

@Table(name = "models", schema = "instrument_data")
data class ModelEntity constructor(
    @Id
    val id: Long? = null,
  
    @Column("model")
    val model: ByteArray,

    @Column("created_at")
    val createdAt: LocalDateTime,

    @Column("last_trained_at")
    val lastTrainedAt: LocalDateTime,

    @Column("task_id")
    val taskId: UUID,

    @Column("parameters")
    val parameters: JsonNode
)
// конструкторы и прочее необходимое

The entity repository itself now inherits from R2dbcRepository:

@Repository
interface ModelRepository : R2dbcRepository<ModelEntity, Long>

The save and load methods of the model are transformed to work in WebFlux:

methods for working with data models

method of loading model from database

internal inline fun <reified T> loadModel(modelId: Long): T {
    val optional = modelRepository.findById(modelId)

    val entity = optional.get()
    val modelByteArray = entity.model

    val byteArrayInputStream = ByteArrayInputStream(modelByteArray)
    val modelObject = ObjectInputStream(byteArrayInputStream).use { it.readObject() }

    if (modelObject is T) {
        return modelObject
    } else {
        throw ServiceException.withMessage("Model id $modelId has incorrect format")
    }
}

modified to:

internal inline fun <reified T> loadModel(modelId: Long): Mono<T> =
        modelRepository.findById(modelId)
            .map { modelEntity: ModelEntity ->
                ByteArrayInputStream(modelEntity.model)
            }
            .publishOn(Schedulers.boundedElastic())
            .map { byteArrayInputStream: ByteArrayInputStream ->
                ObjectInputStream(byteArrayInputStream).use { it.readObject() }
            }
            .flatMap { modelObject ->
                if (modelObject is T) {
                    Mono.just(modelObject)
                } else {
                    Mono.error(ServiceException.withMessage("Model id $modelId has incorrect format"))
                }
            }

and the save method:

fun saveModel(
    model : PredictionModel<Vector, XGBoostRegressionModel>,
    taskId : UUID,
    modelParameters : AnalyticsRequest.ModelParameters
) {
    val byteArrayOutputStream = ByteArrayOutputStream()
    ObjectOutputStream(byteArrayOutputStream).use { it.writeObject(model) }
    val modelByteArray: ByteArray = byteArrayOutputStream.toByteArray()
    val jsonParams : JsonNode = objectMapper.convertValue(modelParameters, JsonNode::class.java)

    val entity = ModelEntity(modelByteArray, taskId, jsonParams)
    modelRepository.save(entity)
    log.info("Model for task id {} saved. Parameters map: {}, jsonNode: {}",
        taskId, modelParameters, jsonParams)
}

modified to:

fun saveModel(
        model: PredictionModel<Vector, XGBoostRegressionModel>,
        taskId: UUID,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<Void> =
        Mono.fromCallable {
            val jsonParams: JsonNode = objectMapper.convertValue(modelParameters, JsonNode::class.java)

            val byteArrayOutputStream = ByteArrayOutputStream()
            ObjectOutputStream(byteArrayOutputStream).use { objectOutputStream ->
                objectOutputStream.writeObject(model)
            }
            val modelByteArray: ByteArray = byteArrayOutputStream.toByteArray()

            ModelEntity(modelByteArray, taskId, jsonParams)
        }
            .subscribeOn(Schedulers.boundedElastic())
            .flatMap { entity ->
                modelRepository.save(entity)
                    .doOnSuccess {
                        log.info(
                            "Model for task id {} saved. Parameters map: {}, jsonNode: {}",
                            taskId, modelParameters, entity.parameters.toString()
                        )
                    }
                    .then()
            }

Cassandra

Cassandra repositories were built on the basis of interaction with the spark session. Reworking methods is pretty easy. So, the method for getting a dataset in the base abstract repository:

cassandraDataset web
fun cassandraDataset(keyspace: String, table: String): Dataset<Row> {
    val cassandraDataset: Dataset<Row> = sparkSession.read()
        .format("org.apache.spark.sql.cassandra")
        .option("keyspace", keyspace)
        .option("table", table)
        .load()

    cassandraDataset.createOrReplaceTempView(table)
    return cassandraDataset
}

modified to:

cassandradataset webflux
fun cassandraDataset(keyspace: String, table: String): Mono<Dataset<Row>> =
    Mono.fromCallable {
        val cassandraDataset: Dataset<Row> = sparkSession.read()
            .format("org.apache.spark.sql.cassandra")
            .option("keyspace", keyspace)
            .option("table", table)
            .load()

        cassandraDataset.createOrReplaceTempView(table)
        cassandraDataset
    }

dataset saving method:

saveDataSet web
open fun saveDataSet(dataset: Dataset<Row>) {
    dataset.write()
        .format("org.apache.spark.sql.cassandra")
        .mode("append")
        .option("confirm.truncate", "false")
        .option("keyspace", keyspace)
        .option("table", table)
        .save();
}

modified to:

saveDataSet webflux
open fun saveDataSet(dataset: Dataset<Row>): Mono<Void> =
    Mono.fromRunnable {
        dataset.write()
            .format("org.apache.spark.sql.cassandra")
            .mode("append")
            .option("confirm.truncate", "false")
            .option("keyspace", keyspace)
            .option("table", table)
            .save()
    }

method for obtaining a base dataset with certain offsets:

getBaseDataSet web
fun getBaseDataSet(
    ticker: String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    currentOffset : Int,
    batchSize : Int
): Dataset<Row> {
    val filteredDataset = cassandraDataset(table)
        .filter(
            functions.col("ticker").equalTo(ticker)
                .and(functions.col("task_number").equalTo(taskNumber.toString()))
                .and(functions.col("datetime").between(dateStart, dateEnd))
        )

    val offsetDataset = filteredDataset.withColumn(
        "row_number",
        functions.row_number().over(orderBy("datetime"))
    )

    return offsetDataset
        .filter(functions.col("row_number")
            .between(currentOffset + 1, currentOffset + batchSize))
        .drop("row_number")
}

modified to:

getBaseDataSet webflux
fun getBaseDataSet(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate,
    currentOffset: Int,
    batchSize: Int
): Mono<Dataset<Row>> =
    cassandraDataset(table)
        .map { dataset ->
            dataset
                .filter(
                    functions.col("ticker").equalTo(ticker)
                        .and(functions.col("task_number").equalTo(taskNumber.toString()))
                        .and(functions.col("datetime").between(dateStart, dateEnd))
                ).withColumn(
                    "row_number",
                    functions.row_number().over(orderBy("datetime"))
                )
                .filter(
                    functions.col("row_number")
                        .between(currentOffset + 1, currentOffset + batchSize)
                )
                .drop("row_number")
        }

The remaining repositories of specific tables are rewritten according to the same principle.

In the data service, the method of combining datasets should be mentioned (now repositories return reactive Mono<Dataset<Row>>):

getMainDataset web
fun getMainDataset(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate
) : Dataset<Row> {
    val timeSeries = timeSeriesRepository.getDataset(ticker, taskNumber, dateStart, dateEnd).`as`("ts")
    val emaDataSet = emaRepository.getEmaDataSet(ticker, dateStart, dateEnd).`as`("ema")
    val stochasticDataset = stochasticRepository.getStochasticDataSet(ticker, dateStart, dateEnd).`as`("stoch")
    val bBandsDataset = bBandIndicatorRepository.getBBandsDataSet(ticker, dateStart, dateEnd).`as`("bb")
    val macdDataset = macdRepository.getMacdDataSet(ticker, dateStart, dateEnd).`as`("macd")
    val rsiDataset = rsiRepository.getRsiDataSet(ticker, dateStart, dateEnd).`as`("rsi")
    val smaDataset = smaRepository.getSmaDataSet(ticker, dateStart, dateEnd).`as`("sma")
    val willrDataset = willrRepository.getWillrDataSet(ticker, dateStart, dateEnd).`as`("willr")

    return combineDatasets(
        timeSeries, emaDataSet, stochasticDataset, bBandsDataset, macdDataset, rsiDataset, smaDataset, willrDataset
    )
}

modified to:

getMainDataset webflux
fun getMainDataset(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate
): Mono<Dataset<Row>> {
    val timeSeriesMono = timeSeriesRepository.getDataset(ticker, taskNumber, dateStart, dateEnd)
        .map { dataset -> dataset.alias("ts") }
    val emaDataSetMono = emaRepository.getEmaDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("ema") }
    val stochasticDatasetMono = stochasticRepository.getStochasticDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("stoch") }
    val bBandsDatasetMono = bBandIndicatorRepository.getBBandsDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("bb") }
    val macdDatasetMono = macdRepository.getMacdDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("macd") }
    val rsiDatasetMono = rsiRepository.getRsiDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("rsi") }
    val smaDatasetMono = smaRepository.getSmaDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("sma") }
    val willrDatasetMono = willrRepository.getWillrDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("willr") }

    return Mono.zip(
        timeSeriesMono, emaDataSetMono, stochasticDatasetMono, bBandsDatasetMono,
        macdDatasetMono, rsiDatasetMono, smaDatasetMono, willrDatasetMono
    ).map { tuple ->
        combineDatasets(tuple.t1, tuple.t2, tuple.t3, tuple.t4, tuple.t5, tuple.t6, tuple.t7, tuple.t8)
    }
}

here 8 datasets are obtained in Mono-wrappers, the wrappers are combined into one Mono using .zip() and passed to the dataset combination method for execution, which has not changed.

StockAnalyticsService

Predict execution method using stored model:

predictWithExistingModel web
fun predictWithExistingModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    modelId : Long
): StockPredictDto {
    val model: PredictionModel<Vector, XGBoostRegressionModel> = modelService.loadModel(modelId)
    val data = dataReaderService.getMainDataset(ticker, taskNumber, dateStart, dateEnd)

    var predictions = model.transform(data)
    predictions = predictions.select("dateTime", "prediction")
    return StockPredictDto.fromDataset(predictions)
}

modified to:

predictWithExistingModel webflux
fun predictWithExistingModel(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate,
    modelId: Long
): Mono<StockPredictDto> =
    modelService.loadModel<PredictionModel<Vector, XGBoostRegressionModel>>(modelId)
        .flatMap { model ->
            dataReaderService.getMainDataset(ticker, taskNumber, dateStart, dateEnd)
                .map { data ->
                    val predictions = model.transform(data)
                        .select("dateTime", "prediction")
                    StockPredictDto.fromDataset(predictions)
                }
        }

Model training method:

trainModel web
fun trainModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    evalPivotPoint : Long,
    offset : Long,
    modelParameters : AnalyticsRequest.ModelParameters
) : ModelTrainResultResponse {
    val pivot = dateEnd.minusDays(evalPivotPoint)

    val tdf = dataReaderService.getDatasetWithLabel(ticker, taskNumber, dateStart, pivot, offset)
    val edf = dataReaderService.getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset)
        .selectExpr(*allColumns)

    val modelParams = createModelParams(modelParameters)
    val regressor = xgBoostRegressor(modelParams)

    val model: PredictionModel<Vector, XGBoostRegressionModel> = regressor.fit(tdf)
    val predictions = model.transform(edf)

    combinedDataRepository.saveData(tdf.selectExpr(*allColumns).unionAll(edf), ticker, taskNumber)
    modelService.saveModel(model, taskNumber, modelParameters)

    val result = predictions.withColumn("error", col("prediction").minus(col(labelName)))
    return ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
}

modified to:

trainModel webflux
fun trainModel(
        ticker: String,
        taskNumber: UUID,
        dateStart: LocalDate,
        dateEnd: LocalDate,
        evalPivotPoint: Long,
        offset: Long,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<ModelTrainResultResponse> =
        Mono.just(dateEnd.minusDays(evalPivotPoint))
            .flatMap { pivot: LocalDate ->
                dataReaderService.getDatasetWithLabel(ticker, taskNumber, dateStart, pivot, offset)
                    .zipWith(dataReaderService.getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset))
            }
            .flatMap { tuple: Tuple2<Dataset<Row>, Dataset<Row>> ->
                val tdf = tuple.t1
                val edf = tuple.t2

                val modelParams = createModelParams(modelParameters)
                val regressor = xgBoostRegressor(modelParams)

                Mono.fromCallable { regressor.fit(tdf) }
                    .flatMap { model: XGBoostRegressionModel ->
                        val predictions = model.transform(edf)

                        val saveDataMono = combinedDataRepository.saveData(
                            tdf.selectExpr(*allColumns).unionAll(edf),
                            ticker,
                            taskNumber
                        )

                        modelService.saveModel(model, taskNumber, modelParameters)
                            .then(saveDataMono)
                            .thenReturn(predictions)
                    }
            }
            .map { predictions: Dataset<Row> ->
                val result = predictions.withColumn("error", col("prediction").minus(col(labelName)))
                ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
            }

here tdf and edf are wrapped in Mono so concatenated into a 2-tuple Mono<Tuple2>then in we wrap in Callable function regressor.fit(tdf)which will be executed asynchronously and return the result as model: XGBoostRegressionModel. In function flatMap it is used with the evaluating dataset to get predictions, then saved to the database using the saveModel method described above. The rest of the logic is obvious.

The incremental learning method causes the greatest difficulty (yes, the increment does not work on this model and XGBoost needs to be replaced with another model, but the goal was to transform the logic into a reactive environment and get a working example that can then be used for incremental learning of the model).

Source method:

incrementTrainModel web
fun incrementTrainModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    evalPivotPoint : Long,
    offset : Long,
    batchSize : Int,
    modelParameters : AnalyticsRequest.ModelParameters
) : ModelTrainResultResponse {
    val pivot = dateEnd.minusDays(evalPivotPoint)
    var currentBatchOffset = 0
    var i = 0

    val modelParams = createModelParams(modelParameters)
    val regressor = xgBoostRegressor(modelParams)

    var model: PredictionModel<Vector, XGBoostRegressionModel>? = null
    var predictions: Dataset<Row>? = null

    var tdf: Dataset<Row>?
    do {
        log.info("Iteration {}: currentOffset {}", i, currentBatchOffset)
        tdf = dataReaderService.getDatasetWithLabel(
            ticker, taskNumber, dateStart, pivot, offset, currentBatchOffset, batchSize
        )
        if (tdf.isEmpty) break

        model = regressor.fit(tdf)
        combinedDataRepository.saveData(tdf.selectExpr(*allColumns), ticker, taskNumber)

        currentBatchOffset += batchSize
        i++
    } while (tdf?.isEmpty == false)

    val edf = dataReaderService.getDatasetWithLabel(
        ticker, taskNumber, pivot, dateEnd, offset, 0, 100).selectExpr(*allColumns)
    if (model != null) {
        predictions = model.transform(edf)
    }
    combinedDataRepository.saveData(edf.selectExpr(*allColumns), ticker, taskNumber)
    modelService.saveModel(model!!, taskNumber, modelParameters)

    val result = predictions!!.withColumn("error", col("prediction").minus(col(labelName)))
    return ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
}

modified to:

incrementTrainModel webflux
fun incrementTrainModel(
        ticker: String,
        taskNumber: UUID,
        dateStart: LocalDate,
        dateEnd: LocalDate,
        evalPivotPoint: Long,
        offset: Long,
        batchSize: Int,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<ModelTrainResultResponse> {
        val pivot = dateEnd.minusDays(evalPivotPoint)
        var currentBatchOffset = 0
        var i = 0

        val modelParams = createModelParams(modelParameters)
        val regressor = xgBoostRegressor(modelParams)

        var model: PredictionModel<Vector, XGBoostRegressionModel>? = null
        var tdf: Dataset<Row>? = null

        return Mono.defer {
            dataReaderService.getDatasetWithLabel(
                ticker, taskNumber, dateStart, pivot, offset, currentBatchOffset, batchSize
            )
        }
            .map { dataset ->
                tdf = dataset
                log.info("Iteration {}: currentOffset {}", i, currentBatchOffset)
                if (tdf?.isEmpty == true) {
                    log.warn(
                        "tdf is empty, no more data for learning, Iteration {}: currentOffset {}",
                        i, currentBatchOffset
                    )
                    Mono.empty()
                } else {
                    model = regressor.fit(tdf)
                    log.info("model trained, Iteration {}: currentOffset {}", i, currentBatchOffset)
                    currentBatchOffset += batchSize
                    i++
                    combinedDataRepository.saveData(tdf!!.selectExpr(*allColumns), ticker, taskNumber)
                        .thenReturn(currentBatchOffset + batchSize)
                }
            }
            .repeat { tdf?.isEmpty == false }
            .then(dataReaderService
                .getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset, 0, 100)
                .flatMap { edf ->
                    log.info("Got edf")
                    combinedDataRepository.saveData(edf.selectExpr(*allColumns), ticker, taskNumber)
                        .then(modelService.saveModel(model!!, taskNumber, modelParameters))
                        .thenReturn(model!!.transform(edf))
                }.map { predictions ->
                    log.info("Predictions stage")
                    val result = predictions?.withColumn(
                        "error", col("prediction")
                            .minus(col(labelName))
                    )
                    ModelTrainResultResponse(ModelTrainResult.listFromDataset(result!!.selectExpr(*resultExp)))
                })
            .doOnError { exception ->
                log.error("Error while increment learning; taskNumber = {}", taskNumber, exception)
                ModelTrainResultResponse()
            }
    }

Unlike Java, lambda expressions in Kotlin do not require variables to be effectively final, so variables currentBatchOffseti, model and tdf may change during the main stream.

Here the dataset getting function is wrapped in Mono.defer(). The peculiarity of this approach is that the execution of the function is delayed until the subscription to this Mono. And the subscription will be repeated by the .repeat() method until the condition is met tdf?.isEmpty == false.

When the next tdf is empty, the logic in then: edf dataset will be obtained from cassandra, which will be stored in the combined data table, model predictions will also be obtained and the model itself will be saved. Then the result of the method will be prepared from the predicts. In case of an error, an empty result of the method will be returned.

Not to say that this is an ideal implementation of the method, but as an example it will do.

Detailed implementation can be seen in a separate repository branch.

Comparison of two implementations

As you know, the reactive stack differs from the servlet stack in that it often takes less resources to execute the same logic. In some cases, the speed of the algorithm execution may increase.

Testing took place according to the following methodology:

  1. The service runs in a Docker container with 4 CPUs and 4 Gb of memory, uses Spark Executor (v. 3.3.2, JDK 17), also in a Docker container that connects to a Spark Standalone master in a virtual machine. All running on a single machine running Windows 10 Pro, model training tasks run on an NVidia 4090 GPU.

  2. Within 10 minutes, method requests are made: training a new model (POST /analytics – for “1 request” for short), obtaining predictions using a saved model (GET /analytics – for “2 request” for short) and incremental training (POST /analytics/ increment – short for “3rd query”) with batch_size = 50 records, during which 12 iterations are made over 6 hundred records in Cassandra tables. The first cycle is on a “not warmed up” driver (the first requests are always executed longer), then two identical cycles, one request of each method on a “warmed up driver”, and in the fourth cycle 1, 2, 3 methods are launched simultaneously.

  3. Driver works in Spark Cluster mode, one Spark Session is used for the entire duration of the application;

  4. The initial JVM startup options are the same: initial heap size 512 MB, maximum size not specified, default GC (G1).

Resource Consumption Results:

Maximum CPU consumption

Average CPU consumption

Maximum memory consumption, Gb

Average memory consumption, Gb

Stop the world in 10 minutes quantity

spring web

3.4

1.5

4

2

4

spring webflux

3.4

1.1

1

0.5

0

With the above parameters for the servlet stack, 4 stops the world from G1 GC were observed, with one time the result of executing the predicts from the stored model was a server error.

The graph shows that memory consumption grows linearly until the moment when there is no more free space for the heap and it needs to be cleaned.

Web application resource consumption graph with JVM parameters -Xms512m

Web application resource consumption graph with JVM parameters -Xms512m

The reactive stack has a different picture: after the first requests, stable ~ 0.5 GB of memory. In terms of CPU consumption, the difference is not so big.

Resource consumption graph of a WebFlux application with JVM options -Xms512m

Resource consumption graph of a WebFlux application with JVM options -Xms512m

Request execution speed:

Comparison table Web and WebFlux application version with JVM parameters -Xms512m

Comparison table Web and WebFlux application version with JVM parameters -Xms512m

Top 5 classes by memory consumption:

Considering that the entire resulting dataset takes about 55 MB, this amount of allocated memory raises questions. Stack trace analysis showed that in most cases the source and cause is Spark and Rapids, which build a query plan, exchange data between the database, executors and the driver, prepare data arrays for loading into the GPU and subtract the result from it. After spending some time studying the issue of optimizing memory usage, I can conclude that this is the regular behavior of the system in this configuration, and you need to learn how to live with it when using the servlet stack.

The first attempts to live with this resulted in changing the JVM startup parameters for the servlet stack to the following: -Xms512m -Xmx3g -XX:GCTimeRatio=19 (a hard indication that the system can spend up to 5% of the time on garbage collection – (1 / (1 +19))) -XX:+UseZGC. Considering that the reactive stack needs an average of 512 MB of memory, and that the Z GC consumes slightly more memory than the G1 GC, the maximum heap size bar has been reduced to 3 GB.

Web application resource consumption graph with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

Web application resource consumption graph with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

CPU consumption has decreased slightly, there is a similar memory consumption, but stop the world is no longer fixed. Judging by the graph, after the completion of the POST /analytics and GET /analytics methods, the heap is cleared, but when POST /analytics/increment is completed, the heap is cleared only by the time it approaches its maximum size. There is no logic that could lead to a memory leak, the reason for such a high memory consumption remains unclear.

GC switch results in resource consumption table:

Maximum CPU consumption

Average CPU consumption

Maximum memory consumption, Gb

Average memory consumption, Gb

Stop the world quantity

spring web

3.4

1.5

3

1.5

0

spring webflux

3.4

1.1

1

0.5

0

and query execution speed:

Comparison table of Web application request execution speed with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC versus WebFlux application with G1 GC and JVM parameters -Xms512m

Comparison table of Web application request execution speed with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC versus WebFlux application with G1 GC and JVM parameters -Xms512m

It became interesting what would happen if we set the maximum heap size for the G1 GC as for the Z GC and set a hard limit on the execution time for garbage collection. In this case, it turned out that the memory is filled as before, but stop the world has become larger, since there is less available memory, and, accordingly, it fills up faster. Resource consumption remained about the same level:

Web application resource consumption graph with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC

Web application resource consumption graph with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC

Maximum CPU consumption

Average CPU consumption

Maximum memory consumption, Gb

Average memory consumption, Gb

Stop the world quantity

spring web

3.4

1.5

3

1.5

6

spring webflux

3.4

1.1

1

0.5

0

The speed of query execution has increased, but not significantly.

Comparison table of query execution speed of Web application with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC versus WebFlux application with G1 GC and JVM parameters -Xms512m

Comparison table of query execution speed of Web application with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC versus WebFlux application with G1 GC and JVM parameters -Xms512m

I also tried to use Parallel GC in the servlet stack with the parameters -Xms512m -Xmx4g -XX:GCTimeRatio=19 -XX:+UseParallelGC. The results are the worst, in 10 minutes it was possible to drive only 2 cycles. If the first two methods were executed in approximately the same time without deviations, then the incremental learning method was executed for the first time in 3min 32s, which is about 1.5 minutes worse than the average result of the servlet stack, and the second request hung and ran for 8min 10s. The results were not recorded in the tables.

Graph of resource consumption of the Web version of the application with Pasrallel GC

Graph of resource consumption of the Web version of the application with Pasrallel GC

Finally, I applied the JVM settings -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC to the WebFlux version, which turned out to be the most optimal at the moment in terms of resource consumption and request processing speed. Comparison tables of the version with default parameters and G1 GC and the version with custom JVM parameters with Z GC are below.

Resource consumption schedule:

Graph of resource consumption of WebFlux application with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

Graph of resource consumption of WebFlux application with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

Resource consumption table:

Maximum CPU consumption

Average CPU consumption

Maximum memory consumption, Gb

Average memory consumption, Gb

Stop the world quantity

Spring Webflux G1 GC

3.4

1.1

1

0.5

0

Spring Webflux Z GC

3.4

1.4

2.84

1.5

0

Query execution speed table:

Comparison table of query execution speed of WebFlux application with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC versus WebFlux application with G1 GC and JVM parameters -Xms512m

Comparison table of query execution speed of WebFlux application with JVM parameters -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC versus WebFlux application with G1 GC and JVM parameters -Xms512m

Summary tables with all versions and different values ​​of JVM configuration parameters are presented below. Based on results for WebFlux on G1 GC with one JVM setting of minimum hip size 512m.

Summary tables for resource consumption

Summary tables for resource consumption

Query Runtime Summary Tables

Query Runtime Summary Tables

Conclusion

Summing up after writing the third article on building a distributed machine learning system in Java and Kotlin, the biggest conclusion that suggests itself is that it is difficult to build such a system, there are many unknowns, you need to do a lot of research, but it would be quite possible to achieve a working solution, there would be a desire.

If it so happens that you need to perform ML tasks on the JVM technology stack, learn Python and do not do garbage, but sell an alternative system to the management Kotlin and Spring Webflux (as an alternative – Web with Z GC as an alternative), and, of course, Apache Spark would be a great choice as a basis. At the end of work on any application, it is worth checking with a profiler, since with a very high probability, with default JVM parameters, the application will not work optimally.

Another question is whether this system is efficient in terms of performance and resource consumption? Without tests on an alternative system (for example, Python + Dask), I find it difficult to objectively answer this question. Perhaps in the future I will try to raise such a system and write alternative logic in python, then there will be something to compare with and what to write another article about.

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *