-
Notifications
You must be signed in to change notification settings - Fork 1
/
m2cgen.py
159 lines (130 loc) · 5.15 KB
/
m2cgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2022 Cartesi Pte. Ltd.
#
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use
# this file except in compliance with the License. You may obtain a copy of the
# License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
from os import environ
import model
import json
import traceback
import logging
import requests
logging.basicConfig(level="INFO")
logger = logging.getLogger(__name__)
rollup_server = environ["ROLLUP_HTTP_SERVER_URL"]
logger.info(f"HTTP rollup_server url is {rollup_server}")
def hex2str(hex):
"""
Decodes a hex string into a regular string
"""
return bytes.fromhex(hex[2:]).decode("utf-8")
def str2hex(str):
"""
Encodes a string as a hex string
"""
return "0x" + str.encode("utf-8").hex()
def classify(input):
"""
Predicts a given input's classification using the model generated with m2cgen
"""
# computes the score from the input
score = model.score(input)
# interprets the score to retrieve the predicted class index
class_index = None
if isinstance(score, list):
class_index = score.index(max(score))
else:
if (score > 0):
class_index = 1
else:
class_index = 0
# returns the class specified by the predicted index
return model.classes[class_index]
def format(input):
"""
Transforms a given input so that it is in the format expected by the m2cgen model
"""
formatted_input = {}
for key in input.keys():
if key in model.columns:
# key in model: just copy the value
formatted_input[key] = input[key]
else:
# key not in model: it may need to be transformed due to One Hot Encoding
# - in OHE, there is a column for each possible key/value combination
# - a OHE column has value 1 if the entry contains the key/value combination
# - for each key, there is an extra column <key>_nan for unknown values
ohe_key = key + "_" + str(input[key])
ohe_key_unknown = key + "_nan"
if ohe_key in model.columns:
formatted_input[ohe_key] = 1
else:
formatted_input[ohe_key_unknown] = 1
# builds output as a list/array with one entry for each column in the model
output = []
for column in model.columns:
if column in formatted_input:
# uses known value for columns present in input
output.append(formatted_input[column])
else:
# uses value 0 for columns not present in the input (all other OHE columns)
output.append(0)
return output
def handle_advance(data):
logger.info(f"Received advance request data {data}")
status = "accept"
try:
# retrieves input as string
input = hex2str(data["payload"])
logger.info(f"Received input: '{input}'")
# converts input to the format expected by the m2cgen model
input_json = json.loads(input)
input_formatted = format(input_json)
# computes predicted classification for input
predicted = classify(input_formatted)
logger.info(f"Data={input}, Predicted: {predicted}")
# emits output notice with predicted class name
output = str2hex(str(predicted))
logger.info(f"Adding notice with payload: {predicted}")
response = requests.post(
rollup_server + "/notice", json={"payload": output})
logger.info(
f"Received notice status {response.status_code} body {response.content}")
except Exception as e:
status = "reject"
msg = f"Error processing data {data}\n{traceback.format_exc()}"
logger.error(msg)
response = requests.post(
rollup_server + "/report", json={"payload": str2hex(msg)})
logger.info(
f"Received report status {response.status_code} body {response.content}")
return status
def handle_inspect(data):
logger.info(f"Received inspect request data {data}")
logger.info("Adding report")
response = requests.post(rollup_server + "/report",
json={"payload": data["payload"]})
logger.info(f"Received report status {response.status_code}")
return "accept"
handlers = {
"advance_state": handle_advance,
"inspect_state": handle_inspect,
}
finish = {"status": "accept"}
while True:
logger.info("Sending finish")
response = requests.post(rollup_server + "/finish", json=finish)
logger.info(f"Received finish status {response.status_code}")
if response.status_code == 202:
logger.info("No pending rollup request, trying again")
else:
rollup_request = response.json()
data = rollup_request["data"]
handler = handlers[rollup_request["request_type"]]
finish["status"] = handler(rollup_request["data"])