Table of contents
Open Table of contents
Description
Implement a thread-safe counter with the following methods.
ThreadSafeCounter()
constructor to initialize the variablei
.incr()
computes the result of adding one tonumber
(Main.safeCounter
in Java) by callingincrease(number)
(Main.incr()in Java) and assigns it to the variable
i`.decr()
calculatesnumber
(in Java,Main.safeCounter
) by callingdecrease(number)
(in Java,Main.decr()
) and assigning it to the variablei
.get_count()
(getCount()
in Java/C++) returns the value of the variablei
.
increase(number)
(in Java it is Main.incr()
) and decrease(number)
(in Java it is Main.decr()
) are functions we give that will add one and subtract one respectively to the passed arguments and return the calculated result.
We will open multiple threads for the operation. Eventually we will check if you have implemented addition and subtraction correctly by calling the get_count
(or getCount()
in Java/C++) method, which will be called after each test case is completed.
Example
We will run main.py (Main.java in Java, Main.cpp in C++) to read and run your operation.
ThreadSafeCounter()
decr()
incr()
decr()
incr()
incr()
Output:
1
Explanation:
We performed 3 incr() and 2 decr() operations on i and
completed i - 1 + 1 - 1 + 1 + 1 operations on i = 0, resulting in 1.
Idea
We could use a lock. In Rust mutable reference is mutually exclusive, which means that it cannot coexist with any other references to the same data, no matter whether the reference is mutable or not. This avoids data racing and ensures data freshness.
We could use a simple lock or a re-entrant lock. The latter is more useful when the same thread may need to acquire the same lock multiple times. For example, recursively exploring a graph.
Python
OP_CNT = 100_000
class Counter(ABC):
"""counter ABC"""
def __init__(self, value=0):
self.value = value
@abstractmethod
def increment(self):
pass
@abstractmethod
def decrement(self):
pass
class NCounter(Counter):
"""not thread-safe"""
def __init__(self, value: int = 0):
super().__init__(value)
# https://stackoverflow.com/questions/77096404/cant-create-race-condition-in-python-3-11-using-multiple-threads
def increment(self):
self.value += int(1)
def decrement(self):
self.value -= int(1)
class LCounter(Counter):
"""using lock"""
def __init__(self, value: int = 0):
super().__init__(value)
self.lock = threading.Lock()
def increment(self):
self.lock.acquire()
logging.debug('Acquired a lock')
try:
logging.debug('incrementing')
self.value = self.value + int(1)
finally:
logging.debug('Releasing a lock')
self.lock.release()
def decrement(self):
with self.lock:
logging.debug('Acquired lock, decrementing')
self.value -= int(1)
def inc_worker(c: Counter):
"""workers using counter"""
for _ in range(OP_CNT):
c.increment()
def dec_worker(c: Counter):
"""workers using counter"""
for _ in range(OP_CNT):
c.decrement()
class CCounter(Counter):
"""using Condition RLock, re-entrant lock"""
def __init__(self, value=0):
super().__init__(value)
self.condition = threading.Condition()
def increment(self):
with self.condition:
self.value += int(1)
def decrement(self):
with self.condition:
self.value -= int(1
Unit Test
THREAD_CNT = 10
class TestCounter(TestCase):
def setUp(self):
self.tbt1 = NCounter()
self.tbt2 = LCounter()
self.tbt3 = CCounter()
def test_unsafe_counter(self):
threads = [Thread(target=inc_worker, args=(self.tbt1,)) for _ in range(THREAD_CNT)]
for t in threads: t.start()
for t in threads: t.join()
self.assertNotEqual(self.tbt1.value, THREAD_CNT * OP_CNT)
threads = [Thread(target=dec_worker, args=(self.tbt1,)) for _ in range(THREAD_CNT)]
for t in threads: t.start()
for t in threads: t.join()
self.assertNotEqual(self.tbt1.value, - THREAD_CNT * OP_CNT)
def test_safe_counters(self):
threads = [Thread(target=inc_worker, args=(c,)) for _ in range(THREAD_CNT)
for c in [self.tbt2, self.tbt3]]
for t in threads: t.start()
for t in threads: t.join()
self.assertEqual(self.tbt2.value, THREAD_CNT * OP_CNT)
self.assertEqual(self.tbt3.value, THREAD_CNT * OP_CNT)
def test_rlock_vs_lock_timing(self):
"""python 3.12 Rlock faster 0.3 s < 0.39 s"""
start = time.time()
threads = [Thread(target=dec_worker, args=(self.tbt2,)) for _ in range(THREAD_CNT)]
for t in threads: t.start()
for t in threads: t.join()
self.assertEqual(self.tbt2.value, -THREAD_CNT * OP_CNT)
end = time.time()
print(end - start)
start = time.time()
threads = [Thread(target=dec_worker, args=(self.tbt3,)) for _ in range(THREAD_CNT)]
for t in threads: t.start()
for t in threads: t.join()
self.assertEqual(self.tbt3.value, - THREAD_CNT * OP_CNT)
end = time.time()
print(end - start)
References
- lock vs mutex vs semaphore stackoverflow
- re-entrant lock stackoverflow
- re-entrant lock java 23 doc
- python 3.12 RLock doc
- python re-entrant lock vs lock stackoverflow