diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1a5615d..c07dc28 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,3 +1,5 @@
+default_language_version:
+ python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
@@ -15,3 +17,14 @@ repos:
rev: v16.0.0
hooks:
- id: clang-format
+ files: \.(cpp|h)$
+ args: ["-i"]
+ - repo: https://github.com/google/yapf/
+ rev: v0.40.2
+ hooks:
+ - id: yapf
+ name: yapf
+ language: python
+ entry: yapf
+ args: ["--style=.style.yapf", "-i"]
+ types: [python]
diff --git a/.style.yapf b/.style.yapf
new file mode 100644
index 0000000..ecc2642
--- /dev/null
+++ b/.style.yapf
@@ -0,0 +1,397 @@
+[style]
+# Align closing bracket with visual indentation.
+align_closing_bracket_with_visual_indent=False
+
+# Allow dictionary keys to exist on multiple lines. For example:
+#
+# x = {
+# ('this is the first element of a tuple',
+# 'this is the second element of a tuple'):
+# value,
+# }
+allow_multiline_dictionary_keys=False
+
+# Allow lambdas to be formatted on more than one line.
+allow_multiline_lambdas=False
+
+# Allow splitting before a default / named assignment in an argument list.
+allow_split_before_default_or_named_assigns=True
+
+# Allow splits before the dictionary value.
+allow_split_before_dict_value=True
+
+# Let spacing indicate operator precedence. For example:
+#
+# a = 1 * 2 + 3 / 4
+# b = 1 / 2 - 3 * 4
+# c = (1 + 2) * (3 - 4)
+# d = (1 - 2) / (3 + 4)
+# e = 1 * 2 - 3
+# f = 1 + 2 + 3 + 4
+#
+# will be formatted as follows to indicate precedence:
+#
+# a = 1*2 + 3/4
+# b = 1/2 - 3*4
+# c = (1+2) * (3-4)
+# d = (1-2) / (3+4)
+# e = 1*2 - 3
+# f = 1 + 2 + 3 + 4
+#
+arithmetic_precedence_indication=False
+
+# Number of blank lines surrounding top-level function and class
+# definitions.
+blank_lines_around_top_level_definition=2
+
+# Number of blank lines between top-level imports and variable
+# definitions.
+blank_lines_between_top_level_imports_and_variables=1
+
+# Insert a blank line before a class-level docstring.
+blank_line_before_class_docstring=False
+
+# Insert a blank line before a module docstring.
+blank_line_before_module_docstring=False
+
+# Insert a blank line before a 'def' or 'class' immediately nested
+# within another 'def' or 'class'. For example:
+#
+# class Foo:
+# # <------ this blank line
+# def method():
+# ...
+blank_line_before_nested_class_or_def=True
+
+# Do not split consecutive brackets. Only relevant when
+# dedent_closing_brackets is set. For example:
+#
+# call_func_that_takes_a_dict(
+# {
+# 'key1': 'value1',
+# 'key2': 'value2',
+# }
+# )
+#
+# would reformat to:
+#
+# call_func_that_takes_a_dict({
+# 'key1': 'value1',
+# 'key2': 'value2',
+# })
+coalesce_brackets=False
+
+# The column limit.
+column_limit=120
+
+# The style for continuation alignment. Possible values are:
+#
+# - SPACE: Use spaces for continuation alignment. This is default behavior.
+# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
+# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or
+# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment.
+# - VALIGN-RIGHT: Vertically align continuation lines to multiple of
+# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if
+# cannot vertically align continuation lines with indent characters.
+continuation_align_style=SPACE
+
+# Indent width used for line continuations.
+continuation_indent_width=4
+
+# Put closing brackets on a separate line, dedented, if the bracketed
+# expression can't fit in a single line. Applies to all kinds of brackets,
+# including function definitions and calls. For example:
+#
+# config = {
+# 'key1': 'value1',
+# 'key2': 'value2',
+# } # <--- this bracket is dedented and on a separate line
+#
+# time_series = self.remote_client.query_entity_counters(
+# entity='dev3246.region1',
+# key='dns.query_latency_tcp',
+# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
+# start_ts=now()-timedelta(days=3),
+# end_ts=now(),
+# ) # <--- this bracket is dedented and on a separate line
+dedent_closing_brackets=False
+
+# Disable the heuristic which places each list element on a separate line
+# if the list is comma-terminated.
+disable_ending_comma_heuristic=False
+
+# Place each dictionary entry onto its own line.
+each_dict_entry_on_separate_line=True
+
+# Require multiline dictionary even if it would normally fit on one line.
+# For example:
+#
+# config = {
+# 'key1': 'value1'
+# }
+force_multiline_dict=False
+
+# The regex for an i18n comment. The presence of this comment stops
+# reformatting of that line, because the comments are required to be
+# next to the string they translate.
+i18n_comment=#\..*
+
+# The i18n function call names. The presence of this function stops
+# reformattting on that line, because the string it has cannot be moved
+# away from the i18n comment.
+i18n_function_call=N_, _
+
+# Indent blank lines.
+indent_blank_lines=False
+
+# Put closing brackets on a separate line, indented, if the bracketed
+# expression can't fit in a single line. Applies to all kinds of brackets,
+# including function definitions and calls. For example:
+#
+# config = {
+# 'key1': 'value1',
+# 'key2': 'value2',
+# } # <--- this bracket is indented and on a separate line
+#
+# time_series = self.remote_client.query_entity_counters(
+# entity='dev3246.region1',
+# key='dns.query_latency_tcp',
+# transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
+# start_ts=now()-timedelta(days=3),
+# end_ts=now(),
+# ) # <--- this bracket is indented and on a separate line
+indent_closing_brackets=False
+
+# Indent the dictionary value if it cannot fit on the same line as the
+# dictionary key. For example:
+#
+# config = {
+# 'key1':
+# 'value1',
+# 'key2': value1 +
+# value2,
+# }
+indent_dictionary_value=True
+
+# The number of columns to use for indentation.
+indent_width=4
+
+# Join short lines into one line. E.g., single line 'if' statements.
+join_multiple_lines=False
+
+# Do not include spaces around selected binary operators. For example:
+#
+# 1 + 2 * 3 - 4 / 5
+#
+# will be formatted as follows when configured with "*,/":
+#
+# 1 + 2*3 - 4/5
+no_spaces_around_selected_binary_operators=
+
+# Use spaces around default or named assigns.
+spaces_around_default_or_named_assign=False
+
+# Adds a space after the opening '{' and before the ending '}' dict delimiters.
+#
+# {1: 2}
+#
+# will be formatted as:
+#
+# { 1: 2 }
+spaces_around_dict_delimiters=False
+
+# Adds a space after the opening '[' and before the ending ']' list delimiters.
+#
+# [1, 2]
+#
+# will be formatted as:
+#
+# [ 1, 2 ]
+spaces_around_list_delimiters=False
+
+# Use spaces around the power operator.
+spaces_around_power_operator=False
+
+# Use spaces around the subscript / slice operator. For example:
+#
+# my_list[1 : 10 : 2]
+spaces_around_subscript_colon=False
+
+# Adds a space after the opening '(' and before the ending ')' tuple delimiters.
+#
+# (1, 2, 3)
+#
+# will be formatted as:
+#
+# ( 1, 2, 3 )
+spaces_around_tuple_delimiters=False
+
+# The number of spaces required before a trailing comment.
+# This can be a single value (representing the number of spaces
+# before each trailing comment) or list of values (representing
+# alignment column values; trailing comments within a block will
+# be aligned to the first column value that is greater than the maximum
+# line length within the block). For example:
+#
+# With spaces_before_comment=5:
+#
+# 1 + 1 # Adding values
+#
+# will be formatted as:
+#
+# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment
+#
+# With spaces_before_comment=15, 20:
+#
+# 1 + 1 # Adding values
+# two + two # More adding
+#
+# longer_statement # This is a longer statement
+# short # This is a shorter statement
+#
+# a_very_long_statement_that_extends_beyond_the_final_column # Comment
+# short # This is a shorter statement
+#
+# will be formatted as:
+#
+# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15
+# two + two # More adding
+#
+# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20
+# short # This is a shorter statement
+#
+# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length
+# short # This is a shorter statement
+#
+spaces_before_comment=2
+
+# Insert a space between the ending comma and closing bracket of a list,
+# etc.
+space_between_ending_comma_and_closing_bracket=False
+
+# Use spaces inside brackets, braces, and parentheses. For example:
+#
+# method_call( 1 )
+# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ]
+# my_set = { 1, 2, 3 }
+space_inside_brackets=False
+
+# Split before arguments
+split_all_comma_separated_values=False
+
+# Split before arguments, but do not split all subexpressions recursively
+# (unless needed).
+split_all_top_level_comma_separated_values=False
+
+# Split before arguments if the argument list is terminated by a
+# comma.
+split_arguments_when_comma_terminated=False
+
+# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@'
+# rather than after.
+split_before_arithmetic_operator=False
+
+# Set to True to prefer splitting before '&', '|' or '^' rather than
+# after.
+split_before_bitwise_operator=False
+
+# Split before the closing bracket if a list or dict literal doesn't fit on
+# a single line.
+split_before_closing_bracket=True
+
+# Split before a dictionary or set generator (comp_for). For example, note
+# the split before the 'for':
+#
+# foo = {
+# variable: 'Hello world, have a nice day!'
+# for variable in bar if variable != 42
+# }
+split_before_dict_set_generator=False
+
+# Split before the '.' if we need to split a longer expression:
+#
+# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
+#
+# would reformat to something like:
+#
+# foo = ('This is a really long string: {}, {}, {}, {}'
+# .format(a, b, c, d))
+split_before_dot=False
+
+# Split after the opening paren which surrounds an expression if it doesn't
+# fit on a single line.
+split_before_expression_after_opening_paren=False
+
+# If an argument / parameter list is going to be split, then split before
+# the first argument.
+split_before_first_argument=False
+
+# Set to True to prefer splitting before 'and' or 'or' rather than
+# after.
+split_before_logical_operator=False
+
+# Split named assignments onto individual lines.
+split_before_named_assigns=True
+
+# Set to True to split list comprehensions and generators that have
+# non-trivial expressions and multiple clauses before each of these
+# clauses. For example:
+#
+# result = [
+# a_long_var + 100 for a_long_var in xrange(1000)
+# if a_long_var % 10]
+#
+# would reformat to something like:
+#
+# result = [
+# a_long_var + 100
+# for a_long_var in xrange(1000)
+# if a_long_var % 10]
+split_complex_comprehension=True
+
+# The penalty for splitting right after the opening bracket.
+split_penalty_after_opening_bracket=300
+
+# The penalty for splitting the line after a unary operator.
+split_penalty_after_unary_operator=10000
+
+# The penalty of splitting the line around the '+', '-', '*', '/', '//',
+# ``%``, and '@' operators.
+split_penalty_arithmetic_operator=300
+
+# The penalty for splitting right before an if expression.
+split_penalty_before_if_expr=0
+
+# The penalty of splitting the line around the '&', '|', and '^'
+# operators.
+split_penalty_bitwise_operator=300
+
+# The penalty for splitting a list comprehension or generator
+# expression.
+split_penalty_comprehension=2100
+
+# The penalty for characters over the column limit.
+split_penalty_excess_character=7000
+
+# The penalty incurred by adding a line split to the unwrapped line. The
+# more line splits added the higher the penalty.
+split_penalty_for_added_line_split=30
+
+# The penalty of splitting a list of "import as" names. For example:
+#
+# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
+# long_argument_2,
+# long_argument_3)
+#
+# would reformat to something like:
+#
+# from a_very_long_or_indented_module_name_yada_yad import (
+# long_argument_1, long_argument_2, long_argument_3)
+split_penalty_import_names=0
+
+# The penalty of splitting the line around the 'and' and 'or'
+# operators.
+split_penalty_logical_operator=300
+
+# Use the Tab character for indentation.
+use_tabs=False
diff --git a/README.md b/README.md
index 2f7999c..127e256 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,9 @@ Secure multiparty computation (MPC) is a desired tool to provide privacy to the
MPC-DualDP is a distributed protocol for generating shared differential privacy noise in a two-server setting. MPC-DualDP leverages MPC to sample random noise according to specific distributions, and outputs the noise in the form of secret sharing. For more detailed information, please refer to [this folder](./mpc-dualdp).
+## AnonPSI
+The widely used ECDH-PSI, while keeping all data encrypted, discloses the size of the intersection set during protocol execution. We refer to such protocols as size-revealing PSI. AnonPSI offers a framework for systematically assessing the privacy of intersection-size-revealing PSI protocols by employing carefully designed set membership inference attacks. It enables an adversary to infer whether a targeted individual is in the intersection, which is also known as membership information. For more detailed information, please refer to [this folder](./anonpsi). AnonPSI was recently accepted for NDSS24, and we look forward to engaging in discussions during the offline sessions at NDSS.
+
## Contribution
Please check [Contributing](CONTRIBUTING.md) for more details.
diff --git a/anonpsi/README.md b/anonpsi/README.md
new file mode 100644
index 0000000..34fdb65
--- /dev/null
+++ b/anonpsi/README.md
@@ -0,0 +1,101 @@
+# AnonPSI: An Anonymity Assessment Framework for PSI
+
+## Size-revealing Private Join and Compute
+
+Private Set Intersection (PSI) is a cryptographic method that allows two parties compare encrypted versions of their sets to compute the intersection without disclosing any input information. In many applications of PSI [5,6], the objective isn't necessarily computing the intersection itself but rather functions based on the items within the intersection. Specifically, following the encrypted PSI results, downstream computation can be performed using general-purpose (MPC) protocols for statistical analysis. This process is commonly known as PSI and Compute or, alternatively, Private Join and Compute. A variety of Private Join and Compute protocols have been developed in academia and deployed in the industry. In particular, ECDH-PSI protocols are the most common in the industry due to their ease of use and capacity to handle large datasets with low communication overhead.
+
+It is worth noting that ECDH-PSI, while keeping all data encrypted, discloses the size of the intersection set during protocol execution. We refer to such protocols as **size-revealing PSI**, which also include PSI protocols constructed through other cryptographic primitives such as homomorphic encryption and oblivious pseudorandom function. The disclosure of the intersection size can introduce vulnerabilities. It enables an adversary to infer whether a targeted individual is in the intersection or not, which is also known as the membership information. Notably, individuals who intersect with the other party's set are classified as positive members, while those outside the intersection are classified as negative members. Depending on the application context, either or both of these positive/negative memberships can be sensitive from a privacy-preserving perspective [7].
+
+## [AnonPSI](https://arxiv.org/pdf/2311.18118.pdf): Assessing PSI Privacy through Set Membership Inference Attacks
+
+### What is AnonPSI?
+
+AnonPSI [1] efficiently evaluates privacy in intersection-size-revealing PSI protocols through a robust implementation of set membership inference attacks. Considering the limitation of PSI protocol runs, AnonPSI strategically selects a subset of the target dataset as PSI protocol input. It infers user membership from observed results, employing a depth-first search principle on a hierarchical structure. Each iteration's input is a subset of the previous one. AnonPSI is designed to determine the optimal input size for the current iteration, considering factors like previous input sizes, iteration outcomes, and protocol run restrictions.
+
+#### DyPathBlazer: A bottom-up dynamic programming solution
+
+DyPathBlazer optimizes the input size by maximizing the expected number of individuals' membership inferred within the total protocol run limitations. To achieve so, DyPathBlazer calculates an offline optimal input-size lookup table with three parameters: previous input size, previous positive membership number, and current protocol run budget. For more details about the algorithm, please refer to our research paper in [1] .
+
+
+#### TreeSumExplorer: An N-SUM Guided Approach for PSI-SUM
+
+Another algorithm included in AnonPSI, designed for evaluating PSI-SUM [5], is TreeSumExploer. Due to the specific architecture that releases the summation of intersected inputs, PSI-SUM introduces more vulnerabilities, as the adversary can combine both the summation and the intersection size for more effective attacks. TreeSumExplorer provides a feasible offline N-SUM solution that aids in better determining each individual's membership in the subset. Note that the optimal algorithm addressing the N-SUM problem incurs exponentially growing computational complexity with the size of the input. Effectively selecting the input size involves a trade-off between computation cost and the number of memberships inferred in each round. For more details about the algorithm, please refer to our research paper in [1] .
+
+## AnonPSI's Performance
+
+We showcase the performace of AnonPSI with two public avaiable datasets corresponding to offline setting and online setting respectively: Taobao dataset [3] and the COVID-19 dataset [4]. We compare performance with the existing attack algorithms proposed by Guo [2]. For more details of the data sets and experiments settings, please refer to our research paper in [1] .
+
+The offline attack evaluations are shown in Fig. 1, where the membership leakage percentages using different attacks are compared according to different protocols' call limitations. Further, we consider two subcases when the adversary is only interested in the identity of individuals in the intersection or those who are not in the intersection, respectively.
+
+
+
+
+
+
+
+
+
Fig. 1. Membership information leakages in PSI-SUM under different attacks. Column 1 displays the overall membership leakage, while Column 2 represents positive membership leakage, and Column 3 represents negative membership leakage. The first row corresponds to cases where the product company targets the advertising company, and the second row corresponds to cases where the advertising company targets the product company.
+
+
+
+For the COVID-19 dataset, it is assumed that Party A in PSI protocol is a local community that provides COVID testing services. Party B is the Lab that processes the testing results. It is assumed that Party B directly publishes testing results to each individual, so Party A is unclear which individual tests positive for COVID. On the other hand, Party A keeps monitoring the trend of total positive individuals in the community and is also interested in inferring the identity of each positive individual for targeted control. Party A calls for a PSI protocol with Party B daily. We evaluate different attacks for PSI protocol by comparing the total number of positive individuals they infer according to the testing timeline. The evaluations are shown in Fig. 2. Note that we slightly modified the DyPathBlazer to estimate and maximize only the expected number of positive members. The dataset is assumed to be updated daily but stays the same for one day. Party A's maximal protocol call limit is 10 times per day. Observe that our DyPathBlazer achieves significantly higher efficiency compared to the Guo et al.
+
+
+
+
+
+
Fig. 2. Membership information leakages in PSI under different attacks. Subcase (a) showing total membership leakage (positive member + negative member); (b) showing positive membership leakage only.
+
+
+## How to Use AnonPSI
+
+### Repository Structure
+
+The `src` folder contains the following files:
+- `utils.py`: defines basic functions used throughout the projects.
+- `baseline.py`: baseline algorithm from USENIX22.
+- `improved_baseline.py`: improved baseline algorithm, which leverages both positive and negative memberships.
+- **`dynpathblizer.py`**: the dypathblizer algorithm based on dynamic programming.
+- `dynpathblizer_params.py`: the memo generting functions in the dypathblizer algorithm.
+- `actbayesian.py`: statistical attack actbayesian algorithm.
+- **`treesumexpolorer.py`**: the treesumexplorer algorithm which is based on an offline N sum problem.
+- `sum_only.py`: improved baseline algorithm based on sum information only.
+- `test.py`: runs all algorithms in this repo and compare their performance under given parameters.
+
+### Requirements
+- Python3
+- Install all dependencies via `python3 -m pip install -r requirements.txt`
+
+### Compare among different algorithms
+
+Here we give a simple example to compare all algorithms' performance with synthetic datasets. You can run it via `python3 test.py`.
+If you want to perform attacks on a real dataset, please replace `victim_x` and `target_y` with your real inputs.
+
+```
+ [victim_x, target_y, _] = utils.gen_dummy_set(
+ victim_set_size, target_set_size, dummy_set_size, dense, interserction_cardinality)
+```
+
+## License
+
+AnonPSI is Apache-2.0 License licensed, as found in the [LICENSE](../LICENSE) file.
+
+## Disclaimers
+
+This software is not an officially supported product of TikTok. It is provided as-is, without any guarantees or warranties, whether express or implied.
+
+## Reference
+
+[1] Jiang, Bo et al. “AnonPSI: An Anonymity Assessment Framework for PSI.” ArXiv abs/2311.18118 (2023): n. pag. url-https://arxiv.org/pdf/2311.18118.pdf
+
+[2] Guo, Xiaojie et al. “Birds of a Feather Flock Together: How Set Bias Helps to Deanonymize You via Revealed Intersection Sizes.” USENIX Security Symposium (2022).
+
+[3] Tianchi. Taobao display advertisement click-through rate prediction dataset, from https://tianchi.aliyun.com/dataset/datadetail?dataid=56, 2018.
+
+[4] Machine learning-based prediction of covid-19 diagnosis based on symptoms, from https://github.com/nshomron/covidpred.
+
+[5] Mihaela Ion, Ben Kreuter, Ahmet Erhan Nergiz, Sarvar Patel, Shobhit Saxena, Karn Seth, Mariana Raykova, David Shanahan, and Moti Yung. On deploying secure computing: Private intersection-sum-with- cardinality. In 2020 IEEE European Symposium on Security and Privacy (EuroS& P), pages 370–389, 2020.
+
+[6] Prasad Buddhavarapu, Andrew Knox, Payman Mohassel, Shubho Sen- gupta, Erik Taubeneck, and Vlad Vlaskin. Private matching for compute. Cryptology ePrint Archive, Paper 2020/599, 2020. url- https://eprint.iacr.org/2020/599.
+
+[7] https://github.com/tiktok-privacy-innovation/PrivacyGo/tree/main/dpca-psi
diff --git a/anonpsi/images/Ads_neg.png b/anonpsi/images/Ads_neg.png
new file mode 100644
index 0000000..03a6db3
Binary files /dev/null and b/anonpsi/images/Ads_neg.png differ
diff --git a/anonpsi/images/Ads_pos.png b/anonpsi/images/Ads_pos.png
new file mode 100644
index 0000000..6ef369d
Binary files /dev/null and b/anonpsi/images/Ads_pos.png differ
diff --git a/anonpsi/images/Ads_tot.png b/anonpsi/images/Ads_tot.png
new file mode 100644
index 0000000..6f077e8
Binary files /dev/null and b/anonpsi/images/Ads_tot.png differ
diff --git a/anonpsi/images/Neg_Taobao.png b/anonpsi/images/Neg_Taobao.png
new file mode 100644
index 0000000..46f234b
Binary files /dev/null and b/anonpsi/images/Neg_Taobao.png differ
diff --git a/anonpsi/images/Online_PSI_CA.png b/anonpsi/images/Online_PSI_CA.png
new file mode 100644
index 0000000..1ab7fa4
Binary files /dev/null and b/anonpsi/images/Online_PSI_CA.png differ
diff --git a/anonpsi/images/Pos_Online_PSI_CA.png b/anonpsi/images/Pos_Online_PSI_CA.png
new file mode 100644
index 0000000..9e26480
Binary files /dev/null and b/anonpsi/images/Pos_Online_PSI_CA.png differ
diff --git a/anonpsi/images/Pos_Taobao20.png b/anonpsi/images/Pos_Taobao20.png
new file mode 100644
index 0000000..a8f473d
Binary files /dev/null and b/anonpsi/images/Pos_Taobao20.png differ
diff --git a/anonpsi/images/Tot_Taobao20.png b/anonpsi/images/Tot_Taobao20.png
new file mode 100644
index 0000000..54a6fde
Binary files /dev/null and b/anonpsi/images/Tot_Taobao20.png differ
diff --git a/anonpsi/images/model.png b/anonpsi/images/model.png
new file mode 100644
index 0000000..a907579
Binary files /dev/null and b/anonpsi/images/model.png differ
diff --git a/anonpsi/requirements.txt b/anonpsi/requirements.txt
new file mode 100644
index 0000000..4fd953d
--- /dev/null
+++ b/anonpsi/requirements.txt
@@ -0,0 +1,2 @@
+numpy
+tqdm
diff --git a/anonpsi/src/actbayesian.py b/anonpsi/src/actbayesian.py
new file mode 100644
index 0000000..cd23d61
--- /dev/null
+++ b/anonpsi/src/actbayesian.py
@@ -0,0 +1,149 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file includes actbayesian algorithm."""
+
+import random
+
+import numpy as np
+
+
+def actbayesian(tau, lower_bound, upper_bound, dataset, tolerance, laplacian_scale, sample_rate):
+ """ActBaysian algorithm.
+ Args:
+ tau: A int indicates the remaining protocol call times.
+ lower_bound: A float indicates the lower bound of the stopping criterion.
+ upper_bound: A float indicates the upper bound of the stopping criterion.
+ dataset: input data set with binary label (0,1).
+ laplacian_scale: A float indicates the noise factor of the Laplacian mechanism.
+ sample_rate: A float indicates the sampling rate.
+
+ Returns:
+ true_pos_leak: predicted postitive member set.
+ true_neg_leak: predicted negative member set.
+ pos_err: error in predicting positive member set.
+ neg_err: error in predicting negative member set.
+ """
+ laplacian_location = 0 # noise parameter
+ laplacian_noise = np.random.laplace(laplacian_location, laplacian_scale) # add laplacian noise if needed
+
+ tau = tau - 1 # use one protocol call to find the inital prior vector
+
+ # preprocess the data set to be an 0,1 vector so sum (n) = PSI (dataset)
+ n = len(dataset)
+ summation = np.sum(dataset)
+ summation = summation + laplacian_noise * summation * 1
+ summation = max(summation, 0)
+ summation = min(summation, len(dataset))
+ prior = summation / n
+
+ posterior = [prior for _ in range(n)] # inital prior vector.
+ random.shuffle(posterior)
+ tau -= 1 # protocol budgeted used
+ while 0 <= tau:
+ if max(posterior) < 1:
+ # upper_bound threshold for selecting input
+ positive_bar = max(posterior) - tolerance
+ else:
+ positive_bar = list(set(posterior[:]))
+ if len(positive_bar) < 2:
+ break
+ else:
+ positive_bar.sort(reverse=True)
+ positive_bar = positive_bar[1] - tolerance
+
+ positive_set = []
+ negtive_set = []
+ expected_psi_cardinality = 0
+ for j in range(len(posterior)):
+ laplacian_noise = np.random.laplace(laplacian_location, laplacian_scale)
+ cur_posterior = posterior[j]
+ if cur_posterior >= positive_bar:
+ r = random.choices([0, 1], [1 - sample_rate, sample_rate])[0]
+ if r:
+ positive_set.append(j)
+ else:
+ negtive_set.append(j)
+ expected_psi_cardinality += cur_posterior
+
+ real_psi_cardinality_returned = 0
+
+ if not positive_set:
+ for j in negtive_set:
+ if posterior[j] >= 0.5:
+ posterior[j] = 1
+ else:
+ posterior[j] = 0
+ continue
+
+ if not negtive_set:
+ for j in positive_set:
+ if posterior[j] >= 0.5:
+ posterior[j] = 1
+ else:
+ posterior[j] = 0
+ continue
+
+ for j in range(len(dataset)):
+ if j in positive_set and dataset[j] == 1:
+ real_psi_cardinality_returned += 1
+
+ real_psi_cardinality_returned += laplacian_noise * 0.1 * real_psi_cardinality_returned
+ real_psi_cardinality_returned = max(real_psi_cardinality_returned, 0)
+ real_psi_cardinality_returned = min(real_psi_cardinality_returned, len(positive_set))
+ in_setprior = real_psi_cardinality_returned / len(positive_set)
+ out_setprior = (expected_psi_cardinality - real_psi_cardinality_returned) / len(negtive_set)
+
+ for j in range(len(posterior)):
+ if j in positive_set:
+ posterior[j] = in_setprior
+ if j in negtive_set:
+ posterior[j] = out_setprior
+
+ for j in range(len(posterior)):
+ if posterior[j] >= upper_bound:
+ posterior[j] = 1
+ if posterior[j] <= lower_bound:
+ posterior[j] = 0
+
+ if list(set(posterior)) == [0, 1] or list(set(posterior)) == [1, 0]:
+ break
+ tau -= 1
+
+ error_pos = 0
+ error_neg = 0
+ true_pos_leak = 0
+ true_neg_leak = 0
+ for j in range(len(posterior)):
+ if posterior[j] == 1:
+ true_pos_leak += 1
+ if dataset[j] == 0:
+ error_pos += 1
+ true_pos_leak -= 1
+ elif posterior[j] == 0:
+ true_neg_leak += 1
+ if dataset[j] == 1:
+ error_neg += 1
+ true_neg_leak -= 1
+
+ if true_pos_leak == 0:
+ pos_err = 0
+ else:
+ pos_err = error_pos
+
+ if true_neg_leak == 0:
+ neg_err = 0
+ else:
+ neg_err = error_neg
+
+ return true_pos_leak, true_neg_leak, pos_err, neg_err
diff --git a/anonpsi/src/baseline.py b/anonpsi/src/baseline.py
new file mode 100644
index 0000000..ea50e69
--- /dev/null
+++ b/anonpsi/src/baseline.py
@@ -0,0 +1,95 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file includes baseline algorithm from USENIX22."""
+
+import heapq
+import random
+
+import utils
+
+
+def baseline_attack_with_limited_call(victim_x, target_y, tau):
+ """Baseline MIA(Membership inference attack) for MPC.
+
+ Args:
+ victim_x: A set of victim elements X.
+ target_y: A set of target elements Y.
+ tau: A int indicates the number of protocol call limitation.
+
+ Returns:
+ Predicted set Z_pos and Z_neg.
+ """
+ # Initialize prediction set.
+ z_pos = set()
+ z_neg = set()
+
+ # Initialize the priority Queue
+ forest = []
+ idx = 0 # Indexing for debug.
+ heapq.heapify(forest)
+ initial_node_cardinality = utils.psi_cardinality(victim_x, target_y)
+ protocol_call_num = 1 # Tracking the protocol calling times
+ victim_x = list(victim_x)
+ random.shuffle(victim_x)
+ victim_x = set(victim_x)
+ node = utils.TreeNode()
+ tree = node.form_tree(victim_x) # Generate tree based on X
+ heapq.heappush(forest, (-1 * initial_node_cardinality / len(victim_x), initial_node_cardinality, idx, tree))
+ idx += 1
+
+ while forest and protocol_call_num <= tau:
+ _, current_node_cardinality, _, node = heapq.heappop(forest)
+ # Check the number of elements for calculating priority
+ while 0 < current_node_cardinality < len(node.val) and protocol_call_num <= tau:
+ left_node = node.left
+ right_node = node.right
+
+ # Call protocol once and get density for both children node
+ if len(right_node.val) > len(left_node.val):
+ left_cardinality = utils.psi_cardinality(left_node.val, target_y)
+ protocol_call_num += 1
+ right_cardinality = current_node_cardinality - left_cardinality
+ else:
+ right_cardinality = utils.psi_cardinality(right_node.val, target_y)
+ protocol_call_num += 1
+ left_cardinality = current_node_cardinality - right_cardinality
+
+ # Check the density as the priority for moving to a child node and
+ # push the other child node into the forest
+ right_priority = right_cardinality / len(right_node.val)
+ left_priority = left_cardinality / len(left_node.val)
+ if right_priority > left_priority:
+ if left_cardinality != 0:
+ pushed_priority = -1 * left_cardinality / len(left_node.val)
+ heapq.heappush(forest, (pushed_priority, left_cardinality, idx, left_node))
+ idx += 1
+ else:
+ z_neg.update(left_node.val)
+
+ node = right_node
+ current_node_cardinality = right_cardinality
+ else:
+ if right_cardinality != 0:
+ pushed_priority = -1 * right_cardinality / len(right_node.val)
+ heapq.heappush(forest, (pushed_priority, right_cardinality, idx, right_node))
+ idx += 1
+ else:
+ z_neg.update(right_node.val)
+
+ node = left_node
+ current_node_cardinality = left_cardinality
+
+ if current_node_cardinality > 0 and protocol_call_num <= tau:
+ z_pos.update(node.val)
+ return [z_pos, z_neg]
diff --git a/anonpsi/src/dynpathblizer.py b/anonpsi/src/dynpathblizer.py
new file mode 100644
index 0000000..41759e8
--- /dev/null
+++ b/anonpsi/src/dynpathblizer.py
@@ -0,0 +1,111 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file includes the dypathblizer algorithm based on dynamic programming."""
+
+import heapq
+import random
+
+import utils
+
+
+def dynpathblizer(victim_x, target_y, tau, gamma, phi):
+ """Dynamic programming for MIA.
+
+ Args:
+ victim_x: A set of victim elements X.
+ target_y: A set of target elements Y.
+ tau: A int indicates the number of protocol call limitation.
+ gamma: A hash map represents the optimal partition K for current state(N, C_N, tau).
+ phi: A hash map represents average call required for current state(N, C_N).
+
+ Returns:
+ Predicted set Z_pos and Z_neg.
+ """
+ # Initialize prediction set
+ z_pos = set()
+ z_neg = set()
+
+ # Initialize the priority Queue
+ forest = []
+ idx = 0
+ heapq.heapify(forest)
+ initial_node_cardinality = utils.psi_cardinality(victim_x, target_y)
+ protocol_call_num = 1 # Tracking the protocol calling times
+ victim_x = list(victim_x)
+ random.shuffle(victim_x)
+ victim_x = set(victim_x)
+ heapq.heappush(forest,
+ (min(-1 * initial_node_cardinality / len(victim_x),
+ -1 + initial_node_cardinality / len(victim_x)), initial_node_cardinality, idx, victim_x))
+ idx += 1
+
+ while forest and protocol_call_num <= tau:
+ _, current_node_cardinality, _, node = heapq.heappop(forest)
+ while 0 < current_node_cardinality < len(node) and protocol_call_num < tau:
+ max_call_num = phi["(" + str(len(node)) + "," + str(current_node_cardinality) + ")"]
+ if max_call_num < tau - protocol_call_num:
+ partition_factor_k = len(node) // 2
+ else:
+ partition_factor_k = gamma["S(" + str(len(node)) + "," + str(current_node_cardinality) + "," +
+ str(tau - protocol_call_num) + ")"][1]
+
+ cur_node = list(node)
+ if partition_factor_k == 0 or partition_factor_k == len(node):
+ break
+ random.shuffle(cur_node)
+ left_node = set(cur_node[:partition_factor_k])
+ right_node = set(cur_node[partition_factor_k:])
+
+ # Call protocol once and get density for both children node
+ if len(right_node) > len(left_node):
+ left_cardinality = utils.psi_cardinality(left_node, target_y)
+ protocol_call_num += 1
+ right_cardinality = current_node_cardinality - left_cardinality
+ else:
+ right_cardinality = utils.psi_cardinality(right_node, target_y)
+ protocol_call_num += 1
+ left_cardinality = current_node_cardinality - right_cardinality
+
+ # Check the density as the priority for moving to a child node and
+ # push the other child node into the forest
+ right_priority = max(right_cardinality / len(right_node), 1 - right_cardinality / len(right_node))
+ left_priority = max(left_cardinality / len(left_node), 1 - left_cardinality / len(left_node))
+ if right_priority > left_priority:
+ if left_cardinality != 0:
+ pushed_priority = min(-1 * left_cardinality / len(left_node),
+ -1 + left_cardinality / len(left_node))
+ heapq.heappush(forest, (pushed_priority, left_cardinality, idx, left_node))
+ idx += 1
+ else:
+ z_neg.update(left_node)
+
+ node = right_node
+ current_node_cardinality = right_cardinality
+ else:
+ if right_cardinality != 0:
+ pushed_priority = min(-1 * right_cardinality / len(right_node),
+ -1 + right_cardinality / len(right_node))
+ heapq.heappush(forest, (pushed_priority, right_cardinality, idx, right_node))
+ idx += 1
+ else:
+ z_neg.update(right_node)
+
+ node = left_node
+ current_node_cardinality = left_cardinality
+
+ if current_node_cardinality == len(node) and protocol_call_num <= tau:
+ z_pos.update(node)
+ elif current_node_cardinality == 0 and protocol_call_num <= tau:
+ z_neg.update(node)
+ return [z_pos, z_neg]
diff --git a/anonpsi/src/dynpathblizer_params.py b/anonpsi/src/dynpathblizer_params.py
new file mode 100644
index 0000000..e222511
--- /dev/null
+++ b/anonpsi/src/dynpathblizer_params.py
@@ -0,0 +1,121 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file includes the memo generting functions in the dypathblizer algorithm."""
+
+import math
+
+from tqdm import tqdm
+
+
+def gen_gamma_and_phi(size):
+ """Generates gamma, phi, which are the memo space for the dynamic
+ programming algorithm and the minimized protocol call.
+
+ Args:
+ size: A int indicates the number of victim set elements.
+
+ Returns:
+ Memorized space gamma and phi.
+ """
+ gamma = {}
+ gamma["S(1,0,0)"] = [1, 0]
+ gamma["S(1,0,1)"] = [1, 0]
+ gamma["S(1,1,0)"] = [1, 0]
+ gamma["S(1,1,1)"] = [1, 0]
+ phi = {}
+ phi["(1,0)"] = 0
+ phi["(1,1)"] = 0
+ for n in tqdm(range(2, size + 1)):
+ for c_n in range((n + 1) // 2 + 1):
+ for tau in range(n + 1):
+ cur_state = "S(" + str(n) + "," + \
+ str(c_n) + "," + str(tau) + ")"
+ max_call = "(" + str(n) + "," + str(c_n) + ")"
+ cur_state_reverse = "S(" + str(n) + "," + \
+ str(n - c_n) + "," + str(tau) + ")"
+ max_call_reverse = "(" + str(n) + "," + str(n - c_n) + ")"
+ if c_n == 0 or c_n == n:
+ phi[max_call] = 0
+ phi[max_call_reverse] = 0
+ gamma[cur_state] = (n, 0)
+ gamma[cur_state_reverse] = (n, 0)
+ break
+ elif tau >= n:
+ gamma[cur_state] = (n, n // 2)
+ gamma[cur_state_reverse] = (n, n // 2)
+ elif tau == 0:
+ gamma[cur_state] = (0, 0)
+ gamma[cur_state_reverse] = (0, 0)
+ else:
+ expected_leakage = 0
+ max_k = 0
+ for k in range(1, n // 2 + 1):
+ left_leakage = 0
+ right_leakage = 0
+ for c in range(max(0, c_n + k - n), min(k, c_n) + 1):
+ prob = math.comb(c_n, c) * math.comb(n - c_n, k - c) / math.comb(n, k)
+ left_call_need = phi["(" + str(k) + "," + str(c) + ")"]
+ right_call_need = phi["(" + str(n - k) + "," + str(c_n - c) + ")"]
+
+ if tau - 1 < left_call_need:
+ left_leakage += prob * \
+ gamma["S(" + str(k) + "," + str(c) +
+ "," + str(tau - 1) + ")"][0]
+ if c_n == c or c_n - c == n - k:
+ left_leakage += prob * (n - k)
+ else:
+ left_leakage += prob * k
+ if tau - 1 - left_call_need < right_call_need:
+ left_leakage += prob * \
+ gamma["S(" + str(n - k) + "," + str(c_n - c) +
+ "," + str(tau - 1 - left_call_need) + ")"][0]
+ else:
+ left_leakage += prob * (n - k)
+
+ if tau - 1 < right_call_need:
+ right_leakage += prob * \
+ gamma["S(" + str(n-k) + "," +
+ str(c_n-c) + "," + str(tau-1) + ")"][0]
+ if c == 0 or c == k:
+ right_leakage += prob * k
+ else:
+ right_leakage += prob * (n - k)
+ if tau - 1 - right_call_need < left_call_need:
+ right_leakage += prob * \
+ gamma["S(" + str(k) + "," + str(c) +
+ "," + str(tau - 1 - right_call_need) + ")"][0]
+ else:
+ right_leakage += prob * k
+
+ if left_leakage >= expected_leakage:
+ expected_leakage = left_leakage
+ max_k = max(max_k, k)
+ if right_leakage >= expected_leakage:
+ expected_leakage = right_leakage
+ max_k = max(max_k, k)
+
+ gamma[cur_state] = (expected_leakage, max_k)
+ gamma[cur_state_reverse] = (expected_leakage, max_k)
+
+ if expected_leakage >= n:
+ if "(" + str(n) + "," + str(c_n) + ")" in phi:
+ phi["(" + str(n) + "," + str(c_n) + ")"] = min(phi["(" + str(n) + "," + str(c_n) + ")"],
+ tau)
+ phi["(" + str(n) + "," + str(n - c_n) + ")"] = min(
+ phi["(" + str(n) + "," + str(n - c_n) + ")"], tau)
+ else:
+ phi["(" + str(n) + "," + str(c_n) + ")"] = tau
+ phi["(" + str(n) + "," + str(n - c_n) + ")"] = tau
+ break
+ return [gamma, phi]
diff --git a/anonpsi/src/improved_baseline.py b/anonpsi/src/improved_baseline.py
new file mode 100644
index 0000000..a75c39e
--- /dev/null
+++ b/anonpsi/src/improved_baseline.py
@@ -0,0 +1,103 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This file includes improved baseline algorithm.
+
+Improved algorithm leverages both positive and negative memberships.
+"""
+
+import random
+import heapq
+
+import utils
+
+
+def improved_baseline_attack_with_limited_call(victim_x, target_y, tau):
+ """Improved MIA for MPC.
+
+ Args:
+ victim_x: A set of victim elements X.
+ target_y: A set of target elements Y.
+ tau: A int indicates the number of protocol call limitation.
+
+ Returns:
+ Predicted set Z_pos and Z_neg.
+ """
+ # Initialize prediction set
+ z_pos = set()
+ z_neg = set()
+
+ # Initialize the priority Queue
+ forest = []
+ idx = 0
+ heapq.heapify(forest)
+ initial_node_cardinality = utils.psi_cardinality(victim_x, target_y)
+ protocol_call_num = 1 # Tracking the protocol calling times
+ victim_x = list(victim_x)
+ random.shuffle(victim_x)
+ victim_x = set(victim_x)
+ node = utils.TreeNode()
+ tree = node.form_tree(victim_x) # Generate tree based on X
+ heapq.heappush(forest, (min(-1 * initial_node_cardinality / len(victim_x),
+ -1 + initial_node_cardinality / len(victim_x)), initial_node_cardinality, idx, tree))
+ idx += 1
+
+ while forest and protocol_call_num <= tau:
+ _, current_node_cardinality, _, node = heapq.heappop(forest)
+ # Check the number of elements for calculating priority
+ while 0 < current_node_cardinality < len(node.val) and protocol_call_num <= tau:
+ left_node = node.left
+ right_node = node.right
+
+ # Call protocol once and get density for both children node
+ if len(right_node.val) > len(left_node.val):
+ left_cardinality = utils.psi_cardinality(left_node.val, target_y)
+ protocol_call_num += 1
+ right_cardinality = current_node_cardinality - left_cardinality
+ else:
+ right_cardinality = utils.psi_cardinality(right_node.val, target_y)
+ protocol_call_num += 1
+ left_cardinality = current_node_cardinality - right_cardinality
+
+ # Check the density as the priority for moving to a child node and
+ # push the other child node into the forest
+ right_priority = max(right_cardinality / len(right_node.val), 1 - right_cardinality / len(right_node.val))
+ left_priority = max(left_cardinality / len(left_node.val), 1 - left_cardinality / len(left_node.val))
+ if right_priority > left_priority:
+ if left_cardinality != 0:
+ pushed_priority = min(-1 * left_cardinality / len(left_node.val),
+ -1 + left_cardinality / len(left_node.val))
+ heapq.heappush(forest, (pushed_priority, left_cardinality, idx, left_node))
+ idx += 1
+ else:
+ z_neg.update(left_node.val)
+
+ node = right_node
+ current_node_cardinality = right_cardinality
+ else:
+ if right_cardinality != 0:
+ pushed_priority = min(-1 * right_cardinality / len(right_node.val),
+ -1 + right_cardinality / len(right_node.val))
+ heapq.heappush(forest, (pushed_priority, right_cardinality, idx, right_node))
+ idx += 1
+ else:
+ z_neg.update(right_node.val)
+
+ node = left_node
+ current_node_cardinality = left_cardinality
+
+ if current_node_cardinality > 0 and protocol_call_num <= tau:
+ z_pos.update(node.val)
+ elif current_node_cardinality == 0 and protocol_call_num <= tau:
+ z_neg.update(node.val)
+ return [z_pos, z_neg]
diff --git a/anonpsi/src/sum_only.py b/anonpsi/src/sum_only.py
new file mode 100644
index 0000000..ab57f8d
--- /dev/null
+++ b/anonpsi/src/sum_only.py
@@ -0,0 +1,107 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" This file includes MIA attack that based on SUM information only."""
+
+import heapq
+
+import utils
+
+
+def improved_attack_with_sum_only(victim_x, target_y, tau):
+ """Improved MIA attack for PSI-CA that considers sum information only.
+
+ Args:
+ victim_x: A set of victim elements X.
+ target_y: A set of target elements Y.
+ tau: A int indicates the number of protocol call limitation.
+
+ Returns:
+ Predicted set Z_pos and Z_neg.
+ """
+ # Initialize prediction set
+ z_pos = set()
+ z_neg = set()
+
+ # Initialize the priority Queue
+ forest = []
+ idx = 0
+ heapq.heapify(forest)
+ _, current_sum = utils.psi_sum(victim_x, target_y)
+ total_sum = sum(pair[1] for pair in list(victim_x))
+ protocol_call_num = 1 # Tracking the protocol calling times
+ node = utils.TreeNode()
+ tree = node.form_tree(victim_x) # Generate Tree based on X
+ heapq.heappush(forest, (min(-1 * current_sum / total_sum, -1 + current_sum / total_sum), current_sum, idx, tree))
+ idx += 1
+
+ while forest and protocol_call_num <= tau:
+ _, current_node_sum, _, node = heapq.heappop(forest)
+ total_sum = sum(pair[1] for pair in list(node.val))
+ while 0 < current_node_sum < total_sum and protocol_call_num <= tau:
+ left_node = node.left
+ if left_node:
+ left_node_total = sum(pair[1] for pair in list(left_node.val))
+ else:
+ left_node_total = 0
+
+ right_node = node.right
+ if right_node:
+ right_node_total = sum(pair[1] for pair in list(right_node.val))
+ else:
+ right_node_total = 0
+
+ if not right_node or not left_node or left_node_total == 0 or right_node_total == 0:
+ break
+
+ # Call protocol once and get density for both children node
+ if right_node_total > left_node_total:
+ _, left_node_sum = utils.psi_sum(left_node.val, target_y)
+ protocol_call_num += 1
+ right_node_sum = current_node_sum - left_node_sum
+ else:
+ _, right_node_sum = utils.psi_sum(right_node.val, target_y)
+ protocol_call_num += 1
+ left_node_sum = current_node_sum - right_node_sum
+
+ # Check the density as the priority for moving to a child node and
+ # push the other child node into the forest
+ right_priority = max(right_node_sum / right_node_total, 1 - right_node_sum / right_node_total)
+ left_priority = max(left_node_sum / left_node_total, 1 - left_node_sum / left_node_total)
+ if right_priority > left_priority:
+ if left_node_sum != 0:
+ pushed_priority = min(-1 * left_node_sum / left_node_total, -1 + left_node_sum / left_node_total)
+ heapq.heappush(forest, (pushed_priority, left_node_sum, idx, left_node))
+ idx += 1
+ else:
+ z_neg.update(left_node.val)
+
+ node = right_node
+ current_node_sum = right_node_sum
+ else:
+ if right_node_sum != 0:
+ pushed_priority = min(-1 * right_node_sum / right_node_total,
+ -1 + right_node_sum / right_node_total)
+ heapq.heappush(forest, (pushed_priority, right_node_sum, idx, right_node))
+ idx += 1
+ else:
+ z_neg.update(right_node.val)
+
+ node = left_node
+ current_node_sum = left_node_sum
+
+ if current_node_sum > 0 and protocol_call_num <= tau:
+ z_pos.update(node.val)
+ elif current_node_sum == 0 and protocol_call_num <= tau:
+ z_neg.update(node.val)
+ return [z_pos, z_neg]
diff --git a/anonpsi/src/test.py b/anonpsi/src/test.py
new file mode 100644
index 0000000..e6ab9e4
--- /dev/null
+++ b/anonpsi/src/test.py
@@ -0,0 +1,154 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file includes a testing example."""
+
+import matplotlib.pyplot as plt
+import numpy as np
+from tqdm import tqdm
+
+import actbayesian
+import baseline
+import dynpathblizer
+import dynpathblizer_params
+import improved_baseline
+import treesumexplorer
+import utils
+
+
+def algorithms_compare(tau, victim_x, target_y, gamma, phi, step, iteration, lower_bound, upper_bound, tolerance,
+ noise_scale, sampling_rate):
+ """Runs all algorithms in this repo to compare their performance under given parameters.
+
+ Args:
+ tau: A int indicates the PSI call budget.
+ victim_x: A set of victim elements X.
+ target_y: A set of target elements Y.
+ gamma: A hash map represents the optimal partition K for current state(N, C_N, tau).
+ phi: A hash map represents average call required for current state(N, C_N).
+ step: A int indicates the step of tau in testing.
+ iteration: A int indicates the number of tesing times.
+ lower_bound: A float indicates the lower bound of the stopping criterion in actbayesian.
+ upper_bound: A float indicates the upper bound of the stopping criterion in actbayesian.
+ tolerance: A floar indicates the tolerance factor in actbayesian.
+ noise_scale: A float indicates the noise factor of the Laplacian mechanism in actbayesian.
+ sample_rate: A float indicates the sampling rate in actbayesian.
+
+ Returns:
+ tau_list: A list of PSI call budget.
+ y_base: leakage of the baseline attack.
+ y_impr_base: leakage of the improvded baseline attack.
+ y_dypathblizer: leakage of the dypathblizer.
+ y_treesumexplorer: leakage of the treesumexplorer.
+ y_actbayesian: leakage of the actbayesian.
+ """
+ leakage_base = []
+ leakage_impr_base = []
+ leakage_treesumexplorer = []
+ leakage_dynpathblizer = []
+ leakage_actbayesian = []
+
+ inpute_dataset_actbayesian = []
+ for x in victim_x:
+ if x in target_y:
+ inpute_dataset_actbayesian.append(1)
+ else:
+ inpute_dataset_actbayesian.append(0)
+
+ for k in tqdm(range(0, tau, step)):
+ cur_z_base = 0
+ cur_z_impr_base = 0
+ cur_z_dynpathblizer = 0
+ cur_z_treesumexplorer = 0
+ cur_z_actbayesian = 0
+ for _ in range(iteration):
+ [z_pos_base, z_neg_base] = baseline.baseline_attack_with_limited_call(victim_x, target_y, k)
+ [z_pos_impr_base,
+ z_neg_impr_base] = improved_baseline.improved_baseline_attack_with_limited_call(victim_x, target_y, k)
+ [z_pos_dynpathblizer, z_neg_dynpathblizer] = dynpathblizer.dynpathblizer(victim_x, target_y, k, gamma, phi)
+ [z_pos_treesumexplorer, z_neg_treesumexplorer] = treesumexplorer.psi_sum_attack(victim_x, target_y, k, 1e10)
+ [z_pos_actbayesian, z_neg_actbayesian, _,
+ _] = actbayesian.actbayesian(k, lower_bound, upper_bound, inpute_dataset_actbayesian, tolerance,
+ noise_scale, sampling_rate)
+
+ if len(z_pos_base) + len(z_neg_base) == utils.psi_cardinality(victim_x, target_y):
+ cur_z_base += len(victim_x)
+ else:
+ cur_z_base += len(z_pos_base) + len(z_neg_base)
+
+ cur_z_impr_base += len(z_pos_impr_base) + len(z_neg_impr_base)
+ cur_z_dynpathblizer += len(z_pos_dynpathblizer) + len(z_neg_dynpathblizer)
+ cur_z_treesumexplorer += len(z_pos_treesumexplorer) + len(z_neg_treesumexplorer)
+ cur_z_actbayesian += z_pos_actbayesian + z_neg_actbayesian
+
+ if len(leakage_base) > 0 and cur_z_base / iteration < leakage_base[-1]:
+ leakage_base[-1] = cur_z_base / iteration
+
+ leakage_base.append(cur_z_base / iteration)
+ leakage_impr_base.append(cur_z_impr_base / iteration)
+ leakage_dynpathblizer.append(cur_z_dynpathblizer / iteration)
+ leakage_treesumexplorer.append(cur_z_treesumexplorer / iteration)
+ leakage_actbayesian.append(cur_z_actbayesian / iteration)
+
+ tau_list = np.array(list(range(1, tau + 1, step)))
+ y_base = np.array(leakage_base)
+ y_impr_base = np.array(leakage_impr_base)
+ y_dypathblizer = np.array(leakage_dynpathblizer)
+ y_treesumexplorer = np.array(leakage_treesumexplorer)
+ y_actbayesian = np.array(leakage_actbayesian)
+
+ return [tau_list, y_base, y_impr_base, y_dypathblizer, y_treesumexplorer, y_actbayesian]
+
+
+def test():
+ """Plot utils with default parameters"""
+ victim_set_size = 20
+ target_set_size = 20
+ interserction_cardinality = 10
+ dense = 2
+ dummy_set_size = 0
+ lower_bound = 0.25
+ upper_bound = 0.8
+ tolerance = 0.1
+ noise_scale = 0
+ sampling_rate = 0.5
+ [victim_x, target_y, _] = utils.gen_dummy_set(victim_set_size, target_set_size, dummy_set_size, dense,
+ interserction_cardinality)
+ [gamma, phi] = dynpathblizer_params.gen_gamma_and_phi(victim_set_size)
+
+ tau = 5
+ step = 1
+ iteration = 200
+ [tau_list, y_base, y_impr_base, y_dypathblizer, y_treesumexplorer,
+ y_actbayesian] = algorithms_compare(tau, victim_x, target_y, gamma, phi, step, iteration, lower_bound, upper_bound,
+ tolerance, noise_scale, sampling_rate)
+
+ # plot
+ plt.xlabel('# of PSI Runs')
+ plt.ylabel('# of User Leakage')
+ plt.grid(True)
+
+ plt.plot(tau_list, y_base, label="USENIX22 attacking", marker='s', linewidth=2.0)
+ plt.plot(tau_list, y_impr_base, label="Improved USENIX22", marker='<', linewidth=2.0)
+ plt.plot(tau_list, y_dypathblizer, label="dynapathblizer_DP", marker='o', linewidth=2.0)
+ plt.plot(tau_list, y_treesumexplorer, label="treesumexplorer", marker='d', linewidth=2.0)
+ plt.plot(tau_list, y_actbayesian, label="actbayesian", marker='>', linewidth=2.0)
+ plt.rcParams.update({'font.size': 9})
+ plt.rc('axes', titlesize=12)
+ plt.legend()
+ plt.savefig('test', dpi=600)
+ plt.show()
+
+
+if __name__ == "__main__":
+ test()
diff --git a/anonpsi/src/treesumexplorer.py b/anonpsi/src/treesumexplorer.py
new file mode 100644
index 0000000..af41272
--- /dev/null
+++ b/anonpsi/src/treesumexplorer.py
@@ -0,0 +1,228 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file includes treesumexplorer algorithm for SMIA.
+
+Treesumexplorer combines an offline N sum problem and the baseline attack.
+"""
+
+import heapq
+import math
+import random
+
+import utils
+
+
+def n_sum(nums, target, n):
+ """Calculates all unique combinations of n pairs of values in nums
+ that the second values add up to the target value.
+
+ Args:
+ nums: A list of input sequence.
+ target: A int indicates the target value.
+ n: A int indicates the length of the subsequence that add up to target.
+
+ Returns:
+ A list of result, each mathes the condition.
+ """
+ res = []
+ nums = sorted(nums) # sort the input list for int vector
+ # backtrack(nums, target, n, [], res)
+ backtrack_int(nums, target, n, [], res)
+ return res
+
+
+def backtrack(nums, target, n, paths, res):
+ """A backtracking solution to traversal all possible combinations.
+
+ Args:
+ nums: A list of input sequence.
+ target: A int indicates the target value.
+ n: A int indicates the length of the subsequence that add up to target.
+ paths: A list saves the current paths of traversal.
+ res: A list saves the previous result.
+ """
+ if target < 0 or len(paths) > n:
+ return
+ if target == 0 and len(paths) == n:
+ res.append(paths)
+ return
+ for i in range(len(nums)):
+ if i > 0 and nums[i][1] == nums[i - 1][1]:
+ continue
+ backtrack(nums[i + 1:], target - nums[i][1], n, paths + [nums[i]], res)
+
+
+def backtrack_int(nums, target, n, paths, res):
+ """A backtracking solution to traversal all possible combinations for int vector input.
+
+ Args:
+ nums: A list of input sequence.
+ target: A int indicates the target value.
+ n: A int indicates the length of the subsequence that add up to target.
+ paths: A list saves the current paths of traversal.
+ res: A list saves the previous result.
+ """
+ if target < 0 or len(paths) > n:
+ return
+ if target == 0 and len(paths) == n:
+ res.append(paths)
+ return
+ for i in range(len(nums)):
+ if i > 0 and nums[i] == nums[i - 1]:
+ continue
+ backtrack_int(nums[i + 1:], target - nums[i], n, paths + [nums[i]], res)
+
+
+def computation_complexity(length, dense):
+ """Calculates the computation complexity given the current membership density and length of a seqeunce
+
+ Args:
+ length: A int indicates the input length.
+ dense: A int indicates the membership density.
+
+ Returns:
+ A float value indicates the computation complexity.
+ """
+ return math.log10(length**(length // dense - 1))
+
+
+def solve_equation(lower_bound, upper_bound, tau, computation_budget, dense, tol=1e-5):
+ """Calculates the length of input given current computation cost.
+
+ Args:
+ lower_bound: A int indicates the lower bound of the output range.
+ upper_bound: A int indicates the upper bound of the output range.
+ tau: A int indicates the protocol invocation budget.
+ computation_budget: A int indicates the computation budget.
+ dense: A int indicates the membership density.
+ tol: A double indicates the error tolerance.
+
+ Returns:
+ A int indicates the Length of the input.
+ """
+ val = math.log10(computation_budget / tau)
+ out = 0
+ while lower_bound < upper_bound:
+ middle = (upper_bound + lower_bound) // 2
+ cur = computation_complexity(middle, dense)
+ if abs(cur - val) < tol:
+ return middle
+ elif cur > val:
+ upper_bound = middle - 1
+ else:
+ out = middle
+ lower_bound = middle + 1
+ return out
+
+
+def psi_sum_attack(victim_x, target_y, tau, computation_budget):
+ """TreeSumExporer Attack algorithm.
+ Args:
+ victim_x: A set of victim elements X.
+ target_y: A set of target elements Y.
+ tau: A int indicates the protocol invocation budget.
+ computation_budget: A int indicates the computation budget.
+
+ Returns:
+ Predicted postive(z_pos) and negative(z_neg) membership sets.
+ """
+ z_pos = set()
+ z_neg = set()
+ while tau > 0:
+ victim_x = victim_x.difference(z_pos)
+ victim_x = victim_x.difference(z_neg)
+ if len(victim_x) > 0:
+ z_pos, z_neg, tau = psi_sum_helper(victim_x, target_y, tau, computation_budget, z_pos, z_neg)
+ else:
+ break
+ return [z_pos, z_neg]
+
+
+def psi_sum_helper(victim_x, target_y, tau, computation_budget, z_pos, z_neg):
+ """ TreeSumExporer Attack algorithm helper function.
+
+ Args:
+ victim_x: A set of victim elements X.
+ target_y: A set of target elements Y.
+ tau: A int indicates the protocol invocation budget.
+ computation_budget: A int indicates the computation budget.
+ z_pos: previous predicted postive membership sets.
+ z_neg: previous predicted negative membership sets.
+
+ Returns:
+ Predicted postive(z_pos) and negative(z_neg) membership sets and updated tau.
+ """
+ computation_budget = int(computation_budget)
+ forest = []
+ idx = 0
+ heapq.heapify(forest)
+ [current_cardinality, current_sum] = utils.psi_sum_int(victim_x, target_y)
+ tau -= 1
+
+ if tau == 0:
+ return (z_pos, z_neg, tau)
+
+ candidate_size = len(victim_x)
+ if len(victim_x)**(current_cardinality - 1) >= computation_budget / (tau):
+ candidate_size = solve_equation(0, 500, tau, computation_budget, len(victim_x) // current_cardinality)
+ victim_x = list(victim_x)
+ random.shuffle(victim_x)
+ victim_x = set(victim_x[:candidate_size])
+ [current_cardinality, current_sum] = utils.psi_sum_int(victim_x, target_y)
+ tau -= 1
+
+ candidate_list = n_sum(victim_x, current_sum, current_cardinality)
+ priority = len(victim_x) if len(candidate_list) <= tau else len(victim_x) * (1 -
+ (1 - 1 / len(candidate_list))**(tau))
+ heapq.heappush(forest, (-1 * priority, current_cardinality, current_sum, victim_x, candidate_list, idx))
+ idx += 1
+ while forest:
+ priority, current_cardinality, _, victim_x, candidate_list, idx = heapq.heappop(forest)
+ while 1 < len(candidate_list) and 0 < tau:
+ left_node = candidate_list[0]
+ right_node = victim_x.difference(set(left_node))
+ [left_cardinality, left_sum] = utils.psi_sum_int(set(left_node), target_y)
+ tau -= 1
+ right_cardinality = current_cardinality - left_cardinality
+ right_sum = current_sum - left_sum
+
+ left_candidate_list = n_sum(left_node, left_sum, left_cardinality)
+ right_candidate_list = n_sum(right_node, right_sum, right_cardinality)
+ left_priority = len(left_node) if len(
+ left_candidate_list) <= tau else len(left_node) * (1 - (1 - 1 / len(left_candidate_list))**(tau))
+ right_priority = len(right_node) if len(
+ right_candidate_list) <= tau else len(right_node) * (1 - (1 - 1 / len(right_candidate_list))**(tau))
+
+ if left_priority > right_priority:
+ heapq.heappush(
+ forest,
+ (-1 * right_priority, right_cardinality, right_sum, set(right_node), right_candidate_list, idx))
+ idx += 1
+ victim_x = set(left_node)
+ current_cardinality = left_cardinality
+ candidate_list = left_candidate_list
+ current_sum = left_sum
+ else:
+ heapq.heappush(
+ forest, (-1 * left_priority, left_cardinality, left_sum, set(left_node), left_candidate_list, idx))
+ idx += 1
+ victim_x = set(right_node)
+ current_cardinality = right_cardinality
+ candidate_list = right_candidate_list
+ current_sum = right_sum
+
+ if len(candidate_list) == 1 and 0 <= tau:
+ z_pos.update(set(candidate_list[0]))
+ z_neg.update(victim_x.difference(set(candidate_list[0])))
+ return (z_pos, z_neg, tau)
diff --git a/anonpsi/src/utils.py b/anonpsi/src/utils.py
new file mode 100644
index 0000000..a5d21fb
--- /dev/null
+++ b/anonpsi/src/utils.py
@@ -0,0 +1,180 @@
+# Copyright 2023 TikTok Pte. Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This file defines basic functions used throughout the projects."""
+
+import random
+
+
+def psi_cardinality(victim, target):
+ """Performs private set intersection cardinality (PSI-CA) protocol.
+
+ Args:
+ victim: A set of victim elements.
+ target: A set of target elements.
+
+ Returns:
+ Cardinality of the intersection.
+ """
+ intersected_set = victim.intersection(target)
+ return len(intersected_set)
+
+
+def psi(victim, target):
+ """Performs private set intersection (PSI) protocol.
+
+ Args:
+ victim: A set of victim elements.
+ target: A set of target elements.
+
+ Returns:
+ Full intersection set.
+ """
+ intersected_set = victim.intersection(target)
+ return intersected_set
+
+
+def psi_sum(victim, target):
+ """Performs private set intersection - summation (PSI-SUM) protocol.
+
+ Args:
+ victim: A set of victim elements with its value.
+ target: A set of target elements with its value.
+
+ Returns:
+ Cardinality and summation of the intersection.
+ """
+ intersected_set = victim.intersection(target)
+ intersected_list = list(intersected_set)
+ summation = sum(pair[1] for pair in intersected_list)
+ return [len(intersected_set), summation]
+
+
+def psi_sum_int(victim, target):
+ """Performs private set intersection - summation protocol (PSI-SUM) for int vector input.
+
+ Args:
+ victim: A set of victim elements.
+ target: A set of target elements.
+
+ Returns:
+ Cardinality and summation of the intersection.
+ """
+ intersected_set = victim.intersection(target)
+ intersected_list = list(intersected_set)
+ summation = sum(intersected_list)
+ return [len(intersected_set), summation]
+
+
+class TreeNode:
+ """Tree node which is used to structure the target set"""
+
+ def __init__(self, val=0, left=None, right=None):
+ self.val = val
+ self.left = left
+ self.right = right
+
+ def form_tree(self, input_set):
+ """Forms the tree given the input set.
+
+ Args:
+ input_set: A set of target elements.
+
+ Returns:
+ Root of the target set tree.
+ """
+ cur_node = TreeNode(input_set)
+ if len(input_set) == 1:
+ return cur_node
+ cur_node_list = list(cur_node.val)
+ random.shuffle(cur_node_list)
+ left_node_list = cur_node_list[:len(cur_node_list) // 2]
+ right_node_list = cur_node_list[len(cur_node_list) // 2:]
+ cur_node.left = self.form_tree(set(left_node_list))
+ cur_node.right = self.form_tree(set(right_node_list))
+ return cur_node
+
+ def print_tree(self, root):
+ """prints the formed tree."""
+ if not root:
+ return
+ print(root.val)
+ self.print_tree(root.left)
+ self.print_tree(root.right)
+
+
+def gen_simple_set(victim_set_num, target_set_num, upper_bound):
+ """Generates testing set.
+ Args:
+ victim_set_num: A int indicates the number of victim set elements.
+ target_set_num: A int indicates the number of target set elements.
+ upper_bound: sample upper bound.
+
+ Returns:
+ Testing set of victim X and target Y.
+ """
+ sampling_range = range(0, upper_bound)
+ victim_x = set(random.sample(sampling_range, victim_set_num))
+ target_y = set(random.sample(sampling_range, target_set_num))
+ return [victim_x, target_y]
+
+
+def gen_set_with_intersection(victim_set_num, target_set_num, dense, intersection_num):
+ """Generates set with given intersection.
+
+ Args:
+ victim_set_num: A int indicates the number of victim set elements.
+ target_set_num: A int indicates the number of target set elements.
+ dense: sample density (>1).
+ intersection_num: A int indicates the number of intersection.
+
+ Returns:
+ Testing set of victim X and target Y.
+ """
+ sampling_range = range(0, max(victim_set_num, target_set_num) * dense)
+ sampled_data = random.sample(sampling_range, victim_set_num + target_set_num - intersection_num)
+ intersection = sampled_data[:intersection_num]
+ victim_x = sampled_data[intersection_num:victim_set_num]
+ target_y = sampled_data[victim_set_num:]
+ victim_x += intersection
+ target_y += intersection
+ random.shuffle(victim_x)
+ random.shuffle(target_y)
+ return [set(victim_x), set(target_y)]
+
+
+def gen_dummy_set(victim_set_num, target_set_num, dummy_set_num, dense, intersection_num):
+ """Generates testing set with given dummy set.
+
+ Args:
+ victim_set_num: A int indicates the number of victim set elements.
+ target_set_num: A int indicates the number of target set elements.
+ dummy_set_num: A int indicates the number of dummy set elements.
+ intersection_num: A int indicates the number of intersection.
+ dense: sample density (>1).
+
+ Returns:
+ Testing set of target X , victim Y, and dummy set.
+ """
+ sampling_range = range(0, (victim_set_num + target_set_num + dummy_set_num) * dense)
+ sampled_data = random.sample(sampling_range, victim_set_num + target_set_num + dummy_set_num - intersection_num)
+ intersection = sampled_data[:intersection_num]
+ victim_x = sampled_data[intersection_num:victim_set_num]
+ target_y = sampled_data[victim_set_num:victim_set_num + target_set_num - intersection_num]
+ dummy_set = sampled_data[victim_set_num + target_set_num - intersection_num:]
+ victim_x += intersection
+ target_y += intersection
+ random.shuffle(victim_x)
+ random.shuffle(target_y)
+ random.shuffle(dummy_set)
+ return [set(victim_x), set(target_y), set(dummy_set)]