Skip to content

Commit

Permalink
Merge pull request #3 from dvarelas/feature/dependent-bandit
Browse files Browse the repository at this point in the history
Initial idea of how the dependent bandit might look like
  • Loading branch information
pm3310 authored May 10, 2020
2 parents 4373736 + b996743 commit 2a46cb9
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
94 changes: 94 additions & 0 deletions examples/jester/example_dts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import random
from tqdm import tqdm
import codecs
import csv
import matplotlib.pyplot as plt
from urllib.request import urlopen

from mab_ranking.bandits.rank_bandits import IndependentBandits
from mab_ranking.bandits.bandits import DirichletThompsonSampling

N_RANKS = 10
URL = 'https://jester-jokes-data.s3.amazonaws.com/jesterfinal151cols.csv'


def get_jester_data(url):
data = []
file_stream = urlopen(url)
csv_reader = csv.reader(codecs.iterdecode(file_stream, 'utf-8'), delimiter=',')
for row in csv_reader:
data_row = []
for _item in row[1:]:
if _item in {'99', ''} or float(_item) < 7.0: # Rates above or equal to 7 are considered positive i.e. 1. Oterhwise 0.0.
data_row.append(0.0)
else:
data_row.append(1.0)
data.append(data_row)

return np.asarray(data)


def main():
jester_data = get_jester_data(URL)

filtered_data = []
for _row in jester_data:
if sum(_row) > 1: # Keep only the rows where a user has rated at least one 1 joke
filtered_data.append(_row.tolist())

data = np.asarray(filtered_data)

independent_bandits = IndependentBandits(
num_arms=data.shape[1],
num_ranks=N_RANKS, # Recommend the best 10 jokes
bandit_class=DirichletThompsonSampling
)

num_steps = 1000
sum_binary = 0.0
ctr_list = []

for i in tqdm(range(1, num_steps + 1)):
# Pick a user randomly
random_user_idx = random.randint(0, data.shape[0] - 1)

ground_truth = np.argwhere(data[random_user_idx] == 1).flatten().tolist()
n = len(ground_truth)

selected_items = independent_bandits.choose(context={'previous_action': 0})

hit_rate = len(set(ground_truth).intersection(set(selected_items))) / len(set(ground_truth))

feedback_list = [1.0 if _item in ground_truth else 0.0 for _item in selected_items]
independent_bandits.update(selected_items, feedback_list)

user_binary_relevancy = 1.0 if hit_rate > 0 else 0.0
sum_user_binary_relevancy = user_binary_relevancy
# update the parameters sequentially given the user's last rated joke
for j in range(1, n):
selected_items = independent_bandits.choose(context={'previous_action': ground_truth[j-1]})

hit_rate = len(set(ground_truth).intersection(set(selected_items))) / len(set(ground_truth))

feedback_list = [1.0 if _item in ground_truth else 0.0 for _item in selected_items]
independent_bandits.update(selected_items, feedback_list, context={'previous_action': ground_truth[j-1]})

user_binary_relevancy = 1.0 if hit_rate > 0 else 0.0
sum_user_binary_relevancy += user_binary_relevancy
sum_binary += sum_user_binary_relevancy/n
ctr_list.append(sum_binary / i)

print('CTR at the last step: ' + str(ctr_list[-1]))

def plot_ctr(num_iterations, ctr):
plt.plot(range(1, num_iterations + 1), ctr)
plt.xlabel('num_iterations', fontsize=14)
plt.ylabel('ctr', fontsize=14)
return plt

plot_ctr(len(ctr_list), ctr_list).show()


if __name__ == '__main__':
main()
26 changes: 26 additions & 0 deletions mab_ranking/bandits/bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,29 @@ def choose(self, context=None):
def update(self, arm_id, reward, context=None):
self.num_tries[arm_id] += 1
self.rewards[arm_id] += reward


class DirichletThompsonSampling(Bandit):
def __init__(self, num_arms):
super().__init__(num_arms)
self.rewards = np.asarray([[1.0]] * self.num_arms ** 2).reshape(self.num_arms, self.num_arms, 1)

@staticmethod
def _get_previous_action(context):
return context['previous_action'] if context else 0

def choose(self, context=None):
previous_action = self._get_previous_action(context)

# sample from the distribution given the previous action
probs = [np.random.gamma(self.rewards[previous_action][i][0], 1) for i in range(self.num_arms)]
sorted_indices = np.argsort(probs)[::-1][-len(probs):].tolist()

return sorted_indices[0], sorted_indices

def update(self, arm_id, reward, context=None):
previous_action = self._get_previous_action(context)

# for the time being we assume that the transitions are symmetric
self.rewards[previous_action][arm_id] += reward
self.rewards[arm_id][previous_action] += reward
44 changes: 43 additions & 1 deletion tests/bandits/test_bandits.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import Counter
from unittest import TestCase

from mab_ranking.bandits.bandits import BetaThompsonSampling
from mab_ranking.bandits.bandits import BetaThompsonSampling, DirichletThompsonSampling


class BetaThompsonSamplingTest(TestCase):
Expand Down Expand Up @@ -39,3 +39,45 @@ def test_choose_and_update(self):

chosen_arms_with_frequency = Counter(chosen_arms)
assert chosen_arms_with_frequency[0] > chosen_arms_with_frequency[1]


class DirichletThompsonSamplingTest(TestCase):
def test_choose(self):
num_arms = 3
bandit = DirichletThompsonSampling(num_arms=num_arms)
best_arm_id, sorted_arm_ids = bandit.choose(context={'previous_action': 1})

assert best_arm_id in range(num_arms)

assert len(sorted_arm_ids) == len(set(sorted_arm_ids))

for _arm_id in sorted_arm_ids:
_arm_id in range(num_arms)

def test_update(self):
bandit = DirichletThompsonSampling(num_arms=4)

bandit.update(1, reward=1.0, context={'previous_action': 2})

assert (bandit.rewards == [
[[1.], [1.], [1.], [1.]],
[[1.], [1.], [2.], [1.]],
[[1.], [2.], [1.], [1.]],
[[1.], [1.], [1.], [1.]]]).all()

def test_choose_and_update(self):
bandit = DirichletThompsonSampling(num_arms=3)

chosen_arms = []
context = None
# arm id 0 provides much more reward than arm id 1
for _ in range(500):
best_arm_id, _ = bandit.choose(context=context)
chosen_arms.append(best_arm_id)

reward = 1.0 if best_arm_id == 0 else 0.0
bandit.update(best_arm_id, reward=reward, context=context)
context = {'previous_action': 0}

chosen_arms_with_frequency = Counter(chosen_arms)
assert chosen_arms_with_frequency[0] > chosen_arms_with_frequency[1]

0 comments on commit 2a46cb9

Please sign in to comment.