######################################################
#
# Guanhua Yan (ghyan@binghamton.edu)
#
# Support module used by exploit-meter.py.
# PCA not used, so don't get confused by the filename.
#
######################################################

import os, sys, copy, sklearn
import numpy as np
from sklearn import svm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.decomposition import PCA
from sklearn.tree import DecisionTreeClassifier

# combine all the features into a list
def combine_feature_names(feature_set_list):
    feature_list = []
    for fs in feature_set_list:        
        for f in fs:
            feature_list.append(f)
            
    return feature_list

# collect feature names
def collect_feature_names(mybinpath, feature_type_list, feature_set_list):
    if len(feature_type_list) != len(feature_set_list):
        print "Something is wrong. Sizes don't match!"
        system.exit()
    
    filelist = os.listdir(mybinpath)    
    for f in filelist:
        myfile = os.path.join(mybinpath, f)
        for i in range(len(feature_type_list)):
            ft = feature_type_list[i]
            if myfile.find(ft) != -1:
                # this is a hexdump 1-gram file
                fn = open(myfile, 'r')
                fn.readline()
                for line in fn:
                    fields = line.rstrip().split('\t')
                    key = ft + '.' + fields[0].strip()
                    feature_set_list[i].add(key)
                fn.close()
    return
                
# collect feature values
def collect_feature_values(mybinpath, feature_type_list, feature_map, verify):
    filelist = os.listdir(mybinpath)

    for f in filelist:
        myfile = os.path.join(mybinpath, f)
        cnt = 0
        for ft in feature_type_list:
            if myfile.find(ft) != -1:
                fn = open(myfile, 'r')
                fn.readline()
                
                for line in fn:
                    fields = line.rstrip().split('\t')
                    #print fields
                    #print "fields =", fields
                    key = ft + '.' + fields[0].strip()
                    value = float(fields[1].strip())
                    if (not verify) or (key in feature_map.keys()):
                        feature_map[key] = value
                        cnt += 1
                fn.close()            

        print myfile, cnt
                
    # create an array from the feature map
    feature_vec = []
    for k in feature_map.keys():
        feature_vec.append( feature_map[k] )

    #print "feature_vec =", feature_vec
    return feature_vec

# the two lists must have the same size and match well against each
def train_classification_model(trainpathlist, target_vuln_map, clf_name, feature_type_list):
    feature_set_list = []
    for i in range(len(feature_type_list)):
        feature_set_list.append(set([]))
    
    for mybinpath in trainpathlist:
        collect_feature_names(mybinpath, feature_type_list, feature_set_list)      
    feature_list = combine_feature_names(feature_set_list)

    print "len of feature_list is", len(feature_list)
    #print "check point 2..."
    empty_feature_map = {}
    for f in feature_list:
        empty_feature_map[f] = 0.0

    feature_matrix = []
    for mybinpath in trainpathlist:
        print "collecting feature values for", mybinpath
        feature_vec = collect_feature_values(mybinpath, feature_type_list, copy.copy(empty_feature_map), True)
        feature_matrix.append(feature_vec)
    print "number of binary programs is", len(trainpathlist)

    #print "check point 3..."
    feature_array = np.array(feature_matrix)
    print feature_array

    #new_training_feature_array = pca.transform(feature_array)
    new_training_feature_array = feature_array
    
    cur_clf_map = {}
    for vuln in target_vuln_map.keys():
        # check whether there are two-class data
        first = target_vuln_map[vuln][0]
        two_class = False
        for i in range(len(target_vuln_map[vuln]) - 1):
            if target_vuln_map[vuln][i+1] != first:
                two_class = True
                break
        if two_class:
            if clf_name == 'svm':
                clf = svm.SVC()
            elif clf_name == 'knn':
                clf = KNeighborsClassifier(n_neighbors=2)
            elif clf_name == 'naive_bayes':
                clf = GaussianNB()
            elif clf_name == 'decision_tree':
                clf = DecisionTreeClassifier(random_state=0)
            clf.fit(new_training_feature_array, target_vuln_map[vuln])
            cur_clf_map[vuln] = clf
        else:
            cur_clf_map[vuln] = None
    return (cur_clf_map, empty_feature_map)


def train_pca_model(srcpath, nc, feature_type_list):
    feature_set_list = []    
    for i in range(len(feature_type_list)):
        feature_set_list.append(set([]))
    
    binlist = os.listdir(srcpath)
    for mybinpath in binlist:
        collect_feature_names(os.path.join(srcpath, mybinpath), feature_type_list, feature_set_list)      
    feature_list = combine_feature_names(feature_set_list)
    
    print "len of feature_list is", len(feature_list)
    #print "check point 2..."
    empty_feature_map = {}
    for f in feature_list:
        empty_feature_map[f] = 0.0

    feature_matrix = [] 
    for mybinpath in binlist:
        print "collecting feature values for", mybinpath
        feature_vec = collect_feature_values(os.path.join(srcpath, mybinpath), feature_type_list, copy.copy(empty_feature_map), False)
        feature_matrix.append(feature_vec)
    print "number of binary programs is", len(binlist)

    #print "check point 3..."
    feature_array = np.array(feature_matrix)
    print feature_array

    #print "check point 4..."

    pca = PCA(n_components = nc)
    new_feature_array = pca.fit_transform(feature_array)
    print("Explained variance ratio", pca.explained_variance_ratio_) 
    print "new feature array is..."

    return (pca, empty_feature_map)


def predict_classification_result(clf, feature_vec):
    feature_array = np.array(feature_vec)
    #new_feature_array = pca.transform(feature_array)
    new_feature_array = feature_array
    return clf.predict(new_feature_array)[0]
    

