diff --git a/main.py b/main.py index 4c0192d..c49ffb7 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ def parse_register_fields(): fields = [ r'name: \w+', r'bits: \d+(:\d+)?', - r'type: \w+', + r'access: \w+', f'reset: {ADDRESS_PATTERN}', r'description: .+' ] @@ -129,9 +129,9 @@ def interpret_field(field): if m.group(3) is not None: low = int(m.group(3)) result = {} - result['name'] = field['name'][0] + result['name'] = field['name'][0].strip() result['bits'] = [low, high] - result['type'] = field['type'][0] + result['access'] = field['access'][0].strip() return result def interpret(reg): @@ -140,11 +140,11 @@ def interpret(reg): result = {} name_pattern = None expected = [ - ['Name', 'name', lambda v: ''.join(v)], + ['Name', 'name', lambda v: ''.join(v).strip()], ['Relative Address', 'rel', lambda v: int(v[0], 16)], ['Absolute Address', 'abs', lambda v: int(v[0], 16)], ['Width', 'width', lambda v: int(width_pattern.fullmatch(v[0]).group(1))], - ['Access Type', 'access', lambda v: v[0]], + ['Access Type', 'access', lambda v: v[0].strip()], ['Reset Value', 'reset', lambda v: v[0]], ['Description', 'description', lambda v: ' '.join(v)] ] @@ -180,6 +180,67 @@ def snake_to_camel(name: str): result.append(c.lower()) return ''.join(result) +def access_to_type(access: str): + access = access.upper() + if access in ['RO', 'RW', 'WO']: + return access + elif access in ['MIXED', 'WTC']: + return 'RW' + raise ValueError(access) + +def fields_to_rust(reg): + fields = reg['fields'] + name_pattern = re.compile(r'(.+\w)\d+') + access = access_to_type(reg['access']) + if fields == []: + return (f'{access}', []) + if len(fields) == 1: + bits = fields[0]['bits'] + if bits[1] - bits[0] + 1 == reg['width']: + return (f'{access_to_type(reg["access"])}', []) + namespace = reg['name'].lower() + name = snake_to_camel(reg['name']) + if 'similar' in reg: + # remove the trailing digits + name = name_pattern.fullmatch(name).group(1) + namespace = name_pattern.fullmatch(namespace).group(1) + bitmask = 0 + has_wtc = False + lines = [] + for f in fields: + field_name = f['name'].lower() + field_access = f['access'].upper() + [low, high] = f['bits'] + assert low <= high + if field_access == 'WTC': + has_wtc = True + else: + for i in range(high - low + 1): + bitmask |= 1 << (i + low) + if low == high: + wtc = ', WTC' if field_access == 'WTC' else '' + lines.append(f'register_bit!({namespace}, {field_name}, {low}{wtc});') + else: + value_range = '' + if high - low < 8: + value_range = 'u8' + elif high - low < 32: + value_range = 'u32' + else: + raise ValueError([low, high]) + if field_access == 'WTC': + # we did not implement WTC for multiple bits, but could be done + raise ValueError() + lines.append(f'register_bits!({namespace}, {field_name},' + f'{value_range}, {low}, {high});') + if has_wtc: + lines.insert(0, f'register!({namespace}, {name}, {access},' + f' u{reg["width"]}, {bitmask});') + else: + lines.insert(0, f'register!({namespace}, {name}, {access},' + f' u{reg["width"]});') + return (name, lines) + def emit_rust(base_addr, ending_addr, registers): current_addr = base_addr reserved_id = 0 @@ -225,4 +286,8 @@ for line in sys.stdin: v = end_iterator(parser) for reg in v: reg = interpret(reg) + (name, lines) = fields_to_rust(reg) print(reg) + print(name) + print(lines) + print('----')