Skip to content

Commit c81387f

Browse files
committed
add ranking flag
1 parent eccb957 commit c81387f

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

text_dedup/bigcode/intra_dedup.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,16 @@ def area(s):
341341
# endregion
342342

343343
# region: Quality Control
344-
def process_cluster(cluster: List[Any]) -> List[Any]:
344+
def process_cluster(cluster: List[Any], enabled: bool = False) -> List[Any]:
345+
if not enabled:
346+
np.random.shuffle(cluster)
347+
return cluster[:1]
348+
345349
cluster.sort(
346350
key=lambda x: (
347-
-x[-1] if x[-1] is not None else 0.0,
348-
-x[-2] if x[-2] is not None else 0.0,
351+
-x[-1] if x[-1] is not None else 0.0, # star_events_count
352+
-x[-2] if x[-2] is not None else 0.0, # fork_events_count
353+
-np.datetime64(x[-3]).astype(np.uint64) if x[-3] is not None else 0.0, # visit_date
349354
)
350355
)
351356
return cluster[:1]
@@ -421,6 +426,7 @@ def save_partition(df: pd.DataFrame) -> pd.DataFrame: # type: ignore
421426
parser.add_argument("--output", "-o", type=str, required=True, help="GCS output directory of parquet files")
422427
parser.add_argument("--output_index", "-oi", type=str, help="GCS output directory of index parquet files")
423428
parser.add_argument("--index_only", action="store_true", help="Only output the index, skip deduplication")
429+
parser.add_argument("--rank", action="store_true", help="Rank the duplicates by quality indicators")
424430
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
425431
args = parser.parse_args()
426432

@@ -588,17 +594,27 @@ def save_partition(df: pd.DataFrame) -> pd.DataFrame: # type: ignore
588594
# detected_licenses object
589595
# license_type object
590596

591-
duplicates: pyspark.RDD = (
592-
df.filter(F.col("__component__").isNotNull())
593-
.select(
597+
rank_columns = (
598+
[
594599
"__id__",
595600
"__component__",
596601
args.repo_column,
602+
"visit_date",
597603
"star_events_count",
598-
"fork_events_count",
599-
)
600-
.rdd
601-
).cache()
604+
"fork_events_count"
605+
# "max_stars_repo_stars_event_min_datetime",
606+
# "max_stars_count",
607+
# "max_forks_count",
608+
]
609+
if args.rank
610+
else [
611+
"__id__",
612+
"__component__",
613+
args.repo_column,
614+
]
615+
)
616+
617+
duplicates: pyspark.RDD = (df.filter(F.col("__component__").isNotNull()).select(*rank_columns).rdd).cache()
602618

603619
if args.debug:
604620
NUM_DUPLICATE = duplicates.count()
@@ -620,7 +636,7 @@ def save_partition(df: pd.DataFrame) -> pd.DataFrame: # type: ignore
620636
# region: Remove Low Quality Duplicates
621637
df = df.join(
622638
spark.createDataFrame(
623-
cliques.mapValues(lambda x: process_cluster(cluster=list(x))).flatMap(
639+
cliques.mapValues(lambda x: process_cluster(cluster=list(x), enabled=args.rank)).flatMap(
624640
lambda x: [(ele[0], True) for ele in x[1]]
625641
),
626642
schema=["__id__", "__keep__"],

text_dedup/bigcode/run.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
CLUSTER_NAME="chenghao-temp"
77
PROJECT_ID="huggingface-science-codeparrot"
88
REGION="us-central1"
9-
CONTAINER="gs://the_stack_v2"
10-
DIRECTORY="licensed_files_sample"
11-
NUM_WORKERS=5
9+
CONTAINER=""
10+
DIRECTORY=""
11+
NUM_WORKERS=18
1212
MASTER_MACHINE_TYPE="c2d-standard-16"
1313
MASTER_BOOT_DISK_SIZE=1024
14-
WORKER_MACHINE_TYPE="c2-standard-60"
14+
WORKER_MACHINE_TYPE="c2-standard-16"
1515
WORKER_BOOT_DISK_SIZE=1024
1616
IMAGE_VERSION="2.0-debian10"
1717
SPARK_JARS="gs://spark-lib/bigquery/spark-3.3-bigquery-0.32.2.jar"
@@ -84,6 +84,7 @@ for DIR in $DIRS; do
8484
--threshold $THRESHOLD \
8585
--output_index "$OUTPUT_INDEX_GCS_PATH" \
8686
--repo_column $REPO_COLUMN \
87+
--rank \
8788
--debug
8889
done
8990

0 commit comments

Comments
 (0)