Range Inference

In Tensor Comprehensions, loops are implicit and output tensors sizes are inferred. Concretely, when a user write something like the following stencil:

A(i) = B(i) + C(i-2)

Tensor Comprehensions must deduce the range of values that i iterates over, and from this we derive the size of the tensor A. We also must infer the size of loops that reduce, for example, the implicit loop over k in a matrix multiply:

def matmul(float(M, K) A, float(K, N) B) -> (C) {
    C(m, n) +=! A(m, r_k) * B(r_k, n)
}

If this range inference procedure fails to match the user’s intent, then in the first case the output will not be the size they expect, and in the second case the output values will be incorrect, as either too few or too many terms were included in the summation.

To program productively, one must be able to mentally emulate the written code on the abstract machine defined by the language semantics. Regardless of how well-defined it is, if your source code doesn’t do what you think it does, you have a bug. Thus it’s critical that users build a mental model of how we infer ranges, and are able to do range inference in their heads as they write code. If this requires more thought than writing explicit loops would, we have failed.

With this in mind, we eschew heavy-duty mathematical tools, and take a more straightforward approach for the sake of usability. We infer ranges only in cases where we feel they are obvious, and require explicit annotation elsewhere. We intend to fine-tune this boundary in the future depending on what users find surprising.

The Range Inference Algorithm

To a good first approximation, we infer rectangular ranges that are as large as possible without reading out of bounds on the inputs. If there is more than one way to do this, we throw an error and require explicit annotation using a where clause.

This rule is sufficient to understand the matrix multiply case above. Maximizing the range of i gives it the range of the number of rows of B. Similarly maximizing the range of j gives it the range of the columns of C. k is used twice, so making k as large as possible gives it the lesser of the number of columns of B and the number of rows of C. These in turn are constrained to be equal by the type signature of the function (they are both K).

Now consider a stencil:

A(i) += B(i + k) * K(k)

There are multiple ways in which we could maximize the ranges of i and k. If we first maximize i, we might say that it ranges over the entirety of B. This forces k to take on a single value only, which will not result in the output one expects from a convolution (it ignores most of the kernel!). If we first maximize k, so that it ranges over the entirety of K, then in order to not read out of bounds the range of i must be smaller, and we get an output that is slightly smaller than the input. This is the behavior we prefer.

In order to make this unambiguous without requiring explicit annotation in this simple case, range inference proceeds in rounds. We maintain a set of unresolved variables. Initially it contains all variables not constrained with an explicit where clause. In each round, we consider all the tensor argument expressions that contain a single unresolved variable, and construct a boolean expression that states the access is not out-of-bounds. We then use tools from Halide (solve_for_inner_interval) to find the maximal range for the variable that satisfies this condition, given the ranges of variables already resolved in previous rounds. If the variable was already constrained this round by some other tensor access, we take the intersection of the inferred ranges.

For the stencil above, in the first round we ignore the expression B(i + k) because it contains multiple unresolved variables. We use the expression K(k) to deduce a range for k. In the second round, B(i + k) now contains a single unresolved variable, and we use the already-inferred range of k to deduce a maximal range for i.

Preconditions

While this procedure produces easy-to-justify ranges for each variable, it is not sufficient to ensure that no out-of-bounds reads occur. For example consider:

A(i, j) = B(i) * C(i + j) * D(j)

In the first round, we can resolve both i and j, using B(i) and D(j) respectively. This guarantees that there are no out-of-bounds reads on B and D, and defines the size of A. The size of C could be used to further restrain the range of i or j to avoid out-of-bounds reads on C. But there is no second round since the ranges for all variables have been resolved; and even if there would be, the expression i + j would not translate into a unique rectangular shape for i and j. So we are left with the requirement that C is large enough to cover all reads. In some cases we can statically prove this condition (for example if the sizes of B, C, and D are known constants). In general we emit a compile-time warning.

We intend to add runtime checking of these conditions in the future. However, for some preconditions, it is never desirable to check them at runtime. Consider a lookup table:

def lut(float(J) B, float(I) C) -> A {
  A(i) = B(C(i))
}

The range of i is constrained by its use in C, but we are left with the additional precondition that the values in C over that range never exceed the size of B. Checking this at runtime would require an expensive bounds check in the inner loop. As with the previous case, we currently just emit a compile-time notification that this unchecked precondition exists. The user can suppress it and make this code unconditionally safe by explicitly clamping the expression C(i) to be within the bounds of B like so:

def lut(float(J) B, float(I) C) -> A {
  A(i) = B(max(min(C(i), J-1), 0))
}

Though of course this also has a performance impact.

Worked Examples

We now describe how range inference reasons about several more complex examples. If you find a confusing case, feel free to request that we add it to this section.

Inverted indexing

def reverted(float(I) B) -> A {
  A(i) = B(10 - i)
}

From the use in B, range inference constructs the condition:

0 <= 10 - i < I

This is rearranged by Halide’s solver to give the following range:

9 - I <= i < 11

Strided indexing with constant stride

def subsample_2(float(I) B) -> A {
  A(i) = B(2*i)
}

From the use, range inference constructs the condition:

0 <= 2*i < I

This is rearranged into:

0 <= i < (I+1)/2

Note that the division is integer division, which rounds towards negative infinity in Tensor Comprehensions and Halide.

Strided indexing with offsets

def average_pool_2(float(I) B) -> A {
  A(i) = B(2*i) + B(2*i + 1)
}

From the uses, range inference constructs the conditions:

0 <= 2*i < I
0 <= 2*i + 1 < I

These are rearranged into:

0 <= i < (I+1)/2
0 <= i < I/2

The intersection of these two ranges is:

0 <= i < I/2

One could write the equivalent code:

def average_pool_2(float(I) B) -> A {
  A(i) = B(2*i + k) where k in 0:2
}

The syntax where k in lb:ub is inclusive of the lower bound and exclusive of the upper bound: it constrains the range of k to the integers between lb and ub-1, here k may only take the values 0 and 1.

Since the variable k is already resolved by the where clause. From the use of i, range inference constructs the condition:

0 <= 2*i + k < I

We eliminate k by taking the conjunction of the expression over all values of k, using Halide’s and_condition_over_domain. For the lower bound, k == 0 dominates. For the upper bound, k == 1 dominates.

0 <= 2*i && 2*i + 1 < I

This is equivalent to the intersection of the conditions in the unrolled case, and so we get the same result:

0 <= i < I/2

Strided indexing with dynamic stride

def subsample_2(float(I) B, int(1) S) -> A {
  A(i) = B(S(0)*i)
}

The value of S(0) is not fixed until runtime, so we can’t resolve the size of A or the range of the loop. This case throws a compile-time error. A where clause that defines the range of i is required.

Constant fill using an exists clause

def constant_fill(float(N) A, float c) -> B {
  B(i) = c where exists A(i)
}

An exists clause allows you to add additional expressions to the range inference process without having the expressions affect the actual computation. In this example, it allows you to say that B(i) should have the same size as A(i), but be filled with a constant value c. That is, you should infer the range of B(i) to exist at all the places where A(i) exists. It is equivalent to writing the expression true ? c : A(i), but with clearer intentions.