In the previous article SPARK for "little ones" I wrote about Apache Spark's capabilities and features for data processing. We focused on the key functions of reading, processing, and storing data, not forgetting about code examples that will help beginners quickly get started.

In this article, we'll go deeper and look at optimization. We'll focus on basic concepts, query optimization, and joins. With examples, of course.


Why optimize Spark?

#1. Basic Optimization Concepts

Data schema
Caching and Persistence
Partition management

№2. Query optimization

Choosing the right operations
The principle of lazy evaluation
How to choose the right operations
Using map and flatMap
Using reduceByKey and groupByKey

The order of operations
The Importance of Performing Filtering and Aggregation Before Joins
Strategies for building queries from smaller tables to larger ones
Different strategies for reducing data before joining

Why optimize Spark?

There are several good reasons for this.

Reduced task completion time.

Optimization can significantly reduce the time required to perform various computational operations. This is especially important for processing large amounts of data, where even small improvements can lead to significant reductions in execution time.

Improving the efficiency of resource use.

Optimization helps to use computing resources such as CPU time and RAM more efficiently. This allows you to process more data with less effort and prevents system overload.

Improving application performance.

Optimized Spark applications run faster and with lower latency. This is especially important for real-time applications where data processing speed is critical.

Cost reduction.

Efficient use of resources also results in lower infrastructure costs, as less computing power and memory are required to process the same amounts of data.

Stability and reliability.

Optimization helps avoid performance issues such as freezes and crashes, making systems more stable and reliable in operation.


Optimized solutions are easier to scale because they use resources more efficiently and can handle increasing amounts of data without significantly increasing execution time.

Thus, Apache Spark optimization not only speeds up task execution and improves overall system performance, but also helps reduce operating costs, increase reliability and scalability of solutions.

Let's move on to concepts.

#1. Basic Optimization Concepts

This chapter includes: data schema, caching and persistence, partition management. Let's start with the data schema.

Data schema

Optimizing Apache Spark starts with effective data schema management. This is one of the key concepts that significantly impacts the performance of your applications.

Using schema instead of automatic type detection

By default, Spark can automatically detect data types when reading files such as CSV, JSON, and Parquet. This is convenient for rapid development. However, automatic data type detection (schema inference) can negatively impact performance.

That's why:

Specifying the data schema explicitly has several significant advantages.

Let's look at an example of reading a CSV file with an explicit data schema definition.

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

# Создание SparkSession
spark = SparkSession.builder.appName("SchemaExample").getOrCreate()

# Определение схемы данных
schema = StructType([
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("city", StringType(), True)

# Чтение CSV-файла с явно заданной схемой
df = spark.read.schema(schema).csv("path/to/file.csv")

# Показ первых 5 строк

In this example:

Caching and Persistence

Caching and data persistence in Spark are important performance enhancements. They allow you to store the results of intermediate calculations in memory or on disk so that the calculations do not have to be repeated.

When to use cache()and when persist()And How?

Usage cache():

Usage persist():

Difference between cache() And persist():

Storage levels:

Example using cache().

from pyspark.sql import SparkSession

# Создание SparkSession
spark = SparkSession.builder.appName("CacheExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Кэширование DataFrame

# Первое действие: подсчет строк
count = df.count()
print(f"Total count: {count}")

# Второе действие: фильтрация данных
filtered_df = df.filter(df['age'] > 30)
filtered_count = filtered_df.count()
print(f"Filtered count: {filtered_count}")

# Отключение кэширования

# Остановка SparkSession

Example using persist().

from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel

# Создание SparkSession
spark = SparkSession.builder.appName("PersistExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Персистенция DataFrame с использованием уровня MEMORY_AND_DISK

# Первое действие: подсчет строк
count = df.count()
print(f"Total count: {count}")

# Второе действие: фильтрация данных
filtered_df = df.filter(df['age'] > 30)
filtered_count = filtered_df.count()
print(f"Filtered count: {filtered_count}")

# Отключение персистенции

# Остановка SparkSession

Examples showing the difference in performance.

from pyspark.sql import SparkSession

# Создание SparkSession
spark = SparkSession.builder.appName("NoCacheExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Первое действие: подсчет строк (данные читаются заново)
count = df.count()
print(f"Total count: {count}")

# Второе действие: фильтрация данных и подсчет строк (данные читаются заново)
filtered_df = df.filter(df['age'] > 30)
filtered_count = filtered_df.count()
print(f"Filtered count: {filtered_count}")

# Остановка SparkSession

With caching.

from pyspark.sql import SparkSession

# Создание SparkSession
spark = SparkSession.builder.appName("CacheExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Кэширование DataFrame

# Первое действие: подсчет строк (данные читаются и кэшируются)
count = df.count()
print(f"Total count: {count}")

# Второе действие: фильтрация данных и подсчет строк (данные берутся из кэша)
filtered_df = df.filter(df['age'] > 30)
filtered_count = filtered_df.count()
print(f"Filtered count: {filtered_count}")

# Отключение кэширования

# Остановка SparkSession

Difference in performance.

Partition management

Managing data partitions in Apache Spark is an important part of performance optimization. Properly partitioning data allows for more efficient use of resources and faster task execution.

Splitting data into partitions

When data is loaded into Spark, it is automatically partitioned.

A partition is a logical unit of data that is processed by a single node in a Spark cluster. The default number of partitions depends on the data source and Spark configuration.

How to set up the number of partitions correctly?

Setting the number of partitions depends on the volume of data and available resources in the cluster. Basic recommendations include:

Using repartition and coalesce

Spark provides repartition and coalesce methods to change the number of partitions.

An example of splitting data into partitions during loading.

from pyspark.sql import SparkSession

# Создание SparkSession
spark = SparkSession.builder.appName("PartitioningExample").getOrCreate()

# Чтение данных из CSV-файла (по умолчанию разбивается на партиции)
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Проверка количества партиций
print(f"Number of partitions: {df.rdd.getNumPartitions()}")

An example of setting the number of partitions when loading data.

# Загрузка данных с указанием количества партиций
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True).repartition(100)

# Проверка количества партиций
print(f"Number of partitions after repartition: {df.rdd.getNumPartitions()}")


# Увеличение количества партиций до 100
df_repartitioned = df.repartition(100)

# Проверка количества партиций после repartition
print(f"Number of partitions after repartition: {df_repartitioned.rdd.getNumPartitions()}")


# Уменьшение количества партиций до 10
df_coalesced = df.coalesce(10)

# Проверка количества партиций после coalesce
print(f"Number of partitions after coalesce: {df_coalesced.rdd.getNumPartitions()}")

An example of using repartition to increase the number of partitions.

from pyspark.sql import SparkSession

# Создание SparkSession
spark = SparkSession.builder.appName("RepartitionExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Первоначальное количество партиций
print(f"Initial number of partitions: {df.rdd.getNumPartitions()}")

# Увеличение количества партиций до 100
df_repartitioned = df.repartition(100)

# Проверка количества партиций после repartition
print(f"Number of partitions after repartition: {df_repartitioned.rdd.getNumPartitions()}")

# Остановка SparkSession

An example of using coalesce to reduce the number of partitions.

from pyspark.sql import SparkSession

# Создание SparkSession
spark = SparkSession.builder.appName("CoalesceExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Первоначальное количество партиций
print(f"Initial number of partitions: {df.rdd.getNumPartitions()}")

# Уменьшение количества партиций до 10
df_coalesced = df.coalesce(10)

# Проверка количества партиций после coalesce
print(f"Number of partitions after coalesce: {df_coalesced.rdd.getNumPartitions()}")

# Остановка SparkSession

№2. Query optimization

Optimizing queries in Apache Spark involves choosing the right operations and understanding principles such as lazy evaluation. Let's look at the difference between transformations (e.g. map, filter) and actions (e.g. count, collect), and how to choose the right operations to improve performance.

Choosing the right operations

In Spark, all operations can be divided into two categories: transformations and actions.


These are operations that create a new distributed data set (RDD) from an existing one, but do not perform any computations immediately. Transformations are lazy, meaning they are not executed until the action is called.

Example of map transformation.

rdd = sc.parallelize([1, 2, 3, 4, 5])
squared_rdd = rdd.map(lambda x: x * x)

filter, flatMap, reduceByKey, groupByKey are all also transformations.


These are operations that initiate calculations and return a result. Actions require all preceding transformations to be performed.

Examples of actions are count, collect, take, saveAsTextFile.

Count action example.

count = squared_rdd.count()
print(f"Number of elements: {count}")

The principle of lazy evaluation

Lazy evaluation is a key principle of Spark, which defers the execution of transformations until an action is called. This allows Spark to optimize the execution plan by combining transformations and minimizing the number of passes over the data.


# Создание SparkSession
spark = SparkSession.builder.appName("LazyEvaluationExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Трансформации (ленивые вычисления)
filtered_df = df.filter(df['age'] > 30)
selected_df = filtered_df.select("name", "age")

# Действие (вызывает выполнение всех предыдущих трансформаций)
result = selected_df.collect()

# Печать результатов
for row in result:

# Остановка SparkSession

How to choose the right operations?

Use transformations instead of actions.

If possible, avoid calling actions that collect data on the driver (eg. collect) on large data sets, as this can lead to memory overhead. Instead, use transformations to minimize the amount of data before calling the action.

Bad example: collect on big data.

large_rdd = sc.parallelize(range(1000000))
collected_data = large_rdd.collect()  # Может привести к переполнению памяти

A good example: reducing data before collect.

filtered_rdd = large_rdd.filter(lambda x: x % 2 == 0)
small_collected_data = filtered_rdd.take(10)  # Безопаснее, так как собирается небольшой объём данных.

Using map and flatMap

Use map and flatMap to transform data:

Example of using map.

rdd = sc.parallelize(["apple", "banana", "cherry"])
length_rdd = rdd.map(lambda x: len(x))

Example of using flatMap.

words_rdd = rdd.flatMap(lambda x: x.split("a"))

Using reduceByKey and groupByKey

For data aggregation, use reduceByKey instead of groupByKey, since reduceByKey aggregates data locally on each node before sending it over the network, which reduces the amount of data transferred.

Example of using reduceByKey.

pairs = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
reduced_pairs = pairs.reduceByKey(lambda x, y: x + y)

Example of using groupByKey (less efficient).

grouped_pairs = pairs.groupByKey()
print([(x, list(y)) for x, y in grouped_pairs.collect()])

Example of query optimization.

Let's look at an example where we want to calculate the average age of users over 30 years old from a CSV file.

# Создание SparkSession
spark = SparkSession.builder.appName("QueryOptimizationExample").getOrCreate()

# Чтение данных из CSV-файла
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)

# Оптимизация запроса
# Трансформации
filtered_df = df.filter(df['age'] > 30)
age_sum = filtered_df.groupBy().sum("age")
age_count = filtered_df.groupBy().count()

# Действие: выполнение всех трансформаций и получение результата
sum_age = age_sum.collect()[0][0]
count_age = age_count.collect()[0][0]

average_age = sum_age / count_age
print(f"Average age: {average_age}")

# Остановка SparkSession

In this example, we minimize the data volume by applying filtering before aggregation and avoiding redundant actions. All transformations are combined and executed together when the action is called, allowing Spark to optimize query execution.

The order of operations

The correct order of operations in Spark has a significant impact on the performance of applications. In particular, performing filtering and aggregation before join operations and strategically building queries from smaller tables to larger ones can significantly improve the efficiency of data processing.

Why is it important to perform filtering and aggregation before joins?

An example of performing filtering and aggregation before connections.

from pyspark.sql import SparkSession

# Создание SparkSession
spark = SparkSession.builder.appName("OptimizationExample").getOrCreate()

# Пример данных
data1 = [("Alice", 34, "HR"), ("Bob", 45, "IT"), ("Charlie", 29, "HR"), ("David", 40, "Finance")]
data2 = [("HR", "Human Resources"), ("IT", "Information Technology"), ("Finance", "Financial Department")]

df1 = spark.createDataFrame(data1, ["name", "age", "dept"])
df2 = spark.createDataFrame(data2, ["dept", "department_name"])

# Фильтрация до соединения
filtered_df1 = df1.filter(df1['age'] > 30)

# Соединение
joined_df = filtered_df1.join(df2, filtered_df1.dept == df2.dept)

# Показ результатов

# Остановка SparkSession

In this example, the df1 data is first filtered to contain only records with an age greater than 30. Only then is the join with df2 performed. This sequence reduces the amount of data involved in the join and improves performance.

Strategies for building queries from smaller tables to larger ones

Why is it important?

An example of a query building strategy.

Let's look at an example where we need to perform multiple joins, starting with smaller tables.

# Создание SparkSession
spark = SparkSession.builder.appName("JoinOptimizationExample").getOrCreate()

# Пример данных
data_small = [("Alice", 34, "HR"), ("Bob", 45, "IT")]
data_medium = [("HR", "Human Resources"), ("IT", "Information Technology"), ("Finance", "Financial Department")]
data_large = [("HR", 1), ("IT", 2), ("Finance", 3), ("HR", 4), ("IT", 5), ("Finance", 6)]

df_small = spark.createDataFrame(data_small, ["name", "age", "dept"])
df_medium = spark.createDataFrame(data_medium, ["dept", "department_name"])
df_large = spark.createDataFrame(data_large, ["dept", "id"])

# Соединение маленькой таблицы с средней
joined_small_medium = df_small.join(df_medium, "dept")

# Фильтрация после первого соединения
filtered_join = joined_small_medium.filter(joined_small_medium['age'] > 30)

# Соединение с большой таблицей
final_join = filtered_join.join(df_large, "dept")

# Показ результатов

# Остановка SparkSession

In this example:

Different strategies for reducing data before joining

An example of using aggregation before joining.

# Создание SparkSession
spark = SparkSession.builder.appName("AggregationBeforeJoinExample").getOrCreate()

# Пример данных
data1 = [("Alice", 34, "HR"), ("Bob", 45, "IT"), ("Charlie", 29, "HR"), ("David", 40, "Finance"), ("Eve", 50, "IT")]
data2 = [("HR", "Human Resources"), ("IT", "Information Technology"), ("Finance", "Financial Department")]

df1 = spark.createDataFrame(data1, ["name", "age", "dept"])
df2 = spark.createDataFrame(data2, ["dept", "department_name"])

# Агрегация до соединения (средний возраст по департаментам)
agg_df1 = df1.groupBy("dept").avg("age").alias("avg_age")

# Соединение с таблицей департаментов
joined_df = agg_df1.join(df2, "dept")

# Показ результатов

# Остановка SparkSession

In this example, we first perform an aggregation to calculate the average age by department, and then join the results to the departments table. This reduces the amount of data involved in the join and improves performance.


I hope this article helped you understand the basics of Apache Spark performance optimization and answered some of your questions. Happy data mining!

