[S-TIR] Fix cache_read/cache_write region when inner block has T.whe…#19406
[S-TIR] Fix cache_read/cache_write region when inner block has T.whe…#19406tlopex merged 4 commits intoapache:mainfrom
Conversation
…re predicate
When the actual buffer access is gated by T.where on a nested (inner) sblock,
the outer block's own predicate is trivially true. Both cache_write and cache_read
were computing cache regions based only on that outer predicate, producing allocations
as large as the full loop extent instead of the guarded region
Fix:
- Add CollectNestedBlockPredicates(), a single helper parameterised by
BufferIndexType (kRead / kWrite) that walks the outer block's body,
finds nested sblocks accessing the target buffer, and AND-combines
their predicates after substituting iter-var bindings into the outer
scope.
- Add extra_predicate parameter to RelaxBufferRegion() and AND it with
the block's own predicate before region relaxation.
- cache_write: pass the collected nested-write predicate to
RelaxBufferRegion so the cache allocation is tightened.
- cache_read (Case 2 — input buffer): when a non-trivial nested-read
predicate exists, relax the consumer block's declared read region
under that predicate; otherwise fall back to the original
scope_block->reads path (preserves non-int32 dtypes in extents).
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to refine cache regions during cache_read and cache_write operations by collecting and applying predicates from nested blocks. However, the current implementation incorrectly uses an AND operation to combine predicates from nested blocks, which fails to account for sibling blocks where an OR (union) would be required. Additionally, there are potential null pointer dereferences when accessing the parent of a root block in both CacheRead and CacheWrite primitives.
`CollectNestedBlockPredicates` previously AND-ed the predicates of every nested sblock that accesses the target buffer. This is correct only for a single chain of nesting where all predicates must hold simultaneously. For sibling sblocks (sequential in a SeqStmt), each is an independent access path, so the cache region must cover the *union* of all their access sets — requiring OR, not AND. Fix: accumulate with `||` (OR-identity initial value `Bool(false)` + `found_` flag to skip the identity on the first match). Return `Bool(true)` when no nested block accesses the buffer, preserving the existing `scope_block->reads` / `FullRegion` fallback.
`ffi::GetRef<StmtSRef>` on a null pointer crashes if `block_sref` is the root block of a function (i.e. `block_sref->parent == nullptr`). Two call sites in CacheRead and CacheWrite were unguarded. CacheRead Case 2: add `block_sref->parent != nullptr` to the condition that gates `RelaxBufferRegion`. A root block has no enclosing loops, so the relaxation path is meaningless; the existing fallback to `scope_block->reads` / `FullRegion` is correct and safe. CacheWrite: move `parent_sref` construction inside an explicit null check. When the block is a root block, assign the write region directly as the cache region — again correct because there are no loop variables to relax over.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a new CollectNestedBlockPredicates function and modifies RelaxBufferRegion, CacheRead, and CacheWrite to accurately determine buffer regions for cache allocation. The change ensures that when nested blocks have T.where conditions, the cache region is computed as the union (OR-combination) of their access predicates, fixing a bug where an AND-combination could lead to under-allocated cache. New test cases were added to validate this behavior, especially for sibling nested blocks. A review comment points out that the documentation for CollectNestedBlockPredicates incorrectly describes its behavior as "AND-combined" instead of "OR-combined" predicates, which should be updated.
- Changed "AND-combined" to "OR-combined" in the brief and \return tag
- Added a sentence explaining why OR is correct: sibling blocks each
guarded by a different predicate must be unioned, not intersected
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request improves the precision of cache region analysis for cache_read and cache_write by accounting for predicates in nested blocks. It introduces CollectNestedBlockPredicates to aggregate these predicates using an OR-combination and updates RelaxBufferRegion to utilize them during relaxation. Feedback was provided regarding the traversal logic in CollectNestedBlockPredicates, noting that recursion should stop once a block accessing the target buffer is encountered to prevent redundant or incorrect predicate combinations from child blocks.
| if (accesses_buffer) { | ||
| // Build substitution: nested block iter vars -> their binding values | ||
| // (which are already expressed in terms of the outer scope). | ||
| ffi::Map<Var, PrimExpr> subst; | ||
| for (size_t i = 0; i < block->iter_vars.size(); ++i) { | ||
| subst.Set(block->iter_vars[i]->var, realize->iter_values[i]); | ||
| } | ||
| PrimExpr pred = | ||
| subst.empty() ? realize->predicate : Substitute(realize->predicate, subst); | ||
| // OR the predicates across all accessing nested blocks: each such block is an | ||
| // independent alternative access path (sibling blocks in a SeqStmt), so the | ||
| // cache must cover the *union* of their access regions, not the intersection. | ||
| // Using AND (the previous behaviour) underestimates the required region when | ||
| // sibling blocks have non-overlapping predicates. | ||
| result_ = found_ ? (result_ || pred) : pred; | ||
| found_ = true; | ||
| } | ||
| // Continue recursing into deeper nested blocks. | ||
| StmtVisitor::VisitStmt_(realize); |
There was a problem hiding this comment.
The current implementation of CollectNestedBlockPredicates may produce an incorrect combined predicate for deeply nested blocks. When a nested block that accesses the target buffer is found, its predicate is OR-ed with the result. However, the visitor then continues to recurse into that block's children.
If a child block also accesses the buffer, its predicate will also be OR-ed. This is incorrect because the child block's execution is already gated by the parent's predicate. The union of access regions is governed by the parent's predicate alone in this case. For example, if a parent block has predicate P1 and a child has P2, the effective predicate for an access in the child is P1 && P2. The union of accesses in the parent and child is covered by P1, not P1 || P2.
To fix this, the recursion should only proceed into blocks that do not access the target buffer. If a block accessing the buffer is found, its predicate should be collected, and the visitor should not recurse into its children.
if (accesses_buffer) {
// Build substitution: nested block iter vars -> their binding values
// (which are already expressed in terms of the outer scope).
ffi::Map<Var, PrimExpr> subst;
for (size_t i = 0; i < block->iter_vars.size(); ++i) {
subst.Set(block->iter_vars[i]->var, realize->iter_values[i]);
}
PrimExpr pred =
subst.empty() ? realize->predicate : Substitute(realize->predicate, subst);
// OR the predicates across all accessing nested blocks: each such block is an
// independent alternative access path (sibling blocks in a SeqStmt), so the
// cache must cover the *union* of their access regions, not the intersection.
// Using AND (the previous behaviour) underestimates the required region when
// sibling blocks have non-overlapping predicates.
result_ = found_ ? (result_ || pred) : pred;
found_ = true;
} else {
// Continue recursing into deeper nested blocks if this one doesn't access the buffer.
StmtVisitor::VisitStmt_(realize);
}
tlopex
left a comment
There was a problem hiding this comment.
LGTM! Thanks for the contribution
…re predicate
When the actual buffer access is gated by T.where on a nested (inner) sblock, the outer block's own predicate is trivially true. Both cache_write and cache_read were computing cache regions based only on that outer predicate, producing allocations as large as the full loop extent instead of the guarded region
Fix: