|
Tensor Comprehensions
|
Go to the documentation of this file.
20 static constexpr
auto TC_BATCHNORM_NAME =
"spatialBatchNorm";
22 static constexpr
auto TC_BATCHNORM = R
"TC(
def spatialBatchNorm(
float momentum, float eps,
float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn)
-> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut)
{
mean(c) +=! I(nn, c, hh, ww)
mean(c) = mean(c) / (N * H * W)
rMeanOut(c) = (1 - momentum) * rMeanIn(c) + momentum * mean(c)
centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c)
variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w)
expectedVariance(c) +=! (variance(n, c, h, w) + eps) / (N * H * W)
rVarOut(c) = rsqrt(
(1 - momentum) * rVarIn(c) + momentum * expectedVariance(c))
O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c)
normalizedOut(n, c, h, w) = O(n, c, h, w)
})TC";