import random from copy import deepcopy from time import time
import numpy as np from numpy.linalg import norm
from collections import Counter import time
class Node(object): """ 结点对象 """ def __init__(self, item = None, label = None, dim = None, parent = None, left_child = None, right_child = None): """ item: 结点的值(样本信息) label: 结点的标签 dim: 结点的切分的维度(特征) """ self.item = item self.label = label self.dim = dim self.parent = parent self.left_child = left_child self.right_child = right_child class KDTree(object): def __init__(self, aList, labelList): self.__length = 0 self.__root = self.__create(aList,labelList) def __create(self, aList, labelList, parentNode = None): ''' 创建kd树 aList: 需要传入一个类数组对象(行数表示样本数,列数表示特征数) labellist: 样本的标签 parentNode: 父结点 return: 根结点 ''' dataArray = np.array(aList) m, n = dataArray.shape labelArray = np.array(labelList).reshape(m, 1) if m == 0: return None var_list = [np.var(dataArray[:, col]) for col in range(n)] if var_list: max_index = var_list.index(max(var_list)) max_feat_ind_list = dataArray[:, max_index].argsort() mid_item_index = max_feat_ind_list[m // 2] if m == 1: self.__length += 1 return Node( dim = max_index, label = labelArray[mid_item_index], item = dataArray[mid_item_index], parent = parentNode, left_child = None, right_child = None ) node = Node( dim = max_index, label = labelArray[mid_item_index], item = dataArray[mid_item_index], parent = parentNode, left_child = None, right_child = None ) left_tree = dataArray[max_feat_ind_list[:m // 2:]] left_label = labelArray[max_feat_ind_list[:m // 2]] left_child = self.__create(left_tree, left_label, node) if m == 2: right_child = None else: right_tree = dataArray[max_feat_ind_list[m // 2 + 1:]] right_label = labelArray[max_feat_ind_list[m // 2 + 1:]] right_child = self.__create(right_tree, right_label, node) node.left_child=left_child node.right_child=right_child self.__length += 1 return node @property def length(self): return self.__length
@property def root(self): return self.__root def transfer_dict(self, node): ''' 查看kd树结构 node:需要传入根结点对象 return: 字典嵌套格式的kd树,字典的key是self.item,其余项作为key的值,类似下面格式 {(1,2,3):{ 'label':1, 'dim':0, 'left_child':{(2,3,4):{ 'label':1, 'dim':1, 'left_child':None, 'right_child':None }, 'right_child':{(4,5,6):{ 'label':1, 'dim':1, 'left_child':None, 'right_child':None } } ''' if node == None: return None kd_dict = {} kd_dict[tuple(node.item)] = {} kd_dict[tuple(node.item)]['label'] = node.label[0] kd_dict[tuple(node.item)]['dim'] = node.dim kd_dict[tuple(node.item)]['parent'] = tuple(node.parent.item) if node.parent else None kd_dict[tuple(node.item)]['left_child'] = self.transfer_dict(node.left_child) kd_dict[tuple(node.item)]['right_child'] = self.transfer_dict(node.right_child) return kd_dict def transfer_list(self,node, kdList=[]): ''' 将kd树转化为列表嵌套字典的嵌套字典的列表输出 :param node: 需要传入根结点 :return: 返回嵌套字典的列表 ''' if node == None: return None element_dict = {} element_dict['item'] = tuple(node.item) element_dict['label'] = node.label[0] element_dict['dim'] = node.dim element_dict['parent'] = tuple(node.parent.item) if node.parent else None element_dict['left_child'] = tuple(node.left_child.item) if node.left_child else None element_dict['right_child'] = tuple(node.right_child.item) if node.right_child else None kdList.append(element_dict) self.transfer_list(node.left_child, kdList) self.transfer_list(node.right_child, kdList) return kdList def _find_nearest_neighbour(self, item): ''' 找最近邻点 :param item:需要预测的新样本 :return: 距离最近的样本点 ''' itemArray = np.array(item) if self.length == 0: return None node = self.__root if self.length == 1: return node while True: cur_dim = node.dim if item[cur_dim] == node.item[cur_dim]: return node elif item[cur_dim] < node.item[cur_dim]: if node.left_child == None: return node node = node.left_child else: if node.right_child == None: return node node = node.right_child def knn_algo(self, item, k = 1): ''' 找到距离测试样本最近的前k个样本 :param item: 测试样本 :param k: knn算法参数,定义需要参考的最近点数量,一般为1-5 :return: 返回前k个样本的最大分类标签 ''' if self.length <= k: label_dict = {} for element in self.transfer_list(self.root): if element['label'] in label_dict: label_dict[element['label']] += 1 else: label_dict[element['label']] = 1 sorted_label = sorted(label_dict.items(), key=lambda item:item[1],reverse=True) return sorted_label[0][0] item = np.array(item) node = self._find_nearest_neighbour(item) if node == None: return None print('靠近点%s最近的叶结点为:%s'%(item, node.item)) node_list = [] distance = np.sqrt(sum((item-node.item)**2)) least_dis = distance node_list.append([distance, tuple(node.item), node.label[0]])
if node.left_child != None: left_child = node.left_child left_dis = np.sqrt(sum((item-left_child.item)**2)) if k > len(node_list) or least_dis < least_dis: node_list.append([left_dis, tuple(left_child.item), left_child.label[0]]) node_list.sort() least_dis = node_list[-1][0] if k >= len(node_list) else node_list[k-1][0] while True: if node == self.root: break parent = node.parent par_dis = np.sqrt(sum((item-parent.item)**2)) if k >len(node_list) or par_dis < least_dis: node_list.append([par_dis, tuple(parent.item) , parent.label[0]]) node_list.sort() least_dis = node_list[-1][0] if k >= len(node_list) else node_list[k - 1][0]
if k >len(node_list) or abs(item[parent.dim] - parent.item[parent.dim]) < least_dis : other_child = parent.left_child if parent.left_child != node else parent.right_child if other_child != None: if item[parent.dim] - parent.item[parent.dim] <= 0: self.left_search(item,other_child,node_list,k) else: self.right_search(item,other_child,node_list,k)
node = parent label_dict = {} node_list = node_list[:k] for element in node_list: if element[2] in label_dict: label_dict[element[2]] += 1 else: label_dict[element[2]] = 1 sorted_label = sorted(label_dict.items(), key=lambda item:item[1], reverse=True) return sorted_label[0][0],node_list
def left_search(self, item, node, nodeList, k): ''' 按左中右顺序遍历子树结点,返回结点列表 :param node: 子树结点 :param item: 传入的测试样本 :param nodeList: 结点列表 :param k: 搜索比较的结点数量 :return: 结点列表 ''' nodeList.sort() least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0] if node.left_child == None and node.right_child == None: dis = np.sqrt(sum((item - node.item) ** 2)) if k > len(nodeList) or dis < least_dis: nodeList.append([dis, tuple(node.item), node.label[0]]) return self.left_search(item, node.left_child, nodeList, k) nodeList.sort() least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0] dis = np.sqrt(sum((item-node.item)**2)) if k > len(nodeList) or dis < least_dis: nodeList.append([dis, tuple(node.item), node.label[0]]) if k > len(nodeList) or abs(item[node.dim] - node.item[node.dim]) < least_dis: if node.right_child != None: self.left_search(item, node.right_child, nodeList, k)
return nodeList
def right_search(self,item, node, nodeList, k): ''' 按右根左顺序遍历子树结点 :param item: 测试的样本点 :param node: 子树结点 :param nodeList: 结点列表 :param k: 搜索比较的结点数量 :return: 结点列表 ''' nodeList.sort() least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0] if node.left_child == None and node.right_child == None: dis = np.sqrt(sum((item - node.item) ** 2)) if k > len(nodeList) or dis < least_dis: nodeList.append([dis, tuple(node.item), node.label[0]]) return if node.right_child != None: self.right_search(item, node.right_child, nodeList, k)
nodeList.sort() least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0] dis = np.sqrt(sum((item - node.item) ** 2)) if k > len(nodeList) or dis < least_dis: nodeList.append([dis, tuple(node.item), node.label[0]]) if k > len(nodeList) or abs(item[node.dim] - node.item[node.dim]) < least_dis: self.right_search(item, node.left_child, nodeList, k)
return nodeList
if __name__ == '__main__': t1 = time.time() dataArray = np.random.randint(0,20,size=(10000,2)) label = np.random.randint(0,3,size=(10000,1)) kd_tree = KDTree(dataArray,label) t2 = time.time() label, node_list = kd_tree.knn_algo([12,7],k=5) print('点%s的最接近的前k个点为:%s'%([12,7], node_list)) print('点%s的标签:%s'%([12,7],label)) t3 = time.time() print('创建树耗时:',t2-t1) print('搜索前k个最近邻点耗时:',t3-t2)
|