diff --git a/packages/bigframes/bigframes/core/bytecode.py b/packages/bigframes/bigframes/core/bytecode.py index fb4d3eabd8b7..cfe7e7f05cb4 100644 --- a/packages/bigframes/bigframes/core/bytecode.py +++ b/packages/bigframes/bigframes/core/bytecode.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import dis import operator import sys @@ -20,6 +21,7 @@ import bigframes.core.py_expressions as py_exprs from bigframes.core import expression +from bigframes.operations import generic_ops _BINARY_OP_MAP = { "+": operator.add, @@ -61,15 +63,176 @@ _NULL = py_exprs.PyObject(None) +_RETURN_OPNAMES = {"RETURN_VALUE", "RETURN_CONST"} + +_UNCONDITIONAL_JUMP_OPNAMES = { + "JUMP_FORWARD", + "JUMP_ABSOLUTE", + "JUMP_BACKWARD", + "JUMP_BACKWARD_NO_INTERRUPT", + "JUMP", + "JUMP_NO_INTERRUPT", +} + +_JUMP_IF_FALSE_OPNAMES = { + "POP_JUMP_IF_FALSE", + "POP_JUMP_FORWARD_IF_FALSE", + "POP_JUMP_BACKWARD_IF_FALSE", +} + +_JUMP_IF_TRUE_OPNAMES = { + "POP_JUMP_IF_TRUE", + "POP_JUMP_FORWARD_IF_TRUE", + "POP_JUMP_BACKWARD_IF_TRUE", +} + +_CONDITIONAL_JUMP_OPNAMES = ( + _JUMP_IF_FALSE_OPNAMES + | _JUMP_IF_TRUE_OPNAMES + | { + "JUMP_IF_FALSE_OR_POP", + "JUMP_IF_TRUE_OR_POP", + "POP_JUMP_IF_NONE", + "POP_JUMP_IF_NOT_NONE", + "POP_JUMP_FORWARD_IF_NONE", + "POP_JUMP_FORWARD_IF_NOT_NONE", + "POP_JUMP_BACKWARD_IF_NONE", + "POP_JUMP_BACKWARD_IF_NOT_NONE", + } +) + +_ALL_JUMP_OPNAMES = _UNCONDITIONAL_JUMP_OPNAMES | _CONDITIONAL_JUMP_OPNAMES + + +@dataclasses.dataclass +class BasicBlock: + start_offset: int + instructions: list[dis.Instruction] + successors: list[int] = dataclasses.field(default_factory=list) + predecessors: list[int] = dataclasses.field(default_factory=list) + + +def get_block_starts(instructions: list[dis.Instruction]) -> set[int]: + starts = {0} + for i, inst in enumerate(instructions): + opname = inst.opname + if opname in _ALL_JUMP_OPNAMES: + if isinstance(inst.argval, int): + starts.add(inst.argval) + if i + 1 < len(instructions): + starts.add(instructions[i + 1].offset) + elif opname in _RETURN_OPNAMES: + if i + 1 < len(instructions): + starts.add(instructions[i + 1].offset) + return starts + + +def get_block_successors(block: BasicBlock, next_offsets: dict[int, int]) -> list[int]: + if not block.instructions: + return [] + last_inst = block.instructions[-1] + opname = last_inst.opname + offset = last_inst.offset + + next_offset = next_offsets.get(offset) + + if opname in _RETURN_OPNAMES: + return [] + + if opname in _UNCONDITIONAL_JUMP_OPNAMES: + return [last_inst.argval] + + if opname in _CONDITIONAL_JUMP_OPNAMES: + successors = [last_inst.argval] + if next_offset is not None: + successors.append(next_offset) + return successors + + if next_offset is not None: + return [next_offset] + return [] + + +def build_cfg( + instructions: list[dis.Instruction], next_offsets: dict[int, int] +) -> dict[int, BasicBlock]: + starts = sorted(list(get_block_starts(instructions))) + + blocks: dict[int, BasicBlock] = {} + for i, start in enumerate(starts): + end = starts[i + 1] if i + 1 < len(starts) else None + block_insts = [ + inst + for inst in instructions + if start <= inst.offset and (end is None or inst.offset < end) + ] + blocks[start] = BasicBlock(start_offset=start, instructions=block_insts) + + for block in blocks.values(): + successors = get_block_successors(block, next_offsets) + block.successors = successors + for succ in successors: + blocks[succ].predecessors.append(block.start_offset) + + return blocks + + +def topological_sort(blocks: dict[int, BasicBlock]) -> list[int]: + in_degree = {offset: len(block.predecessors) for offset, block in blocks.items()} + queue = [offset for offset, deg in in_degree.items() if deg == 0] + order = [] + + while queue: + queue.sort() + curr = queue.pop(0) + order.append(curr) + for succ in blocks[curr].successors: + in_degree[succ] -= 1 + if in_degree[succ] == 0: + queue.append(succ) + + # TODO(b/521549179): Support limited loop analysis (eg unroll loops over a constant range). + if len(order) != len(blocks): + raise ValueError( + "Loops are not supported in the Python function for transpilation." + ) + + return order + + +def merge_values( + pairs: list[tuple[expression.Expression, expression.Expression]], +) -> expression.Expression: + if not pairs: + raise ValueError("Cannot merge empty list of values") + if len(pairs) == 1: + return pairs[0][0] + + val = pairs[-1][0] + for next_val, next_cond in reversed(pairs[:-1]): + val = py_exprs.Call( + py_exprs.PyObject(generic_ops.where_op), (next_val, next_cond, val) + ) + return val + + def _compile_bytecode_to_py_expr(func: Callable) -> expression.Expression: instructions = list(dis.get_instructions(func)) + next_offsets = { + inst.offset: next_inst.offset + for inst, next_inst in zip(instructions, instructions[1:]) + } + + blocks = build_cfg(instructions, next_offsets) + order = topological_sort(blocks) + + stack: list[expression.Expression] + local_vars: dict[str, expression.Expression] - stack: list[expression.Expression] = [] globals_dict = func.__globals__ import builtins builtins_dict = builtins.__dict__ - closure_dict = {} if func.__closure__: free_vars = func.__code__.co_freevars @@ -79,174 +242,486 @@ def _compile_bytecode_to_py_expr(func: Callable) -> expression.Expression: except ValueError: pass - for inst in instructions: - opname = inst.opname - - if opname in ("RESUME", "PRECALL"): - continue - - elif opname in ("LOAD_FAST_LOAD_FAST", "LOAD_FAST_BORROW_LOAD_FAST_BORROW"): - var1, var2 = inst.argval - stack.append(expression.UnboundVariableExpression(var1)) - stack.append(expression.UnboundVariableExpression(var2)) - - elif opname.startswith("LOAD_FAST"): - stack.append(expression.UnboundVariableExpression(inst.argval)) - - elif opname in ("LOAD_CONST", "LOAD_SMALL_INT"): - stack.append(py_exprs.PyObject(inst.argval)) - - elif opname == "LOAD_GLOBAL": - # In Python 3.11+, the lowest bit of inst.arg indicates that a NULL - # should be pushed before the global variable. - if sys.version_info >= (3, 11) and inst.arg is not None and (inst.arg & 1): - stack.append(_NULL) - name = inst.argval - found = False - val = None - if name in closure_dict: - val = closure_dict[name] - found = True - elif name in globals_dict: - val = globals_dict[name] - found = True - elif name in builtins_dict: - val = builtins_dict[name] - found = True - - if found: - if isinstance(val, ModuleType): - stack.append(py_exprs.Module(val)) - else: - stack.append(py_exprs.PyObject(val)) - else: - stack.append(expression.UnboundVariableExpression(name)) - - elif opname in ("LOAD_ATTR", "LOAD_METHOD"): - if not stack: - raise ValueError("Stack is empty") - target = stack.pop() - stack.append(py_exprs.GetAttr(target, inst.argval)) - if opname == "LOAD_METHOD": - if isinstance(target, py_exprs.Module): - stack.append(_NULL) - else: - stack.append(target) - - elif opname == "PUSH_NULL": - stack.append(_NULL) - - elif opname == "BINARY_OP": - if len(stack) < 2: - raise ValueError("Stack is empty") - right = stack.pop() - left = stack.pop() - op_symbol = inst.argrepr - if not op_symbol and isinstance(inst.argval, str): - op_symbol = inst.argval - if op_symbol and op_symbol.endswith("="): - op_symbol = op_symbol[:-1] - - if op_symbol not in _BINARY_OP_MAP: - raise ValueError(f"Unsupported binary operator: {op_symbol}") - stack.append( - py_exprs.Call( - py_exprs.PyObject(_BINARY_OP_MAP[op_symbol]), (left, right) - ) - ) - - # Support older Python versions compatibility - elif opname in _OLD_BINARY_OP_MAP: - if len(stack) < 2: - raise ValueError("Stack has < 2 elements") - right = stack.pop() - left = stack.pop() - stack.append( - py_exprs.Call( - py_exprs.PyObject(_OLD_BINARY_OP_MAP[opname]), (left, right) - ) - ) - - elif opname == "COMPARE_OP": - if len(stack) < 2: - raise ValueError("Stack has < 2 elements") - right = stack.pop() - left = stack.pop() - op_symbol = inst.argval - if op_symbol not in _COMPARE_OP_MAP: - raise ValueError(f"Unsupported compare operator: {op_symbol}") - stack.append( - py_exprs.Call( - py_exprs.PyObject(_COMPARE_OP_MAP[op_symbol]), (left, right) - ) - ) - - elif opname in ("UNARY_NEGATIVE", "UNARY_INVERT"): - if not stack: - raise ValueError("Stack is empty") - target = stack.pop() - stack.append( - py_exprs.Call( - py_exprs.PyObject( - operator.neg if opname == "UNARY_NEGATIVE" else operator.invert - ), - (target,), + block_outputs: dict[ + int, tuple[list[expression.Expression], dict[str, expression.Expression]] + ] = {} + block_reach_conditions: dict[int, expression.Expression] = { + 0: py_exprs.PyObject(True) + } + edge_conditions: dict[tuple[int, int], expression.Expression] = {} + edge_stacks: dict[tuple[int, int], list[expression.Expression]] = {} + returns: list[tuple[expression.Expression, expression.Expression]] = [] + + co = func.__code__ + param_names = list(co.co_varnames[: co.co_argcount]) + kwonly_argcount = co.co_kwonlyargcount + param_names.extend( + co.co_varnames[co.co_argcount : co.co_argcount + kwonly_argcount] + ) + + initial_local_vars: dict[str, expression.Expression] = { + name: expression.UnboundVariableExpression(name) for name in param_names + } + + for offset in order: + block = blocks[offset] + + reach_cond: expression.Expression + if offset == 0: + reach_cond = py_exprs.PyObject(True) + else: + incoming = [ + edge_conditions[(pred, offset)] + for pred in block.predecessors + if (pred, offset) in edge_conditions + ] + if not incoming: + continue + + reach_cond = incoming[0] + for cond in incoming[1:]: + reach_cond = py_exprs.Call( + py_exprs.PyObject(operator.or_), (reach_cond, cond) ) - ) - - elif opname == "UNARY_POSITIVE": - if not stack: - raise ValueError("Stack is empty") - target = stack.pop() - stack.append(py_exprs.Call(py_exprs.PyObject(operator.pos), (target,))) - - elif opname == "CALL_INTRINSIC_1": - if inst.argrepr == "INTRINSIC_UNARY_POSITIVE": - if not stack: - raise ValueError("Stack is empty") - target = stack.pop() - stack.append(py_exprs.Call(py_exprs.PyObject(operator.pos), (target,))) - else: - raise ValueError(f"Unsupported intrinsic: {inst.argrepr}") - - elif opname in ("CALL", "CALL_FUNCTION", "CALL_METHOD"): - num_args = inst.arg - assert num_args is not None - if len(stack) < num_args: - raise ValueError("Stack has < 2 elements") - args = [stack.pop() for _ in range(num_args)][::-1] - # In Python 3.11, LOAD_GLOBAL with NULL push puts NULL below the global. - # If NULL is below the callable on the stack, swap them to match - # the expected layout [callable, NULL]. - if len(stack) >= 2 and stack[-2] == _NULL: - stack[-1], stack[-2] = stack[-2], stack[-1] - if stack and stack[-1] == _NULL: - stack.pop() - elif ( - stack - and stack[-1] != _NULL - and isinstance(stack[-1], expression.Expression) - ): - self_arg = stack.pop() - args = [self_arg] + args - if not stack: - raise ValueError("Stack is empty") - callable_expr = stack.pop() - stack.append(py_exprs.Call(callable_expr, tuple(args))) - - elif opname == "RETURN_VALUE": - if not stack: - raise ValueError("Stack is empty") - return stack[-1] - - elif opname in ("STORE_FAST", "POP_TOP"): - if stack: - stack.pop() + block_reach_conditions[offset] = reach_cond + + if offset == 0: + stack = [] + local_vars = initial_local_vars.copy() else: - raise ValueError(f"Unsupported opcode: {opname}") + reachable_preds = [ + pred for pred in block.predecessors if (pred, offset) in edge_stacks + ] + if not reachable_preds: + continue + + h = len(edge_stacks[(reachable_preds[0], offset)]) + stack = [] + for i in range(h): + pairs = [ + (edge_stacks[(p, offset)][i], edge_conditions[(p, offset)]) + for p in reachable_preds + ] + stack.append(merge_values(pairs)) + + all_vars: set[str] = set() + for p in reachable_preds: + all_vars.update(block_outputs[p][1].keys()) + + local_vars = {} + for var in all_vars: + pairs = [ + ( + block_outputs[p][1].get( + var, expression.UnboundVariableExpression(var) + ), + edge_conditions[(p, offset)], + ) + for p in reachable_preds + ] + local_vars[var] = merge_values(pairs) + + jumped = False + for inst in block.instructions: + opname = inst.opname + + match opname: + case "RESUME" | "PRECALL" | "COPY_FREE_VARS" | "NOT_TAKEN" | "NOP": + continue + + case "LOAD_FAST_LOAD_FAST" | "LOAD_FAST_BORROW_LOAD_FAST_BORROW": + var1, var2 = inst.argval + stack.append( + local_vars.get(var1, expression.UnboundVariableExpression(var1)) + ) + stack.append( + local_vars.get(var2, expression.UnboundVariableExpression(var2)) + ) + + case ( + "LOAD_FAST" + | "LOAD_FAST_CHECK" + | "LOAD_FAST_AND_CLEAR" + | "LOAD_FAST_BORROW" + ): + stack.append( + local_vars.get( + inst.argval, + expression.UnboundVariableExpression(inst.argval), + ) + ) + + case "STORE_FAST": + if not stack: + raise ValueError("Stack is empty") + local_vars[inst.argval] = stack.pop() + + case "LOAD_CONST" | "LOAD_SMALL_INT": + stack.append(py_exprs.PyObject(inst.argval)) + + case "LOAD_DEREF" | "LOAD_FROM_DICT_OR_DEREF": + name = inst.argval + found = False + val = None + if name in closure_dict: + val = closure_dict[name] + found = True + elif name in globals_dict: + val = globals_dict[name] + found = True + elif name in builtins_dict: + val = builtins_dict[name] + found = True + + if found: + if isinstance(val, ModuleType): + stack.append(py_exprs.Module(val)) + else: + stack.append(py_exprs.PyObject(val)) + else: + stack.append(expression.UnboundVariableExpression(name)) + + case "LOAD_GLOBAL": + if ( + sys.version_info >= (3, 11) + and inst.arg is not None + and (inst.arg & 1) + ): + stack.append(_NULL) + name = inst.argval + found = False + val = None + if name in closure_dict: + val = closure_dict[name] + found = True + elif name in globals_dict: + val = globals_dict[name] + found = True + elif name in builtins_dict: + val = builtins_dict[name] + found = True + + if found: + if isinstance(val, ModuleType): + stack.append(py_exprs.Module(val)) + else: + stack.append(py_exprs.PyObject(val)) + else: + stack.append(expression.UnboundVariableExpression(name)) + + case "LOAD_ATTR" | "LOAD_METHOD": + if not stack: + raise ValueError("Stack is empty") + target = stack.pop() + stack.append(py_exprs.GetAttr(target, inst.argval)) + + is_method_lookup = (opname == "LOAD_METHOD") or ( + opname == "LOAD_ATTR" + and sys.version_info >= (3, 12) + and inst.arg is not None + and (inst.arg & 1) + ) + if is_method_lookup: + if isinstance(target, py_exprs.Module): + stack.append(_NULL) + else: + stack.append(target) + + case "PUSH_NULL": + stack.append(_NULL) - raise ValueError("No return value found") + case "TO_BOOL": + if not stack: + raise ValueError("Stack is empty") + val = stack.pop() + stack.append( + py_exprs.Call( + py_exprs.PyObject(generic_ops.coerce_to_bool_op), + (val,), + ) + ) + + case "COPY": + idx = inst.arg + if idx is None or idx < 1 or len(stack) < idx: + raise ValueError( + f"Invalid COPY index or stack too small: {idx}" + ) + stack.append(stack[-idx]) + + case "UNARY_NOT": + if not stack: + raise ValueError("Stack is empty") + val = stack.pop() + val_bool = py_exprs.Call( + py_exprs.PyObject(generic_ops.coerce_to_bool_op), + (val,), + ) + stack.append( + py_exprs.Call( + py_exprs.PyObject(operator.not_), + (val_bool,), + ) + ) + + case "SWAP": + idx = inst.arg + if idx is None or idx < 1 or len(stack) < idx: + raise ValueError( + f"Invalid SWAP index or stack too small: {idx}" + ) + stack[-1], stack[-idx] = stack[-idx], stack[-1] + + case "ROT_TWO": + if len(stack) < 2: + raise ValueError("Stack has < 2 elements") + stack[-1], stack[-2] = stack[-2], stack[-1] + + case "ROT_THREE": + if len(stack) < 3: + raise ValueError("Stack has < 3 elements") + stack[-1], stack[-2], stack[-3] = stack[-2], stack[-3], stack[-1] + + case "DUP_TOP": + if not stack: + raise ValueError("Stack is empty") + stack.append(stack[-1]) + + case "BINARY_OP": + if len(stack) < 2: + raise ValueError("Stack is empty") + right = stack.pop() + left = stack.pop() + op_symbol = inst.argrepr + if not op_symbol and isinstance(inst.argval, str): + op_symbol = inst.argval + if op_symbol and op_symbol.endswith("="): + op_symbol = op_symbol[:-1] + + if op_symbol not in _BINARY_OP_MAP: + raise ValueError(f"Unsupported binary operator: {op_symbol}") + stack.append( + py_exprs.Call( + py_exprs.PyObject(_BINARY_OP_MAP[op_symbol]), + (left, right), + ) + ) + + case name if name in _OLD_BINARY_OP_MAP: + if len(stack) < 2: + raise ValueError("Stack has < 2 elements") + right = stack.pop() + left = stack.pop() + stack.append( + py_exprs.Call( + py_exprs.PyObject(_OLD_BINARY_OP_MAP[opname]), + (left, right), + ) + ) + + case "COMPARE_OP": + if len(stack) < 2: + raise ValueError("Stack has < 2 elements") + right = stack.pop() + left = stack.pop() + op_symbol = inst.argval + if op_symbol not in _COMPARE_OP_MAP: + raise ValueError(f"Unsupported compare operator: {op_symbol}") + stack.append( + py_exprs.Call( + py_exprs.PyObject(_COMPARE_OP_MAP[op_symbol]), + (left, right), + ) + ) + + case "UNARY_NEGATIVE" | "UNARY_INVERT": + if not stack: + raise ValueError("Stack is empty") + target = stack.pop() + stack.append( + py_exprs.Call( + py_exprs.PyObject( + operator.neg + if opname == "UNARY_NEGATIVE" + else operator.invert + ), + (target,), + ) + ) + + case "UNARY_POSITIVE": + if not stack: + raise ValueError("Stack is empty") + target = stack.pop() + stack.append( + py_exprs.Call(py_exprs.PyObject(operator.pos), (target,)) + ) + + case "CALL_INTRINSIC_1": + if inst.argrepr == "INTRINSIC_UNARY_POSITIVE": + if not stack: + raise ValueError("Stack is empty") + target = stack.pop() + stack.append( + py_exprs.Call(py_exprs.PyObject(operator.pos), (target,)) + ) + else: + raise ValueError(f"Unsupported intrinsic: {inst.argrepr}") + + case "CALL" | "CALL_FUNCTION" | "CALL_METHOD": + num_args = inst.arg + assert num_args is not None + if len(stack) < num_args: + raise ValueError(f"Stack has fewer than {num_args} elements") + args = [stack.pop() for _ in range(num_args)][::-1] + if len(stack) >= 2 and stack[-2] == _NULL: + stack[-1], stack[-2] = stack[-2], stack[-1] + if stack and stack[-1] == _NULL: + stack.pop() + elif ( + stack + and stack[-1] != _NULL + and isinstance(stack[-1], expression.Expression) + ): + self_arg = stack.pop() + args = [self_arg] + args + if not stack: + raise ValueError("Stack is empty") + callable_expr = stack.pop() + stack.append(py_exprs.Call(callable_expr, tuple(args))) + + case "RETURN_VALUE": + if not stack: + raise ValueError("Stack is empty") + returns.append((stack[-1], reach_cond)) + jumped = True + break + + case "RETURN_CONST": + returns.append((py_exprs.PyObject(inst.argval), reach_cond)) + jumped = True + break + + case "POP_TOP": + if stack: + stack.pop() + + case name if name in _UNCONDITIONAL_JUMP_OPNAMES: + dest = inst.argval + edge_conditions[(offset, dest)] = reach_cond + edge_stacks[(offset, dest)] = stack.copy() + jumped = True + break + + case "JUMP_IF_FALSE_OR_POP" | "JUMP_IF_TRUE_OR_POP": + if not stack: + raise ValueError("Stack is empty") + cond_expr = stack[-1] + cond_bool = py_exprs.Call( + py_exprs.PyObject(generic_ops.coerce_to_bool_op), + (cond_expr,), + ) + dest = inst.argval + next_offset = next_offsets.get(inst.offset) + if opname == "JUMP_IF_FALSE_OR_POP": + not_cond_bool = py_exprs.Call( + py_exprs.PyObject(operator.not_), (cond_bool,) + ) + edge_conditions[(offset, dest)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, not_cond_bool), + ) + edge_stacks[(offset, dest)] = stack.copy() + if next_offset is not None: + edge_conditions[(offset, next_offset)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, cond_bool), + ) + edge_stacks[(offset, next_offset)] = stack[:-1] + else: # JUMP_IF_TRUE_OR_POP + edge_conditions[(offset, dest)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, cond_bool), + ) + edge_stacks[(offset, dest)] = stack.copy() + if next_offset is not None: + not_cond_bool = py_exprs.Call( + py_exprs.PyObject(operator.not_), (cond_bool,) + ) + edge_conditions[(offset, next_offset)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, not_cond_bool), + ) + edge_stacks[(offset, next_offset)] = stack[:-1] + jumped = True + break + + case name if ( + name in _JUMP_IF_FALSE_OPNAMES or name in _JUMP_IF_TRUE_OPNAMES + ): + if not stack: + raise ValueError("Stack is empty") + cond_expr = stack.pop() + cond_expr = py_exprs.Call( + py_exprs.PyObject(generic_ops.coerce_to_bool_op), + (cond_expr,), + ) + + dest = inst.argval + next_offset = next_offsets.get(inst.offset) + + if opname in _JUMP_IF_FALSE_OPNAMES: + not_cond_expr = py_exprs.Call( + py_exprs.PyObject(operator.not_), (cond_expr,) + ) + edge_conditions[(offset, dest)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, not_cond_expr), + ) + edge_stacks[(offset, dest)] = stack.copy() + if next_offset is not None: + edge_conditions[(offset, next_offset)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, cond_expr), + ) + edge_stacks[(offset, next_offset)] = stack.copy() + else: # opname in _JUMP_IF_TRUE_OPNAMES + not_cond_expr = py_exprs.Call( + py_exprs.PyObject(operator.not_), (cond_expr,) + ) + edge_conditions[(offset, dest)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, cond_expr), + ) + edge_stacks[(offset, dest)] = stack.copy() + if next_offset is not None: + edge_conditions[(offset, next_offset)] = py_exprs.Call( + py_exprs.PyObject(operator.and_), + (reach_cond, not_cond_expr), + ) + edge_stacks[(offset, next_offset)] = stack.copy() + jumped = True + break + + case name if name in _ALL_JUMP_OPNAMES: + raise ValueError(f"Unsupported jump opcode: {opname}") + + case _: + raise ValueError(f"Unsupported opcode: {opname}") + + if not jumped: + next_offset = next_offsets.get(block.instructions[-1].offset) + if next_offset is not None: + edge_conditions[(offset, next_offset)] = reach_cond + edge_stacks[(offset, next_offset)] = stack.copy() + + block_outputs[offset] = (stack, local_vars) + + if not returns: + raise ValueError("No return value found") + + return merge_values(returns) def py_to_expression(func: Callable) -> expression.Expression: diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 3f9fcb5b75df..71767402b556 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -884,6 +884,25 @@ def numeric_to_datetime( ) +@scalar_op_compiler.register_unary_op(ops.coerce_to_bool_op) +def coerce_to_bool_op_impl(x: ibis_types.Value): + x_type = x.type() + if x_type.is_boolean(): + res = x + elif x_type.is_numeric(): + res = x != 0 # type: ignore + elif x_type.is_string(): + res = x.length() > 0 # type: ignore + elif x_type.is_binary(): + res = x.length() > 0 # type: ignore + elif isinstance(x_type, ibis_dtypes.Array): + res = x.length() > 0 # type: ignore + else: + res = x.notnull() + + return res.fill_null(False) # type: ignore + + @scalar_op_compiler.register_unary_op(ops.AsTypeOp, pass_op=True) def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp): to_type = bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype( diff --git a/packages/bigframes/bigframes/core/compile/polars/compiler.py b/packages/bigframes/bigframes/core/compile/polars/compiler.py index 2477f27b6432..0431c22a3403 100644 --- a/packages/bigframes/bigframes/core/compile/polars/compiler.py +++ b/packages/bigframes/bigframes/core/compile/polars/compiler.py @@ -370,6 +370,28 @@ def _( ) -> pl.Expr: return pl.when(condition).then(original).otherwise(otherwise) + @compile_op.register(gen_ops.CoerceToBoolOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, gen_ops.CoerceToBoolOp) + from_type = self._expr_types.get(id(input)) + if from_type is None: + return input.cast(pl.Boolean).fill_null(False) + + if from_type == bigframes.dtypes.BOOL_DTYPE: + res = input + elif bigframes.dtypes.is_numeric(from_type): + res = input != 0 + elif from_type == bigframes.dtypes.BYTES_DTYPE: + res = input.bin.size() > 0 + elif bigframes.dtypes.is_string_like(from_type): + res = input.str.len_chars() > 0 + elif bigframes.dtypes.is_array_like(from_type): + res = input.list.len() > 0 + else: + res = input.is_not_null() + + return res.fill_null(False) + @compile_op.register(gen_ops.AsTypeOp) def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: assert isinstance(op, gen_ops.AsTypeOp) diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 2cc27cb8e5a2..90c8270ae1d6 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -148,6 +148,28 @@ def _(expr: TypedExpr) -> sge.Expression: ) +@register_unary_op(ops.coerce_to_bool_op) +def _(expr: TypedExpr) -> sge.Expression: + from_type = expr.dtype + sg_expr = expr.expr + + if from_type == dtypes.BOOL_DTYPE: + res = sg_expr + elif dtypes.is_numeric(from_type): + res = sge.NEQ(this=sg_expr, expression=sge.convert(0)) + elif dtypes.is_string_like(from_type): + res = sge.GT(this=sge.func("LENGTH", sg_expr), expression=sge.convert(0)) + elif dtypes.is_array_like(from_type): + res = sge.GT(this=sge.func("ARRAY_LENGTH", sg_expr), expression=sge.convert(0)) + else: + res = sge.Is( + this=sge.paren(sg_expr, copy=False), + expression=sg.not_(sge.Null(), copy=False), + ) + + return sge.Coalesce(this=res, expressions=[sge.convert(False)]) + + @register_ternary_op(ops.where_op) def _( original: TypedExpr, condition: TypedExpr, replacement: TypedExpr diff --git a/packages/bigframes/bigframes/core/py_expressions.py b/packages/bigframes/bigframes/core/py_expressions.py index e1885e13afea..ddd88131d092 100644 --- a/packages/bigframes/bigframes/core/py_expressions.py +++ b/packages/bigframes/bigframes/core/py_expressions.py @@ -29,7 +29,13 @@ const, deref, ) -from bigframes.operations import NUMPY_TO_BINOP, NUMPY_TO_OP, generic_ops, numeric_ops +from bigframes.operations import ( + NUMPY_TO_BINOP, + NUMPY_TO_OP, + ScalarOp, + generic_ops, + numeric_ops, +) _CALLABLE_TO_OP = { **NUMPY_TO_OP, @@ -365,6 +371,8 @@ def resolve_call(call: Call) -> Expression: op = _CALLABLE_TO_OP[fn] return OpExpression(op, call.inputs) elif isinstance(callable, PyObject): + if isinstance(callable.value, ScalarOp): + return OpExpression(callable.value, call.inputs) if callable.value in python_op_maps.PYTHON_TO_BIGFRAMES: op = python_op_maps.PYTHON_TO_BIGFRAMES[callable.value] # type: ignore return OpExpression(op, call.inputs) diff --git a/packages/bigframes/bigframes/dtypes.py b/packages/bigframes/bigframes/dtypes.py index 51ee96432390..3cc7e918aa0f 100644 --- a/packages/bigframes/bigframes/dtypes.py +++ b/packages/bigframes/bigframes/dtypes.py @@ -468,7 +468,12 @@ def is_clusterable(type_: ExpressionType) -> bool: def is_bool_coercable(type_: ExpressionType) -> bool: # TODO: Implement more bool coercions - return (type_ is None) or is_numeric(type_) or is_string_like(type_) + return ( + (type_ is None) + or is_numeric(type_) + or is_string_like(type_) + or is_array_like(type_) + ) BIGFRAMES_STRING_TO_BIGFRAMES: Dict[DtypeString, Dtype] = { diff --git a/packages/bigframes/bigframes/operations/__init__.py b/packages/bigframes/bigframes/operations/__init__.py index b63a150afaea..f02091ab3919 100644 --- a/packages/bigframes/bigframes/operations/__init__.py +++ b/packages/bigframes/bigframes/operations/__init__.py @@ -93,6 +93,7 @@ from bigframes.operations.generic_ops import ( AsTypeOp, CaseWhenOp, + CoerceToBoolOp, IsInOp, MapOp, RowKey, @@ -100,6 +101,7 @@ case_when_op, clip_op, coalesce_op, + coerce_to_bool_op, fillna_op, hash_op, invert_op, @@ -255,6 +257,8 @@ "maximum_op", "minimum_op", "notnull_op", + "CoerceToBoolOp", + "coerce_to_bool_op", "RowKey", "SqlScalarOp", "where_op", diff --git a/packages/bigframes/bigframes/operations/generic_ops.py b/packages/bigframes/bigframes/operations/generic_ops.py index 99cda5fc095f..e4a4af90a8f8 100644 --- a/packages/bigframes/bigframes/operations/generic_ops.py +++ b/packages/bigframes/bigframes/operations/generic_ops.py @@ -45,6 +45,21 @@ ) notnull_op = NotNullOp() + +# Semantics match Python's truth value testing (truthy and falsey objects). +# See https://docs.python.org/3/library/stdtypes.html#truth-value-testing +CoerceToBoolOp = base_ops.create_unary_op( + name="coerce_to_bool", + type_signature=op_typing.FixedOutputType( + dtypes.is_bool_coercable, dtypes.BOOL_DTYPE, description="coercable to bool" + ), +) +CoerceToBoolOp.__doc__ = ( + "Coerce a value to a boolean, matching Python's truth value testing semantics " + "(truthy/falsey). See https://docs.python.org/3/library/stdtypes.html#truth-value-testing" +) +coerce_to_bool_op = CoerceToBoolOp() + HashOp = base_ops.create_unary_op( name="hash", type_signature=op_typing.FixedOutputType( diff --git a/packages/bigframes/bigframes/operations/python_op_maps.py b/packages/bigframes/bigframes/operations/python_op_maps.py index 7efe7fc12626..37d6f0174484 100644 --- a/packages/bigframes/bigframes/operations/python_op_maps.py +++ b/packages/bigframes/bigframes/operations/python_op_maps.py @@ -22,6 +22,7 @@ array_ops, bool_ops, comparison_ops, + generic_ops, numeric_ops, string_ops, ) @@ -47,6 +48,8 @@ operator.and_: bool_ops.and_op, operator.or_: bool_ops.or_op, operator.xor: bool_ops.xor_op, + operator.invert: generic_ops.invert_op, + operator.not_: generic_ops.invert_op, ## math math.log: numeric_ops.ln_op, math.log10: numeric_ops.log10_op, diff --git a/packages/bigframes/tests/system/small/engines/test_generic_ops.py b/packages/bigframes/tests/system/small/engines/test_generic_ops.py index 05739a1c1b63..438ef5153983 100644 --- a/packages/bigframes/tests/system/small/engines/test_generic_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_generic_ops.py @@ -405,6 +405,39 @@ def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +def test_engines_coerce_to_bool_op_scalars( + scalars_array_value: array_value.ArrayValue, engine +): + arr, _ = scalars_array_value.compute_values( + [ + ops.coerce_to_bool_op.as_expr(expression.deref("bool_col")), + ops.coerce_to_bool_op.as_expr(expression.deref("int64_col")), + ops.coerce_to_bool_op.as_expr(expression.deref("float64_col")), + ops.coerce_to_bool_op.as_expr(expression.deref("string_col")), + ops.coerce_to_bool_op.as_expr(expression.deref("bytes_col")), + ] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +def test_engines_coerce_to_bool_op_arrays( + arrays_array_value: array_value.ArrayValue, engine +): + arr, _ = arrays_array_value.compute_values( + [ + ops.coerce_to_bool_op.as_expr(expression.deref("int_list_col")), + ops.coerce_to_bool_op.as_expr(expression.deref("bool_list_col")), + ops.coerce_to_bool_op.as_expr(expression.deref("float_list_col")), + ops.coerce_to_bool_op.as_expr(expression.deref("string_list_col")), + ] + ) + + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine): arr, _ = scalars_array_value.compute_values( diff --git a/packages/bigframes/tests/unit/test_py_udf.py b/packages/bigframes/tests/unit/test_py_udf.py index ad491f6a7393..d4865a1a5bbd 100644 --- a/packages/bigframes/tests/unit/test_py_udf.py +++ b/packages/bigframes/tests/unit/test_py_udf.py @@ -15,6 +15,7 @@ import pathlib from typing import Generator +import numpy as np import pandas as pd import pandas.testing import pytest @@ -225,14 +226,6 @@ def foo(x, y): def test_transpilation_unsupported_ops_raise( scalars_df_index, ): - def foo_with_if(x): - if x > 0: - return x - return -x - - with pytest.raises(ValueError): - scalars_df_index["int64_col"].apply(foo_with_if) - def foo_with_loop(x): total = 0 for i in range(x): @@ -241,3 +234,220 @@ def foo_with_loop(x): with pytest.raises(ValueError): scalars_df_index["int64_col"].apply(foo_with_loop) + + +def my_foo(x: int): + return x + 1 + + +def test_local_series_apply_simple(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index["int64_col"].apply(my_foo).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(my_foo) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def my_numpy_foo(x: int): + return np.add(x, x) * (np.cos(x) - np.sin(3)) + + +def test_local_series_apply_w_numpy(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index["int64_col"].apply(my_numpy_foo).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(my_numpy_foo) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_simple_lamdba(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index["int64_col"].apply(lambda x: x + 3).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(lambda x: x + 3) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_ternary_lamdba(scalars_df_index, scalars_pandas_df_index): + bf_result = ( + scalars_df_index["int64_col"] + .apply(lambda x: "positive" if x > 0 else "negative") + .to_pandas() + ) + pd_result = scalars_pandas_df_index["int64_col"].apply( + lambda x: "positive" if x > 0 else "negative" + ) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_nested_fizzbuzz(session): + # challenging: closure, multiple exits, mutating variables + foo_div = 3 + buzz_div = 5 + pd_series = pd.Series( + range(20), + dtype="Int64", + index=pd.Index(range(20), dtype="Int64"), + name="integers", + ) + bf_series = bpd.Series(pd_series, session=session) + + def fizzbuzz(x): + if (x % 3) and (x % 5): + return str(x) + val = "" + if (x % foo_div) == 0: + val += "fizz" + if (x % buzz_div) == 0: + val += "buzz" + return val + + bf_result = bf_series.apply(fizzbuzz).to_pandas() + pd_result = pd_series.apply(fizzbuzz).astype(pd.StringDtype(storage="pyarrow")) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_dataframe_apply_w_ternary_lamdba( + scalars_df_index, scalars_pandas_df_index +): + bf_result = scalars_df_index.apply( + lambda x: x.int64_col if x.rowindex_2 > 5 else x.float64_col, axis=1 + ).to_pandas() + pd_result = scalars_pandas_df_index.apply( + lambda x: x.int64_col if x.rowindex_2 > 5 else x.float64_col, axis=1 + ).astype("Float64") + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_nested_ifs(scalars_df_index, scalars_pandas_df_index): + def nested_ifs(x): + if x > 0: + if x > 100: + return x * 10 + else: + return x * 2 + else: + if x < -100: + return x * 20 + return x * -1 + + bf_result = scalars_df_index["int64_col"].apply(nested_ifs).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(nested_ifs) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_elif(scalars_df_index, scalars_pandas_df_index): + def elif_fn(x): + if x > 100: + return 1 + elif x > 50: + return 2 + elif x > 0: + return 3 + else: + return 4 + + bf_result = scalars_df_index["int64_col"].apply(elif_fn).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(elif_fn) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_logical_not(scalars_df_index, scalars_pandas_df_index): + def logical_not_fn(x): + if not (x > 0): + return -x + return x + + bf_result = scalars_df_index["int64_col"].apply(logical_not_fn).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(logical_not_fn) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_short_circuit(scalars_df_index, scalars_pandas_df_index): + def short_circuit(x): + if (x > 0 and x < 100) or x == 55555: + return 1 + return 0 + + bf_result = scalars_df_index["int64_col"].apply(short_circuit).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(short_circuit) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_var_assignments( + scalars_df_index, scalars_pandas_df_index +): + def var_assign(x): + val = x + if x > 0: + val = val + 10 + if val > 100: + val = val * 2 + else: + val = val - 10 + return val + + bf_result = scalars_df_index["int64_col"].apply(var_assign).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].apply(var_assign) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_logical_and_val( + scalars_df_index, scalars_pandas_df_index +): + def logical_and_val(x): + return (x % 3) and 100 + + bf_result = ( + scalars_df_index["int64_col"].dropna().apply(logical_and_val).to_pandas() + ) + pd_result = scalars_pandas_df_index["int64_col"].dropna().apply(logical_and_val) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_logical_or_val(scalars_df_index, scalars_pandas_df_index): + def logical_or_val(x): + return (x % 3) or 200 + + bf_result = scalars_df_index["int64_col"].dropna().apply(logical_or_val).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].dropna().apply(logical_or_val) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_logical_and_mixed( + scalars_df_index, +): + def logical_and_mixed(x): + return (x % 3) and "hello" + + with pytest.raises(TypeError, match="Cannot coerce"): + scalars_df_index["int64_col"].apply(logical_and_mixed) + + +def test_local_series_apply_w_logical_not_val( + scalars_df_index, scalars_pandas_df_index +): + def logical_not_val(x): + return not x + + bf_result = scalars_df_index["bool_col"].dropna().apply(logical_not_val).to_pandas() + pd_result = scalars_pandas_df_index["bool_col"].dropna().apply(logical_not_val) + + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_local_series_apply_w_compare_chain(scalars_df_index, scalars_pandas_df_index): + def compare_chain(x): + return 0 < x < 1000 + + bf_result = scalars_df_index["int64_col"].dropna().apply(compare_chain).to_pandas() + pd_result = scalars_pandas_df_index["int64_col"].dropna().apply(compare_chain) + + assert_series_equal(bf_result, pd_result, check_dtype=False) diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/strings.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/strings.py index aa6d070162b6..c2dc151ae07e 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/strings.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/strings.py @@ -361,9 +361,10 @@ class ExtractFragment(ExtractURLField): @public -class StringLength(StringUnary): - """Compute the length of a string.""" +class StringLength(Unary): + """Compute the length of a string or binary value.""" + arg: Value[dt.String | dt.Binary] dtype = dt.int64 diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/types/binary.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/types/binary.py index 093f4cd42125..b89eb6c1f1ab 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/types/binary.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/types/binary.py @@ -35,6 +35,16 @@ def hashbytes( def __invert__(self) -> BinaryValue: return ops.BitwiseNot(self).to_expr() + def length(self) -> ir.IntegerValue: + """Compute the length of a binary value. + + Returns + ------- + IntegerValue + The length of each binary value in the expression + """ + return ops.StringLength(self).to_expr() + @public class BinaryScalar(Scalar, BinaryValue):