Autograd with TCΒΆ
To create a torch.autograd
function backed by TC one can just use the
make_autograd()
helper function. Note that backward computations must be
provided explicitly as TC functions.
conv = """ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) { O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw) } def convolution_igrad(float(M,C,KH,KW) W1, float(N,M,H,W) d_O) -> (d_I) { d_I(n, c, h, w) +=! d_O( n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw) } def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1) { d_W1(m, c, kh, kw) +=! d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w) } """ N, C, H, W, O, kH, kW = 32, 4, 56, 56, 16, 1, 1 T = tc.define( conv, tc.make_autotuned_options_factory( starting_options='naive', tuner_config=tuner_config)) I, W = ( torch.randn(N, C, H, W, device='cuda', requires_grad=True), torch.randn(O, C, kH, kW, device='cuda', requires_grad=True)) def convolution_backward(I, W, d_O): d_I = T.convolution_igrad(W, d_O) d_O = T.convolution_wgrad(I, d_O) return (d_I, d_O) convolution_function = tc.make_autograd( T.convolution, convolution_backward) # First occurrence triggers tuning out = convolution_function(I, W) out.sum().backward() # Subsequent occurrences do not out = convolution_function(I, W) out.sum().backward()