#!/usr/bin/env python3

import sys
import re

def entry_is_device(entry):
    first_arg_type = entry[1][1:].split(' ')[0]
    device_types = ['VkDevice', 'VkCommandBuffer', 'VkQueue']
    return (first_arg_type in device_types) and (entry[0] != 'vkGetDeviceProcAddr')

def main():
    pure_entrypoints = []
    entrypoints = []
    extensions = []
    pure_list = ['vkCreateInstance', 'vkEnumerateInstanceExtensionProperties', 'vkEnumerateInstanceLayerProperties']
    with open(sys.argv[1], 'r') as f:
        header = f.readlines()

    for line in header:
        m = re.search('typedef \S+.*PFN_([^\)]+)\)(.*);$', line)
        if m and m.group(1)[-3:] != 'KHR' and m.group(1)[-3:] != 'EXT' and m.group(2) != '(void)':
            entry = m.group(1)
            if entry == 'vkGetInstanceProcAddr':
                continue

            if entry in pure_list:
                pure_entrypoints.append((m.group(1), m.group(2)))
            else:
                entrypoints.append((m.group(1), m.group(2)))
        elif m and (m.group(1)[-3:] == 'KHR' or m.group(1)[-3:] == 'EXT') and m.group(2) != '(void)':
            entry = m.group(1)
            if 'Android' in entry:
                continue
            if 'Xlib' in entry:
                continue
            if 'Xcb' in entry:
                continue
            if 'Win32' in entry:
                continue
            if 'Wayland' in entry:
                continue
            if 'Mir' in entry:
                continue
            extensions.append((m.group(1), m.group(2)))


    with open(sys.argv[2], 'w') as f:
        print('''
/* This header is autogenerated by vulkan_loader_generator.py */
#ifndef VULKAN_SYMBOL_WRAPPER_H
#define VULKAN_SYMBOL_WRAPPER_H
#define VK_NO_PROTOTYPES
#include <vulkan/vulkan.h>

#ifdef __cplusplus
extern "C" {
#endif
''', file = f)

        for entry in pure_entrypoints:
            s = entry[0]
            print('extern PFN_{} vulkan_symbol_wrapper_{};'.format(s, s), file = f)
            print('#define {} vulkan_symbol_wrapper_{}'.format(s, s), file = f)
        for entry in entrypoints:
            s = entry[0]
            print('extern PFN_{} vulkan_symbol_wrapper_{};'.format(s, s), file = f)
            print('#define {} vulkan_symbol_wrapper_{}'.format(s, s), file = f)
        for entry in extensions:
            s = entry[0]
            print('extern PFN_{} vulkan_symbol_wrapper_{};'.format(s, s), file = f)
            print('#define {} vulkan_symbol_wrapper_{}'.format(s, s), file = f)

        print('''
void vulkan_symbol_wrapper_init(PFN_vkGetInstanceProcAddr get_instance_proc_addr);
PFN_vkGetInstanceProcAddr vulkan_symbol_wrapper_instance_proc_addr(void);
VkBool32 vulkan_symbol_wrapper_load_global_symbols(void);
VkBool32 vulkan_symbol_wrapper_load_core_instance_symbols(VkInstance instance);
VkBool32 vulkan_symbol_wrapper_load_core_symbols(VkInstance instance);
VkBool32 vulkan_symbol_wrapper_load_core_device_symbols(VkDevice device);
VkBool32 vulkan_symbol_wrapper_load_instance_symbol(VkInstance instance, const char *name, PFN_vkVoidFunction *ppSymbol);
VkBool32 vulkan_symbol_wrapper_load_device_symbol(VkDevice device, const char *name, PFN_vkVoidFunction *ppSymbol);

#define VULKAN_SYMBOL_WRAPPER_LOAD_INSTANCE_SYMBOL(instance, name, pfn) vulkan_symbol_wrapper_load_instance_symbol(instance, name, (PFN_vkVoidFunction*) &(pfn))
#define VULKAN_SYMBOL_WRAPPER_LOAD_INSTANCE_EXTENSION_SYMBOL(instance, name) vulkan_symbol_wrapper_load_instance_symbol(instance, #name, (PFN_vkVoidFunction*) & name)
#define VULKAN_SYMBOL_WRAPPER_LOAD_DEVICE_SYMBOL(device, name, pfn) vulkan_symbol_wrapper_load_device_symbol(device, name, (PFN_vkVoidFunction*) &(pfn))
#define VULKAN_SYMBOL_WRAPPER_LOAD_DEVICE_EXTENSION_SYMBOL(device, name) vulkan_symbol_wrapper_load_device_symbol(device, #name, (PFN_vkVoidFunction*) & name)
''', file = f)

        print('''
#ifdef __cplusplus
}
#endif
#endif
''', file = f)

    with open(sys.argv[3], 'w') as f:
        print('''
/* This header is autogenerated by vulkan_loader_generator.py */
#include "vulkan_symbol_wrapper.h"
''', file = f)

        for entry in pure_entrypoints:
            s = entry[0]
            print('PFN_{} vulkan_symbol_wrapper_{};'.format(s, s), file = f)
        for entry in entrypoints:
            s = entry[0]
            print('PFN_{} vulkan_symbol_wrapper_{};'.format(s, s), file = f)
        for entry in extensions:
            s = entry[0]
            print('PFN_{} vulkan_symbol_wrapper_{};'.format(s, s), file = f)

        print('''
static PFN_vkGetInstanceProcAddr GetInstanceProcAddr;
void vulkan_symbol_wrapper_init(PFN_vkGetInstanceProcAddr get_instance_proc_addr)
{
    GetInstanceProcAddr = get_instance_proc_addr;
}

PFN_vkGetInstanceProcAddr vulkan_symbol_wrapper_instance_proc_addr(void)
{
    return GetInstanceProcAddr;
}
''', file = f)

        print('''
VkBool32 vulkan_symbol_wrapper_load_instance_symbol(VkInstance instance, const char *name, PFN_vkVoidFunction *ppSymbol)
{
    *ppSymbol = GetInstanceProcAddr(instance, name);
    return *ppSymbol != NULL;
}''', file = f)

        print('''
VkBool32 vulkan_symbol_wrapper_load_device_symbol(VkDevice device, const char *name, PFN_vkVoidFunction *ppSymbol)
{
    *ppSymbol = vkGetDeviceProcAddr(device, name);
    return *ppSymbol != NULL;
}''', file = f)

        print('''
VkBool32 vulkan_symbol_wrapper_load_global_symbols(void)
{''', file = f)
        for pure in pure_entrypoints:
            print('    if (!VULKAN_SYMBOL_WRAPPER_LOAD_INSTANCE_SYMBOL(NULL, "{}", {})) return VK_FALSE;'.format(pure[0], pure[0]), file = f)
        print('    return VK_TRUE;', file = f)
        print('}', file = f)

        print('''
VkBool32 vulkan_symbol_wrapper_load_core_symbols(VkInstance instance)
{''', file = f)
        for entry in entrypoints:
            print('    if (!VULKAN_SYMBOL_WRAPPER_LOAD_INSTANCE_SYMBOL(instance, "{}", {})) return VK_FALSE;'.format(entry[0], entry[0]), file = f)
        print('    return VK_TRUE;', file = f)
        print('}', file = f)

        print('''
VkBool32 vulkan_symbol_wrapper_load_core_instance_symbols(VkInstance instance)
{''', file = f)
        for entry in entrypoints:
            if not entry_is_device(entry):
                print('    if (!VULKAN_SYMBOL_WRAPPER_LOAD_INSTANCE_SYMBOL(instance, "{}", {})) return VK_FALSE;'.format(entry[0], entry[0]), file = f)
        print('    return VK_TRUE;', file = f)
        print('}', file = f)

        print('''
VkBool32 vulkan_symbol_wrapper_load_core_device_symbols(VkDevice device)
{''', file = f)
        for entry in entrypoints:
            if entry_is_device(entry):
                print('    if (!VULKAN_SYMBOL_WRAPPER_LOAD_DEVICE_SYMBOL(device, "{}", {})) return VK_FALSE;'.format(entry[0], entry[0]), file = f)
        print('    return VK_TRUE;', file = f)
        print('}', file = f)

if __name__ == '__main__':
    main()