import numpy as np
import random
from tqdm import tqdm

def read_train_vector(input_file, vc=None, shuffle=True, sample=False):  # flows, vectors, valid_column
    flows = list()
    vectors = list()

    with open(input_file, 'r') as fin:
        for line in tqdm(fin.readlines()):
            line = line.strip()
            if line.strip == "":
                continue
            flows.append(line.split(':')[0])
            vectors.append(np.array([float(x) for x in line.split(':')[1].split(',')]))
        
    vectors = np.stack(vectors)
    if shuffle is True:
        arr_index = np.arange(len(vectors))
        np.random.shuffle(arr_index)
        vectors = vectors[arr_index]

    if sample is True:
        vectors = random.sample(vectors, 50000)


    if vc is None:
        valid_column = np.max(vectors!=0, axis=0)
    else:
        valid_column = vc

    vectors = vectors[:, valid_column]
    return flows, vectors, valid_column


def read_test_vector(input_file, vc=None):  # flows, vectors, valid_column
    flows = list()
    vectors = list()
    labels = list()

    with open(input_file, 'r') as fin:
        for line in tqdm(fin.readlines()):
            line = line.strip()
            if line.strip() == "":
                continue
            flows.append(line.split(':')[0])
            vectors.append(np.array([float(x) for x in line.split(':')[2].split(',')]))
            labels.append(int(line.split(':')[1]))

    vectors = np.stack(vectors)
    labels = np.array(labels)

    vectors = vectors[:, vc]
    return flows, vectors, labels

def get_mean_std(matrix):
    mean = []
    std = []
    for item in np.transpose(matrix):
        mean.append(np.mean(item[item>0.00001]))
        std.append(max(1, np.std(item[item>0.00001])))
    
    return mean, std

def normalization(matrix, mean, std):
    n_mat = np.array(matrix, dtype=np.float32)
    n_mat = np.where(n_mat<0.00001, -1, (n_mat - mean) / std)
    return n_mat

def get_data_vae(train_file, val_file, test_file):
    _, train_raw, valid_columns = read_train_vector(train_file)
    _, val_raw, _ = read_train_vector(val_file, valid_columns)
    flows, test_raw, labels = read_test_vector(test_file, valid_columns)

    train_mean, train_std = get_mean_std(train_raw)
    train_x = normalization(train_raw, train_mean, train_std)
    val_x = normalization(val_raw, train_mean, train_std)
    test_x = normalization(test_raw, train_mean, train_std)

    train_y = np.zeros(len(train_x), dtype=np.int32)
    val_y = np.zeros(len(train_x), dtype=np.int32)
    test_y = labels

    return (train_x, train_y), (val_x, val_y), (test_x, test_y), flows


def get_z_dim(x_dim):
    tmp = x_dim
    z_dim = 5
    while tmp > 20:
        z_dim *= 2
        tmp = tmp // 20
    return z_dim
