blob: 08342e565c0bee39d193f8a0b97e726f16e39ad5 [file] [log] [blame]
"""Reader for training log.
See lib/Analysis/TrainingLogger.cpp for a description of the format.
"""
import ctypes
import dataclasses
import json
import math
import sys
import typing
_element_types = {
'float': ctypes.c_float,
'double': ctypes.c_double,
'int8_t': ctypes.c_int8,
'uint8_t': ctypes.c_uint8,
'int16_t': ctypes.c_int16,
'uint16_t': ctypes.c_uint16,
'int32_t': ctypes.c_int32,
'uint32_t': ctypes.c_uint32,
'int64_t': ctypes.c_int64,
'uint64_t': ctypes.c_uint64
}
@dataclasses.dataclass(frozen=True)
class TensorSpec:
name: str
port: int
shape: list[int]
element_type: type
@staticmethod
def from_dict(d: dict):
name = d['name']
port = d['port']
shape = [int(e) for e in d['shape']]
element_type_str = d['type']
if element_type_str not in _element_types:
raise ValueError(f'uknown type: {element_type_str}')
return TensorSpec(
name=name,
port=port,
shape=shape,
element_type=_element_types[element_type_str])
class TensorValue:
def __init__(self, spec: TensorSpec, buffer: bytes):
self._spec = spec
self._buffer = buffer
self._view = ctypes.cast(self._buffer,
ctypes.POINTER(self._spec.element_type))
self._len = math.prod(self._spec.shape)
def spec(self) -> TensorSpec:
return self._spec
def __len__(self) -> int:
return self._len
def __getitem__(self, index):
if index < 0 or index >= self._len:
raise IndexError(f'Index {index} out of range [0..{self._len})')
return self._view[index]
def read_tensor(fs: typing.BinaryIO, ts: TensorSpec) -> TensorValue:
size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
data = fs.read(size)
return TensorValue(ts, data)
def pretty_print_tensor_value(tv: TensorValue):
print(f'{tv.spec().name}: {",".join([str(v) for v in tv])}')
def read_stream(fname: str):
with open(fname, 'rb') as f:
header = json.loads(f.readline())
tensor_specs = [TensorSpec.from_dict(ts) for ts in header['features']]
score_spec = TensorSpec.from_dict(
header['score']) if 'score' in header else None
context = None
while event_str := f.readline():
event = json.loads(event_str)
if 'context' in event:
context = event['context']
continue
observation_id = int(event['observation'])
features = []
for ts in tensor_specs:
features.append(read_tensor(f, ts))
f.readline()
score = None
if score_spec is not None:
score_header = json.loads(f.readline())
assert int(score_header['outcome']) == observation_id
score = read_tensor(f, score_spec)
f.readline()
yield context, observation_id, features, score
def main(args):
last_context = None
for ctx, obs_id, features, score in read_stream(args[1]):
if last_context != ctx:
print(f'context: {ctx}')
last_context = ctx
print(f'observation: {obs_id}')
for fv in features:
pretty_print_tensor_value(fv)
if score:
pretty_print_tensor_value(score)
if __name__ == '__main__':
main(sys.argv)