Tensor Comprehensions
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
copy.h
Go to the documentation of this file.
1 
16 #pragma once
17 
18 #include "tc/library/common.h"
19 
20 namespace tc {
21 
22 static constexpr auto COPY_DOC = R"DOC( def copy(float(...) I) -> (O) { O(...) = I(...) } )DOC";
23 
24 constexpr static auto COPY_TC_NAME = "copy";
25 
26 constexpr static auto COPY_GRAD_TC_NAME = "copyGrad";
27 
28 namespace {
29 constexpr static auto COPY_TC = R"TC( def copy(float(${dimParams}) I) -> (O) { O(${dimIndices}) = I(${dimIndices}) } )TC";
30 
31 constexpr static auto COPY_GRAD_TC = R"TC( def copyGrad(float(${dimParams}) O_grad) -> (I_grad) { I_grad(${dimIndices}) = O_grad(${dimIndices}) } )TC";
32 } // namespace
33 
34 std::string
35 setInputDims(std::string tcStr, int numDims, std::string paramPrefix) {
36  std::string dimParams, dimIndices;
37  for (int i = 0; i < numDims; i++) {
38  dimParams += paramPrefix + std::to_string(i);
39  dimIndices += "i" + std::to_string(i);
40  if (i < numDims - 1) {
41  dimParams += ",";
42  dimIndices += ",";
43  }
44  }
45  tcStr = replaceString(tcStr, "${dimParams}", dimParams);
46  tcStr = replaceString(tcStr, "${dimIndices}", dimIndices);
47  return tcStr;
48 }
49 
50 std::string makeCopyTc(int numDims) {
51  return setInputDims(COPY_TC, numDims, "P");
52 }
53 
54 std::string makeCopyGradTc(int numDims) {
55  return setInputDims(COPY_GRAD_TC, numDims, "PG");
56 }
57 
58 } // namespace tc
59 
std::string makeCopyTc(int numDims)
Definition: copy.h:58
std::string setInputDims(std::string tcStr, int numDims, std::string paramPrefix)
Definition: copy.h:43
std::string makeCopyGradTc(int numDims)
Definition: copy.h:62
std::string replaceString(std::string str, const std::string &search, const std::string &replace)
Definition: common.h:20