Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
batchnorm.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 namespace tc {
19 
20 static constexpr auto TC_BATCHNORM_NAME = "spatialBatchNorm";
21 
22 static constexpr auto TC_BATCHNORM = R"TC( def spatialBatchNorm( float momentum, float eps, float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn) -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut) { mean(c) +=! I(nn, c, hh, ww) mean(c) = mean(c) / (N * H * W) rMeanOut(c) = (1 - momentum) * rMeanIn(c) + momentum * mean(c) centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c) variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w) expectedVariance(c) +=! (variance(n, c, h, w) + eps) / (N * H * W) rVarOut(c) = rsqrt( (1 - momentum) * rVarIn(c) + momentum * expectedVariance(c)) O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c) normalizedOut(n, c, h, w) = O(n, c, h, w) })TC";
23 
24 } // namespace tc
25