Skip to content
Merged
81 changes: 74 additions & 7 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import traceback
import types
import typing
from itertools import dropwhile

from apache_beam import coders
from apache_beam import pvalue
Expand Down Expand Up @@ -1387,6 +1388,59 @@ def partition_for(self, element, num_partitions, *args, **kwargs):
return self._fn(element, num_partitions, *args, **kwargs)


def _get_function_body_without_inners(func):
source_lines = inspect.getsourcelines(func)[0]
source_lines = dropwhile(lambda x: x.startswith("@"), source_lines)
def_line = next(source_lines).strip()
if def_line.startswith("def ") and def_line.endswith(":"):
first_line = next(source_lines)
indentation = len(first_line) - len(first_line.lstrip())
final_lines = [first_line[indentation:]]

skip_inner_def = False
if first_line[indentation:].startswith("def "):
skip_inner_def = True
for line in source_lines:
line_indentation = len(line) - len(line.lstrip())

if line[indentation:].startswith("def "):
skip_inner_def = True
continue

if skip_inner_def and line_indentation == indentation:
skip_inner_def = False

if skip_inner_def and line_indentation > indentation:
continue
final_lines.append(line[indentation:])

return "".join(final_lines)
else:
return def_line.rsplit(":")[-1].strip()


def _check_fn_use_yield_and_return(fn):
if isinstance(fn, types.BuiltinFunctionType):
return False
try:
source_code = _get_function_body_without_inners(fn)
has_yield = False
has_return = False
for line in source_code.split("\n"):
if line.lstrip().startswith("yield ") or line.lstrip().startswith(
"yield("):
has_yield = True
if line.lstrip().startswith("return ") or line.lstrip().startswith(
"return("):
has_return = True
if has_yield and has_return:
return True
return False
except Exception as e:
_LOGGER.debug(str(e))
return False


class ParDo(PTransformWithSideInputs):
"""A :class:`ParDo` transform.

Expand Down Expand Up @@ -1427,6 +1481,14 @@ def __init__(self, fn, *args, **kwargs):
if not isinstance(self.fn, DoFn):
raise TypeError('ParDo must be called with a DoFn instance.')

# DoFn.process cannot allow both return and yield
if _check_fn_use_yield_and_return(self.fn.process):
_LOGGER.warning(
'Using yield and return in the process method '
'of %s can lead to unexpected behavior, see:'
'https://github.com/apache/beam/issues/22969.',
self.fn.__class__)

# Validate the DoFn by creating a DoFnSignature
from apache_beam.runners.common import DoFnSignature
self._signature = DoFnSignature(self.fn)
Expand Down Expand Up @@ -2663,6 +2725,7 @@ def from_runner_api_parameter(unused_ptransform, combine_payload, context):

class CombineValuesDoFn(DoFn):
"""DoFn for performing per-key Combine transforms."""

def __init__(
self,
input_pcoll_type,
Expand Down Expand Up @@ -2725,6 +2788,7 @@ def default_type_hints(self):


class _CombinePerKeyWithHotKeyFanout(PTransform):

def __init__(
self,
combine_fn, # type: CombineFn
Expand Down Expand Up @@ -2939,11 +3003,12 @@ class GroupBy(PTransform):
The GroupBy operation can be made into an aggregating operation by invoking
its `aggregate_field` method.
"""

def __init__(
self,
*fields, # type: typing.Union[str, typing.Callable]
**kwargs # type: typing.Union[str, typing.Callable]
):
):
if len(fields) == 1 and not kwargs:
self._force_tuple_keys = False
name = fields[0] if isinstance(fields[0], str) else 'key'
Expand All @@ -2966,7 +3031,7 @@ def aggregate_field(
field, # type: typing.Union[str, typing.Callable]
combine_fn, # type: typing.Union[typing.Callable, CombineFn]
dest, # type: str
):
):
"""Returns a grouping operation that also aggregates grouped values.

Args:
Expand Down Expand Up @@ -3054,7 +3119,7 @@ def aggregate_field(
field, # type: typing.Union[str, typing.Callable]
combine_fn, # type: typing.Union[typing.Callable, CombineFn]
dest, # type: str
):
):
field = _expr_to_callable(field, 0)
return _GroupAndAggregate(
self._grouping, list(self._aggregations) + [(field, combine_fn, dest)])
Expand Down Expand Up @@ -3096,10 +3161,12 @@ class Select(PTransform):

pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x)))
"""
def __init__(self,
*args, # type: typing.Union[str, typing.Callable]
**kwargs # type: typing.Union[str, typing.Callable]
):

def __init__(
self,
*args, # type: typing.Union[str, typing.Callable]
**kwargs # type: typing.Union[str, typing.Callable]
):
self._fields = [(
expr if isinstance(expr, str) else 'arg%02d' % ix,
_expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args)
Expand Down
113 changes: 113 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for the core python file."""
# pytype: skip-file

import logging
import unittest

import pytest

import apache_beam as beam


class TestDoFn1(beam.DoFn):
def process(self, element):
yield element


class TestDoFn2(beam.DoFn):
def process(self, element):
def inner_func(x):
yield x

return inner_func(element)


class TestDoFn3(beam.DoFn):
"""mixing return and yield is not allowed"""
def process(self, element):
if not element:
return -1
yield element


class TestDoFn4(beam.DoFn):
"""test the variable name containing return"""
def process(self, element):
my_return = element
yield my_return


class TestDoFn5(beam.DoFn):
"""test the variable name containing yield"""
def process(self, element):
my_yield = element
return my_yield


class TestDoFn6(beam.DoFn):
"""test the variable name containing return"""
def process(self, element):
return_test = element
yield return_test


class TestDoFn7(beam.DoFn):
"""test the variable name containing yield"""
def process(self, element):
yield_test = element
return yield_test


class TestDoFn8(beam.DoFn):
"""test the code containing yield and yield from"""
def process(self, element):
if not element:
yield from [1, 2, 3]
else:
yield element


class CreateTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog

def test_dofn_with_yield_and_return(self):
warning_text = 'Using yield and return'

with self._caplog.at_level(logging.WARNING):
assert beam.ParDo(sum)
assert beam.ParDo(TestDoFn1())
assert beam.ParDo(TestDoFn2())
assert beam.ParDo(TestDoFn4())
assert beam.ParDo(TestDoFn5())
assert beam.ParDo(TestDoFn6())
assert beam.ParDo(TestDoFn7())
assert beam.ParDo(TestDoFn8())
assert warning_text not in self._caplog.text

with self._caplog.at_level(logging.WARNING):
beam.ParDo(TestDoFn3())
assert warning_text in self._caplog.text


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()