# Multithreading and Thread Safety

## Overview

### Topics Covered

* Threads
* Python modules
  * `threading`
  * `concurrent.futures`
  * `queue`
* Learn "the hard way" then "the easy way"
* Build a thread-safe message queue system

### Terminology

* **CPU (Central Processing Unit)**: A piece of hardware in a computer that executes binary code.
* **OS (Operating System)**: A software that schedules when programs can use the CPU.
* **Process:** A program that is being executed.
* **Thread:** Part of a process.

## Motivation

* **"Blocking"** is where a thread is stuck, waiting for a something to finish so it can complete its function.
* When single-threaded apps get **blocked**, it results in poor user experience and slower overall execution time.
* Multi-threaded apps can do more than one function "at the same time" (not really, but it appears that way).
* While one thread is blocked, other threads can continue their execution.

## The Problem with Single Thread

Consider the following code snippet:

```python
import time

def myfunc():
    print('hello')
    # time.sleep(10) simulates that the thread gets blocked, such as heavy I/O.
    # The process will be stuck here for 10 seconds.
    # During this time, we can't do anything.
    time.sleep(10)
    return True

if __name__ == '__main__':
    myfunc()
```

During `time.sleep(10)`, we should let the CPU do some other work. This introduces the idea of **multithreading**.

## Two Threads

Consider the following code snippet:

```python
import time
import threading 

def myfunc(name):
    print(f'myfunc started with {name}')
    time.sleep(10)
    print('myfunc ended')

if __name__ == '__main__':
    print('main started')
    t = threading.Thread(target=myfunc, args=['ret2basic'])
    t.start()
    print('main ended')
```

Output:

![Two threads](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2F3447GhgCdahL7iAPtCix%2Fimage.png?alt=media\&token=62137f76-4103-4d4f-8bdb-d19219293687)

This ensure that the `main` thread executes without waiting for the `myfunc`thread.

## Daemon Thread

Daemon is like a background process. The main difference between a regular thread and a daemon thread is that **the main thread will not wait for daemon threads to complete before exiting**. Consider the following code snippet:

```python
import time
import threading 

def myfunc(name):
    print(f'myfunc started with {name}')
    time.sleep(10)
    print('myfunc ended')

if __name__ == '__main__':
    print('main started')
    # Pay attention to the `daemon=True` option.
    # The main thread will not wait for a daemon thread.
    t = threading.Thread(target=myfunc, args=['ret2basic'], daemon=True)
    t.start()
    print('main ended')
```

Output:

![Daemon thread](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FdjD9vVmExzCFeE7t9ybl%2Fimage.png?alt=media\&token=4b3c4a77-1e26-4f51-a1a1-f963ccf641db)

Using daemon thread is bad in this case since the `myfunc` thread did not complete its work before the `main` thread exits.

## Joining Threads

The `join()` method to bring all your threads together before the main thread exits. From Python documentation:

![.join()](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FCp39OFtHXkjboTpGFLGK%2Fimage.png?alt=media\&token=df3a9548-c4f2-4215-8c6a-bd83900a2fb5)

Consider the following code snippet:

```python
import time
import threading 

def myfunc(name):
    print(f'myfunc started with {name}')
    time.sleep(10)
    print('myfunc ended')

if __name__ == '__main__':
    print('main started')
    t = threading.Thread(target=myfunc, args=['ret2basic'])
    t.start()
    # Pay attention to this .join() method.
    # At t.join(), the main thread is going to wait for thread `t` to finish.
    t.join()
    print('main ended')
```

Output:

![Joining threads](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2F25n2OKzh2BMv3OjrX0xv%2Fimage.png?alt=media\&token=7a22627c-3e15-4540-96b9-5ea4bfec300c)

## Multiple Threads

Consider the following code snippet:

```python
import time
import threading 

def myfunc1(name):
    print(f'myfunc1 started with {name}')
    time.sleep(10)
    print('myfunc1 ended')

def myfunc2(name):
    print(f'myfunc2 started with {name}')
    time.sleep(10)
    print('myfunc2 ended')

def myfunc3(name):
    print(f'myfunc3 started with {name}')
    time.sleep(10)
    print('myfunc3 ended')

if __name__ == '__main__':
    print('main started')
    t1 = threading.Thread(target=myfunc1, args=['ret2basic'])
    t1.start()
    t2 = threading.Thread(target=myfunc2, args=['foo'])
    t2.start()
    t3 = threading.Thread(target=myfunc3, args=['bar'])
    t3.start()
    t1.join()
    t2.join()
    t3.join()
    print('main ended')
```

Output:

![Multiple threads](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2F2Ni5LArZg8TmNbmaeUc7%2Fimage.png?alt=media\&token=3472ae83-b6c0-4998-855d-7305e04e7147)

## Thread Pool

The code from the "Multiple Threads" section can be refactored using `concurrent.futures.ThreadPoolExecutor()`. From Python documentation:

![.map()](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FyiaXffT4KKnG5ZYxTqOA%2Fimage.png?alt=media\&token=0309c5f6-c38d-45d3-bbd9-dbb770d4990a)

Consider the following code snippet:

```python
import concurrent.futures
import time

def myfunc(name):
    print(f'myfunc started with {name}')
    time.sleep(10)
    print(f'myfunc ended with {name}')

if __name__ == '__main__':
    print('main begins')
    with concurrent.futures.ThreadPoolExecutor(max_workers=3) as e:
        # This is equivalent to:
        # myfunc('ret2basic'), and
        # myfunc('foo'), and
        # myfunc('bar')
        e.map(myfunc, ['ret2basic', 'foo', 'bar'])
    print('main ended')
```

Output:

![Thread Pool Executor](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FHrtqodzbKauO7shd7NT6%2Fimage.png?alt=media\&token=72636c4c-f438-4899-87aa-f0a9b10094e3)

## Race Conditions

A **race condition** happens when more than one thread is trying to access a shared piece of data at the same time. Learn more:

{% content-ref url="../../../pwn/linux-exploitation/race-conditions" %}
[race-conditions](https://ret2basic.gitbook.io/ctfnote/pwn/linux-exploitation/race-conditions)
{% endcontent-ref %}

## Example: Bank Account Program

From Python documentation:

![.submit()](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FgcXox5FP6wRRYaZEBqMz%2Fimage.png?alt=media\&token=9a02b391-b2a0-4d17-9921-1faf4a97beb1)

Consider the following code snippet:

```python
import concurrent.futures
import time

class Account:
    def __init__(self):
        # Shared data
        self.balance = 100
    def update(self, transaction, amount):
        print(f'{transaction} thread updating...')

        local_copy = self.balance
        local_copy += amount
        time.sleep(1)
        self.balance = local_copy

        print(f'{transaction} thread finishing...')

if __name__ == '__main__':
    account = Account()
    print(f'starting with balance of {account.balance}')
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as ex:
        for transaction, amount in [('deposit', 50), ('withdrawal', -150)]:
            # This is equivalent to:
            # account.update('deposit', 50), and
            # account.update('withdrawal', -150)
            ex.submit(account.update, transaction, amount)
    print(f'ending balance of {account.balance}')
```

Output:

![Bank account program](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FmDjd8bftXtZDgGllbXog%2Fimage.png?alt=media\&token=56a39cf6-f653-41bc-8c68-d5acc82097b6)

Here the `deposit` thread created a copy of `self.balance` and the `withdrawl` thread created another copy of `self.balance`. We want the result to be 0, but the actual result is either -50 or 150, depending on which thread overwrites `self.balance` right before the program terminates. This is no good in this case, therefore we need **lock** to protect our shared data.

## Lock

Suppose we have a lock object called `self.lock,` then:

* `self.lock.acquire()`: Lock
* `self.lock.release()`: Unlock
* Or just use `with self.lock`

The code between the `acquire()` and `release()` methods are executed **atomically** so that there is no chance that a thread will read a non-updated version after another thread has already made a change.

Consider the following code snippet:

```python
import concurrent.futures
import time
import threading

class Account:
    def __init__(self):
        self.balance = 100
        self.lock = threading.Lock()

    def update(self, transaction, amount):
        print(f"{transaction} thread updating...")
        
        # Using `with` statement is easier than using acquire() and release()
        with self.lock:
            local_copy = self.balance
            local_copy += amount
            time.sleep(1)
            self.balance = local_copy

        print(f"{transaction} thread finishing...")

if __name__ == '__main__':
    account = Account()

    print(f"starting with balance of {account.balance}")

    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as ex:
        for transaction, amount in [('deposit', 50), ('withdrawal', -150)]:
            ex.submit(account.update, transaction, amount)

    print(f"ending balance of {account.balance}")
```

Output:

![Lock objects](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FjA7AdnRzEMn2zklOKISU%2Fimage.png?alt=media\&token=ea0671cf-5916-4985-84d3-cdd482b78b96)

This result is just what we want.

## Deadlock and RLock

If you wrote `lock.acquire()` and forgot to do `lock.release()`, the lock becomes a **deadlock**. For example, if we lock twice, the lock becomes a deadlock:

```python
import threading

lock = threading.Lock()
lock.acquire()
lock.acquire() # Deadlock!
```

The solution to this problem is using RLock. A **Reentrant Lock (RLock)** is a synchronization primitive that may be acquired multiple times by the same thread. Internally, it uses the concepts of "owning thread" and "recursion level" in addition to the locked/unlocked state used by primitive locks. In the locked state, some thread owns the lock; in the unlocked state, no thread owns it. For example:

```python
import threading

rlock = threading.RLock()
rlock.acquire()
rlock.acquire() # No deadlock!

rlock.release() # Locked, count=1
print(rlock)
print(threading.current_thread())

print('------------------------------')

rlock.release() # Unlocked
print(rlock)
print(threading.current_thread())
```

Output:

![RLock](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FBKUUGaljRcm1dbn5HPnT%2Fimage.png?alt=media\&token=6b694206-6f7b-4643-82c4-1a9b7386158a)

## The Producer-Consumer Pipeline

```python
import random
import concurrent.futures
import time
import threading

# EOF tag
FINISH = 'THE END'

class Pipeline:
    def __init__(self, capacity):
        # Maximum number of messages allowed in the pipeline
        self.capacity = capacity
        # Shared data
        self.message = None

        self.producer_pipeline = []
        self.consumer_pipeline = []

        self.producer_lock = threading.Lock()
        self.consumer_lock = threading.Lock()
        self.consumer_lock.acquire()

    def set_message(self, message):
        """Put a message on the pipeline."""

        print(f"producing message '{message}'")

        self.producer_pipeline.append(message)

        self.producer_lock.acquire()
        self.message = message # Shared data
        self.consumer_lock.release()

    def get_message(self):
        """Grab a message from the pipeline."""

        print(f"consuming message '{self.message}'")

        self.consumer_lock.acquire()
        message = self.message # Share data
        self.producer_lock.release()

        self.consumer_pipeline.append(message)

        return message

def producer(pipeline):
    """Put messages on the pipeline."""

    for _ in range(pipeline.capacity):
        message = random.randint(1, 100)
        pipeline.set_message(message)

    # EOF tag
    pipeline.set_message(FINISH)

def consumer(pipeline):
    """Grab messages from the pipeline."""

    message = None
    
    while message != FINISH:
        message = pipeline.get_message()
        if message != FINISH:
            # Simulate an I/O or network operation
            # random.random() returns a random number between 0 and 1
            time.sleep(random.random())

if __name__ == '__main__':
    # Create a pipeline than can hold 10 messages
    pipeline = Pipeline(10)

    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as ex:
        ex.submit(producer, pipeline)
        ex.submit(consumer, pipeline)

    print(f"producer: {pipeline.producer_pipeline}")
    print(f"consumer: {pipeline.consumer_pipeline}")
```

Output:

![Producer-consumer pipeline with Lock](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FEFTLRxoOGE8ptYw0DPrS%2Fimage.png?alt=media\&token=1782d47e-43d6-4a99-b246-15e922838e2d)

## The `queue` Module

In this section, we are going to refactor the producer-consumer pipeline using the `queue` module and threading events.

A **queue** can be declared using `queue.Queue(maxsize=0)`, where `maxsize` is is an integer that sets the upperbound limit on the number of items that can be placed in the queue. Insertion will block once this size has been reached, until queue items are consumed. If `maxsize <= 0`, the queue size is infinite. A queue supports the following operations:

* `Queue.put(item, block=True, timeout=None)`
  * **Put** `item` **into the queue.**
  * If optional args `block` is true and `timeout` is `None` (the default), block if necessary until a free slot is available.
  * If `timeout` is a positive number, it blocks at most `timeout` seconds and raises the `Full` exception if no free slot was available within that time.
  * Otherwise (`block` is false), put an item on the queue if a free slot is immediately available, else raise the `Full` exception (`timeout` is ignored in that case).
* `Queue.get(block=True, timeout=None)`
  * **Remove and return an item from the queue.**
  * If optional args `block` is true and `timeout` is `None` (the default), block if necessary until an item is available.
  * If `timeout` is a positive number, it blocks at most `timeout` seconds and raises the `Empty` exception if no item was available within that time.
  * Otherwise (`block` is false), return an item if one is immediately available, else raise the `Empty` exception (`timeout` is ignored in that case).
* `Queue.qsize()`
  * **Return the approximate size of the queue.**
  * Note, `qsize() > 0` doesn't guarantee that a subsequent `get()` will not block, nor will `qsize() < maxsize` guarantee that `put()` will not block.

**Threading event** replaces the lock mechanism. From Python documentation:

![Threading event](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2FeW800AHxJG2IG2xT0UPP%2Fimage.png?alt=media\&token=ddadf0de-d817-4b9f-a9e5-9abeb009ceff)

The `set()` method is equivalent to the adhoc `FINISH = 'THE END'` flag we invented in the producer-consumer pipeline program. Here is the refactored program:

```python
import random
import concurrent.futures
import time
import threading
import queue

class Pipeline(queue.Queue):
    def __init__(self):
        # Inherit the `queue.Queue` class.
        super().__init__(maxsize=20)

        self.producer_pipeline = []
        self.consumer_pipeline = []

    def set_message(self, message):
        """Put a message on the pipeline."""

        print(f"producing message '{message}'")
        self.producer_pipeline.append(message)
        self.put(message)

    def get_message(self):
        """Grab a message from the pipeline."""
        
        message = self.get()
        print(f"consuming message '{message}'")
        self.consumer_pipeline.append(message)

        return message

def producer(pipeline, event):
    """Put messages on the pipeline."""

    while not event.is_set():
        message = random.randint(1, 100)
        pipeline.set_message(message)

def consumer(pipeline, event):
    """Grab messages from the pipeline."""

    while not pipeline.empty() or not event.is_set():
        print(f"queue size is {pipeline.qsize()}")
        message = pipeline.get_message()
        time.sleep(random.random())

if __name__ == '__main__':

    pipeline = Pipeline()
    event = threading.Event()

    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        executor.submit(producer, pipeline, event)
        executor.submit(consumer, pipeline, event)
        time.sleep(0.5)
        # This sets the internal flag to True
        event.set()

    print(f"producer: {pipeline.producer_pipeline}")
    print(f"consumer: {pipeline.consumer_pipeline}")
```

## Semaphore Objects

Lock and RLock only allows one thread to work at a time, but sometimes we want multiple threads to work at a time. For example, allow 10 members to access the database but only 4 members are allowed to access network connection. In such case, we need semaphore.

**Semaphore** can be used to limit the access to the shared resources with limited capacity. From Python documentation:

![Semaphore](https://3988450783-files.gitbook.io/~/files/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-MWVjG_njKgBtvmnKaJh%2Fuploads%2F0DIEyg9nNPQkaKwBCK1Q%2Fimage.png?alt=media\&token=2ba0f74a-2d09-4d50-bebf-8aa368774c61)

The following code demonstrates the usage of semaphore as counter:

```python
import threading

sem = threading.Semaphore(value=50)
print(sem._value) # 50

sem.acquire()
sem.acquire()
sem.acquire()
print(sem._value) # 47

sem.release()
print(sem._value) # 48
```

Here is a sample program:

```python
import concurrent.futures
import random
import threading
import time

def welcome(semaphore, reached_max_users):
    while True:
        visitor_number = 0
        while not reached_max_users.is_set():
            print(f"welcome visitor #{visitor_number}")
            semaphore.acquire()
            visitor_number += 1
            time.sleep(random.random())

def monitor(semaphore, reached_max_users):
    while True:
        print(f"[monitor] semaphore={semaphore._value}")
        time.sleep(3)
        if semaphore._value == 0:
            reached_max_users.set()
            print('[monitor] reached max users!')
            print('[monitor] kicking a user out...')
            semaphore.release()
            time.sleep(0.05)
            reached_max_users.clear()

if __name__ == '__main__':
    number_of_users = 50
    reached_max_users = threading.Event()
    semaphore = threading.Semaphore(value=50)

    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        executor.submit(welcome, semaphore, reached_max_users)
        executor.submit(monitor, semaphore, reached_max_users)

```

## Reference

{% embed url="<https://realpython.com/courses/threading-python>" %}
Threading in Python - Real Python
{% endembed %}
