Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions ann/src/main/python/dataflow/faiss_index_bq_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
import logging
import os
import pkgutil
import sys
from urllib.parse import urlsplit


import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from .apache_beam.options.pipeline_options import PipelineOptions
import faiss


Expand Down Expand Up @@ -94,8 +93,8 @@ def parse_metric(config):
raise Exception(f"Unknown metric: {metric_str}")


def run_pipeline(argv=[]):
config = parse_d6w_config(argv)
def run_pipeline(argv=[], log_level = logging.INFO):
config = parse_d6w_config(argv=None)
argv_with_extras = argv
if config["gpu"]:
argv_with_extras.extend(["--experiments", "use_runner_v2"])
Expand All @@ -108,7 +107,7 @@ def run_pipeline(argv=[]):
"gcr.io/twttr-recos-ml-prod/dataflow-gpu/beam2_39_0_py3_7",
]
)

logging.getLogger().setLevel(log_level)
options = PipelineOptions(argv_with_extras)
output_bucket_name = urlsplit(config["output_location"]).netloc

Expand Down Expand Up @@ -228,5 +227,10 @@ def extract_output(self, rows):


if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
run_pipeline(sys.argv)
parser = argparse.ArgumentParser()
parser.add_argument("--log_level", dest="log_level", default="INFO", help="Logging level")
args, pipeline_args = parser.parse_known_args()

logging.getLogger().setLevel(args.log_level)
run_pipeline(pipeline_args, log_level=args.log_level)