1"""Code generator for Code Completion Model Inference.
3Tool runs on the Decision Forest model defined in {model} directory.
4It generates two files: {output_dir}/{filename}.h and {output_dir}/{filename}.cpp
5The generated files defines the Example class named {cpp_class} having all the features as class members.
6The generated runtime provides an `Evaluate` function which can be used to score a code completion candidate.
15 """Holds class name and names of the enclosing namespaces."""
18 ns_and_class = cpp_class.split(
"::")
19 self.
ns = [ns
for ns
in ns_and_class[0:-1]
if len(ns) > 0]
20 self.
name = ns_and_class[-1]
21 if len(self.
name) == 0:
22 raise ValueError(
"Empty class name.")
25 """Returns snippet for opening namespace declarations."""
26 open_ns = [
"namespace %s {" % ns
for ns
in self.
ns]
27 return "\n".join(open_ns)
30 """Returns snippet for closing namespace declarations."""
31 close_ns = [
"} // namespace %s" % ns
for ns
in reversed(self.
ns)]
32 return "\n".join(close_ns)
36 """Returns the header guard for the generated header."""
37 return "GENERATED_DECISION_FOREST_MODEL_%s_H" % filename.upper()
41 """Returns code snippet for a leaf/boost node."""
42 return "%s: return %sf;" % (label, n[
"score"])
46 """Returns code snippet for a if_greater node.
47 Jumps to true_label if the Example feature (NUMBER)
is greater than the threshold.
48 Comparing integers
is much faster than comparing floats. Assuming floating points
49 are represented
as IEEE 754, it order-encodes the floats to integers before comparing them.
50 Control falls through
if condition
is evaluated to false.
"""
51 threshold = n["threshold"]
52 return "%s: if (E.get%s() >= %s /*%s*/) goto %s;" % (
62 """Returns code snippet for a if_member node.
63 Jumps to true_label if the Example feature (ENUM)
is present
in the set of enum values
64 described
in the node.
65 Control falls through
if condition
is evaluated to false.
"""
67 [
"BIT(%s_type::%s)" % (n[
"feature"], member)
for member
in n[
"set"]]
69 return "%s: if (E.get%s() & (%s)) goto %s;" % (
77def node(n, label, next_label):
78 """Returns code snippet for the node."""
81 "if_greater": if_greater_node,
82 "if_member": if_member_node,
83 }[n[
"operation"]](n, label, next_label)
86def tree(t, tree_num, node_num):
87 """Returns code for inferencing a Decision Tree.
88 Also returns the size of the decision tree.
90 A tree starts with its label `t{tree
91 A node of the tree starts
with label `t{tree
93 The tree contains two types of node: Conditional node
and Leaf node.
94 - Conditional node evaluates a condition. If true, it jumps to the true node/child.
95 Code
is generated using pre-order traversal of the tree considering
96 false node
as the first child. Therefore the false node
is always the
97 immediately next label.
98 - Leaf node adds the value to the score
and jumps to the next tree.
100 label = "t%d_n%d" % (tree_num, node_num)
103 if t[
"operation"] ==
"boost":
104 code.append(
node(t, label=label, next_label=
"t%d" % (tree_num + 1)))
107 false_code, false_size =
tree(t[
"else"], tree_num=tree_num, node_num=node_num + 1)
109 true_node_num = node_num + false_size + 1
110 true_label =
"t%d_n%d" % (tree_num, true_node_num)
112 true_code, true_size =
tree(t[
"then"], tree_num=tree_num, node_num=true_node_num)
114 code.append(
node(t, label=label, next_label=true_label))
116 return code + false_code + true_code, 1 + false_size + true_size
120 """Returns code for header declaring the inference runtime.
122 Declares the Example class named {cpp_class} inside relevant namespaces.
123 The Example
class contains all the features as class members. This
124 class can be used to represent a code completion candidate.
125 Provides `float Evaluate()` function which can be used to score the Example.
129 for f
in features_json:
132 if f[
"kind"] ==
"NUMBER":
135 "void set%s(float V) { %s = OrderEncode(V); }" % (feature, feature)
137 elif f[
"kind"] ==
"ENUM":
139 "void set%s(unsigned V) { %s = 1LL << V; }" % (feature, feature)
142 raise ValueError(
"Unhandled feature type.", f[
"kind"])
146 "uint%d_t %s = 0;" % (64
if f[
"kind"] ==
"ENUM" else 32, f[
"name"])
147 for f
in features_json
150 "LLVM_ATTRIBUTE_ALWAYS_INLINE uint%d_t get%s() const { return %s; }"
151 % (64
if f[
"kind"] ==
"ENUM" else 32, f[
"name"], f[
"name"])
152 for f
in features_json
159#include "llvm/Support/Compiler.h"
173 // Produces an integer that sorts in the same order
as F.
174 // That
is: a < b <==> orderEncode(a) < orderEncode(b).
175 static uint32_t OrderEncode(float F);
178float Evaluate(const %s&);
184 cpp_class.ns_begin(),
188 nline.join(class_members),
196 i = struct.unpack(
"<I", struct.pack(
"<f", v))[0]
205 """Generates evaluation functions for each tree and combines them in
206 `float Evaluate(const {Example}&)` function. This function can be
207 used to score an Example."""
212 code +=
"namespace {\n"
214 for tree_json
in forest_json:
215 code +=
"LLVM_ATTRIBUTE_NOINLINE float EvaluateTree%d(const %s& E) {\n" % (
220 " " +
"\n ".join(
tree(tree_json, tree_num=tree_num, node_num=0)[0]) +
"\n"
224 code +=
"} // namespace\n\n"
228 code +=
"float Evaluate(const %s& E) {\n" % cpp_class.name
229 code +=
" float Score = 0;\n"
230 for tree_num
in range(len(forest_json)):
231 code +=
" Score += EvaluateTree%d(E);\n" % tree_num
232 code +=
" return Score;\n"
239 """Generates code for the .cpp file."""
242 angled_include = [
"#include <%s>" % h
for h
in [
"cstring",
"limits"]]
245 qouted_headers = {filename +
".h",
"llvm/ADT/bit.h"}
247 qouted_headers |= {f[
"header"]
for f
in features_json
if f[
"kind"] ==
"ENUM"}
248 quoted_include = [
'#include "%s"' % h
for h
in sorted(qouted_headers)]
251 using_decls =
"\n".join(
252 "using %s_type = %s;" % (feature[
"name"], feature[
"type"])
253 for feature
in features_json
254 if feature[
"kind"] ==
"ENUM"
261#define BIT(X) (1LL << X)
267uint32_t %s::OrderEncode(float F) {
268 static_assert(std::numeric_limits<float>::is_iec559, "");
269 constexpr uint32_t TopBit = ~(~uint32_t{0} >> 1);
271 // Get the bits of the float. Endianness
is the same
as for integers.
272 uint32_t U = llvm::bit_cast<uint32_t>(F);
273 std::memcpy(&U, &F, sizeof(U));
274 // IEEE 754 floats compare like sign-magnitude integers.
275 if (U & TopBit) // Negative float.
276 return 0 - U; // Map onto the low half of integers, order reversed.
277 return U + TopBit; // Positive floats map onto the high half of integers.
283 nl.join(angled_include),
284 nl.join(quoted_include),
285 cpp_class.ns_begin(),
294 parser = argparse.ArgumentParser(
"DecisionForestCodegen")
295 parser.add_argument(
"--filename", help=
"output file name.")
296 parser.add_argument(
"--output_dir", help=
"output directory.")
297 parser.add_argument(
"--model", help=
"path to model directory.")
300 help=
"The name of the class (which may be a namespace-qualified) created in generated header.",
302 ns = parser.parse_args()
304 output_dir = ns.output_dir
305 filename = ns.filename
306 header_file =
"%s/%s.h" % (output_dir, filename)
307 cpp_file =
"%s/%s.cpp" % (output_dir, filename)
308 cpp_class =
CppClass(cpp_class=ns.cpp_class)
310 model_file =
"%s/forest.json" % ns.model
311 features_file =
"%s/features.json" % ns.model
313 with open(features_file)
as f:
314 features_json = json.load(f)
316 with open(model_file)
as m:
317 forest_json = json.load(m)
319 with open(cpp_file,
"w+t")
as output_cc:
322 forest_json=forest_json,
323 features_json=features_json,
329 with open(header_file,
"w+t")
as output_h:
332 features_json=features_json, cpp_class=cpp_class, filename=filename
337if __name__ ==
"__main__":
def __init__(self, cpp_class)
def header_guard(filename)
def evaluate_func(forest_json, cpp_class)
def gen_header_code(features_json, cpp_class, filename)
def gen_cpp_code(forest_json, features_json, filename, cpp_class)
def boost_node(n, label, next_label)
def tree(t, tree_num, node_num)
def node(n, label, next_label)
def if_member_node(n, label, next_label)
def if_greater_node(n, label, next_label)