Skip to content

[S-TIR] Fix cache_read/cache_write region when inner block has T.whe…#19406

Merged
tlopex merged 4 commits intoapache:mainfrom
elvin-n:amalyshe/cache_write_predicate
Apr 16, 2026
Merged

[S-TIR] Fix cache_read/cache_write region when inner block has T.whe…#19406
tlopex merged 4 commits intoapache:mainfrom
elvin-n:amalyshe/cache_write_predicate

Conversation

@elvin-n
Copy link
Copy Markdown
Contributor

@elvin-n elvin-n commented Apr 15, 2026

…re predicate

When the actual buffer access is gated by T.where on a nested (inner) sblock, the outer block's own predicate is trivially true. Both cache_write and cache_read were computing cache regions based only on that outer predicate, producing allocations as large as the full loop extent instead of the guarded region

Fix:

  • Add CollectNestedBlockPredicates(), a single helper parameterised by BufferIndexType (kRead / kWrite) that walks the outer block's body, finds nested sblocks accessing the target buffer, and AND-combines their predicates after substituting iter-var bindings into the outer scope.
  • Add extra_predicate parameter to RelaxBufferRegion() and AND it with the block's own predicate before region relaxation.
  • cache_write: pass the collected nested-write predicate to RelaxBufferRegion so the cache allocation is tightened.
  • cache_read (Case 2 — input buffer): when a non-trivial nested-read predicate exists, relax the consumer block's declared read region under that predicate; otherwise fall back to the original scope_block->reads path (preserves non-int32 dtypes in extents).

…re predicate

  When the actual buffer access is gated by T.where on a nested (inner) sblock,
the outer block's own predicate is trivially true. Both cache_write and cache_read
were computing cache regions based only on that outer predicate, producing allocations
as large as the full loop extent instead of the guarded region

  Fix:
  - Add CollectNestedBlockPredicates(), a single helper parameterised by
    BufferIndexType (kRead / kWrite) that walks the outer block's body,
    finds nested sblocks accessing the target buffer, and AND-combines
    their predicates after substituting iter-var bindings into the outer
    scope.
  - Add extra_predicate parameter to RelaxBufferRegion() and AND it with
    the block's own predicate before region relaxation.
  - cache_write: pass the collected nested-write predicate to
    RelaxBufferRegion so the cache allocation is tightened.
  - cache_read (Case 2 — input buffer): when a non-trivial nested-read
    predicate exists, relax the consumer block's declared read region
    under that predicate; otherwise fall back to the original
    scope_block->reads path (preserves non-int32 dtypes in extents).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to refine cache regions during cache_read and cache_write operations by collecting and applying predicates from nested blocks. However, the current implementation incorrectly uses an AND operation to combine predicates from nested blocks, which fails to account for sibling blocks where an OR (union) would be required. Additionally, there are potential null pointer dereferences when accessing the parent of a root block in both CacheRead and CacheWrite primitives.

Comment thread src/s_tir/schedule/primitive/cache_read_write.cc Outdated
Comment thread src/s_tir/schedule/primitive/cache_read_write.cc
Comment thread src/s_tir/schedule/primitive/cache_read_write.cc Outdated
elvin-n added 2 commits April 15, 2026 16:10
`CollectNestedBlockPredicates` previously AND-ed the predicates of every
nested sblock that accesses the target buffer.  This is correct only for
a single chain of nesting where all predicates must hold simultaneously.
For sibling sblocks (sequential in a SeqStmt), each is an independent
access path, so the cache region must cover the *union* of all their
access sets — requiring OR, not AND.

Fix: accumulate with `||` (OR-identity initial value `Bool(false)` +
  `found_` flag to skip the identity on the first match).  Return
  `Bool(true)` when no nested block accesses the buffer, preserving the
  existing `scope_block->reads` / `FullRegion` fallback.
`ffi::GetRef<StmtSRef>` on a null pointer crashes if `block_sref` is
the root block of a function (i.e. `block_sref->parent == nullptr`).
Two call sites in CacheRead and CacheWrite were unguarded.

CacheRead Case 2: add `block_sref->parent != nullptr` to the condition
that gates `RelaxBufferRegion`.  A root block has no enclosing loops,
so   the relaxation path is meaningless; the existing fallback to
`scope_block->reads` / `FullRegion` is correct and safe.

CacheWrite: move `parent_sref` construction inside an explicit null
check.  When the block is a root block, assign the write region directly
as the cache region — again correct because there are no loop variables
to relax over.
@elvin-n
Copy link
Copy Markdown
Contributor Author

elvin-n commented Apr 15, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CollectNestedBlockPredicates function and modifies RelaxBufferRegion, CacheRead, and CacheWrite to accurately determine buffer regions for cache allocation. The change ensures that when nested blocks have T.where conditions, the cache region is computed as the union (OR-combination) of their access predicates, fixing a bug where an AND-combination could lead to under-allocated cache. New test cases were added to validate this behavior, especially for sibling nested blocks. A review comment points out that the documentation for CollectNestedBlockPredicates incorrectly describes its behavior as "AND-combined" instead of "OR-combined" predicates, which should be updated.

Comment thread src/s_tir/schedule/primitive/cache_read_write.cc Outdated
  - Changed "AND-combined" to "OR-combined" in the brief and \return tag
  - Added a sentence explaining why OR is correct: sibling blocks each
    guarded by a different predicate must be unioned, not intersected
@elvin-n
Copy link
Copy Markdown
Contributor Author

elvin-n commented Apr 15, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves the precision of cache region analysis for cache_read and cache_write by accounting for predicates in nested blocks. It introduces CollectNestedBlockPredicates to aggregate these predicates using an OR-combination and updates RelaxBufferRegion to utilize them during relaxation. Feedback was provided regarding the traversal logic in CollectNestedBlockPredicates, noting that recursion should stop once a block accessing the target buffer is encountered to prevent redundant or incorrect predicate combinations from child blocks.

Comment on lines +576 to +594
if (accesses_buffer) {
// Build substitution: nested block iter vars -> their binding values
// (which are already expressed in terms of the outer scope).
ffi::Map<Var, PrimExpr> subst;
for (size_t i = 0; i < block->iter_vars.size(); ++i) {
subst.Set(block->iter_vars[i]->var, realize->iter_values[i]);
}
PrimExpr pred =
subst.empty() ? realize->predicate : Substitute(realize->predicate, subst);
// OR the predicates across all accessing nested blocks: each such block is an
// independent alternative access path (sibling blocks in a SeqStmt), so the
// cache must cover the *union* of their access regions, not the intersection.
// Using AND (the previous behaviour) underestimates the required region when
// sibling blocks have non-overlapping predicates.
result_ = found_ ? (result_ || pred) : pred;
found_ = true;
}
// Continue recursing into deeper nested blocks.
StmtVisitor::VisitStmt_(realize);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of CollectNestedBlockPredicates may produce an incorrect combined predicate for deeply nested blocks. When a nested block that accesses the target buffer is found, its predicate is OR-ed with the result. However, the visitor then continues to recurse into that block's children.

If a child block also accesses the buffer, its predicate will also be OR-ed. This is incorrect because the child block's execution is already gated by the parent's predicate. The union of access regions is governed by the parent's predicate alone in this case. For example, if a parent block has predicate P1 and a child has P2, the effective predicate for an access in the child is P1 && P2. The union of accesses in the parent and child is covered by P1, not P1 || P2.

To fix this, the recursion should only proceed into blocks that do not access the target buffer. If a block accessing the buffer is found, its predicate should be collected, and the visitor should not recurse into its children.

      if (accesses_buffer) {
        // Build substitution: nested block iter vars -> their binding values
        // (which are already expressed in terms of the outer scope).
        ffi::Map<Var, PrimExpr> subst;
        for (size_t i = 0; i < block->iter_vars.size(); ++i) {
          subst.Set(block->iter_vars[i]->var, realize->iter_values[i]);
        }
        PrimExpr pred =
            subst.empty() ? realize->predicate : Substitute(realize->predicate, subst);
        // OR the predicates across all accessing nested blocks: each such block is an
        // independent alternative access path (sibling blocks in a SeqStmt), so the
        // cache must cover the *union* of their access regions, not the intersection.
        // Using AND (the previous behaviour) underestimates the required region when
        // sibling blocks have non-overlapping predicates.
        result_ = found_ ? (result_ || pred) : pred;
        found_ = true;
      } else {
        // Continue recursing into deeper nested blocks if this one doesn't access the buffer.
        StmtVisitor::VisitStmt_(realize);
      }

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the contribution

@tlopex tlopex merged commit 0a79095 into apache:main Apr 16, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants