from logging import log
import os
from typing import *

from tracegnn.models.gvae_tf.constants import *
from tracegnn.models.gvae_tf.dataset import *
from tracegnn.data import *
from tracegnn.models.gvae_tf.GVAE_woRL import Model

import torch.utils.data as Data
import pickle
import shutil
import click
import json

os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Hyper parameters
output_dir = 'tracegnn/models/trace_anomaly/dataset/'

def make_bytes_db(input_dir: str, names: List[str]) -> BytesDB:
    if len(names) > 1:
        return BytesMultiDB(*[
            BytesSqliteDB(os.path.join(input_dir, name))
            for name in names
        ])
    return BytesSqliteDB(os.path.join(input_dir, names[0]))

def get_path_list(g: TraceGraph, id_manager: TraceGraphIDManager) -> List[Tuple[str, float]]:
    result: List[Tuple(str, float)] = []
    stack: List[TraceGraphNode] = []
    cur_list: List[str] = []

    stack.append(g.root)
    while stack:
        node = stack.pop()
        cur_list.append(id_manager.operation_id.reverse_map(node.operation_id))
        result.append(('#'.join(cur_list), node.features.avg_latency))

        for child in sorted(node.children, key=lambda x: x.operation_id, reverse=True):
            stack.append(child)

    return result

def convert(dataset, data_names):
    print('Loading data...')
    # Load data
    dataset_path = f'/srv/data/tracegnn/{dataset}/processed'
    id_manager = TraceGraphIDManager(dataset_path)
    train_db = make_bytes_db(dataset_path, ['train'])
    val_db = make_bytes_db(dataset_path, ['val'])
    test_db = make_bytes_db(dataset_path, data_names)
    # test_db = make_bytes_db(dataset_path, ['test', 'test-drop-anomaly4', 'test-latency-anomaly4'])
    # test_indices = np.random.choice(np.arange(0, len(all_dataset_for_test)), size=TEST_SIZE, replace=False)
    # test_dataset = Data.Subset(all_dataset_for_test, indices=test_indices)

    print('Generating dict...')
    path_dict: Dict[str, int] = {}
    path_cnt: Dict[str, int] = {}
    train_stv_list: List[Tuple[str,Dict[str,float]]] = []
    val_stv_list: List[Tuple[str,Dict[str,float]]] = []
    test_stv_list: List[Tuple[str,Dict[str,float],int]] = []

    for i in tqdm(range(train_db.data_count())):
        tr_graph = TraceGraph.from_bytes(train_db.get(i))
        path_list = get_path_list(tr_graph, id_manager)
        
        stv_dict: Dict[str, float] = {}
        for path, rt in path_list:
            if path not in path_cnt:
                path_cnt[path] = 0
            path_cnt[path] += 1
            stv_dict[path] = rt
        train_stv_list.append((f'{tr_graph.trace_id[0]}{tr_graph.trace_id[1]}', stv_dict))

    for i in tqdm(range(val_db.data_count())):
        tr_graph = TraceGraph.from_bytes(train_db.get(i))
        path_list = get_path_list(tr_graph, id_manager)
        
        stv_dict: Dict[str, float] = {}
        for path, rt in path_list:
            if path not in path_cnt:
                path_cnt[path] = 0
            path_cnt[path] += 1
            stv_dict[path] = rt
        val_stv_list.append((f'{tr_graph.trace_id[0]}{tr_graph.trace_id[1]}', stv_dict))

    for i in tqdm(range(test_db.data_count())):
        tr_graph = TraceGraph.from_bytes(test_db.get(i))
        path_list = get_path_list(tr_graph, id_manager)
        
        stv_dict: Dict[str, float] = {}
        for path, rt in path_list:
            stv_dict[path] = rt

        label = 0 if not tr_graph.data.get('is_anomaly') else (
                    1 if tr_graph.data['anomaly_type'] == 'drop' else 2)
        test_stv_list.append((f'{tr_graph.trace_id[0]}{tr_graph.trace_id[1]}', stv_dict, label))

    for k in path_cnt.keys():
        if path_cnt[k] > 0:
            path_dict[k] = len(path_dict)

    stv_len = len(path_dict)
    print(f'Finished. stv_len={stv_len}')

    print('Writing to file...')
    output_path = os.path.join(output_dir, dataset)
    if os.path.exists(output_path):
        shutil.rmtree(output_path)

    os.makedirs(output_path)
    with open(os.path.join(output_path, 'train'), 'wt') as f:
        for trace_id, stv_dict in tqdm(train_stv_list):
            result = ['0'] * stv_len
            for k, v in stv_dict.items():
                if k not in path_dict: continue
                result[path_dict[k]] = str(v)
            f.write(f"{trace_id}:{','.join(result)}\n")

    with open(os.path.join(output_path, 'val'), 'wt') as f:
        for trace_id, stv_dict in tqdm(val_stv_list):
            result = ['0'] * stv_len
            for k, v in stv_dict.items():
                if k not in path_dict: continue
                result[path_dict[k]] = str(v)
            f.write(f"{trace_id}:{','.join(result)}\n")
    
    with open(os.path.join(output_path, 'test'), 'wt') as f:
        for trace_id, stv_dict, label in tqdm(test_stv_list):
            result = ['0'] * stv_len
            for k, v in stv_dict.items():
                if k not in path_dict: continue
                result[path_dict[k]] = str(v)
            f.write(f"{trace_id}:{label}:{','.join(result)}\n")

    # Write stv_len
    if os.path.exists(os.path.join('paper-data', 'trace_anomaly_stv_len.json')):
        stv_len_json = json.load(open(os.path.join('paper-data', 'trace_anomaly_stv_len.json'), 'rt'))
    else:
        stv_len_json = {}

    stv_len_json[dataset] = stv_len
    os.makedirs('paper-data', exist_ok=True)
    json.dump(stv_len_json, open(os.path.join('paper-data', 'trace_anomaly_stv_len.json'), 'wt'))

    pickle.dump(path_dict, open(os.path.join(output_path, 'idx.pkl'), 'wb'))
