Skip to content

Fix torch.split fails in to_edge with alias annotations#18700

Open
Lidang-Jiang wants to merge 2 commits intopytorch:mainfrom
Lidang-Jiang:fix/split-to-edge
Open

Fix torch.split fails in to_edge with alias annotations#18700
Lidang-Jiang wants to merge 2 commits intopytorch:mainfrom
Lidang-Jiang:fix/split-to-edge

Conversation

@Lidang-Jiang
Copy link
Copy Markdown

Fixes #11723

Summary

torch.split fails with RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations when used with to_edge_transform_and_lower and a partitioner that requests op preservation.

Root cause: _remove_invalid_ops_for_not_decompose relies on torchgen's aliased_return_names() to detect ops with aliased returns. However, for ops returning lists of aliased tensors (e.g., split.Tensor returns Tensor(a)[]), aliased_return_names() returns [None], failing to detect the alias annotation. This lets split.Tensor pass through into the EDGE_DO_NOT_DECOMP namespace, where functionalization fails.

Fix: Add a fallback check using op._schema.returns directly, which correctly reports alias_info on list return types. This also fixes the same latent issue for chunk.default and tensor_split.sections.

Test plan

  • Added test_remove_invalid_ops_filters_aliased_list_returns regression test
  • Run: pytest exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns -xvs
  • Verified existing split-related test still passes: test_to_out_variant_singleon_tensor_list
  • Verified existing broken ops test still passes: test_compile_fix_broken_ops
Before fix
==================== BEFORE FIX ====================
RESULT: FAILED
RuntimeError: Found a custom (non-ATen) operator whose output has alias annotations: EDGE_DO_NOT_DECOMP::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]. We only support functionalizing operators whose outputs do not have alias annotations (e.g. 'Tensor(a)' is a Tensor with an alias annotation whereas 'Tensor' is a Tensor without. The '(a)' is the alias annotation). The alias annotation specifies that the output Tensor shares storage with an input that has the same annotation. Please check if (1) the output needs to be an output (if not, don't return it), (2) if the output doesn't share storage with any inputs, then delete the alias annotation. (3) if the output indeed shares storage with an input, then add a .clone() before returning it to prevent storage sharing and then delete the alias annotation. Otherwise, please file an issue on GitHub.

While executing %split : [num_users=3] = call_function[target=torch.ops.EDGE_DO_NOT_DECOMP.split.Tensor](args = (%x, 2), kwargs = {})
Original traceback:
None
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
After fix
==================== AFTER FIX ====================
WARNING:root:Op aten.split.Tensor was requested for preservation by partitioner.  This request is ignored because it aliases output.

Test 1: to_edge (no partitioner)
RESULT: SUCCESS - outputs match

Test 2: to_edge_transform_and_lower with split.Tensor preservation
RESULT: SUCCESS - split.Tensor correctly filtered from EDGE_DO_NOT_DECOMP
         (AttributeError from dummy partitioner partition(), not from split bug)

Test 3: _remove_invalid_ops_for_not_decompose filter check
  aten::split.Tensor                            -> FILTERED (correct)
  aten::chunk                                   -> FILTERED (correct)
  aten::tensor_split.sections                   -> FILTERED (correct)
Unit test output
$ pytest exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns -xvs

============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-8.4.2
collected 1 item

exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns PASSED

============================== 1 passed in 6.83s ===============================

$ pytest exir/tests/test_passes.py::TestPasses::test_to_out_variant_singleon_tensor_list -xvs
PASSED

$ pytest exir/tests/test_passes.py::TestPasses::test_compile_fix_broken_ops -xvs
PASSED

This PR was authored with the assistance of Claude.

Fixes pytorch#11723

_remove_invalid_ops_for_not_decompose relied on torchgen's
aliased_return_names() to detect ops with aliased returns, but it
returns [None] for ops returning lists of aliased tensors (e.g.,
split.Tensor returns Tensor(a)[]). This let split.Tensor through
into the EDGE_DO_NOT_DECOMP namespace where functionalization failed.

Add a fallback check using op._schema.returns directly, which
correctly reports alias_info on list return types. This also
fixes the same latent issue for chunk and tensor_split.

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18700

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 11 Awaiting Approval

As of commit c79aca7 with merge base 6020c29 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 4, 2026
- Change 'may fail' to 'does not detect' (torchgen structurally cannot
  handle ListType alias annotations)
- Add split_with_sizes.default to test to document overlap with blocklist

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
@Lidang-Jiang
Copy link
Copy Markdown
Author

@pytorchbot label "release notes: exir"

@pytorch-bot pytorch-bot bot added the release notes: exir Changes to any dialects and passes on these dialects, such as memory planning label Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: exir Changes to any dialects and passes on these dialects, such as memory planning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.split fails in to_edge

2 participants