Writing PyTorch layers with TC

In order to write a new layer with TC, you need to follow the steps below:

  1. Define your TC language and pass it to tc.define
  2. Create input torch tensors
  3. Run the layer and get output

In the third step, when the TC is run on give set of inputs, TC backend will first compile the language on given tensor sizes, runs the layer and returns the output. If the layer has already been run at least once, in the next runs, TC backend will skip the compilation and will run the layer directly.

Example

An example demonstrating each step above is:

import tensor_comprehensions as tc
import torch
MATMUL_LANG = """
def matmul(float(M, K) A, float(K, N) B) -> (C) {
    C(m, n) +=! A(m, r_k) * B(r_k, n)
}
"""
# the `name` should match the definition name in the `lang`
matmul = tc.define(MATMUL_LANG, name="matmul")
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda()
out = matmul(mat1, mat2)

Below is a complete documentation of each API call:

tc.define

Specifying CudaMappingOptions

TC is transformed into CUDA kernel by using the Options which is used to run the layer and hence also determines the performance of the kernel generated. Therefore, it is important to use good Options for running a kernel. You can read more about mapping options here - Mapping Options.

There are two ways to set the Options:

  • Autotuning: You can autotune the kernel the kernel on certain input tensor sizes, cache the options and use them to run the layer. See Autotuning layers for how to autotune kernels.
  • Default Mapping: We provide various default options that can be chosen to closely represent the kernel. The defaults provided are:
    • pointwise: if kernel resembles a pointwise operation
    • mlp: if kernel resembles an Linear layer operation
    • conv: if kernel resembles a convolution operation
    • group_conv: if kernel resembles a group convolution operation
    • naive: if none of the above, then chose naive default

An example for how to pass options:

import tensor_comprehensions as tc
import torch
lang = """
def matmul(float(M, K) A, float(K, N) B) -> (C) {
    C(m, n) +=! A(m, r_k) * B(r_k, n)
}
"""
matmul = tc.define(lang, name="matmul")
mat1, mat2 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda()
out = matmul(mat1, mat2, options=tc.CudaMappingOptions("mlp"))

Note

If the mapping options are not passed by user, the naive mapping options will be chosen as default and the kernel performance might be very bad. Hence, we strongly recommend user to use either of two ways above for specifying kernel mapping options.

Reduction Operators

Reduction operators may be suffixed with ! (for example +=!) to indicate that the tensor to which values are accumulated should first be initialized with the identity of the reduction operator (e.g., 0 for +). Otherwise, values are accumulated directly to the output or temporary tensor passed to the kernel.

Different input sizes for same TC

If you have a TC definition that would like to use to run on different combinations of input sizes, you need to define TC once. An example:

import tensor_comprehensions as tc
import torch
lang = """
def matmul(float(M, K) A, float(K, N) B) -> (C) {
    C(m, n) +=! A(m, r_k) * B(r_k, n)
}
"""
matmul = tc.define(lang, name="matmul")
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda()
out1 = matmul(mat1, mat2)

# different input sizes
mat3, mat4 = torch.randn(100, 400).cuda(), torch.randn(400, 500).cuda()
out2 = matmul(mat3, mat4)

Whenever the input tensor sizes change, TC backend will re-compile the definition with input sizes again. If the input tensor sizes do not change, the compilation happens only once and then you can keep running the layer.

Multiple TC definitions in language

Let’s say you want to define all of your TCs in one string and later use that string for running different operations defined in the string. You an do so easily. You can define a lang variable that holds the TC definition for all your operations. Every time you want to run a different operation, you can make a tc.define call on the lang variable, specify the name corresponding to the operation definition and get the TC layer for it. Below is an example for how to do this:

import tensor_comprehensions as tc
import torch
lang = """
def matmul(float(M, K) A, float(K, N) B) -> (C) {
    C(m, n) +=! A(m, r_k) * B(r_k, n)
}
def abs(float(M, N) A) -> (O1) {
    O1(m, n) = fabs(A(m, n))
}
"""
matmul = tc.define(lang, name="matmul")
mat1, mat2 = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda()
out = matmul(mat1, mat2)

abs = tc.define(lang, name="abs")
A = torch.randn(3, 4).cuda()
out = abs(A)

Note

We are working on better ways to leverage using multiple TC in one language nicely. This current behavior will likely change in near future.

Writing layers with scalars

If you have an operation that requires a constant scalar value for bounds inference, for example, kernel or stride in case of convolution operation, we need to pass the TC with the substituted scalar value because right now, we don’t support using scalars for bound inference. The substitution can be done in two ways and users can adopt whatever feels more convenient.

  • Option 1: Pass a constants dictionary to the tc.define call. An example for how to do this easily is below:

Warning

This particular way of using scalar is a stop-gap solution while we work on finding better way of handling scalars for bounds inference. This solution will likely be changed in ~1 month timespan.

import tensor_comprehensions as tc
import torch
lang = """
def avgpool(float(B, C, H, W) input) -> (output) {{
    output(b, c, h, w) +=! input(b, c, h * {sH} + r_kh, w * {sW} + r_kw) / ({kH} * {kW})
        where r_kh in 0:{kH}, r_kw in 0:{kW}
}}
"""
avgpool = tc.define(lang, name="avgpool", constants={"sH":1, "sW":1, "kH":2, "kW":2})
inp = torch.ones(32, 3, 10, 10).cuda()
out = avgpool(inp)

Note

In python, the formatting of strings requires usage of {{...}}. Hence the above example uses these brackets. You only need to do this if your TC consists of scalars.

  • Option 2: Format the string using python regex. An example below:
import tensor_comprehensions as tc
import torch
import re
LANG="""
def avgpool(float(B, C, H, W) input) -> (output) {
    output(b, c, h, w) +=! input(b, c, h * <sH> + r_kh, w * <sW> + r_kw) / (<kH> * <kW>)
        where r_kh in 0:<kH>, r_kw in 0:<kW>
}
"""
sH, sW, kH, kW = 1, 1, 2, 2
LANG = re.sub('<sh>', str(sH), LANG)
LANG = re.sub('<sw>', str(sW), LANG)
LANG = re.sub('<kH>', str(kH), LANG)
LANG = re.sub('<kW>', str(kW), LANG)
avgpool = tc.define(LANG, name="avgpool")
inp = torch.ones(1, 1, 4, 4).cuda()
out = avgpool(inp)

Built-in Functions

TC allows using some CUDA built-in functions as well when defining the TC language. During the execution, CUDA API will be called for those built-in functions. For example, let’s say we want to use fmax CUDA function in our TC language. An example for how this would be done is below:

import tensor_comprehensions as tc
import torch
LANG = """
def relu(float(B,M) I) -> (O1){
  O1(b, m) = fmax(I(b, m), 0)
}
"""
relu = tc.define(LANG, name="relu")
inp = torch.randn(100, 128).cuda()
out = relu(inp)

TC only supports a subset of built-in CUDA functions. You can find the documentation for these functions at the official CUDA documentation here. The functions supported in TC are:

acos, acosh, asin, asinh, atan2, atan, atanh, cbrt, ceil, copysign, cos, cosh, cospi, cyl_bessel_i0, cyl_bessel_i1, erfc, erfcinv, erfcx, erf, erfinv, exp10, exp2, exp, expm1, fabs, fdim, fdivide, floor, fma, fmax, fmin, fmod, hypot, j0, j1, lgamma, log10, log1p, log2, logb, log, nextafter, normf, norm3d, norm4d, normcdf, normcdfinv, pow, rcbrt, remainder, rhypot, rnorm3d, rnorm4d, round, rsqrt, sin, sinh, sinpi, sqrt, tan, tanh, tgamma, trunc, y0, y1