Python API

High-level API

We provide a high-level API which allows one to easily experiment with Tensor Comprehensions.

tensor_comprehensions.define(tc, mapping_options_factory)[source]

Create a helper class with methods that implement multiple TC defs.

Parsing a TC string with multiple defs and return a helper object with method names that match each of the TC def. Later, JIT compilation occurs on-demand the first time one such method is called with PyTorch Tensors of new sizes. The returned TC helper class is backed by a compilation cache which memoizes the results of compilation and avoids spurious recompilations. In order to determine the MappingOptions, used for JIT compiling a particular TC def on inputs of particular sizes, the mapping_options_factory function is called. We provide the factory builder functions make_naive_options_factory(), make_load_from_cache_options_factory() and make_autotuned_options_factory()

Further user-defined factory functions can be easily written to extend the behavior.

Warning

If you chose to benchmark TC using this high-level API, be sure to understand how compilation, tuning and memoization interact. More generally, the low-level API should be used for benchmarking purposes.

Parameters:
  • tc (str) – a string containing one of more TC defs.
  • mapping_options_factory (Callable[[str, str, Iterable[Tensor]], MappingOptions]) – a function that takes a string with multiple TC defs, an entry_point and input PyTorch Tensors and produces a MappingOptions.
Return type:

a Callable helper object with methods corresponding to the TC def names and backed by a compilation cache.

Examples

One can define TC functions compiled with naive options for the purpose of correctness check debugging:

>>> T = tc.define(
... '''
... def add(float(N) A, float(N) B) -> (C) { C(i) = A(i) + B(i) }
... def sub(float(N) A, float(N) B) -> (C) { C(i) = A(i) - B(i) }
... ''',
... tc.make_naive_options_factory())
... A, B = torch.randn(100, device='cuda'), torch.randn(100, device='cuda')
... C = T.add(A, B)
... tc.assert_almost_equal(C, torch.add(A, B), A, B)
... D = T.sub(A, B)
... tc.assert_almost_equal(D, (A - B), A, B)

One can also obtain a reinforced tuning behavior by:

>>> tuner_config = tc.TunerConfig().threads(5).generations(3).pop_size(5)
... with tempfile.NamedTemporaryFile() as cache_file:
...     group_normalization = '''
...     def moments(float(N, K) I) -> (mean, var) {
...         # var = E(x^2) - mean^2.
...         mean(n) +=! I(n, r_k)
...          var(n) +=! I(n, r_k) * I(n, r_k)
...         mean(n)  = mean(n) / (K)
...          var(n)  =  var(n) / (K) - mean(n) * mean(n)
...     }
...
...     def group_normalization(
...         float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta,
...         float(N, G) mean, float(N, G) var) -> (O)
...     {
...         O(n, g, d, h, w) = gamma(g, d)
...             * ( I(n, g, d, h, w) - mean(n, g) )
...             * rsqrt( var(n, g) + 1e-5 )
...             + beta(g, d)
...     }
...     '''
...
...     N, G, D, H, W = 32, 32, 4, 56, 56
...     I, gamma, beta = (
...         torch.randn(N, G, D, H, W, device='cuda'),
...         torch.randn(G, D, device='cuda'),
...         torch.randn(G, D, device='cuda'))
...
...     T = tc.define(
...         group_normalization,
...         tc.make_autotuned_options_factory(
...             starting_options='naive',
...             tuner_config=tuner_config,
...             cache_filename=cache_file.name,
...             store_to_cache=True))
...     # First occurrence triggers tuning from naive options and
...     # stores to cache.
...     mean, var = T.moments(I.view((N * G, -1)))
...     out = T.group_normalization(
...         I, gamma, beta, mean.view((N, G)), var.view((N, G)))
...
...     # Create a new TC object to retrigger tuning, this time
...     # starting from MappingOptions loaded from cache.
...     T = tc.define(
...         group_normalization,
...         tc.make_autotuned_options_factory(
...             tuner_config=tuner_config,
...             cache_filename=cache_file.name,
...             load_from_cache=True,
...             store_to_cache=True))
...     mean, var = T.moments(I.view((N * G, -1)))
...     out = T.group_normalization(
...         I, gamma, beta, mean.view((N, G)), var.view((N, G)))
tensor_comprehensions.make_autograd(forward_fun, backward_fun)[source]

Create a Callable helper object with torch.autograd support.

Parameters:
  • forward_fun (Callable[[Iterable[Tensor]], Iterable[Tensor]]) – a function that takes PyTorch Tensors and implements the forward operation. Returns PyTorch Tensors.
  • backward_fun (Callable[[Iterable[Tensor]], Iterable[Tensor]]) – a function that takes PyTorch Tensors and implements the forward operation. Returns PyTorch Tensors.
Return type:

a Callable helper object with torch.autograd support.

Warning

If you chose to benchmark TC using this high-level API, be sure to understand how autogr, compilation, tuning and memoization interact. More generally, the low-level API should be used for benchmarking purposes.

Example

>>> conv = '''
... def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
...     O(n, m, h, w) +=!
...         I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
... }
... def convolution_igrad(float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
...     -> (d_I)
... {
...     d_I(n, c, h, w) +=!
...         d_O(  n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw)
... }
... def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
... {
...     d_W1(m, c, kh, kw) +=!
...         d_O(r_n,   m, r_h - kh, r_w - kw) *  I(r_n, c,  r_h,  r_w)
... }
... '''
...
... N, C, H, W, O, kH, kW = 32, 4, 56, 56, 16, 1, 1
... T = tc.define(
...     conv,
...     tc.make_autotuned_options_factory(
...         starting_options='naive',
...         tuner_config=tuner_config))
... I, W = (
...     torch.randn(N, C, H, W, device='cuda', requires_grad=True),
...     torch.randn(O, C, kH, kW, device='cuda', requires_grad=True))
...
... def convolution_backward(I, W, d_O):
...     d_I = T.convolution_igrad(W, d_O)
...     d_O = T.convolution_wgrad(I, d_O)
...     return (d_I, d_O)
...
... convolution_function = tc.make_autograd(
...     T.convolution, convolution_backward)
...
... # First occurrence triggers tuning
... out = convolution_function(I, W)
... out.sum().backward()
...
... # Subsequent occurrences do not
... out = convolution_function(I, W)
... out.sum().backward()

The define() function provides an implicit compilation caching functionality which alleviates the need to implement a caching mechanism at the user-facing level. The question still remains which MappingOptions to use to compile. Since this is still an open problem, we provide support for user-defined functions to specify this behavior. We require a user of the define() function to provide a MappingOptions generator function whose sole purpose is to determine the options with which to compile a particular TC def for particular input sizes.

To facilitate usage we provide the following generators:

tensor_comprehensions.make_naive_options_factory()[source]

Return a factory that always generates naive MappingOptions.

For easily getting started with TC and debugging purposes only.

Return type:a function that takes a string with multiple TC defs, an entry_point and input PyTorch Tensors and produces a MappingOptions.
tensor_comprehensions.make_load_from_cache_options_factory(cache_filename)[source]

Return a factory that loads MappingOptions from a cache file.

Parameters:cache_filename (str) – the filename
Return type:a function that takes a string with multiple TC defs, an entry_point and input PyTorch Tensors and produces a MappingOptions.
tensor_comprehensions.make_autotuned_options_factory(starting_options=None, tuner_config=<tensor_comprehensions.tclib.TunerConfig object>, cache_filename=None, load_from_cache=False, store_to_cache=False)[source]

Return a factory that runs autotuning to determine the best MappingOptions.

The returned factory just calls the autotune() function, see its documentation for more information.

Return type:a function that takes a string with multiple TC defs, an entry_point and input PyTorch Tensors and produces a MappingOptions.

Custom behavior to select MappingOptions may be implemented in addition to the provided defaults. The signature of custom generators must match:

def some_generator(tc: str, entry_point: str, *inputs: torch.Tensor)
    -> MappingOptions:
        ...

Low-level API

We also provide a low-overhead API which avoids implicit behavior and is generally useful for benchmarking.

class tensor_comprehensions.Executor(executor)[source]

Callable helper class to hold the result of compiling a TC def with fixed input sizes.

__call__(*inputs, outputs=None, unchecked=False)[source]

Run the compiled TC kernel.

Parameters:
  • inputs (Tensor) – PyTorch Tensors for which the compiled kernel has been specialized. You must use tensors of the same sizes as you have specialized for otherwise illegal memory accesses will occur.
  • outputs (Optional[Tuple[Tensor]]) – PyTorch Tensors into which the TC kernel will write. If left unspecified, new tensors will be allocated (which will have a noticeable performance impact until the caching allocator kicks in).
  • unchecked (Optional[bool]) – Disable shape checks (at your own risk) which reduces overhead in the case of low-latency kernels.
Returns:

A PyTorch Tensor, or a tuple of Pytorch Tensors in the case of multiple return values.

Example

>>> A, B = (
...     torch.randn(100, device='cuda').fill_(1),
...     torch.randn(100, device='cuda').fill_(1))
... add = tc.compile(
...     'def add(float(N) A, float(N) B) -> (C) { C(i) = A(i) + B(i) }',
...     'add',
...     'naive',
...     A, B,
... )
... C = add(A, B)
>>> print(C.min(), C.max())
tensor(2., device='cuda:0') tensor(2., device='cuda:0')
Return type:Union[Tensor, Tuple[Tensor]]
tensor_comprehensions.compile(tc, entry_point, mapping_options, *inputs)[source]

Returns a compiled, callable, low-overhead Executor.

An example of usage is provided in Executor.

Parameters:
  • tc (str) – a string containing one of more TC defs.
  • entry_point (str) – the name of the TC def to compile and execute.
  • mapping_options (Union[str, MappingOptions]) – the options to use for compilation.
  • inputs (Tensor) – PyTorch Tensors for which the compiled kernel is specialized.
Return type:

Executor, a low-overhead callable class to launch the kernel compiled from the entry_point.

tensor_comprehensions.autotune(tc, entry_point, *inputs, starting_options=None, tuner_config=<tensor_comprehensions.tclib.TunerConfig object>, cache_filename=None, load_from_cache=False, store_to_cache=False)[source]

Tunes the defined TC function for given inputs.

The MappingOptions from which tuning starts is either passed explicitly via starting_options or loaded from a cache file (when both cache_filename and load_from_cache are properly specified). Exactly one of starting_options and load_from_cache must be specified.

It is possible to obtain a reinforcement tuning behavior by tuning over multiple executions and specifying both load_from_cache and store_to_cache. It is recommended to only use a single cache file for all TC defs and reinforce it over time.

An example of usage is provided with autotune_and_compile().

Parameters:
  • tc (str) – a string containing one of more TC defs.
  • entry_point (str) – the name of the TC def to compile and execute.
  • inputs (Tensor) – PyTorch Tensors that TC should tune for. The inputs must be passed in the order they are also passed in the definition of the TC function.
  • starting_options (Union[str, MappingOptions, None]) – MappingOptions from which tuning should start.
  • tuner_config (Optional[TunerConfig]) – TunerConfig to control the behavior of the autotuner.
  • load_from_cache (Optional[bool]) – Get the starting MappingOptions by loading from cache_filename. If loading fails to recover an entry from the cache file for the given input sizes an assertion error will trigger.
  • store_to_cache (Optional[bool]) – Optionally store the best result by appending it to the backing cache file.
Returns:

The best options found during this tuning run.

Return type:

MappingOptions

tensor_comprehensions.autotune_and_compile(tc, entry_point, *inputs, starting_options=None, tuner_config=<tensor_comprehensions.tclib.TunerConfig object>, cache_filename=None, load_from_cache=False, store_to_cache=False)[source]

Calls autotune, compiles with best options then returns an Executor.

Takes the same arguments as the autotune() function.

Example

>>> A, B = (
... torch.randn(10 ** 5, device='cuda').fill_(1.0),
... torch.randn(10 ** 5, device='cuda').fill_(1.0))
... add = tc.autotune_and_compile(
...    "def add(float(N) A, float(N) B) -> (C) { C(i) = A(i) + B(i) }",
...    "add",
...    A, B,
...    starting_options='naive',
...    tuner_config=tc.TunerConfig().threads(5).generations(3).pop_size(5)
... )
... C = add(A, B)
>>> print(C.min(), C.max())
tensor(2., device='cuda:0') tensor(2., device='cuda:0')
Return type:Executor

Additionally the assert_almost_equal() helper function is useful in performing numerical checks.

tensor_comprehensions.assert_almost_equal(actual, expected, *inputs, operations=1, precision=1e-07)[source]

Asserts numerical precision requirements.

Parameters:
  • actual (Tensor) – the PyTorch Tensor to check.
  • expected (Tensor) – the expected PyTorch Tensor.
  • inputs (Tensor) – PyTorch Tensors passed as inputs to the TC that produced the actual Tensor.
  • operations (Optional[int]) – maximum number of iterated operations per produced value. This is used to compute the required absolute precision.
  • precision (Optional[float]) – relative precision at which to check.

Caching and Configuration

Finally we also document a subset of the helper types for caching and configuration that are commonly used.

Python bindings for Tensor Comprehensions

class tensor_comprehensions.tclib.MappingOptionsCache

Helper class to manipulate cache files containing serialized MappingOptions

load(self: tensor_comprehensions.tclib.MappingOptionsCache, arg0: str, arg1: str, arg2: tuple, arg3: int) → List[tc::CudaMappingOptions]

Load the best entries from cache.

Parameters:
  • tc – a string containing one of more TC defs
  • entry_point – the TC def to compile and execute
  • inputs – Pytorch Tensors whose sizes we build an executor for
  • num_candidates – number of candidates to return

Example

>>> import tensor_comprehensions as tc
... import tensor_comprehensions.tclib as tclib
... cache = tc.MappingOptionsCache(cache_file.name)
... best_options, = cache.load(
...     tensordot_str, entry_point, (I0, I1), 10)
... executor = tclib.compile(
...     mm_str, "matmul", (A, B), tc.MappingOptions('naive'))
... C = executor.run((A, B), ())
Returns:A vector of MappingOptions
class tensor_comprehensions.tclib.MappingOptions

MappingOptions to drive the polyhedral compiler

fixParametersBeforeScheduling(self: tensor_comprehensions.tclib.MappingOptions, arg0: bool) → tensor_comprehensions.tclib.MappingOptions

Perform automatic loop scheduling taking into account specific tensor sizes. May produce faster kernels but significantly increases compilation time. Note that the mapping will be performed for specific tensor sizes anyway

intraTileScheduleFusionStrategy(self: tensor_comprehensions.tclib.MappingOptions, arg0: str) → tensor_comprehensions.tclib.MappingOptions

Require TC to try and execute different TC expressions interleaved (Max), separately (Min) or interleaved as long as sufficient parallelism is exploited (Preserve3Coincident) by performing loop fusion and fission. Applies before tiling

mapToBlocks(self: tensor_comprehensions.tclib.MappingOptions, arg0: List[int]) → tensor_comprehensions.tclib.MappingOptions

The configuration of CUDA grid, i.e. the number of CUDA blocks along three dimensions. Must be within the range allowed by CUDA (maximum 2^31-1 for the first value and 65535 for the second and third)

mapToThreads(self: tensor_comprehensions.tclib.MappingOptions, arg0: List[int]) → tensor_comprehensions.tclib.MappingOptions

The configuration of CUDA block, i.e. the number of CUDA threads in each block along three dimensions. Must be within the range allowed by CUDA (maximum 1024 for the first and second value, 32 for the third, product below 1024)

matchLibraryCalls(self: tensor_comprehensions.tclib.MappingOptions, arg0: bool) → tensor_comprehensions.tclib.MappingOptions

Replace computation patterns with calls to highly optimized libraries (such as CUB, CUTLASS, …) when possible

maxSharedMemory(self: tensor_comprehensions.tclib.MappingOptions, arg0: int) → tensor_comprehensions.tclib.MappingOptions

The amount of shared memory to use, in bytes. If not provided, TC will query the active GPU and use all available shared memory.

outerScheduleFusionStrategy(self: tensor_comprehensions.tclib.MappingOptions, arg0: str) → tensor_comprehensions.tclib.MappingOptions

Require TC to try and execute different TC expressions interleaved (Max), separately (Min) or interleaved as long as sufficient parallelism is exploited (Preserve3Coincident) by performing loop fusion and fission. Applies to inner loops created by tiling

scheduleFusionStrategy(self: tensor_comprehensions.tclib.MappingOptions, arg0: str) → tensor_comprehensions.tclib.MappingOptions

Set up outerScheduleFusionStrategy and intraTileFusionStrategy to the given value

serialize(self: tensor_comprehensions.tclib.MappingOptions) → bytes

Serialize the options to a protobuf string

tile(self: tensor_comprehensions.tclib.MappingOptions, arg0: List[int]) → tensor_comprehensions.tclib.MappingOptions

Perform loop tiling on the generated code with the given sizes. Independent of mapping to a grid of thread blocks

unroll(self: tensor_comprehensions.tclib.MappingOptions, arg0: int) → tensor_comprehensions.tclib.MappingOptions

Perform loop unrolling on the generated code and produce at most the given number of statements

unrollCopyShared(self: tensor_comprehensions.tclib.MappingOptions, arg0: bool) → tensor_comprehensions.tclib.MappingOptions

Also unroll the copies to and from shared memory. If an unroll value is not provided, has no effect

usePrivateMemory(self: tensor_comprehensions.tclib.MappingOptions, arg0: bool) → tensor_comprehensions.tclib.MappingOptions

Create thread-local copies of data in private memory

useReadOnlyCache(self: tensor_comprehensions.tclib.MappingOptions, arg0: bool) → tensor_comprehensions.tclib.MappingOptions

Use the readonly cache (i.e. emit __ldg loads)

useSharedMemory(self: tensor_comprehensions.tclib.MappingOptions, arg0: bool) → tensor_comprehensions.tclib.MappingOptions

Create block-local copies of data in shared memory when this can leverage data reuse or global memory access coalescing

class tensor_comprehensions.tclib.TunerConfig

Helper class to manage the behavior of the autotuner

crossover_rate(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-tuner_gen_crossover_rate (Crossover rate for genetic autotuning)
type: uint32 default: 80
devices(self: tensor_comprehensions.tclib.TunerConfig, arg0: str) → tensor_comprehensions.tclib.TunerConfig
-tuner_devices (Comma separated list of GPUs to use for autotuning)
type: string default: “0”
generations(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-tuner_gen_generations (How many generations to run genetic tuning for)
type: uint32 default: 25
logtostderr(self: tensor_comprehensions.tclib.TunerConfig, arg0: bool) → tensor_comprehensions.tclib.TunerConfig
-logtostderr (log messages go to stderr instead of logfiles) type: bool
default: false
mutation_rate(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-tuner_gen_mutation_rate (Mutation rate for genetic autotuning)
type: uint32 default: 7
number_elites(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-tuner_gen_number_elites (The number of best candidates that are preserved
intact between generations) type: uint32 default: 10
pop_size(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-tuner_gen_pop_size (Population size for genetic autotuning) type: uint32
default: 100
stderrthreshold(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-stderrthreshold (log messages at or above this level are copied to stderr
in addition to logfiles. This flag obsoletes –alsologtostderr.) type: int32 default: 2
threads(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-tuner_threads (Number of CPU threads to use when autotuning) type: uint32
default: 8
tuner_min_launch_total_threads(self: tensor_comprehensions.tclib.TunerConfig, arg0: int) → tensor_comprehensions.tclib.TunerConfig
-tuner_min_launch_total_threads (Prune out kernels mapped to fewer than
this many threads and block) type: uint64 default: 64