Applications of Continuation-passing style in Go

In this article, we'll look at the concept of continuation-passing programming and examples of its use, explore how this style can improve the readability and maintainability of code in Go applications, and discuss potential pitfalls and limitations to provide a full understanding of how and when to use it in your development practice.

Introduction

In a normal (Direct style) call, we supply parameters to the function input and expect some value at the output. For example, the addition function:

func add(x, y int) int {
	return x + y
}

// Использование функции add
res := add(1, 2)
fmt.Println(res)

When using Continuation-passing style (CPS) a continuation function is added to the list of parameters k

func addCps(x, y int, k func(res int)) {
	k(x + y)
}

// Использование функции addCps
addCps(1, 2, func(res int) { fmt.Println(res) })

When using CPS, we get capabilities that are not available in conventional Direct style programming – now the function controls the execution flow. That is, following its internal logic, the function can run the continuation twice, save it and execute it later, or not call the continuation at all.

func addCps(x, y int, k func(res int)) {
    if x == 0 {
		// Выходим без вызова продолжения
		return
    }
	if x >= y {
		// Вызываем продолжение дважды - сначала с суммой, потом с нулем
		k(x + y)
		k(0)
	} else {
		// Однократное исполнение продолжения
		k(x + y)
	}
}

Separation of business logic and service code

Let's say we have a data structure LinkedList

type LinkedList struct {
	Head int
	Tail *LinkedList
}

Let's add a method to output its contents to the console.

func (list *LinkedList) Print() {
	// Служебный код
	for cur := list; cur != nil; cur = cur.Tail {
		// Бизнес-логика
		fmt.Printf("%d ", cur.Head)
	} 
}

Now we need to calculate, for example, the sum of the elements of a list

func (list *LinkedList) Sum() int {
	sum := 0
	// Служебный код
	for cur := list; cur != nil; cur = cur.Tail {
		// Бизнес-логика
		sum += cur.Head
	} 
	return sum
}

It seems to smell like copy-paste. And if this is a library data structure, then users will also be required to copy this entire boilerplate. It is better to encapsulate the service code, and place the business logic behind the continuation abstraction:

func (list *LinkedList) Traverse(k func(val int)) {
	for cur := list; cur != nil; cur = cur.Tail {
		k(cur.Head)
	}
}

Let's refactor the methods Print And Sum

func (list *LinkedList) Print() {
	list.Traverse(func(val int) { fmt.Printf("%d ", val) })
}

func (list *LinkedList) Sum() int {
	sum := 0
	list.Traverse(func(val int) { sum += val })
	return sum
}

Method Traverse will always go through the list from start to finish. This is not always necessary, sometimes it is necessary to terminate the iteration early when some condition occurs. Let the user himself inform when it is necessary to interrupt the algorithm, for this we will change the return type of the continuation from void on bool

Meaning true let it encode the extension of iterations, and false completion.

func (list *LinkedList) Traverse2(k func(val int) bool) {
	for cur := list; cur != nil; cur = cur.Tail {
		// Теперь проверяем, что вернуло продолжение
		keepGoing := k(cur.Head)

		if !keepGoing {
			// keepGoing == false - это сигнализирует о завершении
			break
		}
	}
}

Based on Traverse2 you can write a search method Contains(x)which will terminate the iteration of elements when the first occurrence is found:

func (list *LinkedList) Contains(x int) bool {
	found := false
	list.Traverse2(func(val int) bool {
		if val == x {
			found = true
			return false
		}
		return true
	})
	return found
}

Continuations in Go stdlib

A similar approach is used in the recently introduced Range-over-func design. [1].

The iterator for range has minimal differences from our handwritten one:

func (list *LinkedList) Iter() func(func(int) bool) {
	iterator := func(yield func(int) bool) {
		for cur := list; cur != nil; cur = cur.Tail {
			if !yield(cur.Head) {
				return
			}
		}
	}
	return iterator
}

The main difference is the method Iter() does not take a continuation as an argument, instead it is written as curried functions. Iter() can be thought of as a constructor or “builder” for the iterator, where you can put initialization code and other settings.

The iterator code itself corresponds to ours Traverse2 up to renaming k on yield. It is worth noting that the iterator expects a continuation as input, and the operator body range is not a function, but the compiler will take care of this and automatically convert the body to the form func(...) bool

Example of use Iter

// Функция для печати на консоль
printIt := func(x int) bool { println(x); return true }

// Итератор можно применить в конструкции range
for x := range list.Iter() {
    println(x)
}

// Можно вызвать напрямую
iterator := list.Iter()
// Итератор это функция, к которой нужно применить продолжение
iterator(printIt)

// Более краткая запись
list.Iter()(printIt)

Trampolines

It is known that many algorithms in programming are elegantly solved using recursion: processing tree data (json / xml / directory structure), graphs, etc. Recursion is good for everything, except for one thing – in most programming languages ​​there is a limit on the depth of the call stack, which can lead to an error. overflows stack.

There are several solutions to this problem: rewriting the algorithm without using a recursive call, converting it to tail recursion to use tail call optimization (Tail-call optimization – TCO)

The Go compiler does not support TCO, but we can manually apply a technique to optimize calls – the so-called springboardThe essence of the pattern is as follows: the trampoline accepts a function as input, upon completion of the work the function returns not the result, but the next continuation, the trampoline in the cycle executes this continuation and expects another continuation from it, and so on.

Let's look at this pattern with an example. Let's go back to a regular linked list and write an iterator in a recursive style:

func IterRec(list *LinkedList, k func(v int)) {
	if list == nil {
		// Дошли до конца списка - завершаем итерацию
		return
	}

	// Иначе вызываем продолжение на текущей голове списка
	k(list.Head)

	// Рекурсивный вызов для хвоста списка
	IterRec(list.Tail, k)
}

Of course, this function will fail with an overflow error on a large data set. But we can evaluate each subsequent step lazily. To do this, first declare a type representing lazy recursive evaluations:

// Объявляем рекурсивный тип функции
type Thunk func() Thunk

Now let's wrap the recursive call in a lazily evaluated wrapper (thunk):

func IterRec(list *LinkedList, k func(v int)) Thunk {
	if list == nil {
		// Дошли до конца списка - вернём пустое продолжение
		return nil
	}

	// Иначе вызываем продолжение k на текущей голове списка
	k(list.Head)

	// Ленивое продолжение для хвоста списка
	return func() Thunk { return IterRec(list.Tail, k) }
}

Great, now lazy calculations are in the heap, not the stack. Since it's just one call frame, its size is small.

Lazy calculations will not run by themselves, you need someone to launch them. In our case, they will be launched by a trampoline:

// Запуск отложенных вычислений
func RunTrampoline(initial Thunk) {
	thunk := initial()
	for thunk != nil {
		thunk = thunk()
	}
}

Let's run some calculation on the list.

max := list.Head
findMax := func(v int) {
	if v > max {
		max = v
	}
}
RunTrampoline(IterRec(list, findMax))
println(max)

Resource Management

A typical algorithm when working with resources is:

  • request a resource (e.g. open a file)

  • perform actions on a resource (read/write)

  • release resource (close file)

An error may occur at any step.

func writeFile(path, content string) error {
    file, err := os.OpenFile(path, os.O_CREATE | os.O_WRONLY, 0600)
    if err != nil {
        return err
    }
    defer file.Close()

    _, err = file.Write([]byte(content))
    if err != nil {
        return err
    }

    return nil
}

Not only is the code filled with low-level details, but the interface for working with file resources itself does not protect us from incorrect use – we can easily forget about close() and the resource will leak, but you will not see any compilation errors or warnings. Here we should apply linear types, but unfortunately they were not brought to go.

We've already solved the problem of separating code into system and user parts, and now we can apply a similar pattern to working with file resources.

// Файловый ресурс
type FileResource = func(cont FileContinuation) error

// Функция-продолжение для инкапсуляции бизнес-логики
type FileContinuation = func(fd *os.File) error

func WorkWithFile(path string, flags int, perm os.FileMode) FileResource {
	// Каррированный инициализатор ресурса
	return func(cont FileContinuation) error {
		// Системные вызовы
		file, err := os.OpenFile(path, flags, perm)
		if err != nil {
			return err
		}
		defer file.Close()

		// Пользовательская бизнес-логика
		err = cont(file)
		return err
	}
}

Now it is enough to test the WorkWithFile function once and then reuse it throughout the entire project, without having to think about the correctness of working with files.

func main() {
    // Файл не будет открыт здесь
    // FileResource ожидает на вход функцию-продолжение
    fileRes := WorkWithFile("./file.txt", os.O_CREATE | os.O_WRONLY, 0600)
    // ...

    // Файл откроется только здесь
    err := fileRes(myBusinessLogic)
    // А тут он уже закрыт

    // Передаем этот же ресурс в другую функцию
    // Файл повторно откроется
    err = fileRes(otherBusinessLogic)
}

func myBusinessLogic(fd *os.File) error {
    // Работаем с файловым дескриптором fd
}

Automatic commit/rollback of transactions

Transactions are another resource that needs to be handled carefully.

// Транзакционный ресурс
type TxResource = func(TxContinuation) error

// Функция-продолжение для инкапсуляции бизнес-логики
type TxContinuation = func(tx *sql.Tx) error

// Конструктор ресурса
func Transaction(db *sql.DB) TxResource {
	return func(cont func(tx *sql.Tx) error) error {
		// Стартуем транзакцию
		tx, err := db.Begin()
		if err != nil {
			return err
		}

		// Исполняем транзакционный код
		err = cont(tx)

		// Коммит или откат транзакции
		if err != nil {
			_ = tx.Rollback()
			return err
		} else {
			return tx.Commit()
		}
	}
}

Example of use

func execInTransaction(db *sql.DB) (string, error) {
    var result string
    err := Transaction(db)(func(tx *sql.Tx) error {
        res, err := tx.Query("some query")
        if err != nil {
            return err
        }
        
        result = "some calculated result"
        return nil
    })
    return result, err
}

In transactional code, if an error occurs, it is enough to simply return err and rollback will start automatically.

No more deadlocks in sync.WaitGroup

You may have seen the following example of using synchronization:

func worker(id int, wg *sync.WaitGroup) {
    fmt.Printf("Worker %d starting\n", id)
    time.Sleep(time.Second)
    fmt.Printf("Worker %d done\n", id)
    wg.Done()
}

func main() {
    var wg sync.WaitGroup
    for i := 1; i <= 5; i++ {
        wg.Add(1)
        go worker(i, &wg)
    }
    wg.Wait()
}

The operations of getting and releasing the resource (the wg counter) are scattered chaotically throughout the code. We can accidentally miss wg.Done or we specify the value incorrectly in the method wg.Add. Why does business logic need to know about wg? And we also need to remember that wg needs to be passed by reference. It is better to put all system code in a SafeWaitGroup wrapper:

type Spawner interface {
	Run(task func())
}

type SafeWaitGroup interface {
	Spawner
	Wait()
}

type safeWaitGroup struct {
    wg *sync.WaitGroup
}

func NewSafeWaitGroup() SafeWaitGroup {
    return &safeWaitGroup{new(sync.WaitGroup)}
}

func (swg *safeWaitGroup) Run(task func ()) {
    swg.wg.Add(1)
    go func() {
        task()
        swg.wg.Add(-1)
    }()
}

func (swg *safeWaitGroup) Wait() {
    swg.wg.Wait()
}

func RunGroup(taskRunner func(Spawner)) {
	swg := NewSafeWaitGroup()
	taskRunner(swg)
	swg.Wait()
}

Now the example can be rewritten in a safer form.

func worker(id int) {
	fmt.Printf("Worker %d starting\n", id)
	time.Sleep(time.Second)
	fmt.Printf("Worker %d done\n", id)
}

func main() {
	RunGroup(func(spawner Spawner) {
		for i := 1; i <= 5; i++ {
			i := i // замыкание текущего значения нужно до версии 1.22
			spawner.Run(func() { worker(i) })
		}
	})
}

Method Run() expects a lazy function and this can also be abstracted:

func Suspended[A any](arg A, k func(arg A)) func() {
	return func() {
		k(arg)
	}
}

func main() {
	RunGroup(func(spawner Spawner) {
		for i := 1; i <= 5; i++ {
			i := i // замыкание текущего значения нужно до версии 1.22
			spawner.Run(Suspended(i, worker))
		}
	})
}

Conclusion

We figured out how to use CPS to invert control flow, hide system implementation details and safely manage resources, thus obtaining reliable and readable code.

Sources

  1. Go Wiki: Rangefunc Experiment

  2. Functional programming in Golang

Similar Posts

Leave a Reply

Your email address will not be published. Required fields are marked *